In [1]:
from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from torch.distributions.normal import Normal
from torchvision.utils import save_image

batchSize = 64 
image_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
transform = transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                ])
dataset = dset.CIFAR10(root = './data', download = True, transform = transform) 
dataloader = torch.utils.data.DataLoader(dataset, batch_size = batchSize, shuffle = True, num_workers = 2) 

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

Files already downloaded and verified


In [49]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 3, 4, 2, 1),
            nn.BatchNorm2d(3),
            nn.Tanh(),
            
        )

    def forward(self, input):
        output = self.main(input)
        return output

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(512, 1, 4, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1)


In [50]:
netG = Generator().to(device)
netG.apply(weights_init)
netD = Discriminator().to(device)
netD.apply(weights_init)
Tensor = torch.FloatTensor

In [64]:
criterion = nn.BCELoss()

#fixed_noise = np.random.normal(loc=0.0,scale=1, size = (batchSize, 100, 1, 1))
#print(fixed_noise)
#fixed_noise = torch.FloatTensor(fixed_noise).to(device)


fixed_noise = torch.randn(batchSize, 100, 1, 1).to(device)
#print(fixed_noise)
real_label = 1
fake_label = 0


optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))




In [67]:
def train_model(dataloader,netD,netG,criterion, fixed_noise, real_label, fake_label):
    #D_x = 0
    #D_G_z1 = 0
    #D_G_z2 = 0
    #corrects = 0
    for epoch in range(50):
        D_x = 0
        D_G_z1 = 0
        D_G_z2 = 0
        corrects = 0
        for i, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD.zero_grad()
            real_cpu = data[0].to(device)
            batch_size = real_cpu.size(0)
            label = torch.full((batch_size,), real_label).to(device)

            output = netD(real_cpu)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x += output.mean().item()
            
            
            

            # train with fake
            noise = torch.randn(batch_size, 100, 1, 1).to(device)
            fake = netG(noise)
            label.fill_(fake_label)
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 += output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label).to(device)  # fake labels are real for generator cost
            output = netD(fake)
            corrects += (output == label).sum().item()
            
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 += output.mean().item()
            optimizerG.step()

            

            if i % 100 == 0:
                save_image(real_cpu,
                        '../results/val/real_samples{}_{}.png'.format(epoch,i),
                        normalize=True)
                fake = netG(fixed_noise)
                save_image(fake.detach(),
                        '../results/val/fake_samples_epoch{}_{}.png'.format(epoch, i),
                        normalize=True)
        print('[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                  % (epoch, 25,
                     errD.item(), errG.item(), D_x/len(dataloader), D_G_z1/len(dataloader), D_G_z2/len(dataloader)))         

In [68]:
train_model(dataloader,netD,netG,criterion, fixed_noise, real_label, fake_label)




[0/25] Loss_D: 0.5533 Loss_G: 1.8133 D(x): 0.9211 D(G(z)): 0.0784 / 0.0382
[1/25] Loss_D: 0.2995 Loss_G: 3.5792 D(x): 0.9293 D(G(z)): 0.0707 / 0.0340
[2/25] Loss_D: 0.0583 Loss_G: 5.6495 D(x): 0.9370 D(G(z)): 0.0630 / 0.0295
[3/25] Loss_D: 0.1039 Loss_G: 5.8079 D(x): 0.9193 D(G(z)): 0.0806 / 0.0410
[4/25] Loss_D: 0.2752 Loss_G: 8.1090 D(x): 0.9257 D(G(z)): 0.0745 / 0.0365
[5/25] Loss_D: 0.1766 Loss_G: 8.1754 D(x): 0.9313 D(G(z)): 0.0686 / 0.0333
[6/25] Loss_D: 0.2472 Loss_G: 5.3084 D(x): 0.9343 D(G(z)): 0.0653 / 0.0320
[7/25] Loss_D: 0.0778 Loss_G: 8.6404 D(x): 0.9450 D(G(z)): 0.0549 / 0.0243
[8/25] Loss_D: 0.1056 Loss_G: 6.0366 D(x): 0.9242 D(G(z)): 0.0757 / 0.0366
[9/25] Loss_D: 1.3089 Loss_G: 22.5661 D(x): 0.9529 D(G(z)): 0.0473 / 0.0203
[10/25] Loss_D: 0.0191 Loss_G: 6.6133 D(x): 0.9226 D(G(z)): 0.0771 / 0.0380
[11/25] Loss_D: 0.1874 Loss_G: 4.8512 D(x): 0.9310 D(G(z)): 0.0689 / 0.0342
[12/25] Loss_D: 0.3060 Loss_G: 5.2900 D(x): 0.9342 D(G(z)): 0.0657 / 0.0319
[13/25] Loss_D: 0.320