<a href="https://colab.research.google.com/github/thotakuria/surgical-tool-segmentation-using-deep-learning-techniques/blob/main/gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pytorch-lightning

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split

import matplotlib.pyplot as plt

import pytorch_lightning as pl


random_seed = 6
torch.manual_seed(random_seed)

BATCH_SIZE=32
AVAIL_GPUS = min(1, torch.cuda.device_count())
NUM_WORKERS=int(os.cpu_count() / 2)

In [6]:
import torch
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets

class MyNormalise:

  def __call__(self, x):
    return x*2.0 - 1
data_transforms = transforms.Compose([transforms.Resize([28,28]) ,transforms.ToTensor(), MyNormalise()])
image_datasets = torchvision.datasets.ImageFolder(root= "drive/MyDrive/dataset/",transform=data_transforms)
dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=32, shuffle=True, num_workers=2)

In [7]:

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
    
        self.conv1 = nn.Conv2d(3, 8, kernel_size=3)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=2)
        self.conv3 = nn.Conv2d(16, 64, kernel_size=2)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(1600, 128)
        self.fc2 = nn.Linear(128, 1)
  
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = F.relu(F.max_pool2d(self.conv3(x), 1))

        x = x.view(-1, 1600)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return torch.sigmoid(x)

In [8]:

class Generator(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.lin1 = nn.Linear(latent_dim, 7*7*64) 
        self.ct1 = nn.ConvTranspose2d(64, 32, 4, stride=2)
        self.ct2 = nn.ConvTranspose2d(32, 16, 4, stride=2) 
        self.conv = nn.Conv2d(16, 3, kernel_size=7) 
        self.bn = nn.BatchNorm2d(16)
    

    def forward(self, x):
    
        x = self.lin1(x)
        x = F.relu(x)
        x = x.view(-1, 64, 7, 7)

        x = self.ct1(x)
        x = F.relu(x)

        x = self.ct2(x)
        x = F.relu(x)
        x = self.bn(x)

        x = self.conv(x)
        return torch.tanh(x)

In [10]:
import matplotlib.pyplot as plt

class GAN(pl.LightningModule):
    def __init__(self, latent_dim=10, lr=0.002):
        super().__init__()
        self.save_hyperparameters()
        self.generator = Generator(latent_dim=self.hparams.latent_dim)
        self.discriminator = Discriminator()

       #random noise
        m = torch.zeros(1, self.hparams.latent_dim)
        std = torch.ones(1, self.hparams.latent_dim)
        self.validation_z = torch.normal(m, std)
    def forward(self, z):
        return self.generator(z)
    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)
    def training_step(self, batch, batch_idx, optimizer_idx):
        real_imgs, _ = batch
        m = torch.zeros(real_imgs.shape[0], self.hparams.latent_dim)
        std = torch.ones(real_imgs.shape[0], self.hparams.latent_dim)
        z = torch.normal(m, std)
        z = z.type_as(real_imgs)
        if optimizer_idx != 2:
            fake_imgs = self(z)
            y_hat = self.discriminator(fake_imgs)

            y = torch.ones(real_imgs.size(0), 1)
            y = y - 0.2
            y = y.type_as(real_imgs)

            g_loss = self.adversarial_loss(y_hat, y)

            log_dict = {"g_loss": g_loss}
            return{"loss": g_loss, "progress_bar": log_dict, "log": log_dict}
        else:
            y_hat_real = self.discriminator(real_imgs)
            y_real = torch.ones(real_imgs.size(0), 1)
            y_real = y_real.type_as(real_imgs)
    
            real_loss = self.adversarial_loss(y_hat_real, y_real)
            y_hat_fake = self.discriminator(self(z).detach())

            y_fake = torch.zeros(real_imgs.size(0), 1)
            y_fake = y_fake.type_as(real_imgs)

            fake_loss = self.adversarial_loss(y_hat_fake, y_fake)

            d_loss = (real_loss + fake_loss) / 2
            log_dict = {"d_loss": d_loss}
            return {"loss": d_loss, "progress_bar": log_dict, "log": log_dict}

    def configure_optimizers(self):
        lr = self.hparams.lr
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=0.00002)
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002)
        return [opt_g, opt_g, opt_d], []

    def plot_imgs(self):
        z = self.validation_z.type_as(self.generator.lin1.weight)
        sample_imgs = self(z).cpu()

        print('epoch', self.current_epoch)
        fig = plt.figure()
        for i in range(sample_imgs.size(0)):
            plt.subplot(3, 4, i+1)
            plt.tight_layout()
            sample_img = sample_imgs.detach()[0, :, :, :]
            sample_img = np.transpose(sample_img, (1, 2, 0))
            sample_img = sample_img/2.0 + 0.5
            plt.imshow(sample_img)
            plt.title("Generated Data")
            plt.xticks([])
            plt.yticks([])
            plt.axis('off')
        plt.show()

    def on_train_epoch_end(self):
        return


In [None]:
 #random noise
    m = torch.zeros(1, self.hparams.latent_dim)
    std = torch.ones(1, self.hparams.latent_dim)
    z = torch.normal(m, std).type_as(self.generator.lin1.weight)
    sample_imgs = self(z).cpu()
    print('epoch', self.current_epoch)
    sample_img = sample_imgs.detach()[0, :, :, :]
    sample_img = np.transpose(sample_img, (1, 2, 0))
    sample_img = sample_img/2.0 + 0.5
    plt.imshow(sample_img)
    plt.xticks([])
    plt.yticks([])
    plt.axis('off')

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
def imshow(img):
    img = img / 2 + 0.5
    plt.imshow(np.transpose(img, (1, 2, 0)))
dataiter = iter(dataloaders)
images, labels = dataiter.next()
images = images.numpy() 
fig = plt.figure(figsize=(25, 4))
for idx in np.arange(20):
    ax = fig.add_subplot(2, 20/2, idx+1, xticks=[], yticks=[])
    imshow(images[idx])

In [None]:
model.plot_imgs()

In [None]:
dm = dataloaders
model = GAN()
trainer = pl.Trainer(max_epochs=7000, log_every_n_steps=2, gradient_clip_val = 5.0, )
trainer.fit(model, dm)

In [None]:
trainer = pl.Trainer(max_epochs=10, log_every_n_steps=2, gradient_clip_val = 5.0, )
trainer.fit(model, dm)