In [1]:
import torch as th 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim
from torch.utils.data import  random_split, Dataset, DataLoader
import torchvision
import torchvision.transforms as T
import pytorch_lightning as pl 
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
import torchdyn 
from torchdyn.core import NeuralODE
from torchgan.losses import  MinimaxGeneratorLoss, MinimaxDiscriminatorLoss
from ignite.metrics import FID

import os 
import cv2 
import numpy as np 
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

In [2]:
class GenerativeData(Dataset):
    def __init__(self, path, inpshape=(28, 36), outshape=(224, 288)):
        super(GenerativeData, self).__init__()
        self.path = path 
        self.inpshape = inpshape
        self.outshape = outshape
        self.inpimg, self.outimg = self.datareader()

    def preprocess(self, image, imagesize):
        process = T.Compose([T.ToTensor(), 
                             T.Resize(imagesize),
                             T.Normalize(mean=(0.5, 0.5, 0.5), std=(1, 1, 1))])
        return process(np.array(image, dtype=np.float32))

    def datareader(self):
        X = []
        Y = []
        files = os.listdir(self.path)
        for c, file in enumerate(files):
            image = cv2.imread(self.path+file)
            X.append(self.preprocess(image, self.inpshape))
            Y.append(self.preprocess(image, self.outshape))

        return X, Y
    
    def __len__(self):
        return len(self.inpimg)
    
    def __getitem__(self, idx):
        inpimg = th.tensor(self.inpimg[idx], dtype=th.float32)
        outimg = th.tensor(self.outimg[idx], dtype=th.float32)
        return {"inputs": inpimg, "outputs": outimg}             

