In [5]:
from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from torch.distributions.normal import Normal
from torchvision.utils import save_image
from torch.nn import functional as F

batchSize = 64 
image_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
transform = transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                ])
dataset = dset.CIFAR10(root = './data', download = True, transform = transform) 
dataloader = torch.utils.data.DataLoader(dataset, batch_size = batchSize, shuffle = True, num_workers = 2) 

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

Files already downloaded and verified


In [49]:
class generator(nn.Module):
    # initializers
    def __init__(self, d=128):
        super(generator, self).__init__()
        self.deconv1_1 = nn.ConvTranspose2d(100, d*2, 4, 1, 0)
        self.deconv1_1_bn = nn.BatchNorm2d(d*2)
        self.deconv1_2 = nn.ConvTranspose2d(10, d*2, 4, 1, 0)
        self.deconv1_2_bn = nn.BatchNorm2d(d*2)
        self.deconv2 = nn.ConvTranspose2d(d*4, d*2, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm2d(d*2)
        self.deconv3 = nn.ConvTranspose2d(d*2, d, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm2d(d)
        self.deconv4 = nn.ConvTranspose2d(d, 3, 4, 2, 1)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, input, label):
        x = F.relu(self.deconv1_1_bn(self.deconv1_1(input)))
        y = F.relu(self.deconv1_2_bn(self.deconv1_2(label)))
        #print(x.shape)
        #print(y.shape)
        x = torch.cat([x, y], 1)
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = F.tanh(self.deconv4(x))
        # x = F.relu(self.deconv4_bn(self.deconv4(x)))
        # x = F.tanh(self.deconv5(x))

        return x

class discriminator(nn.Module):
    # initializers
    def __init__(self, d=128):
        super(discriminator, self).__init__()
        self.conv1_1 = nn.Conv2d(3, d, 4, 2, 1)
        self.conv1_2 = nn.Conv2d(10, d, 4, 2, 1)
        self.conv2 = nn.Conv2d(d*2, d*2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(d*2)
        self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(d*4)
        self.conv4 = nn.Conv2d(d * 4, 1, 4, 1, 0)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, input, label):
        #print(input.shape)
        x = F.leaky_relu(self.conv1_1(input), 0.2)
        y = F.leaky_relu(self.conv1_2(label), 0.2)
        x = torch.cat([x, y], 1)
        #print(x.shape)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        x = F.sigmoid(self.conv4(x))

        return x

def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()


In [50]:
# fixed noise & label
temp_z_ = torch.randn(10, 100)
fixed_z_ = temp_z_
fixed_y_ = torch.zeros(10, 1)
for i in range(9):
    fixed_z_ = torch.cat([fixed_z_, temp_z_], 0)
    temp = torch.ones(10, 1) + i
    fixed_y_ = torch.cat([fixed_y_, temp], 0)

fixed_z_ = fixed_z_.view(-1, 100, 1, 1)
fixed_y_label_ = torch.zeros(100, 10)
fixed_y_label_.scatter_(1, fixed_y_.type(torch.LongTensor), 1)
fixed_y_label_ = fixed_y_label_.view(-1, 10, 1, 1)
fixed_z_, fixed_y_label_ = Variable(fixed_z_.cuda(), volatile=True), Variable(fixed_y_label_.cuda(), volatile=True)

  


In [51]:
G = generator(128)
D = discriminator(128)
G.weight_init(mean=0.0, std=0.02)
D.weight_init(mean=0.0, std=0.02)
G.cuda()
D.cuda()

discriminator(
  (conv1_1): Conv2d(3, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv1_2): Conv2d(10, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv3_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
)

In [52]:
criterion = nn.BCELoss()

#fixed_noise = np.random.normal(loc=0.0,scale=1, size = (batchSize, 100, 1, 1))
#print(fixed_noise)
#fixed_noise = torch.FloatTensor(fixed_noise).to(device)


fixed_noise = torch.randn(batchSize, 100, 1, 1).to(device)
#print(fixed_noise)
real_label = 1
fake_label = 0


optimizerD = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))




In [53]:
# label preprocess
onehot = torch.zeros(10, 10)
onehot = onehot.scatter_(1, torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).view(10,1), 1).view(10, 10, 1, 1)
fill = torch.zeros([10, 10, image_size, image_size])
for i in range(10):
    fill[i, i, :, :] = 1


