<a href="https://colab.research.google.com/github/prikshit-2000/WGan-GP-Celeba-pytorch/blob/main/Wgan_GP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [109]:
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


In [110]:
!mkdir data_faces && wget https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip 

mkdir: cannot create directory ‘data_faces’: File exists


In [111]:
import zipfile

with zipfile.ZipFile("celeba.zip","r") as zip_ref:
  zip_ref.extractall("data_faces/")

In [112]:
import os
root = 'data_faces/img_align_celeba'
img_list = os.listdir(root)
print(len(img_list))

202599


In [161]:

device = torch.device('cuda')
LEARNING_RATE = 1e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_CRITIC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10
# WEIGHT_CLIP = 0.01

In [162]:
my_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),
    
]) 
# dataset = MNIST(root='data/',train=True,transform=my_transforms,download=True)
# dataset = datasets.CelebA(root='data',split="train",transform=my_transforms,download=True)
# dataset = datasets.CelebA('data', split="train", transform=transforms.ToTensor(), download=True)
dataset = datasets.ImageFolder(root="./data_faces", transform=my_transforms)
dataloader =  DataLoader(dataset,BATCH_SIZE,shuffle=True)

In [163]:

# for batch_idx,(real,_) in enumerate(dataset):
#   if real.shape!=torch.Size([3, 64, 64]):
#     print(batch_idx,real.shape)


In [164]:
class Discriminator(nn.Module):

  def __init__(self,channels_img,features_d):
    super(Discriminator,self).__init__()
    self.net = nn.Sequential(
        nn.Conv2d(channels_img,features_d,kernel_size=4,stride=2,padding=1),
        nn.InstanceNorm2d(features_d),
        nn.LeakyReLU(0.2),

        nn.Conv2d(features_d,features_d*2,kernel_size=4,stride=2,padding=1),
        nn.InstanceNorm2d(features_d*2),
        nn.LeakyReLU(0.2),

        nn.Conv2d(features_d*2,features_d*4,kernel_size=4,stride=2,padding=1),
        nn.InstanceNorm2d(features_d*4),
        nn.LeakyReLU(0.2),

        nn.Conv2d(features_d*4,features_d*8,kernel_size=4,stride=2,padding=1),
        nn.InstanceNorm2d(features_d*8),
        nn.LeakyReLU(0.2),

        nn.Conv2d(features_d*8,1,kernel_size=4,stride=2,padding=0),
        
    )
  def forward(self,x):
    return self.net(x)

In [165]:
class Generator(nn.Module):

  def __init__(self,channels_noise,channels_img,features_g):
    super(Generator,self).__init__()
    self.net =nn.Sequential(
        nn.ConvTranspose2d(channels_noise,features_g*16,kernel_size=4,stride=1,padding=0),
        nn.BatchNorm2d(features_g*16),
        nn.ReLU(),

        nn.ConvTranspose2d(features_g*16,features_g*8,kernel_size=4,stride=2,padding=1),
        nn.BatchNorm2d(features_g*8),
        nn.ReLU(),

        nn.ConvTranspose2d(features_g*8,features_g*4,kernel_size=4,stride=2,padding=1),
        nn.BatchNorm2d(features_g*4),
        nn.ReLU(),

        nn.ConvTranspose2d(features_g*4,features_g*2,kernel_size=4,stride=2,padding=1),
        nn.BatchNorm2d(features_g*2),
        nn.ReLU(),

        nn.ConvTranspose2d(features_g*2,channels_img,kernel_size=4,stride=2,padding=1),
        nn.Tanh()
        
           


    )
  def forward(self,x):
    return self.net(x)

In [166]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

In [167]:
def gradient_penalty(critic,real,fake,device='cuda'):
  BATCH_SIZE, C, H, W = real.shape
  epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
  interpolated_images = real * epsilon + fake * (1 - epsilon)
  

  mixed_scores = critic(interpolated_images)

  gradient = torch.autograd.grad(
      inputs = interpolated_images,
      outputs = mixed_scores,
      grad_outputs = torch.ones_like(mixed_scores),
      create_graph = True,
      retain_graph = True,
      
  )[0]

  gradient = gradient.view(gradient.shape[0],-1)
  gradient_norm = gradient.norm(2,dim = 1)
  gradient_penalty = torch.mean((gradient_norm - 1)**2)
  return gradient_penalty


In [None]:
G = Generator(Z_DIM,CHANNELS_IMG,FEATURES_GEN).to(device)
C= Discriminator(CHANNELS_IMG,FEATURES_CRITIC).to(device)
initialize_weights(G)
initialize_weights(C)

In [None]:
opt_gen = optim.Adam(G.parameters(),lr = LEARNING_RATE, betas=(0.0,0.9))
opt_critic= optim.Adam(C.parameters(),lr = LEARNING_RATE,betas=(0.0,0.9))

In [None]:
fixed_noise = torch.randn(32,Z_DIM,1,1).to(device)
writer_real = SummaryWriter(f'logs/real')
writer_fake = SummaryWriter(f'logs/fake')
step = 0

In [None]:
G.train()
C.train()

In [None]:
for  epoch in range(NUM_EPOCHS):
  for batch_idx,(real,_) in enumerate(dataloader):
    real = real.to(device)
    # print(real.shape)  
      
    for _ in range(CRITIC_ITERATIONS):
      noise = torch.randn(BATCH_SIZE,Z_DIM,1,1).to(device)
      fake = G(noise)
      critic_real = C(real).reshape(-1)
      critic_fake = C(fake).reshape(-1)
      gp = gradient_penalty(C,real,fake,device=device)
      loss_critic =  -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP*gp
      C.zero_grad()
      loss_critic.backward(retain_graph = True)
      opt_critic.step()
      
     
        
    output = C(fake).reshape(-1)
    loss_gen = torch.mean(output)
    G.zero_grad()
    loss_gen.backward()
    opt_gen.step()
    
    if batch_idx % 100 == 0 and batch_idx > 0:
      G.eval()
      C.eval()
      print(
          f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
            Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
      )

      with torch.no_grad():
          fake = G(noise)
          
          img_grid_real = torchvision.utils.make_grid(
              real[:32], normalize=True
          )
          img_grid_fake = torchvision.utils.make_grid(
              fake[:32], normalize=True
          )

          writer_real.add_image("Real", img_grid_real, global_step=step)
          writer_fake.add_image("Fake", img_grid_fake, global_step=step)

      step += 1
      G.train()
      C.train()

            
            
            
            
        
        
        

In [None]:
%load_ext tensorboard

In [None]:
tensorboard --logdir /content/logs