In [None]:
import sys
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torch.autograd.variable import Variable
from torchvision import datasets, transforms
from PIL import Image

In [None]:
print(sys.version)
device='cuda'

train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')


In [None]:
 folder_data = glob.glob("../input/celeba-dataset/img_align_celeba/img_align_celeba/*.jpg")
 len_data = len(folder_data)
 print(len_data)

 train_image_paths = folder_data[0:200000]

 class TrainDataset(Dataset):
   def __init__(self, image_paths, train=True):
     self.image_paths = image_paths
     self.transforms = transforms.Compose([
                               transforms.Resize(64),
                               transforms.CenterCrop(64),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])
     
   def __getitem__(self, index):
     image = Image.open(self.image_paths[index])
     t_image = self.transforms(image)
     return t_image

   def __len__(self):
     return len(self.image_paths)


In [None]:
train_dataset = TrainDataset(train_image_paths,train = True)
print(len(train_dataset))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


In [None]:
class GeneratorNet(torch.nn.Module):
  def __init__(self):
    super(GeneratorNet, self).__init__()
    self.main = nn.Sequential(
        nn.ConvTranspose2d(100, 1024, kernel_size = 4, stride = 1, padding = 0, bias = False),
        nn.BatchNorm2d(1024),
        nn.ReLU(inplace = True),

        nn.ConvTranspose2d(1024, 512, kernel_size = 4, stride = 2, padding = 1, bias =False),
        nn.BatchNorm2d(512),
        nn.ReLU(inplace = True),

        nn.ConvTranspose2d(512, 256, kernel_size = 4, stride = 2, padding = 1, bias=False),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace = True),

        nn.ConvTranspose2d(256, 128, kernel_size = 4, stride = 2, padding = 1, bias=False),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace = True),

        nn.ConvTranspose2d(128, 3, kernel_size = 4, stride = 2, padding = 1, bias=False),
        nn.Tanh()
    )
    

  def forward(self, x):
    #print(x)
    x = self.main(x)
    #print(x.shape)
    return x

generator = GeneratorNet()
generator.float()
generator = generator.to(device)

generator.apply(weights_init)

print(generator)

In [None]:
class DiscriminatorNet(torch.nn.Module):
  def __init__(self):
    super(DiscriminatorNet, self).__init__()
    self.main = nn.Sequential(
      nn.Conv2d(3, 128, kernel_size = 5, stride = 2, padding = 2, bias = False),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(128, 256, kernel_size = 5, stride = 2, padding = 2, bias = False),
      nn.BatchNorm2d(256),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(256, 512, kernel_size = 5, stride = 2, padding =2, bias = False),
      nn.BatchNorm2d(512),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(512, 1024, kernel_size = 5, stride = 2, padding = 2, bias = False),
      nn.BatchNorm2d(1024),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(1024, 1, kernel_size = 4, stride = 1, padding = 0, bias = False)
    )
    
  def forward(self, x):
    x = self.main(x)
    return x

discriminator = DiscriminatorNet()
discriminator.float()
discriminator = discriminator.to(device)

discriminator.apply(weights_init)

print(discriminator)

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizerG = optim.Adam(generator.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerD = optim.Adam(discriminator.parameters(), lr = 0.0002, betas = (0.5, 0.999))

In [None]:
def noise(size):
  n = Variable(torch.randn(size, 100, 1, 1))
  n = n.to(device)
  return n

samples = 16
fixed_noise = noise(samples)

In [None]:
lossesD = []
lossesG = []

num_epochs = 30
for epoch in range(num_epochs):
  discriminator.train()
  generator.train()
  lossD = 0
  lossG = 0
  prob_real = 0
  prob_fake = 0 
  for num_iter, (real_batch) in enumerate(train_loader):

    x_real = Variable(real_batch).to(device)
    optimizerD.zero_grad()
    pred_real = discriminator(x_real)
    loss_real = criterion(pred_real.view(-1,1), torch.ones((x_real.size(0),1),device='cuda'))
    loss_real.backward()
    z = noise(x_real.size(0))
    x_fake = generator(z)
    x_fake.detach()
    pred_fake = discriminator(x_fake)
    loss_fake = criterion(pred_fake.view(-1,1), torch.zeros((x_real.size(0),1),device='cuda'))
    loss_fake.backward()
    optimizerD.step()
    lossD = lossD + loss_real + loss_fake

    fake_x = generator(z)
    optimizerG.zero_grad()
    fake_pred = discriminator(fake_x)
    loss_gen = criterion(fake_pred.view(-1,1), torch.ones((x_real.size(0),1),device='cuda'))
    loss_gen.backward()
    optimizerG.step()
    lossG = lossG + loss_gen

  lossesD.append(lossD/len(train_loader))
  lossesG.append(lossG/len(train_loader))
  print("Epoch No. = "+ str(epoch+1))
  print("Discriminator Loss = "+ str(lossesD[epoch].item()), "Generator Loss = "+ str(lossesG[epoch].item()))
  if (epoch+1)%5==0 or (epoch+1)>25:
    torch.save(generator.state_dict(),'g_epoch-{}.pth'.format(epoch+1))
    torch.save(discriminator.state_dict(), 'd_epoch-{}.pth'.format(epoch+1))

  with torch.no_grad():
    generated_images = generator(fixed_noise.detach())
    for i in range(16):
      plt.subplot(4, 4, 1 + i)
      plt.axis('off')
      plt.imshow(np.transpose(generated_images.cpu().numpy()[i],(1,2,0)))
    plt.show()  

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(lossesG,label="G")
plt.plot(lossesD,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
test_noise = noise(4)
with torch.no_grad():
  test_images = generator(test_noise.detach())
  for i in range(4):
	  plt.subplot(2, 2, 1 + i)
	  plt.axis('off')
	  plt.imshow(np.transpose(test_images.cpu().numpy()[i],(1,2,0)))
  plt.show()

In [None]:
test_noise = noise(4)
with torch.no_grad():
  test_images = generator(test_noise.detach())
  for i in range(4):
	  plt.subplot(2, 2, 1 + i)
	  plt.axis('off')
	  plt.imshow(np.transpose(test_images.cpu().numpy()[i],(1,2,0)))
  plt.show()

In [None]:
test_noise = noise(4)
with torch.no_grad():
  test_images = generator(test_noise.detach())
  for i in range(4):
	  plt.subplot(2, 2, 1 + i)
	  plt.axis('off')
	  plt.imshow(np.transpose(test_images.cpu().numpy()[i],(1,2,0)))
  plt.show()