<a href="https://colab.research.google.com/github/saquibali7/GANs/blob/main/GAN3_WGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [2]:
class Discriminator(nn.Module):
  def __init__(self, channels_img, features_d):
    super(Discriminator, self).__init__()

    self.disc = nn.Sequential(
        nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(0.2),
        self._block(features_d, features_d*2, 4,2,1),
        self._block(features_d*2, features_d*4, 4,2,1),
        self._block(features_d*4, features_d*8, 4,2,1),
        nn.Conv2d(features_d*8,1,kernel_size=4,stride=2,padding = 0),
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
      return nn.Sequential(
          nn.Conv2d(in_channels,out_channels, kernel_size, stride, padding, bias=False),
          nn.InstanceNorm2d(out_channels, affine=True),
          nn.LeakyReLU(0.2),
      )

  def forward(self, x):
        return self.disc(x)

In [3]:


class Generator(nn.Module):
  def __init__(self, channels_noise, channels_img, features_g):
    super(Generator, self).__init__()
    self.gen = nn.Sequential(
        self._block(channels_noise, features_g*16, 4, 1, 0),
        self._block(features_g*16, features_g*8, 4,2,1),
        self._block(features_g*8, features_g*4, 4,2,1),
        self._block(features_g*4, features_g*2, 4,2,1),
        nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1),
        nn.Tanh(),
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
      return nn.Sequential(
          nn.ConvTranspose2d(
                    in_channels, 
                    out_channels, 
                    kernel_size, 
                    stride, 
                    padding, 
                    bias=False),
          nn.BatchNorm2d(out_channels), ## for WGAN with gradient penalty use nn.InstanceNorm2d
          nn.ReLU(),          
                  
      )

  def forward(self, x):
        return self.gen(x)

In [4]:
def initialise_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal(m.weight.data, 0.0, 0.02)

In [5]:
def test():
  N, in_channels, H, W = 8, 3, 64, 64
  noise_dim = 100
  x = torch.randn(N, in_channels, H, W)
  disc = Discriminator(in_channels, 8)
  initialise_weights(disc)
  assert disc(x).shape == (N,1,1,1), "fail"
  gen = Generator(noise_dim, in_channels, 8)
  initialise_weights(gen)
  z = torch.randn(N, noise_dim, 1, 1)
  assert gen(z).shape == (N, in_channels, H, W), "fail"

In [6]:
test()

  after removing the cwd from sys.path.


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE= 5e-5 ## form WGAN with gradient penalty lr = 1e-4
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 3
NOISE_DIM = 128
EPOCHS = 250
FEATURES_DISC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
WEIGHTS_CLIP = 0.01

In [8]:
transforms = transforms.Compose(
    [
     transforms.Resize(IMAGE_SIZE),
     transforms.ToTensor(),
     transforms.Normalize(
         [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
     ),
    ]
)

In [9]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [10]:
import os

path = '/content/drive/MyDrive/Morula/'
file = os.listdir(path)
print(len(file))

1478


In [11]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [12]:
train_transforms = A.Compose(
    [   
        A.Resize(IMAGE_SIZE, IMAGE_SIZE),
        A.Normalize(
         [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
        ToTensorV2(),
    ]
)

In [13]:
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision.datasets as datasets
from PIL import Image

In [14]:
class ImageDataset(Dataset):
  def __init__(self, root_dir, transform=None):
    self.root_dir = root_dir
    self.files = os.listdir(root_dir)
    self.transform = transform

  def __len__(self):
    return len(self.files)

  def __getitem__(self, idx):
    img_file = self.files[idx]
    img_path = os.path.join(self.root_dir, img_file)
    img = np.array(Image.open(img_path))
    if self.transform:
      img = self.transform(image=img)['image']
    return img

In [15]:
dataset = ImageDataset(root_dir=path, transform=train_transforms)
loader = DataLoader(dataset, batch_size = BATCH_SIZE)

gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)

In [16]:
initialise_weights(critic)
initialise_weights(gen)
opt_critic = optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)
opt_gen = optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)




  after removing the cwd from sys.path.


In [17]:
fixed_noise = torch.randn(32, NOISE_DIM, 1,1).to(device)


In [18]:
!pip install wandb
import wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.13.1-py2.py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 25.5 MB/s 
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.9.2-py2.py3-none-any.whl (157 kB)
[K     |████████████████████████████████| 157 kB 54.9 MB/s 
Collecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.9-py3-none-any.whl (9.4 kB)
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting GitPython>=1.0.0
  Downloading GitPython-3.1.27-py3-none-any.whl (181 kB)
[K     |████████████████████████████████| 181 kB 73.0 MB/s 
[?25hCollecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting setproctitle
  Downloading setproctitle-1.3.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting gitdb<5,>=4.0.1
  Downloading gitdb-4.0.9-

In [19]:
def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    mixed_scores = critic(interpolated_images)

    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty



In [20]:
from torchvision.utils import make_grid

In [21]:
step=0
critic.train()
gen.train()

Generator(
  (gen): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(128, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): ConvTranspose2d(128, 3, 

In [None]:
wandb.init("task-01")

for epoch in range(EPOCHS):
  for batch_idx, (real) in enumerate(loader):
    real = real.to(device)
    curr_batch_size = real.shape[0]

    for _ in range(CRITIC_ITERATIONS):
      noise = torch.randn(curr_batch_size, NOISE_DIM, 1, 1).to(device)
      fake = gen(noise)
      critic_real = critic(real).reshape(-1)
      critic_fake = critic(fake).reshape(-1)
      gp = gradient_penalty(critic, real, fake, device=device)
      critic_loss = (-(torch.mean(critic_real) - torch.mean(critic_fake)) + 10 * gp)
      critic.zero_grad()
      critic_loss.backward(retain_graph=True)
      opt_critic.step()  

    ## Training of Generator
    
    gen_fake = critic(fake).reshape(-1)
    loss_gen = -torch.mean(gen_fake)
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()

    
  wandb.log({"Generartor  Loss": loss_gen, 'epoch':epoch })
  wandb.log({"Discriminator Loss": critic_loss, 'epoch':epoch })

  with torch.inference_mode():
   for batch_idx, real_img in enumerate(loader):
     noise = torch.randn(curr_batch_size, NOISE_DIM, 1, 1).to(device)
     fake_img = gen(noise)
     grid1 = make_grid(real_img)
     grid2 = make_grid(fake_img)
     grid1 = wandb.Image(grid1, caption="Original Image")
     grid2 = wandb.Image(grid2, caption="Fake Image")
     wandb.log({"Original Image ": grid1})
     wandb.log({"Fake Image": grid2})

    

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
## for WGAN


In [None]:
ran  = torch.randn((1,128,1,1)).to(device)
x = gen(ran)

In [None]:
import matplotlib.pyplot as plt

In [None]:
x.shape

torch.Size([1, 3, 64, 64])

In [None]:
x = x.detach().

In [None]:
plt.imshow(x[0].permute(1,2,0))

AttributeError: ignored