Importing the required packages

In [None]:
import torch
import torch.nn as n
import torch.nn.functional as f
import torch.optim as optim
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchsummary import summary
import cv2

Defining the celebrity dataset folder

In [None]:
# INPUT_DATA_DIR="/home/aiteam/TeamData/Vishal_Data/Learning/DataSet"
INPUT_DATA_DIR="celeba-dataset/img_align_celeba/img_align_celeba/"

Checking whether cuda is available or not 

In [None]:
torch.cuda.is_available()

Assigning the device as cuda if cuda is available else cpu

In [None]:
cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Defining some of the hyperparameters

In [None]:
LR_D = 0.0002
LR_G = 0.0002
BATCH_SIZE = 128
EPOCHS = 10000
BETA1 = 0.5
WEIGHT_INIT_STDDEV = 0.02

DCGAN paper specifies that the weights be randomly initialized from normal distribution with mean of 0 and standard deviation of 0.02. This function do the same for the weights of the model

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        n.init.normal_(m.weight.data, 0.0, WEIGHT_INIT_STDDEV)
    elif classname.find('BatchNorm') != -1:
        n.init.normal_(m.weight.data, 1.0, WEIGHT_INIT_STDDEV)
        n.init.constant_(m.bias.data, 0)

Defining the Generator as given in the DCGAN paper

