In [72]:
import torch as th 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import pytorch_lightning as pl 
import lightning as L
import torchvision.transforms as T
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
from torchgan.losses import  MinimaxGeneratorLoss, MinimaxDiscriminatorLoss
from ignite.metrics import FID

import os 
import cv2 
import numpy as np 
import matplotlib.pyplot as plt

import warnings 
warnings.filterwarnings("ignore")

In [73]:
class GenerativeData(Dataset):
    def __init__(self, path, inpshape=(28, 36), outshape=(224, 288), datasize:int=1000):
        super(GenerativeData, self).__init__()
        self.path = path 
        self.inpshape = inpshape
        self.outshape = outshape
        self.datasize = datasize
        self.inpimg, self.outimg = self.datareader()

    def preprocess(self, image, imagesize):
        process = T.Compose([T.ToTensor(), 
                             T.Resize(imagesize),
                             T.Normalize(mean=(0.5, 0.5, 0.5), std=(1, 1, 1))])
        return process(np.array(image, dtype=np.float32))

    def datareader(self):
        X = []
        Y = []
        files = os.listdir(self.path)
        for c, file in enumerate(files):
            image = cv2.imread(self.path+file)
            X.append(self.preprocess(image, self.inpshape))
            Y.append(self.preprocess(image, self.outshape))
            if c>self.datasize:
                break

        return X, Y
    
    def __len__(self):
        return len(self.inpimg)
    
    def __getitem__(self, idx):
        inpimg = th.tensor(self.inpimg[idx], dtype=th.float32)
        outimg = th.tensor(self.outimg[idx], dtype=th.float32)
        return {"inputs": inpimg, "outputs": outimg}             

