In [1]:
import torch
import numpy as np
import cv2
import os
import random as rand
import torchvision
import pandas as pd
from tqdm import tqdm
from torch import nn, Tensor
import matplotlib.pyplot as plt
from typing import Optional
from torch.nn import functional as F
from torchvision.transforms import v2 as T
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
from math import ceil

In [2]:
class CIFAR(Dataset):
    def __init__(self, path="/scratch/s25090/archive/cifar-10/train", dataset:Optional[list]=None):
        super().__init__()
        self.path = path
        self.files = os.listdir(self.path) if dataset is None else dataset
        self.T = T.Compose([
           T.ToImage(), 
           T.ToDtype(torch.float32, scale=True),
           T.Resize((32, 32)),
           T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def __len__(self,):
        return len(self.files)
    
    def __getitem__(self, idx):
        file = self.files[idx]
        img_path = os.path.join(self.path, file)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.T(img)
        return img

In [3]:
class GenBlock(nn.Module):
    def __init__(self, in_channel, out_channel, is_final):
        super().__init__()
        layers = [
            nn.Conv2d(in_channel, (out_channel+in_channel)//2, 3, 1, 1),
            nn.BatchNorm2d((out_channel+in_channel)//2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d((in_channel+out_channel)//2, out_channel, 3, 1, 1)
        ]
        
        if not is_final:
            layers.append(nn.BatchNorm2d(out_channel))
            layers.append(nn.LeakyReLU(0.2, inplace=True))

        layers.append(nn.UpsamplingNearest2d(scale_factor=2))
        self.layer = nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)

class DisBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channel, (out_channel+in_channel)//2, 3, 1, 1),
            nn.BatchNorm2d((out_channel+in_channel)//2),
            nn.LeakyReLU(0.2),
            nn.Conv2d((out_channel+in_channel)//2, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, 2)
        )
    
    def forward(self, x):
        return self.layer(x)

class ResGenBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.conv_block = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        )
        
        self.shortcut = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)
        )

    def forward(self, x):
        return self.conv_block(x) + self.shortcut(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=100):
        super().__init__()
        self.initial_linear = nn.Linear(z_dim, 1024 * 4 * 4)
        
        self.net = nn.Sequential(
            GenBlock(1024, 512, is_final=False), 
            GenBlock(512, 256, is_final=False),  
            GenBlock(256, 128,  is_final=False), 
            GenBlock(128,  64,  is_final=False),
            GenBlock(64,  64,  is_final=True),
        )
        
        self.final_layer = nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        if len(z.shape) > 2:
            z = z.view(z.size(0), -1)
            
        x = self.initial_linear(z)
        x = x.view(-1, 1024, 4, 4)
        x = self.net(x)
        return self.final_layer(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            DisBlock(3, 32),   
            DisBlock(32, 64),
            DisBlock(64, 128),
        )
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 1) 
        )

    def forward(self, x):
        x = self.net(x)
        return self.classifier(x)

