In [1]:
# Deep Convolutional GANs

# Importing the libraries
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable


In [2]:
batchSize = 64 # We set the size of the batch.
imageSize = 64 # We set the size of the generated images (64x64).

In [4]:
transform = transforms.Compose([transforms.Resize(imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]) # We create a list of transformations (scaling, tensor conversion, normalization) to apply to the input images.


In [5]:
dataset = dset.CIFAR10(root = './data', download = True, transform = transform) # We download the training set in the ./data folder and we apply the previous transformations on each image.
dataloader = torch.utils.data.DataLoader(dataset, batch_size = batchSize, shuffle = True, num_workers = 2) # We use dataLoader to get the images of the training set batch by batch.


Files already downloaded and verified


In [6]:
# Defining the weights_init function that takes as input a neural network m and that will initialize all its weights.
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


In [7]:
class G(nn.Module):

    def __init__(self):
        super(G, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias = False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias = False),
            nn.Tanh()
        )

    def forward(self, input):
        output = self.main(input)
        return output

In [8]:
class D(nn.Module):

    def __init__(self):
        super(D, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias = False),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(64, 128, 4, 2, 1, bias = False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(128, 256, 4, 2, 1, bias = False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(256, 512, 4, 2, 1, bias = False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(512, 1, 4, 1, 0, bias = False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1)


In [9]:
# Creating the generator
netG = G()
netG.apply(weights_init)
netD = D()
netD.apply(weights_init)

# Training the DCGANs

criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999))

for epoch in range(25):

    for i, data in enumerate(dataloader, 0):
        
        # 1st Step: Updating the weights of the neural network of the discriminator

        netD.zero_grad()
        
        # Training the discriminator with a real image of the dataset
        real, _ = data
        input = Variable(real)
        target = Variable(torch.ones(input.size()[0]))
        output = netD(input)
        errD_real = criterion(output, target)
        
        # Training the discriminator with a fake image generated by the generator
        noise = Variable(torch.randn(input.size()[0], 100, 1, 1))
        fake = netG(noise)
        target = Variable(torch.zeros(input.size()[0]))
        output = netD(fake.detach())
        errD_fake = criterion(output, target)
        
        # Backpropagating the total error
        errD = errD_real + errD_fake
        errD.backward()
        optimizerD.step()

        # 2nd Step: Updating the weights of the neural network of the generator

        netG.zero_grad()
        target = Variable(torch.ones(input.size()[0]))
        output = netD(fake)
        errG = criterion(output, target)
        errG.backward()
        optimizerG.step()
        
        # 3rd Step: Printing the losses and saving the real images and the generated images of the minibatch every 100 steps

        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch, 25, i, len(dataloader), errD.data[0], errG.data[0]))
        if i % 100 == 0:
            vutils.save_image(real, '%s/real_samples.png' % "./results", normalize = True)
            fake = netG(noise)
            vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch), normalize = True)