In [74]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, num:int, infilter:int, outfilter:int, kernel:int, moment:float=0.9, alpha:float=0.1):
        super(DiscriminatorBlock, self).__init__()
        self.conv = nn.ModuleList(modules=[nn.Conv2d(infilter, outfilter, kernel)])
        self.norm = nn.ModuleList(modules=[nn.BatchNorm2d(outfilter, momentum=moment)])
        self.pad  = nn.ZeroPad2d(int((kernel-1)//2))
        self.act = nn.LeakyReLU(alpha)
        self.pool = nn.MaxPool2d(2, 2)
        
        for _ in range(num-1):
            self.conv.append(nn.Conv2d(outfilter, outfilter, kernel))
            self.norm.append(nn.BatchNorm2d(outfilter, momentum=moment))

    def forward(self, x):
        for conv, norm in zip(self.conv, self.norm):
            x = self.act(norm(conv(self.pad(x))))
        return self.pool(x)

class Discriminator(nn.Module):
    def __init__(self, num:list=[1, 1, 1, 1], filter:int=32, start_kernel:int=5, kernel:int=3, 
                 moment:float=0.9, alpha:float=0.1, dense:int=64, gf:float=2.0, drop:float=0.2):
        super(Discriminator, self).__init__()
        self.convblock = nn.ModuleList(modules=[DiscriminatorBlock(1, 3, filter, start_kernel, moment, alpha)])
        for n in num:
            self.convblock.append(DiscriminatorBlock(n, filter, int(filter*gf), kernel, moment, alpha))
            filter = int(filter*gf)

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flat = nn.Flatten()
        self.drop = nn.Dropout(drop)
        self.dense = nn.Linear(filter, dense)
        self.final = nn.Linear(dense, 1)

    def forward(self, x):
        for convblock in self.convblock:
            x = convblock(x)
        x = self.flat(self.pool(x))
        x = F.relu(self.dense(x))
        return F.softmax(self.final(x), dim=-1)

In [75]:
class GeneratorBlock(nn.Module):
    def __init__(self, num:int, infilter:int, outfilter:int, kernel:int, moment:float=0.9, alpha:float=0.1):
        super(GeneratorBlock, self).__init__()
        pad  = int((kernel-1)//2)
        self.conv = nn.ModuleList(modules=[nn.ConvTranspose2d(infilter, outfilter, kernel, stride=2, padding=pad)])
        self.norm = nn.ModuleList(modules=[nn.BatchNorm2d(outfilter, momentum=moment)])
        self.act = nn.LeakyReLU(alpha)
        
        for _ in range(num-1):
            self.conv.append(nn.ConvTranspose2d(outfilter, outfilter, kernel, padding=pad))
            self.norm.append(nn.BatchNorm2d(outfilter, momentum=moment))

    def forward(self, x):
        for conv, norm in zip(self.conv, self.norm):
            x = self.act(norm(conv(x)))     
        return x

class Generator(nn.Module):
    def __init__(self, num:list=[2, 2, 2], filter:int=32, start_kernel:int=5, kernel:int=3, 
                 moment:float=0.9, alpha:float=0.1, gf:float=2.0):
        super(Generator, self).__init__()
        self.convblock = [GeneratorBlock(1, 3, filter, start_kernel, moment, alpha)] 
        for n in num:
            self.convblock.append(GeneratorBlock(n, filter, int(filter*gf), kernel, moment, alpha))
            filter = int(filter*gf)

        self.convblock.append(GeneratorBlock(1, filter, 3, kernel, moment, alpha))
        self.convblock = nn.ModuleList(modules=self.convblock)

    def forward(self, x):
        for convblock in self.convblock:
            x = convblock(x)
        return x

In [76]:
class GAN(L.LightningModule):
    def __init__(self, dataset:Dataset, generator:nn.Module, discriminator:nn.Module, batch:int=10, valid:int=10, split:float=0.7, lr:float=1e-4):
        super().__init__() 
        self.automatic_optimization = False
        self.discriminator = discriminator
        self.generator = generator
        self.batch = batch
        self.valid = valid 
        self.lr = lr
        split = int(dataset.__len__()*split)
        self.traindata, self.validdata = random_split(dataset, [split, dataset.__len__()-split])

    def forward(self, x):
        return self.generator(x)
    
    def loss(self, ypred, ytrue):
        return F.binary_cross_entropy(ypred, ytrue)
    
    def training_step(self, batch):
        gen = batch["inputs"]
        dis = batch["outputs"]

        goptim, doptim = self.optimizers()
        self.toggle_optimizer(goptim)
        self.genimage = self(gen)

        valid = th.ones(gen.size(0), 1)
        valid = valid.type_as(gen)

        gloss = self.loss(self.discriminator(self(gen)), valid)
        self.log("g_loss", gloss, on_step=True, on_epoch=True, prog_bar=True, logger=False)
        self.manual_backward(gloss)
        goptim.step()
        goptim.zero_grad()
        self.untoggle_optimizer(goptim)

        valid = th.ones(gen.size(0), 1)
        valid = valid.type_as(gen)
        real_loss = self.loss(self.discriminator(dis), valid)

        fake = th.zeros(gen.size(0), 1)
        self.toggle_optimizer(doptim)
        fake_loss = self.loss(self.discriminator(self(gen).detach()), fake)

        dloss = (fake_loss+real_loss)/2
        self.log("d_loss", dloss, on_step=True, on_epoch=True, prog_bar=True, logger=False)
        self.manual_backward(dloss)
        doptim.step()
        doptim.zero_grad()
        self.untoggle_optimizer(doptim)

    def configure_optimizers(self):
        goptim = optim.Adam(self.generator.parameters(), lr=self.lr)
        doptim = optim.Adam(self.discriminator.parameters(), lr=self.lr)
        return [doptim, goptim], []

    def train_dataloader(self):
        return DataLoader(self.traindata, batch_size=self.batch, shuffle=True)
    
    # def val_dataloader(self):
    #     return DataLoader(self.validdata, batch_size=self.valid, shuffle=True)

In [78]:
path = r'C:\Users\suyash\Downloads\SeaGAN\seacreature_images_transformed\seacreature_images_transformed/'
dataset = GenerativeData(path, datasize=100)
dis = Discriminator()
gen = Generator()

progress_bar = RichProgressBar(theme=RichProgressBarTheme(description="blue",progress_bar="green_yellow",progress_bar_finished="green1",
        progress_bar_pulse="#6206E0",batch_progress="blue",  time="black",processing_speed="black",metrics="black", ),)
learn = GAN(dataset, gen, dis, batch=2, valid=2)
trainer = L.Trainer(min_epochs=30, max_epochs=30)
trainer.fit(learn)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name          | Type          | Params
------------------------------------------------
0 | discriminator | Discriminator | 1.6 M 
1 | generator     | Generator     | 1.2 M 
------------------------------------------------
2.8 M     Trainable params
0         Non-trainable params
2.8 M     Total params
11.113    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]