In [None]:
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from IPython.display import HTML

In [None]:

batch = 32
img_size = 64
num_epochs = 50
lr = 0.0002



In [None]:
transform = transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.CenterCrop(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])

In [None]:
trainset = dset.CIFAR10(root='./data', train=True,
                                        download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch,
                                          shuffle=True, num_workers=2)

In [None]:
for img in trainloader:
    print(img[0].shape)
    break

In [None]:

def initialize(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        #self.ngpu = ngpu
        self.main = nn.Sequential(
           
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
   
            nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2, inplace=True),
         
            nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),
      
            nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.LeakyReLU(0.2, inplace=True),
          
            nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:

netD = Discriminator().to(device)


if (device.type == 'cuda') :
    netD = nn.DataParallel(netD)

netD.apply(weights_init)

print(netD)

In [None]:

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
  
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( 100, 64 * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.ReLU(True),
          
            nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True),
  
            nn.ConvTranspose2d( 64 * 4, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
           
            nn.ConvTranspose2d( 64 * 2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d( 64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
           
        )

    def forward(self, input):
        return self.main(input)

In [None]:

netG = Generator().to(device)


if (device.type == 'cuda') :
    netG = nn.DataParallel(netG)



netG.apply(weights_init)


print(netG)

In [None]:
real_label = 1
fake_label = 0


In [None]:

criterion = nn.BCELoss()

fixed_noise = torch.randn(64, 100, 1, 1, device=device)

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:

img_list = []
G_losses = []
D_losses = []
iters = 0
real_img=[]

# For each epoch
for epoch in range(num_epochs):

    for i, data in enumerate(trainloader, 0):

        netD.zero_grad()

        real = data[0].to(device)
        b_size = real.size(0)

        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

        output = netD(real).view(-1)

        errD_r = criterion(output, label)

        errD_r.backward()
        #D_x = output.mean().item()

        noise = torch.randn(b_size, 100, 1, 1, device=device)

        fake = netG(noise)
        label.fill_(fake_label)

        output = netD(fake.detach()).view(-1)
  
        errD_f = criterion(output, label)
        
        errD_f.backward()
        #D_G_z1 = output.mean().item()
  
        errD = errD_r + errD_f

        optimizerD.step()


        netG.zero_grad()
        label.fill_(real_label) 
 
        output = netD(fake).view(-1)
    
        errG = criterion(output, label)
        errG.backward()
        #D_G_z2 = output.mean().item()
        optimizerG.step()
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\t'
                  % (epoch, num_epochs, i, len(trainloader),
                     errD.item(), errG.item()))
            #vutils.save_image(fake,'%s/results_epoch_%03d.png' % ('DCGAN/', epoch))
            #vutils.save_image(real,'%s/real_results_epoch_%03d.png' % ('DCGAN/', epoch))
            
        #if epoch==100:
                #vutils.save_image(fake,'%s/results_epoch_%03d.png' % ('DCGAN/', epoch))
                #vutils.save_image(real,'%s/real_results_epoch_%03d.png' % ('DCGAN/', epoch))

        G_losses.append(errG.item())
        D_losses.append(errD.item())


        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(trainloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            real_img.append(vutils.make_grid(real, padding=2, normalize=True))

        iters += 1

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
img_list[0]

In [None]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())