In [None]:
class Generator(n.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = n.ConvTranspose2d(100,512,4,1,0,bias=False)
        self.bn1 = n.BatchNorm2d(512)
        self.conv2 = n.ConvTranspose2d(512,256,4,2,1,bias=False)
        self.bn2 = n.BatchNorm2d(256)
        self.conv3 = n.ConvTranspose2d(256,128,4,2,1,bias=False)
        self.bn3 = n.BatchNorm2d(128)
        self.conv4 = n.ConvTranspose2d(128,64,4,2,1,bias=False)
        self.bn4 = n.BatchNorm2d(64)
        self.conv5 = n.ConvTranspose2d(64,3,4,2,1,bias=False)

    def forward(self,input):
        conv = f.relu(self.bn1(self.conv1(input)))
        conv = f.relu(self.bn2(self.conv2(conv)))
        conv = f.relu(self.bn3(self.conv3(conv)))
        conv = f.relu(self.bn4(self.conv4(conv)))
        out = torch.tanh(self.conv5(conv))
        
        return out

Assigning the generator to cuda and initialising the model weights using above mentioned function

In [None]:
# rand = torch.randn(1,100,1,1)
gen = Generator().to(cuda).float()
gen.apply(weights_init)
summary(gen,(100,1,1))

Defining the discriminator acrhitecture as given in the reference paper

In [None]:
class Discriminator(n.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = n.Conv2d(3,64,4,2,1,bias=False)
        self.conv2 = n.Conv2d(64,128,4,2,1,bias=False)
        self.bn2 = n.BatchNorm2d(128)
        self.conv3 = n.Conv2d(128,256,4,2,1,bias=False)
        self.bn3 = n.BatchNorm2d(256)
        self.conv4 = n.Conv2d(256,512,4,2,1,bias=False)
        self.bn4 = n.BatchNorm2d(512)
        self.conv5 = n.Conv2d(512,1,4,1,0,bias=False)
        self.drop = n.Dropout2d(0.2)
        
    def forward(self,input):
        conv = f.leaky_relu(self.conv1(input))
        conv = f.leaky_relu(self.bn2(self.conv2(conv)))
        conv = f.leaky_relu(self.bn3(self.conv3(conv)))
        conv = f.leaky_relu(self.bn4(self.conv4(conv)))
        out = torch.sigmoid(self.drop(self.conv5(conv)))
        
        return out
        

Assigning the discriminator to cuda and initialising the model weights using above mentioned function

In [None]:
disc = Discriminator().to(cuda).float()
disc.apply(weights_init)
summary(disc,(3,64,64))

Defining the binary cross entropy loss, creating the fixed noise to check the output while training and defining the optimizers for generator and discriminator

In [None]:
criterion = n.BCELoss()

fixed_noise = torch.randn(2,100,1,1)
# fixed_noise = torch.FloatTensor(3, 100,1,1).uniform_(0, 1)

gen_optimizer = optim.Adam(gen.parameters(),lr=LR_G,betas=(BETA1,0.999))
disc_optimizer = optim.Adam(disc.parameters(),lr=LR_G,betas=(BETA1,0.999))

Display utility of displaying images using matplotlib

In [None]:
def show_samples(sample_images):
    figure, axes = plt.subplots(1, sample_images.shape[0], figsize = (6,6))
    for index, axis in enumerate(axes):
        axis.axis('off')
        image_array = sample_images[index]
        axis.imshow(image_array)
#         image = Image.fromarray(image_array)
#     plt.savefig("DC/DC_"+str(epoch)+".png", bbox_inches='tight', pad_inches=0)
    plt.show()
    plt.close()

Given the noise and checkpoint, Generating the fake images 

In [None]:
def imagePostProcess(noise,modelPath):
    model = load_checkpoint(modelPath)
    im_tensor = noise.float()
    out_tensor = model(im_tensor)
#     print(out_tensor.shape)
    out = np.reshape(out_tensor,[out_tensor.shape[0],out_tensor.shape[2],out_tensor.shape[3],out_tensor.shape[1]])
    out = out.numpy()
    
    out = np.clip(out,0,1)
    
    return out

In [None]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = checkpoint['model']
    model.load_state_dict(checkpoint['state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = False
    
    model.eval()
    
    return model

In [None]:
def loadImages(imageList,path,resize=False):
    images=[]
    for image in (imageList):
        if resize==True:
            img = cv2.resize(cv2.imread(os.path.join(path,image)),(64,64)) 
        else:
            img = cv2.imread(os.path.join(path,image))
        img = img.reshape(img.shape[2],img.shape[0],img.shape[1])
        images.append(img)
    return np.array(images)/255.0

In [None]:
data = os.listdir(INPUT_DATA_DIR)

In [None]:
weight_folder = os.path.join(os.getcwd(),"DCGAN_weights")
if not os.path.exists(weight_folder):
    os.makedirs(weight_folder)

Defining the losses for generator and discriminator and training both of them

In [None]:
batch_count = len(data)//BATCH_SIZE
# batch_count=200
for epoch in range(EPOCHS):
    Gloss=[]
    Dloss=[]
    for batch in tqdm(range(batch_count)):
        disc.zero_grad()
        data_list = data[batch*BATCH_SIZE:(batch+1)*BATCH_SIZE]
        data_batch = loadImages(data_list,INPUT_DATA_DIR,True)
        
        real_label = torch.ones(BATCH_SIZE,1,device=cuda,dtype=torch.float)
#         real_label = torch.FloatTensor(BATCH_SIZE, 1).uniform_(0.9, 1).to(cuda)
        
        disc_out = disc(torch.from_numpy(data_batch).to(cuda).float())
        disc_out = disc_out.reshape((BATCH_SIZE,1))
#         print(disc_out.shape)
#         print(real_label.type)
        Dloss_real = criterion(disc_out,real_label)
        Dloss_real.backward()
        
        noise = torch.randn(BATCH_SIZE, 100, 1, 1, device=cuda)
        
#         noise = torch.FloatTensor(BATCH_SIZE, 100,1,1).uniform_(0, 1).to(cuda)
        fake_image = gen(noise)
        fake_out = disc(fake_image)
        fake_out = fake_out.reshape((BATCH_SIZE,1))
 
        fake_label = torch.zeros(BATCH_SIZE,1,device=cuda,dtype=torch.float)
#         fake_label = torch.FloatTensor(BATCH_SIZE, 1).uniform_(0, 0.1).to(cuda)
    
#     
        Dloss_fake = criterion(fake_out,fake_label)
#         Dloss_fake = 1 - fake_out
        Dloss_fake.sum().backward(retain_graph=True)
        dloss = Dloss_real+Dloss_fake
#         print(dloss.mean().item())
        Dloss.append(dloss.item())
#         dloss.backward(retain_graph=True)
  
        disc_optimizer.step()
        
        gen.zero_grad()
        gloss = criterion(fake_out,real_label)
        
        gloss.backward()
        Gloss.append(gloss.item())
        n.utils.clip_grad_norm_(gen.parameters(),1e3)
        
        gen_optimizer.step()
        
        torch.cuda.empty_cache()
        
        
    if(epoch%5==0):

        checkpoint = {'model': Generator(),
              'input_size': 64,
              'output_size': 64,
              'state_dict': gen.state_dict()}
        torch.save(checkpoint,os.path.join(weight_folder,"DCGAN"+str(epoch+1)+".pth"))

        out_images = imagePostProcess(fixed_noise,os.path.join(weight_folder,"DCGAN"+str(epoch+1)+".pth"))
        show_samples(out_images)
    print("Epoch ::::  "+str(epoch+1)+" GenLoss ==>> "+str(np.mean(Gloss))+" DiscLoss ==>> "+str(np.mean(Dloss)))
