<a href="https://colab.research.google.com/github/satani99/generative_deep_learning/blob/main/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import auth

auth.authenticate_user()

In [2]:
# !curl https://sdk.cloud.google.com | bash

In [3]:
# !gcloud init --skip-diagnostics

In [4]:
!gsutil cp gs://quickdraw_dataset/full/numpy_bitmap/camel.npy .


Copying gs://quickdraw_dataset/full/numpy_bitmap/camel.npy...
- [1 files][ 90.8 MiB/ 90.8 MiB]                                                
Operation completed over 1 objects/90.8 MiB.                                     


In [53]:
import torch.nn as nn 
import torch 
import torchvision.utils as vutils

In [28]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

In [42]:
# THE DISCRIMINATOR
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()

    self.conv0 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout()
    self.conv1 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)
    self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
    self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
    self.linear = nn.Linear(2048, 1)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x = self.dropout(self.relu(self.conv0(x)))
    x = self.dropout(self.relu(self.conv1(x)))
    x = self.dropout(self.relu(self.conv2(x)))
    x = self.dropout(self.relu(self.conv3(x)))
    out = self.sigmoid(self.linear(x.view(-1, 2048)))
    return out





In [7]:
# THE GENERATOR 
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()

    self.linear = nn.Linear(100, 3136)
    self.batchnorm0 = nn.BatchNorm1d(3136)
    self.relu = nn.ReLU()
    self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
    self.conv0 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
    self.batchnorm1 = nn.BatchNorm2d(128)
    self.conv1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
    self.batchnorm2 = nn.BatchNorm2d(64)
    self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 
    self.batchnorm3 = nn.BatchNorm2d(64)
    self.conv3 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) 
    self.tanh = nn.Tanh() 

  def forward(self, x):
    x = self.linear(x)
    x = self.relu(self.batchnorm0(x))
    x = self.upsample(x.view(-1, 64, 7, 7))
    x = self.relu(self.batchnorm1(self.conv0(x)))
    x = self.upsample(x)
    x = self.relu(self.batchnorm2(self.conv1(x)))
    x = self.relu(self.batchnorm3(self.conv2(x)))
    out = self.tanh(self.conv3(x)) 
    return out 






In [44]:
netG = Generator() .to(device)
netD = Discriminator().to(device) 

In [27]:
import torch.optim as optim 

criterion = nn.BCELoss() 

fixed_noise = torch.randn(64, 100, 1, 1, device=device)

real_label = 1.
fake_label = 0. 

optimizerD = optim.Adam(netD.parameters(), lr=0.0008)
optimizerG = optim.Adam(netG.parameters(), lr=0.0008)

In [33]:
import numpy as np 
from torch.utils.data import Dataset, DataLoader 
 
class GANDataset(Dataset):
  def __init__(self, np_file_path):
    self.files = np_file_path 
  
  def __getitem__(self, index):
    x = np.load(self.files)[index]
    x = torch.from_numpy(x).float()
    return x 

  def __len__(self):
    return len(self.files)

In [34]:
dataset = GANDataset('/content/camel.npy')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=1)

In [55]:
img_list = []
G_losses = []
D_losses = [] 
iters = 0

print("Starting Training Loop...")

for epoch in range(2000):
  for i, data in enumerate(dataloader):
    netD.zero_grad() 

    real_cpu = data.to(device)
    b_size = real_cpu.size(0)
    label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

    output = netD(real_cpu.view(18, 1, 28, 28)).view(-1)

    errD_real = criterion(output, label)

    errD_real.backward()
    D_x = output.mean().item() 

    noise = torch.randn(b_size, 100, 1, 1, device=device)

    fake = netG(noise.view(-1, 100))
    label.fill_(fake_label)

    output = netD(fake.detach()).view(-1)

    errD_fake = criterion(output, label) 

    errD_fake.backward()
    D_G_z1 = output.mean().item() 

    errD = errD_real + errD_fake 

    optimizerD.step()

    netG.zero_grad()
    label.fill_(real_label)

    output = netD(fake).view(-1)

    errG = criterion(output, label)

    errG.backward()
    D_G_z2 = output.mean().item()

    optimizerG.step() 

    if i % 50 == 0:
      print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch, 5000, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) 

    G_losses.append(errG.item())
    D_losses.append(errD.item())

    if (iters % 500 == 0) or ((epoch == 5000-1) and (i == len(dataloader)-1)):
      with torch.no_grad():
        fake = netG(fixed_noise.view(-1, 100)).detach().cpu()
      img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

    iters += 1
    

Starting Training Loop...
[0/5000][0/1]	Loss_D: 7.7089	Loss_G: 0.7139	D(x): 0.0175	D(G(z)): 0.4868 / 0.4898
[1/5000][0/1]	Loss_D: 7.5955	Loss_G: 0.7152	D(x): 0.0552	D(G(z)): 0.4902 / 0.4891
[2/5000][0/1]	Loss_D: 7.1456	Loss_G: 0.7171	D(x): 0.0913	D(G(z)): 0.4895 / 0.4882
[3/5000][0/1]	Loss_D: 6.7572	Loss_G: 0.7102	D(x): 0.1510	D(G(z)): 0.4912 / 0.4916
[4/5000][0/1]	Loss_D: 6.2997	Loss_G: 0.7133	D(x): 0.1323	D(G(z)): 0.4898 / 0.4901
[5/5000][0/1]	Loss_D: 8.4730	Loss_G: 0.7158	D(x): 0.0376	D(G(z)): 0.4892 / 0.4888
[6/5000][0/1]	Loss_D: 7.8517	Loss_G: 0.7146	D(x): 0.0520	D(G(z)): 0.4874 / 0.4894
[7/5000][0/1]	Loss_D: 7.1426	Loss_G: 0.7132	D(x): 0.0632	D(G(z)): 0.4918 / 0.4901
[8/5000][0/1]	Loss_D: 8.3667	Loss_G: 0.7164	D(x): 0.0795	D(G(z)): 0.4894 / 0.4886
[9/5000][0/1]	Loss_D: 7.7053	Loss_G: 0.7106	D(x): 0.1188	D(G(z)): 0.4888 / 0.4914
[10/5000][0/1]	Loss_D: 6.7334	Loss_G: 0.7121	D(x): 0.0848	D(G(z)): 0.4913 / 0.4906
[11/5000][0/1]	Loss_D: 6.7259	Loss_G: 0.7143	D(x): 0.0671	D(G(z)): 0.48