In [54]:
def train_model(dataloader,netD,netG,criterion, optimizerD, optimizerG):
    #D_x = 0
    #D_G_z1 = 0
    #D_G_z2 = 0
    #corrects = 0
    saved_model_G = netG
    saved_model_D = netD
    for epoch in range(25):
        D_x = 0
        D_G_z1 = 0
        D_G_z2 = 0
        corrects = 0
        for i, (images,labels) in enumerate(dataloader, 0):
            real_cpu = images.to(device)
            batch_size = real_cpu.size(0)
            labels = labels.to(device)
            # labels(will only be used in loss function)
            y_real_ = (torch.ones(batch_size)).to(device)
            y_fake_ = (torch.zeros(batch_size)).to(device)
            
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD.zero_grad()
            
            mod_labels = fill[labels].to(device)
            output = netD(real_cpu, mod_labels)
            
            errD_real = criterion(output, y_real_)
            errD_real.backward()
            D_x += output.mean().item()
            
            
            

            # train with fake
            noise = torch.randn(batch_size, 100, 1, 1).to(device)
            y_ = (torch.rand(batch_size, 1) * 10).type(torch.LongTensor).squeeze()
            y_label_ = onehot[y_].to(device)
            fake = netG(noise, y_label_)
            
            y_fill_ = fill[y_].to(device)
            output = netD(fake.detach(), y_fill_)
            
            errD_fake = criterion(output, y_fake_)
            errD_fake.backward()
            D_G_z1 += output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            #label.fill_(real_label).to(device)  # fake labels are real for generator cost
            output = netD(fake, y_fill_)
            #corrects += (output == torch.FloatTensor(labels)).sum().item()
            
            errG = criterion(output, y_real_)
            errG.backward()
            D_G_z2 += output.mean().item()
            optimizerG.step()

            

            if i % len(dataloader) == 0:
                save_image(real_cpu,
                        '../results/val/real_samples{}_{}.png'.format(epoch,i),
                        normalize=True)
                #print("hmm")
                fake = netG(noise, y_label_)
                save_image(fake.detach(),
                        '../results/val/fake_samples_epoch{}_{}.png'.format(epoch, i),
                        normalize=True)
        
            if(round(errD.item(),1) == round(errG.item(),1)):
                print(epoch)
                print(errD.item(),    round(errD.item(),1),  errG.item(),   round(errG.item(),1))
                saved_model_G = netG
                saved_model_D = netD
            
        print('[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                  % (epoch, 25,
                     errD.item(), errG.item(), D_x/len(dataloader), D_G_z1/len(dataloader), D_G_z2/len(dataloader)))
    return saved_model_G,saved_model_D

In [55]:
model_G, model_D = train_model(dataloader,D,G,criterion, optimizerD, optimizerG)

[0/25] Loss_D: 0.4538 Loss_G: 4.0499 D(x): 0.8080 D(G(z)): 0.1869 / 0.0523


  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


1
1.1773673295974731 1.2 1.249106526374817 1.2
1
1.249679446220398 1.2 1.1887118816375732 1.2
1
1.2380160093307495 1.2 1.2494089603424072 1.2
1
1.1872479915618896 1.2 1.1932779550552368 1.2
[1/25] Loss_D: 1.0930 Loss_G: 5.1519 D(x): 0.7671 D(G(z)): 0.2309 / 0.0873
[2/25] Loss_D: 0.4309 Loss_G: 3.1673 D(x): 0.7951 D(G(z)): 0.2029 / 0.0837
3
0.8409655690193176 0.8 0.8085958957672119 0.8
[3/25] Loss_D: 0.2985 Loss_G: 3.9847 D(x): 0.8031 D(G(z)): 0.1962 / 0.0854
[4/25] Loss_D: 0.3141 Loss_G: 4.6108 D(x): 0.8069 D(G(z)): 0.1915 / 0.0879
[5/25] Loss_D: 0.3666 Loss_G: 4.8858 D(x): 0.7960 D(G(z)): 0.2027 / 0.0950
6
1.0633553266525269 1.1 1.0673272609710693 1.1
6
1.1936947107315063 1.2 1.168980360031128 1.2
6
0.9267253875732422 0.9 0.8709442615509033 0.9
[6/25] Loss_D: 1.0706 Loss_G: 5.3575 D(x): 0.7520 D(G(z)): 0.2468 / 0.1351
7
1.1575927734375 1.2 1.155476689338684 1.2
7
0.9099842309951782 0.9 0.9391071796417236 0.9
7
0.9707447290420532 1.0 0.9877685904502869 1.0
7
1.126444935798645 1.1 1.108