In [1]:
import torch
from torch import nn
import torch.utils.data
import torchvision.datasets as data_set
import torchvision.transforms as transforms
from torch import optim
from torch.autograd import Variable
import numpy as np
import torchvision.utils as vutils



In [2]:
lr=0.0002
epochs=100
batch_size=100

In [3]:
class Generator(nn.Module):

    def __init__(self,in_channels):
        super(Generator,self).__init__()
        self.fc1=nn.Linear(in_channels,384)

        self.t1=nn.Sequential(
            nn.ConvTranspose2d(in_channels=384,out_channels=192,kernel_size=(4,4),stride=1,padding=0),
            nn.BatchNorm2d(192),
            nn.ReLU()
        )
        self.t2=nn.Sequential(
            nn.ConvTranspose2d(in_channels=192,out_channels=96,kernel_size=(4,4),stride=2,padding=1),
            nn.BatchNorm2d(96),
            nn.ReLU()
        )
        self.t3=nn.Sequential(
            nn.ConvTranspose2d(in_channels=96,out_channels=48,kernel_size=(4,4),stride=2,padding=1),
            nn.BatchNorm2d(48),
            nn.ReLU()
        )
        self.t4=nn.Sequential(
            nn.ConvTranspose2d(in_channels=48,out_channels=3,kernel_size=(4,4),stride=2,padding=1),
            nn.Tanh()
        )
    
    def forward(self,x):
        x=x.view(-1,110)
        x=self.fc1(x)
        x=x.view(-1,384,1,1)
        x=self.t1(x)
        x=self.t2(x)
        x=self.t3(x)
        x=self.t4(x)
        return x 

In [4]:
class Discriminator(nn.Module):
    
    def __init__(self,classes=10):
        
        super(Discriminator,self).__init__()
        self.c1=nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=16,kernel_size=(3,3),stride=2,padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5)
            )
        self.c2=nn.Sequential(
            nn.Conv2d(in_channels=16,out_channels=32,kernel_size=(3,3),stride=1,padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5)
            )
        self.c3=nn.Sequential(
            nn.Conv2d(in_channels=32,out_channels=64,kernel_size=(3,3),stride=2,padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5)
            )
        self.c4=nn.Sequential(
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=(3,3),stride=1,padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5)
            )
        self.c5=nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=(3,3),stride=2,padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5)
            )
        self.c6=nn.Sequential(
            nn.Conv2d(in_channels=256,out_channels=512,kernel_size=(3,3),stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5)
            )
        self.fc_source=nn.Linear(4*4*512,1)
        self.fc_class=nn.Linear(4*4*512,classes)
        self.sig=nn.Sigmoid()
        self.soft=nn.Softmax()

    def forward(self,x):

        x=self.c1(x)
        x=self.c2(x)
        x=self.c3(x)
        x=self.c4(x)
        x=self.c5(x)
        x=self.c6(x)
        x=x.view(-1,4*4*512)
        F=self.sig(self.fc_source(x))
        C=self.soft(self.fc_class(x))
        
        return F,C 

In [5]:
real_label = torch.FloatTensor(batch_size).cuda()
real_label.fill_(1)
real_label = real_label.unsqueeze(1)

In [6]:
fake_label = torch.FloatTensor(batch_size).cuda()
fake_label.fill_(0)
fake_label = fake_label.unsqueeze(1)


In [7]:

e_noise = torch.FloatTensor(batch_size, 110, 1, 1).normal_(0, 1)
e_noise_ = np.random.normal(0, 1, (batch_size, 110))
e_label = np.random.randint(0, 10, batch_size)
e_onehot = np.zeros((batch_size, 10))
e_onehot[np.arange(batch_size), e_label] = 1
e_noise_[np.arange(batch_size), :10] = e_onehot[np.arange(batch_size)]
e_noise_ = (torch.from_numpy(e_noise_))
e_noise.data.copy_(e_noise_.view(batch_size, 110, 1, 1))
e_noise=e_noise.cuda()

In [8]:
def calc_acc(preds, labels):
    corr = 0
    pred = preds.data.max(1)[1]
    corr = pred.eq(labels.data).cpu().sum()
    acc = float(corr) / float(len(labels.data)) * 100.0
    return acc

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

In [None]:
dataset = data_set.CIFAR10(root='data/', download=True,transform=transforms.Compose([transforms.Scale((32,32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))

trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
gen=Generator(110).cuda()
disc=Discriminator().cuda()

gen.apply(weights_init)

optimD=optim.Adam(disc.parameters(),lr)
optimG=optim.Adam(gen.parameters(),lr)

source_obj=nn.BCELoss()

class_obj=nn.NLLLoss()

In [None]:




for epoch in range(epochs):
    for i,data in enumerate(dataloader,0):

        optimD.zero_grad()

        image,label=data
        image,label=image.cuda(),label.cuda()

        source_,class_=disc(image)

        source_error=source_obj(source_,real_label)
        class_error=class_obj(class_,label)
        error_real=source_error+class_error
        error_real.backward()
        optimD.step()


        accuracy=calc_acc(class_,label)

 


        noise_ = np.random.normal(0, 1, (batch_size, 110))

        label=np.random.randint(0,10,batch_size)
        noise=((torch.from_numpy(noise_)).float())
  

        label=((torch.from_numpy(label)).long())
        label=label.cuda()

        noise_image=gen(noise)


        source_,class_=disc(noise_image.detach())
  
        source_error=source_obj(source_,fake_label)
        class_error=class_obj(class_,label)
        error_fake=source_error+class_error
        error_fake.backward()
        optimD.step()




        gen.zero_grad()
        source_,class_=disc(noise_image)
        source_error=source_obj(source_,real_label)
        class_error=class_obj(class_,label)
        error_gen=source_error+class_error
        error_gen.backward()
        optimG.step()
        iteration_now = epoch * len(trainloader) + i

        if i % 100 == 0:
            print("Epoch--[{} / {}], Loss_Discriminator--[{}], Loss_Generator--[{}],Accuracy--[{}]".format(epoch,epochs,error_fake,error_gen,accuracy))
            constructed = gen(e_noise)
            vutils.save_image(
                constructed.data,
                '%s/results_epoch_%03d.png' % ('images/', epoch))







Files already downloaded and verified


  C=self.soft(self.fc_class(x))


Epoch--[0 / 100], Loss_Discriminator--[1.3686448335647583], Loss_Generator--[0.11144110560417175],Accuracy--[13.0]
Epoch--[0 / 100], Loss_Discriminator--[0.22113250195980072], Loss_Generator--[3.429192304611206],Accuracy--[28.999999999999996]
Epoch--[0 / 100], Loss_Discriminator--[-0.03880634158849716], Loss_Generator--[5.358049392700195],Accuracy--[17.0]
Epoch--[0 / 100], Loss_Discriminator--[-0.04957768693566322], Loss_Generator--[4.824713706970215],Accuracy--[28.000000000000004]
Epoch--[0 / 100], Loss_Discriminator--[-0.08053228259086609], Loss_Generator--[6.095004558563232],Accuracy--[27.0]
Epoch--[1 / 100], Loss_Discriminator--[-0.07236895710229874], Loss_Generator--[7.97749662399292],Accuracy--[19.0]
Epoch--[1 / 100], Loss_Discriminator--[0.06197747588157654], Loss_Generator--[4.30155611038208],Accuracy--[26.0]
Epoch--[1 / 100], Loss_Discriminator--[-0.1107843816280365], Loss_Generator--[9.493159294128418],Accuracy--[26.0]
Epoch--[1 / 100], Loss_Discriminator--[0.1174000054597854