# GOAL - WRITE MODULAR CODE TO TRAIN MODELS
1. Different transforms for clean image, hazy image, test images
2. Training using a function with batch size etc as arguments

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from PIL import Image
import os
from torch.optim import lr_scheduler
from database2 import DehazingDataset
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from torchvision import models

In [None]:

class Block(nn.Module):
    def __init__(self,in_channels, out_channels, down = True, act = 'relu', use_dropout = False):
        super(Block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=4,stride=2,padding=1,bias=False,padding_mode='reflect')
            if down
            else
            nn.ConvTranspose2d(in_channels,out_channels,kernel_size=4,stride=2,padding=1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act=='relu' else nn.LeakyReLU(0.2),
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down
    def forward(self,x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x

class Generator(nn.Module):
    def __init__(self,in_channels=3,features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels,features,4,2,1,padding_mode='reflect'),
            nn.LeakyReLU(0.2),
        )
        self.down1 = Block(features,features*2,down=True,act='leaky',use_dropout=False)
        self.down2 = Block(features*2,features*4,down=True,act='leaky',use_dropout=False)
        self.down3 = Block(features*4,features*8,down=True,act='leaky',use_dropout=False)
        self.down4 = Block(features*8,features*8,down=True,act='leaky',use_dropout=False)
        self.down5 = Block(features*8,features*8,down=True,act='leaky',use_dropout=False)
        self.down6 = Block(features*8,features*8,down=True,act='leaky',use_dropout=False)
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features*8,features*8,4,2,1,padding_mode='reflect'),
            nn.ReLU()
        )

        self.up1 = Block(features*8,features*8,down=False,act='relu',use_dropout=True)
        self.up2 = Block(features*8*2,features*8,down=False,act='relu',use_dropout=True)
        self.up3 = Block(features*8*2,features*8,down=False,act='relu',use_dropout=True)
        self.up4 = Block(features*8*2,features*8,down=False,act='relu',use_dropout=False)
        self.up5 = Block(features*8*2,features*4,down=False,act='relu',use_dropout=False)
        self.up6 = Block(features*4*2,features*2,down=False,act='relu',use_dropout=False)
        self.up7 = Block(features*2*2,features,down=False,act='relu',use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features*2,in_channels,kernel_size=4,stride=2,padding=1),
            nn.Tanh(),
        )
    
    def forward(self,x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1,d7],dim=1))
        up3 = self.up3(torch.cat([up2,d6],dim=1))
        up4 = self.up4(torch.cat([up3,d5],dim=1))
        up5 = self.up5(torch.cat([up4,d4],dim=1))
        up6 = self.up6(torch.cat([up5,d3],dim=1))
        up7 = self.up7(torch.cat([up6,d2],dim=1))
        return self.final_up(torch.cat([up7,d1],dim=1))

def test():
    x = torch.randn((1, 3, 256, 256))
    model = Generator(in_channels=3, features=64)
    preds = model(x)
    print(preds.shape)




In [None]:
 class CNNBlock(nn.Module):
    def __init__(self,in_channels, out_channels, stride = 2):
        super(CNNBlock,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,4,stride,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    
    def forward(self,x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self,in_channels = 3, features = [64,128,256,512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels*2,features[0],kernel_size=4,stride=2,padding=1,padding_mode='reflect'),
            nn.LeakyReLU(0.2)
        ) # according to paper 64 channel doesn't contain BatchNorm2d
        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(CNNBlock(in_channels,feature,stride=1 if feature==features[-1] else 2 ))
            in_channels = feature
        
        layers.append(
            nn.Conv2d(in_channels,1,kernel_size=4,stride=1,padding=1,padding_mode='reflect')
        )
        self.model = nn.Sequential(*layers)
    
    def forward(self,x,y):
        x = torch.cat([x,y],dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x

def test():
    x = torch.randn((1, 3, 256, 256))
    y = torch.randn((1, 3, 256, 256))
    model = Discriminator(in_channels=3)
    preds = model(x, y)
    print(model)
    print(preds.shape)

# FUNCTIONS TO CREATE DATALOADERS

In [None]:
# PUT IN A SEPARATE FILE IF POSSIBLE
transform_hazy = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter()], p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    
])

