In [None]:
import torch as th 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim
from torch.utils.data import  random_split, Dataset, DataLoader
import torchvision
import torchvision.transforms as T
import pytorch_lightning as pl 
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
import torchdyn 
from torchdyn.core import NeuralODE
from torchgan.losses import  MinimaxGeneratorLoss, MinimaxDiscriminatorLoss
# from ignite.metrics import FID

import os 
import cv2 
import numpy as np 
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

In [None]:
class GenerativeData(Dataset):
    def __init__(self, path, inpshape=(14, 18), outshape=(56, 72), datasize:int=500):
        super(GenerativeData, self).__init__()
        self.path = path 
        self.inpshape = inpshape
        self.outshape = outshape
        self.datasize = datasize
        self.inpimg, self.outimg = self.datareader()
        self.device = th.device("mps")

    def preprocess(self, image, imagesize):
        process = T.Compose([T.ToTensor(), 
                             T.Resize(imagesize),
                             T.Normalize(mean=(0.5, 0.5, 0.5), std=(1, 1, 1))])
        return process(np.array(image, dtype=np.float32))

    def datareader(self):
        X = []
        Y = []
        files = os.listdir(self.path)
        for c, file in enumerate(files):
            image = cv2.imread(self.path+file)
            X.append(self.preprocess(image, self.inpshape))
            Y.append(self.preprocess(image, self.outshape))
            if c>self.datasize:
                break

        return X, Y
    
    def __len__(self):
        return len(self.inpimg)
    
    def __getitem__(self, idx):
        inpimg = th.tensor(self.inpimg[idx], dtype=th.float32).to(self.device)
        outimg = th.tensor(self.outimg[idx], dtype=th.float32).to(self.device)
        return {"inputs": inpimg, "outputs": outimg}             

