In [None]:
!pip install scikit-learn
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import random
import numpy as np
import matplotlib.pyplot as plt
# import wandb
import os
import pickle
from torch import nn
from tqdm.notebook import tqdm
import json
from PIL import Image
from torchvision import transforms, utils
from torchvision import models
import torch.optim as optim
from copy import deepcopy
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import plotly.express as px

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Ensure deterministic behavior taken from wandb docs
torch.backends.cudnn.deterministic = True
random.seed(3407)
np.random.seed(3407)
torch.manual_seed(3407)
torch.cuda.manual_seed_all(3407)

In [None]:
class LabelEmbeds(nn.Module):
    def __init__(self, embed_len):
        super(LabelEmbeds,self).__init__()
        self.embed_out = embed_len
        self.embed_layer = nn.Sequential(
            nn.Linear(in_features=1, out_features = embed_len),
            nn.ReLU()
        )

    def forward(self, labels):
        return self.embed_layer(labels)

In [None]:
class GeneratorNetwork(nn.Module):
    # stride=4 implies that every layer will increase the Height and Width by a factor of 2
    
    def __init__(self, embed_network, len_input, layer_channels):
        super(GeneratorNetwork,self).__init__()

        self.len_input = len_input

        self.embed_network = embed_network

        self.class_embed_layer = nn.Sequential(
            nn.Linear(in_features=self.embed_network.embed_out, out_features = len_input//5),
            nn.ReLU()
        )

        self.gen_layer_1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=len_input, out_channels=layer_channels*8, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(layer_channels*8),
            nn.ReLU(inplace=True)
        )

        self.gen_layer_2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=layer_channels*8, out_channels=layer_channels*4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(layer_channels*4),
            nn.ReLU(inplace=True)
        )

        self.gen_layer_3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=layer_channels*4, out_channels=layer_channels*2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(layer_channels*2),
            nn.ReLU(inplace=True)
        )

        self.gen_layer_4 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=layer_channels*2, out_channels=layer_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(layer_channels),
            nn.ReLU(inplace=True)
        )

        self.gen_layer_5 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=layer_channels, out_channels=3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
        
        self.gan_all_layers = nn.Sequential(
            self.gen_layer_1,
            self.gen_layer_2,
            self.gen_layer_3,
            self.gen_layer_4,
            self.gen_layer_5
        )


    def forward(self, class_labels, input_vec=None):
        if input_vec == None:
            # sample from gaussian dist
            input_vec = torch.randn(class_labels.shape[0],4 * (self.len_input//5), 1, 1).to(device)
    
        embeds = self.embed_network(class_labels)
        condition_input = self.class_embed_layer(embeds).reshape(class_labels.shape[0], -1, 1, 1)

        gan_input = torch.cat([input_vec, condition_input], dim=1)

        gan_out = self.gan_all_layers(gan_input)
        return gan_out


In [None]:
# b c h w
# add new channel 
class DiscriminatorNetwork(nn.Module):
    def __init__(self, embed_network, len_input, layer_channels):
        super(DiscriminatorNetwork, self).__init__()
        self.len_input = len_input

        self.embed_network = embed_network

        self.class_embed_layer = nn.Sequential(
            nn.Linear(in_features=self.embed_network.embed_out, out_features = self.len_input),
            nn.ReLU()
        )

        # input: 4 x h x w, out: lc x h/2 x w/2
        self.des_layer_1 = nn.Sequential(
            nn.Conv2d(in_channels=4, out_channels=layer_channels, kernel_size=4, stride=2, padding=1, bias=True),
            nn.LeakyReLU(0.2)
        )

        # input: lc x h/2 x w/2, out: lc * 2 x h/4 x w/4
        self.des_layer_2 = nn.Sequential(
            nn.Conv2d(in_channels=layer_channels, out_channels=layer_channels * 2, kernel_size=4, stride=2, padding=1, bias=True),
            nn.BatchNorm2d(num_features=layer_channels * 2),
            nn.LeakyReLU(negative_slope=0.2)
        )

        # input: lc * 2 x h/4 x w/4, out: lc * 4 x h/8 x w/8
        self.des_layer_3 = nn.Sequential(
            nn.Conv2d(in_channels=layer_channels * 2, out_channels=layer_channels * 4, kernel_size=4, stride=2, padding=1, bias=True),
            nn.BatchNorm2d(num_features=layer_channels * 4),
            nn.LeakyReLU(negative_slope=0.2)
        )

        # input: lc * 4 x h/8 x w/8, out: lc * 8 x h/16 x w/16
        self.des_layer_4 = nn.Sequential(
            nn.Conv2d(in_channels=layer_channels * 4, out_channels=layer_channels * 8, kernel_size=4, stride=2, padding=1, bias=True),
            nn.BatchNorm2d(num_features=layer_channels * 8),
            nn.LeakyReLU(negative_slope=0.2)
        )

        self.classification_head = nn.Sequential(
            nn.LazyLinear(out_features=1),
            nn.Sigmoid()
        )

        self.des_conv_layers = nn.Sequential(
            self.des_layer_1,
            self.des_layer_2,
            self.des_layer_3,
            self.des_layer_4,
        )
    
    def forward(self, images, class_labels):
        # b, L
        embeds = self.embed_network(class_labels)
        # b, L, 1
        embeds = self.class_embed_layer(embeds).reshape(class_labels.shape[0], self.len_input, 1)
        # b, 1, L, L 
        embeds = embeds.repeat(1,1,self.len_input).reshape(class_labels.shape[0], 1, self.len_input, self.len_input)
        input = torch.cat([images, embeds], dim=1)
        out_conv = self.des_conv_layers(input).reshape(class_labels.shape[0],-1)
        return self.classification_head(out_conv)

In [None]:
Config = {
    'batch_size':128,
    'device': device,
    'input_len': 100,
    'dis_input_len':64,
    'learning_rate': 2e-4,
    'epochs': 120,
    'image_size': (64, 64),
    'last_epoch': 70
}

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(Config["image_size"]),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


In [None]:
def make(config):
    embed_network = LabelEmbeds(config["input_len"]//2).to(config["device"])
    gen_model = GeneratorNetwork(embed_network, len_input=100, layer_channels=96).to(config["device"])
    dis_model = DiscriminatorNetwork(embed_network, len_input=config["dis_input_len"], layer_channels=64).to(config["device"])

    return gen_model, dis_model

In [None]:
gen_model, dis_model = make(Config)

In [None]:
gen_model.load_state_dict(torch.load('/content/drive/MyDrive/DL-trained-models/DL-CSE641/300-Epochs/trained-models/299_gen.pth'))
dis_model.load_state_dict(torch.load('/content/drive/MyDrive/DL-trained-models/DL-CSE641/300-Epochs/trained-models/299_dis.pth'))

In [None]:
def gen_image(gen_model, config, class_nums=None):
    if class_nums == None:
        class_nums = torch.Tensor([i+1 for i in range(102)]).reshape(102,1).to(device)
        noise = torch.randn(102, 4* (config["input_len"]//5), 1, 1).to(device)
        
    gen_images = gen_model(class_nums.float())
    gen_images = (gen_images + 1) * 127.5

    fig, axs = plt.subplots(17, 6, figsize=(20, 40))
    axs = axs.flatten()

    for i in range(102):
        axs[i].imshow(gen_images[i].permute(1,2,0).long().cpu().numpy(), vmin=0, vmax=255)
        axs[i].set_title(i)
        
    plt.tight_layout()
    plt.show()