transform_clean = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

transform_test = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter()], p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


# FUNCTIONS TO CREATE DATALOADERS
def create_dataloader(directory, batch_size=32, mean=0.5, std=0.5, transform_hazy = None, transform_clean = None):
    dataset = DehazingDataset(directory, transform_hazy = transform_hazy, transform_clean = transform_clean)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    return dataloader

def create_train_val_dataloaders(root_dir, train_batch_size=32, val_batch_size=32, mean=0.5, std=0.5):
    train_dir = os.path.join(root_dir, 'train')
    val_dir = os.path.join(root_dir, 'val')

    train_dataloader = create_dataloader(train_dir, batch_size=train_batch_size, mean=mean, std=std, transform_hazy = transform_clean, transform_clean = transform_clean)
    val_dataloader = create_dataloader(val_dir, batch_size=val_batch_size, mean=mean, std=std, transform_hazy = transform_clean, transform_clean = transform_clean)

    return train_dataloader, val_dataloader

In [None]:
root_dir = '../Task2Dataset'
train_dataloader, val_dataloader = create_train_val_dataloaders(root_dir)

In [None]:

class Trainer:
    def __init__(self, generator, discriminator, train_dataloader, lr_step_size, lr_gamma, lambda_adv=1, lambda_res=150,
                 lambda_per=150, lambda_reg = 0.00001, num_epochs=10, wgan=False, n_critic=1, use_l1_loss=True, use_adversarial_loss=True, use_perceptual_loss = True):
        self.generator = generator
        self.discriminator = discriminator
        self.train_dataloader = train_dataloader
        self.optimizer_G = optim.RMSprop(generator.parameters(), lr=0.00005) if wgan else optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_D = optim.RMSprop(discriminator.parameters(), lr=0.00005) if wgan else optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.scheduler_G = lr_scheduler.StepLR(self.optimizer_G, lr_step_size, lr_gamma)
        self.scheduler_D = lr_scheduler.StepLR(self.optimizer_D, lr_step_size, lr_gamma)
        self.criterion_G = nn.BCEWithLogitsLoss()
        self.criterion_D = nn.BCEWithLogitsLoss()
        self.num_epochs = num_epochs
        self.n_critic = n_critic if wgan else 1
        self.use_l1_loss = use_l1_loss
        self.use_adversarial_loss = use_adversarial_loss
        self.use_perceptual_loss = use_perceptual_loss
        self.is_wgan = wgan
        self.perceptual_loss_net = models.vgg19(pretrained=True).features[:18].eval()
        self.lambda_adv = lambda_adv
        self.lambda_res = lambda_res
        self.lambda_per = lambda_per
        self.lambda_reg = lambda_reg
        self.l1loss = nn.L1Loss()
        self.l2loss = nn.MSELoss()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def print_summary(self):
        print("Optimizer Summary:")
        print(f"Generator Optimizer: {self.optimizer_G}")
        print(f"Discriminator Optimizer: {self.optimizer_D}")
        print("Scheduler Summary:")
        print(f"Generator Scheduler: {self.scheduler_G}")
        print(f"Discriminator Scheduler: {self.scheduler_D}")
        print("Criterion Summary:")
        print(f"Generator Criterion: {self.criterion_G}")
        print(f"Discriminator Criterion: {self.criterion_D}")
        print("Other Parameters:")
        print(f"Number of Epochs: {self.num_epochs}")
        print(f"Number of Critic Updates per Generator Update: {self.n_critic}")
        print(f"Use L1 Loss: {self.use_l1_loss}")
        print(f"Use Adversarial Loss: {self.use_adversarial_loss}")
        print(f"Use Perceptual Loss: {self.use_perceptual_loss}")
        print(f"Is WGAN: {self.is_wgan}")
        print(f"Device: {self.device}")

    def train(self):
        self.print_summary()
        epochs = []
        g_losses = []
        d_losses = []
        for epoch in range(self.num_epochs):
            self.scheduler_G.step()
            self.scheduler_D.step()

            g_total_loss = 0
            d_total_loss = 0
            batch_no = 1

            for hazy_imgs, clean_imgs in tqdm(self.train_dataloader, desc=f'Epoch {epoch + 1}/{self.num_epochs}'):
                # NON-WGAN TRAINING
                if not self.is_wgan:
                    g_complete_loss, d_loss, fake_imgs = self.train_step_nonwgan(hazy_imgs, clean_imgs)
                    
                # WGAN TRAINING
                elif self.is_wgan:
                    g_complete_loss, d_loss, fake_imgs = self.train_step_wgan(hazy_imgs, clean_imgs)

                g_total_loss += g_complete_loss.item()
                d_total_loss += d_loss.item()

                epochs.append(epoch + batch_no/len(self.train_dataloader))
                if self.is_wgan:
                    g_losses.append(-g_complete_loss.item())  # Negative because the loss is actually maximized in WGAN.
                    d_losses.append(-d_loss.item())
                else:
                    g_losses.append(g_complete_loss.item())
                    d_losses.append(d_loss.item())


                if batch_no % 10 == 0:
                    self.show_images(hazy_imgs, clean_imgs, fake_imgs, num_images=5)


                if batch_no % 20 == 0:
                    self.plot_losses(epochs, g_losses, d_losses)

                batch_no += 1

            g_avg_loss = g_total_loss / len(self.train_dataloader)
            d_avg_loss = d_total_loss / len(self.train_dataloader)
            print(f"Epoch [{epoch + 1}/{self.num_epochs}], Generator Avg. Loss: {g_avg_loss:.4f}, Discriminator Avg. Loss: {d_avg_loss:.4f}")

            if (epoch + 1) % 3 == 0:
                self.save_samples(epoch)

        self.save_models()

    def plot_losses(self, epochs, g_losses, d_losses):
        plt.plot(epochs, g_losses, label='Generator Loss')
        plt.plot(epochs, d_losses, label='Discriminator Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Epoch vs. Losses')
        plt.legend()
        plt.grid(True)
        plt.show()

    def train_step_wgan(self, hazy_imgs, clean_imgs):
        real_imgs = clean_imgs
        fake_imgs = self.generator(hazy_imgs)
                    
        for discr_train in range(self.n_critic):
            real_outputs = self.discriminator(hazy_imgs, real_imgs)
            fake_outputs = self.discriminator(hazy_imgs, fake_imgs.detach())

                        
            # UPDATE THE DISCRIMINATOR [CRITIC]
            self.optimizer_D.zero_grad()
                
            # WGAN utility, we ascend on this hence the loss will be the negative.
            d_loss = -torch.mean(real_outputs - fake_outputs)
                
            d_loss.backward()
            self.optimizer_D.step()
                
            # CLIPPING OF THE DISCRIMINATOR WEIGHTS
            for param in self.discriminator.parameters():
                param.data.clamp_(-0.01, 0.01)
            
        # UPDATE THE GENERATOR
        self.optimizer_G.zero_grad()
            
        # REGENERATE IMAGES AND GET OUTPUTS FROM DISCRIMINATOR
        fake_imgs = self.generator(hazy_imgs)
        fake_outputs = self.discriminator(hazy_imgs, fake_imgs)
            
        #  W-LOSS FOR GENERATOR
        g_loss = -torch.mean(fake_outputs)
        g_loss.backward()
                    
        self.optimizer_G.step()

        return g_loss, d_loss, fake_imgs
    

    def train_step_nonwgan(self, hazy_imgs, clean_imgs):
        # d_loss, g_loss = self.train_step(hazy_imgs.to(self.device), clean_imgs.to(self.device))
        self.optimizer_D.zero_grad()
        real_imgs = clean_imgs
    
        # GENERATOR TAKES HAZY IMAGES AS INPUT
        fake_imgs = self.generator(hazy_imgs)
            
        # PREDICTIONS OF DISCRIMINATOR FOR REAL IMAGES
        real_outputs = self.discriminator(hazy_imgs, real_imgs)
            
        # PREDICTIONS OF DISCRIMINATOR FOR FAKE IMAGES
        fake_outputs = self.discriminator(hazy_imgs, fake_imgs.detach())
            
        # CREATE LABELS FOR LOSS CALCULATION
        real_labels = torch.ones_like(real_outputs)
        fake_labels = torch.zeros_like(fake_outputs)
            
        d_loss_real = self.criterion_D(real_outputs, real_labels)
        d_loss_fake = self.criterion_D(fake_outputs, fake_labels)
        d_loss = (d_loss_real + d_loss_fake)/2
            
        # Update discriminator
        d_loss.backward()
        self.optimizer_D.step()
            
        # Training the generator
        self.optimizer_G.zero_grad()
        fake_imgs = self.generator(hazy_imgs)
        fake_outputs = self.discriminator(hazy_imgs, fake_imgs)
        g_loss = self.criterion_G(fake_outputs, real_labels)
            
        # Compute reconstruction loss
        g_res_loss = 0
        if self.use_l1_loss:
            g_res_loss = self.l1loss(fake_imgs, clean_imgs)


        g_reg_loss = self.lambda_reg * (
            torch.sum(torch.abs(fake_imgs[:, :, :-1, :] - fake_imgs[:, :, 1:, :])) +  # Along height
            torch.sum(torch.abs(fake_imgs[:, :, :, :-1] - fake_imgs[:, :, :, 1:]))  # Along width
        )

        g_l2_loss = self.l2loss(fake_imgs, clean_imgs)
            
        # Update generator
        g_complete_loss = (self.lambda_adv *  g_loss + self.lambda_res * g_res_loss + self.lambda_per * g_l2_loss + g_reg_loss)
        g_complete_loss.backward()
        self.optimizer_G.step()

        return g_complete_loss, d_loss, fake_imgs

    def clamp_discriminator_parameters(self):
        for param in self.discriminator.parameters():
            param.data.clamp_(-0.01, 0.01)


    def save_samples(self, epoch):
        self.generator.eval()
        with torch.no_grad():
            for i, (hazy_imgs, clean_imgs) in enumerate(self.train_dataloader):
                fake_imgs = self.generator(hazy_imgs)
                fake_imgs = fake_imgs * 0.5 + 0.5
                save_image(fake_imgs, f"sample_{epoch}_batch_{i}.png")
        self.generator.train()


    def show_images(self, hazy_imgs, clean_imgs, generated_imgs, num_images=5):
        fig, axes = plt.subplots(3, num_images, figsize=(15, 10))
        for i in range(num_images):
            clean_image = clean_imgs[i].detach().permute(1, 2, 0).cpu().numpy()
            hazy_image = hazy_imgs[i].detach().permute(1, 2, 0).cpu().numpy()
            generated_image = generated_imgs[i].detach().permute(1, 2, 0).cpu().numpy()
        
    
            clean_image = clean_image * 0.5 + 0.5
            hazy_image = hazy_image * 0.5 + 0.5
            generated_image = generated_image * 0.5 + 0.5
    
            
            # Plot hazy images
            axes[0, i].imshow(hazy_image)
            axes[0, i].axis('off')
            axes[0, i].set_title("Hazy Image")
    
            # Plot clean images
            axes[1, i].imshow(clean_image)
            axes[1, i].axis('off')
            axes[1, i].set_title("Clean Image")
    
            # Plot generated images
            axes[2, i].imshow(generated_image)
            axes[2, i].axis('off')
            axes[2, i].set_title("Generated Image")
    
        plt.tight_layout()
        plt.show()

    def save_models(self):
        torch.save(self.generator.state_dict(), 'generator.pth')
        torch.save(self.discriminator.state_dict(), 'discriminator.pth')



In [None]:
# Instantiate Generator and Discriminator
generator = Generator()
discriminator = Discriminator()

# Create Trainer instance and train
lr_step_size = 2
lr_gamma = 0.5
trainer = Trainer(generator, discriminator, train_dataloader, lr_step_size, lr_gamma)
trainer.train()

# lambda_adv=1, lambda_res=150,
# lambda_per=150, num_epochs=10, wgan=False, n_critic=0, use_l1_loss=True, use_adversarial_loss=True

In [None]:
torch.save(generator.state_dict(), 'generator_cgan_wgan.pth')
# torch.save(self.discriminator.state_dict(), 'discriminator.pth')


# ADDING PERCEPTUAL LOSS