In [3]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, num:int, infilter:int, outfilter:int, kernel:int, moment:float=0.9, alpha:float=0.1):
        super(DiscriminatorBlock, self).__init__()
        self.conv = nn.ModuleList(modules=[nn.Conv2d(infilter, outfilter, kernel)])
        self.norm = nn.ModuleList(modules=[nn.BatchNorm2d(outfilter, momentum=moment)])
        self.pad  = nn.ZeroPad2d(int((kernel-1)//2))
        self.act = nn.LeakyReLU(alpha)
        self.pool = nn.MaxPool2d(2, 2)
        
        for _ in range(num-1):
            self.conv.append(nn.Conv2d(outfilter, outfilter, kernel))
            self.norm.append(nn.BatchNorm2d(outfilter, momentum=moment))

    def forward(self, x):
        for conv, norm in zip(self.conv, self.norm):
            x = self.act(norm(conv(self.pad(x))))
        return self.pool(x)

class Discriminator(nn.Module):
    def __init__(self, num:list=[2, 2, 2, 2], filter:int=64, start_kernel:int=7, kernel:int=3, 
                 moment:float=0.9, alpha:float=0.1, dense:int=128, gf:float=2.0, drop:float=0.2):
        super(Discriminator, self).__init__()
        self.convblock = nn.ModuleList(modules=[DiscriminatorBlock(1, 3, filter, start_kernel, moment, alpha)])
        for n in num:
            self.convblock.append(DiscriminatorBlock(n, filter, int(filter*gf), kernel, moment, alpha))
            filter = int(filter*gf)

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flat = nn.Flatten()
        self.drop = nn.Dropout(drop)
        self.dense = nn.Linear(filter, dense)
        self.final = nn.Linear(dense, 1)

    def forward(self, x):
        for convblock in self.convblock:
            x = convblock(x)
        x = self.flat(self.pool(x))
        x = F.relu(self.dense(x))
        return F.softmax(self.final(x), dim=-1)

In [4]:
class GeneratorBlock(nn.Module):
    def __init__(self, num:int, infilter:int, outfilter:int, kernel:int, moment:float=0.9, alpha:float=0.1):
        super(GeneratorBlock, self).__init__()
        pad  = int((kernel-1)//2)
        self.conv = nn.ModuleList(modules=[nn.ConvTranspose2d(infilter, outfilter, kernel, stride=2, padding=pad)])
        self.norm = nn.ModuleList(modules=[nn.BatchNorm2d(outfilter, momentum=moment)])
        self.act = nn.LeakyReLU(alpha)
        
        for _ in range(num-1):
            self.conv.append(nn.ConvTranspose2d(outfilter, outfilter, kernel, padding=pad))
            self.norm.append(nn.BatchNorm2d(outfilter, momentum=moment))

    def forward(self, x):
        for conv, norm in zip(self.conv, self.norm):
            x = self.act(norm(conv(x)))     
        return x

class Generator(nn.Module):
    def __init__(self, num:list=[3, 3, 3], filter:int=64, start_kernel:int=7, kernel:int=3, 
                 moment:float=0.9, alpha:float=0.1, gf:float=2.0):
        super(Generator, self).__init__()
        self.convblock = [GeneratorBlock(1, 3, filter, start_kernel, moment, alpha)] 
        for n in num:
            self.convblock.append(GeneratorBlock(n, filter, int(filter*gf), kernel, moment, alpha))
            filter = int(filter*gf)

        self.convblock.append(GeneratorBlock(1, filter, 3, kernel, moment, alpha))
        self.convblock = nn.ModuleList(modules=self.convblock)

    def forward(self, x):
        for convblock in self.convblock:
            x = convblock(x)
        return x

In [8]:
class GAN(LightningModule):
    def __init__(self, dataset:Dataset, batchsize:int=25, validbatchsize:int=10, split:float=0.3, lr:float=1e-4):
        super(GAN, self).__init__()
        self.discriminator = Discriminator()
        self.generator = Generator()
        self.batchsize = batchsize
        self.validbatchsize = validbatchsize
        split = int(dataset.__len__()*split)
        self.traindata, self.validdata = random_split(dataset, [split, dataset.__len__()-split])
        self.lr = lr

    def forward(self,gen, dis):
        true = self.discriminator(dis)
        gen = self.generator(gen)
        fake = self.discriminator(gen)
        return fake, true

    def training_step(self, batch, batch_idx, optimizer_idx):
        gen = batch["inputs"]
        dis = batch["outputs"]
        fake, true = self(gen, dis)
        gloss = MinimaxGeneratorLoss(fake)
        dloss = MinimaxDiscriminatorLoss(true, fake)
        # fid = FID()(dis, fake)
        cur_lr = self.trainer.optimizers[0].param_groups[0]['lr']
        self.log("lr", cur_lr, prog_bar=True, on_step=True)
        self.log_dict({"generator_loss": gloss, "discriminator_loss": dloss}, on_step=True, on_epoch=True, prog_bar=True, logger=False)
        if optimizer_idx==0:
            return gloss
        else:
            return dloss
    
    def validation_step(self, batch, batch_idx):
        gen = batch["inputs"]
        dis = batch["outputs"]
        fake, true = self(gen, dis)
        gloss = MinimaxGeneratorLoss(fake)
        dloss = MinimaxDiscriminatorLoss(true, fake)
        # fid = FID()(dis, fake)
        cur_lr = self.trainer.optimizers[0].param_groups[0]['lr']        
        self.log("lr", cur_lr, prog_bar=True, on_step=True)
        self.log_dict({"generator_loss": gloss, "discriminator_loss": dloss}, on_step=True, on_epoch=True, prog_bar=True, logger=False)
    

    def configure_optimizers(self):
        goptim = optim.Adam(self.generator.parameters(), lr=self.lr)
        doptim = optim.Adam(self.discriminator.parameters(), lr=self.lr)
        gsch = optim.lr_scheduler.StepLR(goptim, step_size  = 10 , gamma = 0.1)
        dsch = optim.lr_scheduler.StepLR(doptim, step_size  = 10 , gamma = 0.1)
        return  [{"optimizer": goptim, "lr_schedular": {"schedular":gsch, "monitor": "generator_loss"}} , 
                 {"optimizer": doptim, "lr_schedular": {"schedular":dsch, "monitor": "discriminator_loss"}}]
    
    def train_dataloader(self):
        return DataLoader(self.traindata, batch_size=self.batchsize, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.validdata, batch_size=self.validbatchsize, shuffle=True)

In [9]:
path = r'C:\Users\suyash\Downloads\SeaGAN\seacreature_images_transformed\seacreature_images_transformed/'
dataset = GenerativeData(path)

progress_bar = RichProgressBar(theme=RichProgressBarTheme(description="blue",progress_bar="green_yellow",progress_bar_finished="green1",
        progress_bar_pulse="#6206E0",batch_progress="blue",  time="black",processing_speed="black",metrics="black", ),)


In [10]:
learn = GAN(dataset)
trainer = pl.Trainer(min_epochs=30, max_epochs=30, callbacks=progress_bar)
trainer.fit(learn)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: c:\Users\suyash\Desktop\KACHRA\laohub\Smile_in_Pain\Ajgar_Ke_Jalve\Artificial_Intelligence\Neural_Networks\Unsupervised_Learning\Generative_Nets\IMAGE\lightning_logs


Output()