In [None]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, num:int, infilter:int, outfilter:int, kernel:int, moment:float=0.9, alpha:float=0.1):
        super(DiscriminatorBlock, self).__init__()
        self.conv = nn.ModuleList(modules=[nn.Conv2d(infilter, outfilter, kernel)])
        self.norm = nn.ModuleList(modules=[nn.BatchNorm2d(outfilter, momentum=moment)])
        self.pad  = nn.ZeroPad2d(int((kernel-1)//2))
        self.act = nn.LeakyReLU(alpha)
        self.pool = nn.MaxPool2d(2, 2)
        
        for _ in range(num-1):
            self.conv.append(nn.Conv2d(outfilter, outfilter, kernel))
            self.norm.append(nn.BatchNorm2d(outfilter, momentum=moment))

    def forward(self, x):
        for conv, norm in zip(self.conv, self.norm):
            x = self.act(norm(conv(self.pad(x))))
        return self.pool(x)

class Discriminator(nn.Module):
    def __init__(self, num:list=[1, 1 ], filter:int=16, start_kernel:int=7, kernel:int=3, 
                 moment:float=0.9, alpha:float=0.1, dense:int=64, gf:float=2.0, drop:float=0.2):
        super(Discriminator, self).__init__()
        self.convblock = nn.ModuleList(modules=[DiscriminatorBlock(1, 3, filter, start_kernel, moment, alpha)])
        for n in num:
            self.convblock.append(DiscriminatorBlock(n, filter, int(filter*gf), kernel, moment, alpha))
            filter = int(filter*gf)

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flat = nn.Flatten()
        self.drop = nn.Dropout(drop)
        self.dense = nn.Linear(filter, dense)
        self.final = nn.Linear(dense, 1)

    def forward(self, x):
        for convblock in self.convblock:
            x = convblock(x)
        x = self.flat(self.pool(x))
        x = F.relu(self.dense(x))
        return F.softmax(self.final(x), dim=-1)

In [None]:
class GeneratorBlock(nn.Module):
    def __init__(self, num:int, infilter:int, outfilter:int, kernel:int, moment:float=0.9, alpha:float=0.1):
        super(GeneratorBlock, self).__init__()
        pad  = int((kernel-1)//2)
        self.conv = nn.ModuleList(modules=[nn.ConvTranspose2d(infilter, outfilter, kernel, stride=2, padding=pad)])
        self.norm = nn.ModuleList(modules=[nn.BatchNorm2d(outfilter, momentum=moment)])
        self.act = nn.LeakyReLU(alpha)
        
        for _ in range(num-1):
            self.conv.append(nn.ConvTranspose2d(outfilter, outfilter, kernel, padding=pad))
            self.norm.append(nn.BatchNorm2d(outfilter, momentum=moment))

    def forward(self, x):
        for conv, norm in zip(self.conv, self.norm):
            x = self.act(norm(conv(x)))     
        return x

class Generator(nn.Module):
    def __init__(self, num:list=[1, 1], filter:int=16, start_kernel:int=7, kernel:int=3, 
                 moment:float=0.9, alpha:float=0.1, gf:float=2.0):
        super(Generator, self).__init__()
        self.convblock = [GeneratorBlock(1, 3, filter, start_kernel, moment, alpha)] 
        for n in num:
            self.convblock.append(GeneratorBlock(n, filter, int(filter*gf), kernel, moment, alpha))
            filter = int(filter*gf)

        self.convblock.append(GeneratorBlock(1, filter, 3, kernel, moment, alpha))
        self.convblock = nn.ModuleList(modules=self.convblock)

    def forward(self, x):
        for convblock in self.convblock:
            x = convblock(x)
        return x

In [None]:
class GAN(nn.Module):
    def __init__(self):
        super(GAN, self).__init__()
        self.discriminator = Discriminator()
        self.generator = Generator()

    def forward(self,gen, dis):
        gen = self.generator(gen)
        fake = self.discriminator(gen)
        true = self.discriminator(dis)
        return fake, true
    

In [None]:
def datasplit(dataset, trainbatch, validbatch, split):
    split = int(dataset.__len__()*split)
    traindata, validdata = random_split(dataset, [split, dataset.__len__()-split])
    traindata = DataLoader(traindata, batch_size=trainbatch, shuffle=True)
    validdata = DataLoader(validdata, batch_size=validbatch)
    return traindata, validdata

In [None]:
device = th.device("mps")

In [None]:
path = r'/Users/suyashsachdeva/Desktop/GyanBhandar/sea/seacreature_images_transformed/seacreature_images_transformed/'
traindata, validdata = datasplit(GenerativeData(path), 10, 10 , 0.9)
gan = GAN()
gan = gan.to(device)

In [None]:
epochs = 3
learning_rate=1e-4
decay = 1e-2
for epoch in range(epochs):
    print(f"Epoch: {epoch+1}/{epochs}")
    glss = 0
    gvls = 0
    dlss = 0
    dvls = 0
    learning_rate = learning_rate/(1+epoch*decay)
    goptim = optim.Adam(gan.generator.parameters(), lr=learning_rate)
    doptim = optim.Adam(gan.discriminator.parameters(), lr=learning_rate)
    for batch in traindata:
        gen = batch["inputs"]
        dis = batch["outputs"]
        fake, true = gan(gen, dis)
        gloss = MinimaxGeneratorLoss()(fake)
        dloss = MinimaxDiscriminatorLoss()(true, fake)

        goptim.zero_grad()
        doptim.zero_grad()
        gloss.backward(retain_graph=True)
        dloss.backward(retain_graph=True)
        goptim.step()
        doptim.step()
        glss = glss+gloss
        dlss = dlss+dloss
    
    for valid in validdata:
        gen = batch["inputs"]
        dis = batch["outputs"]
        fake, true = G=gan(gen, dis)
        gloss = MinimaxGeneratorLoss()(fake)
        dloss = MinimaxDiscriminatorLoss()(true, fake)

        gvls = gvls+gloss
        dvls = dvls+dloss         
    print(f"\tGenerator Loss: {glss/len(traindata)} || Discriminator Loss: {dvls/len(traindata)} || Generator Loss: {gvls/len(validdata)} || Discriminator Loss: {dvls/len(validdata)}")


In [None]:
traindata["inputs"][0]