In [9]:
import torch as th 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim
from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as T
import pytorch_lightning as pl 
from pytorch_lightning import Trainer, LightningDataModule, LightningModule
import torchdyn 
from torchdyn.core import NeuralODE

import os 
import cv2 
import numpy as np 
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

In [10]:
class GenerativeData(Dataset):
    def __init__(self, path, inpshape=(28, 36), outshape=(224, 288)):
        super(GenerativeData, self).__init__()
        self.path = path 
        self.inpshape = inpshape
        self.outshape = outshape
        self.inpimg, self.outimg = self.datareader()

    def preprocess(self, image, imagesize):
        process = T.Compose([T.ToTensor(), 
                             T.Resize(imagesize),
                             T.Normalize(mean=(0.5, 0.5, 0.5), std=(1, 1, 1))])
        return process(image)

    def datareader(self):
        X = []
        Y = []
        files = os.listdir(self.path)
        for c, file in enumerate(files):
            image = cv2.imread(self.path+file)
            X.append(self.preprocess(image, self.inpshape))
            Y.append(self.preprocess(image, self.outshape))

        return X, Y
    
    def __len__(self):
        return len(self.inpimg)
    
    def __getitem__(self, idx):
        inpimg = self.inpimg[idx]
        outimg = self.outimg[idx]
        return {"inputs": inpimg, "outputs": outimg}             

In [11]:
path = r'C:\Users\suyash\Downloads\SeaGAN\seacreature_images_transformed\seacreature_images_transformed/'
dataset = GenerativeData(path)

In [None]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, num:int, infilter:int, outfilter:int, kernel:int, moment:float=0.9, alpha:float=0.1):
        super(DiscriminatorBlock, self).__init__()
        self.conv = nn.ModuleList([nn.Conv2d(infilter, outfilter, kernel)])
        self.norm = nn.ModuleList([nn.BatchNorm2d(outfilter, momentum=moment)])
        self.pad  = nn.ZeroPad2d(int((kernel-1)//2))
        self.act = nn.LeakyReLU(alpha)
        self.pooling = nn.MaxPool2d(2, 2)
        
        for _ in range(num-1):
            self.conv.append(nn.Conv2d(outfilter, outfilter, kernel))
            self.norm.append(nn.BatchNorm2d(outfilter, momentum=moment))

    def forward(self, x):
        for conv, norm in zip(self.conv, self.norm):
            x = self.act(norm(conv(self.pad(x))))
        return self.pooling(x)

class Discriminator(nn.Module):
    def __init__(self, num:list=[2, 2, 2, 2], filter:int=64, start_kernel:int=7, kernel:int=3, 
                 moment:float=0.9, alpha:float=0.1, dense:int=256, gf:float=2.0, drop:float=0.2):
        super(Discriminator, self).__init__()
        self.convblock = nn.ModuleList([DiscriminatorBlock(1, filter, filter*gf, start_kernel, moment, alpha)])
        for n in num:
            filter = filter*gf
            self.convblock.append(DiscriminatorBlock(n, filter, filter*gf, kernel, moment, alpha))

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flat = nn.Flatten()
        self.drop = nn.Dropout(drop)
        self.dense = nn.Linear(filter*gf, dense)
        self.final = nn.Linear(dense, 1)

    def forward(self, x):
        for convblock in self.convblock:
            x = convblock(x)
        x = self.flat(self.pool(x))
        x = F.relu(self.dense(x))
        return F.softmax(self.final(x), dim=-1)


In [None]:
class GeneratorBlock(nn.Module):
    def __init__(self, num:int, infilter:int, outfilter:int, kernel:int, moment:float=0.9, alpha:float=0.1, end=False):
        super(GeneratorBlock, self).__init__()
        self.conv = nn.ModuleList([nn.Conv2d(infilter, outfilter, kernel)])
        self.norm = nn.ModuleList([nn.BatchNorm2d(outfilter, momentum=moment)])
        self.pad  = nn.ZeroPad2d(int((kernel-1)//2))
        self.act = nn.LeakyReLU(alpha)
        self.pooling = nn.MaxUnpool2d(2, 2)
        self.end = end
        
        for _ in range(num-1):
            self.conv.append(nn.Conv2d(outfilter, outfilter, kernel))
            self.norm.append(nn.BatchNorm2d(outfilter, momentum=moment))

    def forward(self, x):
        for conv, norm in zip(self.conv, self.norm):
            x = self.act(norm(conv(self.pad(x))))
        if not self.end:
            return self.pooling(x)
        else:
            return x

class Generator(nn.Module):
    def __init__(self, num:list=[3, 3, 3], filter:int=64, start_kernel:int=7, kernel:int=3, 
                 moment:float=0.9, alpha:float=0.1, dense:int=256, gf:float=2.0, drop:float=0.2):
        super(Generator, self).__init__()
        self.convblock = nn.ModuleList([])
        for n in num:
            filter = filter*gf
            self.convblock.append(GeneratorBlock(n, filter, filter*gf, kernel, moment, alpha))

        self.genfinal = GeneratorBlock(1, filter*gf, 3, 3, moment, alpha, end=True)

    def forward(self, x):
        for convblock in self.convblock:
            x = convblock(x)
        return self.genfinal(x)

In [None]:
class GAN(nn.LightningModule):
    def __init__(self, datatset, batchsize):
        super(GAN, self).__init__()
        self.generator = Generator()
        self.discriminator = Discriminator()
        self.batchsize = batchsize
        self.dataset = dataset