In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import ParameterGrid
from tqdm.notebook import tqdm
from IPython.display import Image
device = torch.device('cuda')

In [2]:
param_grid = {
    'lr_g': [0.0001, 0.0002],
    'lr_d': [0.0001, 0.0002],
    'num_epochs': [4, 5],
    'beta1': [0.5, 0.9],
    'beta2': [0.999, 0.99]
    
}

In [3]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()

        def conv_up_block(in_channels, out_channels):
            layers = [
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels, momentum=0.78),
                nn.ReLU(),
            ]
            return layers

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128 * 8 * 8),
            nn.ReLU(),
            nn.Unflatten(1, (128, 8, 8)),

            *conv_up_block(128, 128),

            *conv_up_block(128, 64),

            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img

In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        def conv_dis_block(in_channels, out_channels, stride):
            layers = [
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
                nn.BatchNorm2d(out_channels, momentum=0.82),
                nn.LeakyReLU(0.25),
                nn.Dropout(0.25),
            ]
            return layers

        self.model = nn.Sequential(

            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.25),
            
            *conv_dis_block(32, 64, 2),

            *conv_dis_block(64, 128, 2),

            *conv_dis_block(128, 256, 1),

            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        validity = self.model(img)
        return validity


In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

In [6]:
adversarial_loss = nn.BCELoss()
results = []
latent_dim = 100

In [None]:
grid = ParameterGrid(param_grid)

results = []

for params in tqdm(grid, desc="Grid Search Progress", position=0, total=len(grid), leave=True):

    generator = Generator(latent_dim=latent_dim).to(device)
    discriminator = Discriminator().to(device)

    optimizer_G = optim.Adam(generator.parameters(), lr=params['lr_g'], betas=(params['beta1'], params['beta2']))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=params['lr_d'], betas=(params['beta1'], params['beta2']))

    for epoch in tqdm(range(params['num_epochs']), total=params['num_epochs'], desc=f"Epoch", position=1, leave=False):

        with tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{params['num_epochs']}", position=2, leave=False) as pbar:
            running_g_loss = 0.0
            running_d_loss = 0.0
            for i, batch in pbar:
                real_images = batch[0].to(device)

                valid = torch.ones(real_images.size(0), 1, device=device) #when discriminator classiffies, 0 for fake, 1 for real
                fake = torch.zeros(real_images.size(0), 1, device=device)

                # Train Discriminator
                # ---------------------
                optimizer_D.zero_grad()
                z = torch.randn(real_images.size(0), 100, device=device)
                fake_images = generator(z)

                real_loss = adversarial_loss(discriminator(real_images), valid)
                fake_loss = adversarial_loss(discriminator(fake_images.detach()), fake)
                d_loss = (real_loss + fake_loss) / 2
                d_loss.backward()
                optimizer_D.step()

                # Train Generator
                # ---------------------
                optimizer_G.zero_grad()
                g_loss = adversarial_loss(discriminator(fake_images), valid)
                g_loss.backward()
                optimizer_G.step()

                running_g_loss += g_loss.item()
                running_d_loss += d_loss.item()

                pbar.set_postfix(g_loss=running_g_loss / (i + 1), d_loss=running_d_loss / (i + 1))

        results.append({
            'lr_g': params['lr_g'],
            'lr_d': params['lr_d'],
            'epoch': epoch + 1,
            'd_loss': running_d_loss / len(dataloader),
            'g_loss': running_g_loss / len(dataloader),
        })

        if (epoch+1) % params['num_epochs'] ==0:
            with torch.no_grad():
                z = torch.randn(16, latent_dim, device=device)
                generated = generator(z).detach().cpu()
                grid = torchvision.utils.make_grid(generated,\
                                            nrow=4, normalize=True)
                
                grid_np = grid.numpy()
                grid_np = grid_np.transpose(grid_np, (1, 2, 0))
                
                plot_path = f"./plots/generated_epoch_{epoch+1}.png"
                plt.imshow(grid_np)
                plt.axis("off")
                plt.imsave(plot_path, grid_np)
                plt.close()
 
            display(Image(filename=plot_path))

print(results)


In [None]:
results_df = pd.DataFrame(results)

heatmap_data = results_df.pivot_table(index='lr_g', columns='lr_d', values='d_loss')

plt.figure(figsize=(8, 6))
sns.heatmap(heatmap_data, annot=True, cmap='viridis', fmt='.4f', cbar=True)
plt.title('Discriminator Loss for Different Learning Rates')
plt.xlabel('Learning Rate for Discriminator')
plt.ylabel('Learning Rate for Generator')
plt.show()

heatmap_data_gen = results_df.pivot_table(index='lr_g', columns='lr_d', values='g_loss')

plt.figure(figsize=(8, 6))
sns.heatmap(heatmap_data_gen, annot=True, cmap='viridis', fmt='.4f', cbar=True)
plt.title('Generator Loss for Different Learning Rates')
plt.xlabel('Learning Rate for Discriminator')
plt.ylabel('Learning Rate for Generator')
plt.show()

In [None]:
torch.save(generator, "./genV1")
torch.save(discriminator, "./genV1")