class ResNetGenerator(nn.Module):
    def __init__(self, z_dim=100, base_channels=256):
        super().__init__()
        self.linear = nn.Linear(z_dim, 4 * 4 * base_channels)
        self.base_channels = base_channels

        self.blocks = nn.Sequential(
            ResGenBlock(base_channels, base_channels),    
            ResGenBlock(base_channels, base_channels // 2), 
            ResGenBlock(base_channels // 2, base_channels // 4),
        )
        
        self.final_layer = nn.Sequential(
            nn.BatchNorm2d(base_channels // 4),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels // 4, 3, 3, 1, 1), 
            nn.Tanh() 
        )

    def forward(self, z):
        if z.ndim > 2: z = z.view(z.size(0), -1)
            
        x = self.linear(z)
        x = x.view(-1, self.base_channels, 4, 4)
        x = self.blocks(x)
        return self.final_layer(x)

class WGANModel(nn.Module):
    def __init__(self, z_dim=100, is_res=True):
        super().__init__()
        self.generator = Generator(z_dim) if not is_res else ResNetGenerator()
        self.discriminator = Discriminator()
        self.z_dim = z_dim

    def forward(self, z):
        return self.generator(z)

    def compute_discriminator_loss(self, real_imgs, z):
        with torch.no_grad():
            fake_imgs = self.generator(z).detach()

        real_logits = self.discriminator(real_imgs)
        fake_logits = self.discriminator(fake_imgs)

        d_loss = -(torch.mean(real_logits) - torch.mean(fake_logits))
        
        return d_loss

    def compute_generator_loss(self, z):
        fake_imgs = self.generator(z)
        
        fake_logits = self.discriminator(fake_imgs)
        
        g_loss = -torch.mean(fake_logits)
        
        return g_loss, fake_imgs

In [None]:
LEARNING_RATE = 5e-5        
WEIGHT_CLIP = 0.01
N_CRITIC = 5  
DEVICE = 'cuda:3' if torch.cuda.is_available() else 'cpu'
epochs = 300

gan_model = WGANModel()
gan_model = gan_model.to(DEVICE)

opt_gen = torch.optim.RMSprop(gan_model.generator.parameters(), lr=LEARNING_RATE)
opt_dis = torch.optim.RMSprop(gan_model.discriminator.parameters(), lr=LEARNING_RATE)

train_dataset = CIFAR()
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

gen_loss_list = []
dis_loss_list = []

for epoch in range(epochs):
    gan_model.train()
    tqdm_data = tqdm(train_loader, desc=f"Epoch-{epoch+1}/{epochs}")
    
    batch_gen_loss = 0
    batch_dis_loss = 0
    
    for batch_idx, (real_img) in enumerate(tqdm_data):
        real_img = real_img.to(DEVICE)
        bs = real_img.size(0)

        for param in gan_model.discriminator.parameters():
            param.requires_grad = True

        z_dis = torch.randn(bs, 100).to(DEVICE)
        opt_dis.zero_grad()
        
        dis_loss = gan_model.compute_discriminator_loss(real_img, z_dis)
        dis_loss.backward()
        opt_dis.step()

        for p in gan_model.discriminator.parameters():
            p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)
        
        batch_dis_loss += dis_loss.item()

        if batch_idx % N_CRITIC == 0:
            for param in gan_model.discriminator.parameters():
                param.requires_grad = False # Freeze D to save computation
            
            z = torch.randn(bs, 100).to(DEVICE)
            opt_gen.zero_grad()
            
            gen_loss, fake_img = gan_model.compute_generator_loss(z)
            gen_loss.backward()
            opt_gen.step()
            
            batch_gen_loss += gen_loss.item()
            current_gen_loss = gen_loss.item() # For tqdm
        else:
            current_gen_loss = batch_gen_loss / (batch_idx + 1) if batch_idx > 0 else 0

        tqdm_data.set_postfix({
            "GenLoss": current_gen_loss,
            "DisLoss": dis_loss.item()
        })

    avg_gen_loss = batch_gen_loss / (len(train_loader) / N_CRITIC)
    avg_dis_loss = batch_dis_loss / len(train_loader)
    
    gen_loss_list.append(avg_gen_loss)
    dis_loss_list.append(avg_dis_loss)

    print(f"Generator Loss: {avg_gen_loss:.4f}\nDiscriminator Loss: {avg_dis_loss:.4f}")

    if (epoch + 1) % 10 == 0:
        torch.save(gan_model.state_dict(), f'/scratch/s25090/wgan_outputs/weights/Experiment1/gan_epoch_{epoch+1}.pth')
        gan_model.eval()
        with torch.no_grad():
            test_z = torch.randn(8, 100).to(DEVICE)
            gan_image = gan_model.generator(test_z)
            
            comparison = torch.cat([real_img[:8], gan_image[:8]], dim=0)
            grid = make_grid(comparison.cpu(), nrow=8, padding=2, normalize=True)
            
            plt.figure(figsize=(12, 4))
            plt.imshow(grid.permute(1, 2, 0))
            plt.axis('off')
            plt.title(f'Top: Original | Bottom: Generated Image (Epoch {epoch+1})')
            plt.savefig(f"/scratch/s25090/wgan_outputs/plots/Experiment1/Epoch-{epoch+1}.png")
            plt.close()

plt.figure(figsize=(10, 5))
plt.title("Generator vs Discriminator Loss (WGAN)")
plt.plot(gen_loss_list, label="Generator")
plt.plot(dis_loss_list, label="Discriminator")
plt.xlabel("Epochs")
plt.ylabel("Wasserstein Loss")
plt.savefig(f"/scratch/s25090/wgan_outputs/plots/Experiment1_loss.png")
plt.legend()
plt.show()

Epoch-1/300: 100%|██████████| 391/391 [02:41<00:00,  2.43it/s, GenLoss=0.0106, DisLoss=-8.13e-6] 


Generator Loss: 0.0107
Discriminator Loss: 0.0005


Epoch-2/300: 100%|██████████| 391/391 [01:17<00:00,  5.06it/s, GenLoss=0.0112, DisLoss=-8.43e-5] 


Generator Loss: 0.0109
Discriminator Loss: -0.0000


Epoch-3/300: 100%|██████████| 391/391 [00:49<00:00,  7.92it/s, GenLoss=0.114, DisLoss=-0.0601]    


Generator Loss: 0.0347
Discriminator Loss: -0.0172


Epoch-4/300: 100%|██████████| 391/391 [00:31<00:00, 12.26it/s, GenLoss=0.159, DisLoss=-0.113]   


Generator Loss: 0.1103
Discriminator Loss: -0.0748


Epoch-5/300: 100%|██████████| 391/391 [00:31<00:00, 12.48it/s, GenLoss=0.0595, DisLoss=-0.11]  


Generator Loss: 0.1091
Discriminator Loss: -0.0966


Epoch-6/300: 100%|██████████| 391/391 [00:31<00:00, 12.58it/s, GenLoss=0.12, DisLoss=-0.148]   


Generator Loss: 0.1016
Discriminator Loss: -0.1105


Epoch-7/300: 100%|██████████| 391/391 [00:31<00:00, 12.43it/s, GenLoss=0.151, DisLoss=-0.129]  


Generator Loss: 0.1026
Discriminator Loss: -0.1174


Epoch-8/300: 100%|██████████| 391/391 [00:30<00:00, 12.92it/s, GenLoss=0.0947, DisLoss=-0.141] 


Generator Loss: 0.1086
Discriminator Loss: -0.1213


Epoch-9/300: 100%|██████████| 391/391 [01:53<00:00,  3.44it/s, GenLoss=0.103, DisLoss=-0.127]  


Generator Loss: 0.1042
Discriminator Loss: -0.1142


Epoch-10/300: 100%|██████████| 391/391 [02:25<00:00,  2.69it/s, GenLoss=0.0807, DisLoss=-0.124] 


Generator Loss: 0.0981
Discriminator Loss: -0.1098


Epoch-11/300: 100%|██████████| 391/391 [03:16<00:00,  1.99it/s, GenLoss=0.00874, DisLoss=-0.115]


Generator Loss: 0.0924
Discriminator Loss: -0.1054


Epoch-12/300: 100%|██████████| 391/391 [01:55<00:00,  3.37it/s, GenLoss=0.14, DisLoss=-0.0857]    


Generator Loss: 0.0858
Discriminator Loss: -0.0960


Epoch-13/300: 100%|██████████| 391/391 [01:40<00:00,  3.89it/s, GenLoss=0.143, DisLoss=-0.102]     


Generator Loss: 0.0702
Discriminator Loss: -0.0846


Epoch-14/300: 100%|██████████| 391/391 [01:24<00:00,  4.62it/s, GenLoss=-0.0353, DisLoss=-0.0845]  


Generator Loss: 0.0634
Discriminator Loss: -0.0838


Epoch-15/300: 100%|██████████| 391/391 [01:43<00:00,  3.79it/s, GenLoss=0.136, DisLoss=-0.0758]   


Generator Loss: 0.0624
Discriminator Loss: -0.0816


Epoch-16/300: 100%|██████████| 391/391 [01:47<00:00,  3.63it/s, GenLoss=0.00366, DisLoss=-0.0857] 


Generator Loss: 0.0650
Discriminator Loss: -0.0768


Epoch-17/300: 100%|██████████| 391/391 [01:12<00:00,  5.36it/s, GenLoss=0.00555, DisLoss=-0.0602]  


Generator Loss: 0.0642
Discriminator Loss: -0.0779


Epoch-18/300: 100%|██████████| 391/391 [00:56<00:00,  6.87it/s, GenLoss=0.129, DisLoss=-0.085]    


Generator Loss: 0.0638
Discriminator Loss: -0.0782


Epoch-19/300: 100%|██████████| 391/391 [02:09<00:00,  3.01it/s, GenLoss=0.0943, DisLoss=-0.104]    


Generator Loss: 0.0610
Discriminator Loss: -0.0786


Epoch-20/300: 100%|██████████| 391/391 [02:53<00:00,  2.25it/s, GenLoss=-0.0307, DisLoss=-0.0828]  


Generator Loss: 0.0545
Discriminator Loss: -0.0786


Epoch-21/300: 100%|██████████| 391/391 [01:03<00:00,  6.20it/s, GenLoss=0.118, DisLoss=-0.0741]   


Generator Loss: 0.0521
Discriminator Loss: -0.0791


Epoch-22/300: 100%|██████████| 391/391 [00:43<00:00,  8.97it/s, GenLoss=0.116, DisLoss=-0.0607]    


Generator Loss: 0.0563
Discriminator Loss: -0.0791


Epoch-23/300: 100%|██████████| 391/391 [00:26<00:00, 14.59it/s, GenLoss=0.107, DisLoss=-0.106]    


Generator Loss: 0.0531
Discriminator Loss: -0.0801


Epoch-24/300: 100%|██████████| 391/391 [00:30<00:00, 12.91it/s, GenLoss=0.112, DisLoss=-0.0879]    


Generator Loss: 0.0508
Discriminator Loss: -0.0789


Epoch-25/300: 100%|██████████| 391/391 [00:28<00:00, 13.58it/s, GenLoss=0.0959, DisLoss=-0.0884]  


Generator Loss: 0.0483
Discriminator Loss: -0.0784


Epoch-26/300: 100%|██████████| 391/391 [00:28<00:00, 13.53it/s, GenLoss=0.121, DisLoss=-0.0802]   


Generator Loss: 0.0433
Discriminator Loss: -0.0770


Epoch-27/300: 100%|██████████| 391/391 [00:28<00:00, 13.94it/s, GenLoss=-0.00523, DisLoss=-0.0862]


Generator Loss: 0.0360
Discriminator Loss: -0.0761


Epoch-28/300: 100%|██████████| 391/391 [00:28<00:00, 13.73it/s, GenLoss=-0.012, DisLoss=-0.0815] 


Generator Loss: 0.0321
Discriminator Loss: -0.0761


Epoch-29/300: 100%|██████████| 391/391 [00:28<00:00, 13.87it/s, GenLoss=0.0975, DisLoss=-0.0716]  


Generator Loss: 0.0371
Discriminator Loss: -0.0751


Epoch-30/300: 100%|██████████| 391/391 [00:28<00:00, 13.86it/s, GenLoss=0.124, DisLoss=-0.0664]  


Generator Loss: 0.0386
Discriminator Loss: -0.0755


Epoch-31/300: 100%|██████████| 391/391 [00:29<00:00, 13.40it/s, GenLoss=-0.0329, DisLoss=-0.0744] 


Generator Loss: 0.0339
Discriminator Loss: -0.0758


Epoch-32/300: 100%|██████████| 391/391 [00:30<00:00, 12.90it/s, GenLoss=0.121, DisLoss=-0.0678]   


Generator Loss: 0.0423
Discriminator Loss: -0.0766


Epoch-33/300: 100%|██████████| 391/391 [00:29<00:00, 13.16it/s, GenLoss=0.0758, DisLoss=-0.0913] 


Generator Loss: 0.0365
Discriminator Loss: -0.0761


Epoch-34/300: 100%|██████████| 391/391 [00:29<00:00, 13.13it/s, GenLoss=0.0781, DisLoss=-0.0686] 


Generator Loss: 0.0417
Discriminator Loss: -0.0748


Epoch-35/300: 100%|██████████| 391/391 [00:30<00:00, 12.97it/s, GenLoss=0.126, DisLoss=-0.0895]   


Generator Loss: 0.0431
Discriminator Loss: -0.0743


Epoch-36/300: 100%|██████████| 391/391 [00:29<00:00, 13.30it/s, GenLoss=-0.0351, DisLoss=-0.0762] 


Generator Loss: 0.0415
Discriminator Loss: -0.0741


Epoch-37/300: 100%|██████████| 391/391 [00:28<00:00, 13.96it/s, GenLoss=-0.0217, DisLoss=-0.075]  


Generator Loss: 0.0473
Discriminator Loss: -0.0741


Epoch-38/300: 100%|██████████| 391/391 [00:29<00:00, 13.10it/s, GenLoss=0.125, DisLoss=-0.0774]  


Generator Loss: 0.0415
Discriminator Loss: -0.0738


Epoch-39/300: 100%|██████████| 391/391 [00:30<00:00, 12.68it/s, GenLoss=-0.0182, DisLoss=-0.0897] 


Generator Loss: 0.0409
Discriminator Loss: -0.0723


Epoch-40/300: 100%|██████████| 391/391 [00:29<00:00, 13.41it/s, GenLoss=0.121, DisLoss=-0.0824]   


Generator Loss: 0.0376
Discriminator Loss: -0.0729


Epoch-41/300: 100%|██████████| 391/391 [00:32<00:00, 12.06it/s, GenLoss=0.114, DisLoss=-0.0826]   


Generator Loss: 0.0418
Discriminator Loss: -0.0719


Epoch-42/300: 100%|██████████| 391/391 [00:31<00:00, 12.40it/s, GenLoss=0.129, DisLoss=-0.0657]   


Generator Loss: 0.0434
Discriminator Loss: -0.0711


Epoch-43/300: 100%|██████████| 391/391 [00:31<00:00, 12.57it/s, GenLoss=0.00302, DisLoss=-0.0682] 


Generator Loss: 0.0406
Discriminator Loss: -0.0712


Epoch-44/300: 100%|██████████| 391/391 [00:31<00:00, 12.54it/s, GenLoss=-0.0257, DisLoss=-0.0799] 


Generator Loss: 0.0402
Discriminator Loss: -0.0702


Epoch-45/300: 100%|██████████| 391/391 [00:32<00:00, 12.12it/s, GenLoss=-0.00119, DisLoss=-0.0818]


Generator Loss: 0.0439
Discriminator Loss: -0.0699


Epoch-46/300: 100%|██████████| 391/391 [00:30<00:00, 12.94it/s, GenLoss=0.0837, DisLoss=-0.0715]  


Generator Loss: 0.0402
Discriminator Loss: -0.0692


Epoch-47/300: 100%|██████████| 391/391 [00:32<00:00, 11.97it/s, GenLoss=-0.0258, DisLoss=-0.0543] 


Generator Loss: 0.0385
Discriminator Loss: -0.0683


Epoch-48/300: 100%|██████████| 391/391 [00:33<00:00, 11.72it/s, GenLoss=0.0957, DisLoss=-0.063]  


Generator Loss: 0.0371
Discriminator Loss: -0.0675


Epoch-49/300: 100%|██████████| 391/391 [00:30<00:00, 12.67it/s, GenLoss=0.0957, DisLoss=-0.0601] 


Generator Loss: 0.0371
Discriminator Loss: -0.0674


Epoch-50/300: 100%|██████████| 391/391 [00:31<00:00, 12.27it/s, GenLoss=-0.00226, DisLoss=-0.0691] 


Generator Loss: 0.0398
Discriminator Loss: -0.0663


Epoch-51/300: 100%|██████████| 391/391 [00:32<00:00, 11.97it/s, GenLoss=0.0285, DisLoss=-0.0875]  


Generator Loss: 0.0343
Discriminator Loss: -0.0653


Epoch-52/300: 100%|██████████| 391/391 [00:33<00:00, 11.84it/s, GenLoss=-0.00583, DisLoss=-0.0774]


Generator Loss: 0.0358
Discriminator Loss: -0.0650


Epoch-53/300: 100%|██████████| 391/391 [00:31<00:00, 12.54it/s, GenLoss=-0.0359, DisLoss=-0.0729] 


Generator Loss: 0.0369
Discriminator Loss: -0.0640


Epoch-54/300:  29%|██▉       | 115/391 [00:46<26:40,  5.80s/it, GenLoss=0.00851, DisLoss=-0.063] 