[0/25][0/782] Loss_D: 1.7074 Loss_G: 6.1108
[0/25][1/782] Loss_D: 1.0308 Loss_G: 5.0058
[0/25][2/782] Loss_D: 1.1606 Loss_G: 6.2429
[0/25][3/782] Loss_D: 0.8380 Loss_G: 6.1915
[0/25][4/782] Loss_D: 0.7631 Loss_G: 6.0950
[0/25][5/782] Loss_D: 0.9401 Loss_G: 7.4754
[0/25][6/782] Loss_D: 0.7324 Loss_G: 7.4171
[0/25][7/782] Loss_D: 0.8342 Loss_G: 8.2096
[0/25][8/782] Loss_D: 0.9868 Loss_G: 8.0933
[0/25][9/782] Loss_D: 0.7768 Loss_G: 10.0024
[0/25][10/782] Loss_D: 0.4764 Loss_G: 7.7819
[0/25][11/782] Loss_D: 0.5769 Loss_G: 11.4204
[0/25][12/782] Loss_D: 0.1907 Loss_G: 9.5035
[0/25][13/782] Loss_D: 0.5956 Loss_G: 8.5178
[0/25][14/782] Loss_D: 1.0823 Loss_G: 15.3645
[0/25][15/782] Loss_D: 0.5958 Loss_G: 12.5578
[0/25][16/782] Loss_D: 0.4848 Loss_G: 6.2974
[0/25][17/782] Loss_D: 3.7559 Loss_G: 16.2150
[0/25][18/782] Loss_D: 0.4636 Loss_G: 16.2617
[0/25][19/782] Loss_D: 0.5944 Loss_G: 9.3837
[0/25][20/782] Loss_D: 1.7930 Loss_G: 15.0584
[0/25][21/782] Loss_D: 0.2055 Loss_G: 13.9449
[0/25][22/78

[0/25][180/782] Loss_D: 0.6356 Loss_G: 3.9650
[0/25][181/782] Loss_D: 0.4209 Loss_G: 7.7804
[0/25][182/782] Loss_D: 0.0853 Loss_G: 7.4761
[0/25][183/782] Loss_D: 0.1713 Loss_G: 5.3782
[0/25][184/782] Loss_D: 0.2322 Loss_G: 6.0610
[0/25][185/782] Loss_D: 0.1860 Loss_G: 6.1565
[0/25][186/782] Loss_D: 0.2238 Loss_G: 4.8479
[0/25][187/782] Loss_D: 0.2833 Loss_G: 5.9690
[0/25][188/782] Loss_D: 0.3604 Loss_G: 5.6354
[0/25][189/782] Loss_D: 0.2149 Loss_G: 4.9929
[0/25][190/782] Loss_D: 0.1741 Loss_G: 6.5946
[0/25][191/782] Loss_D: 0.1385 Loss_G: 5.9625
[0/25][192/782] Loss_D: 0.1774 Loss_G: 5.3845
[0/25][193/782] Loss_D: 0.1705 Loss_G: 6.2054
[0/25][194/782] Loss_D: 0.2273 Loss_G: 5.0159
[0/25][195/782] Loss_D: 0.1667 Loss_G: 5.8277
[0/25][196/782] Loss_D: 0.1579 Loss_G: 5.4027
[0/25][197/782] Loss_D: 0.1173 Loss_G: 5.9153
[0/25][198/782] Loss_D: 0.3730 Loss_G: 3.3799
[0/25][199/782] Loss_D: 0.7212 Loss_G: 13.0564
[0/25][200/782] Loss_D: 0.4221 Loss_G: 13.4035
[0/25][201/782] Loss_D: 0.1768 L

[0/25][358/782] Loss_D: 0.5443 Loss_G: 4.2787
[0/25][359/782] Loss_D: 0.5009 Loss_G: 4.4525
[0/25][360/782] Loss_D: 0.3814 Loss_G: 3.5443
[0/25][361/782] Loss_D: 0.4181 Loss_G: 3.8727
[0/25][362/782] Loss_D: 0.5967 Loss_G: 3.3447
[0/25][363/782] Loss_D: 0.4635 Loss_G: 4.5098
[0/25][364/782] Loss_D: 0.4514 Loss_G: 2.9265
[0/25][365/782] Loss_D: 0.6495 Loss_G: 3.1787
[0/25][366/782] Loss_D: 0.7026 Loss_G: 5.7500
[0/25][367/782] Loss_D: 1.0087 Loss_G: 1.2922
[0/25][368/782] Loss_D: 1.2531 Loss_G: 6.8810
[0/25][369/782] Loss_D: 2.0431 Loss_G: 1.1995
[0/25][370/782] Loss_D: 1.6348 Loss_G: 7.2373
[0/25][371/782] Loss_D: 1.7596 Loss_G: 2.9046
[0/25][372/782] Loss_D: 0.8533 Loss_G: 3.7891
[0/25][373/782] Loss_D: 0.7132 Loss_G: 4.3623
[0/25][374/782] Loss_D: 1.2245 Loss_G: 1.4532
[0/25][375/782] Loss_D: 1.6104 Loss_G: 5.6113
[0/25][376/782] Loss_D: 1.5382 Loss_G: 2.3985
[0/25][377/782] Loss_D: 0.5029 Loss_G: 3.4303
[0/25][378/782] Loss_D: 0.4816 Loss_G: 4.3919
[0/25][379/782] Loss_D: 0.9291 Los

KeyboardInterrupt: 