In [None]:
import torch
import torch.nn as nn
import cv2
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from PIL import Image
import os
from database2 import DehazingDataset
import matplotlib.pyplot as plt
import random
import torch
import torch.nn as nn
from torch.nn.modules import padding
from torch.nn.modules.batchnorm import BatchNorm2d

# Generator

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Generator used on github repo

In [None]:

class Block(nn.Module):
    def __init__(self,in_channels, out_channels, down = True, act = 'relu', use_dropout = False):
        super(Block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=4,stride=2,padding=1,bias=False,padding_mode='reflect')
            if down
            else
            nn.ConvTranspose2d(in_channels,out_channels,kernel_size=4,stride=2,padding=1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act=='relu' else nn.LeakyReLU(0.2),
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down
    def forward(self,x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x

class Generator(nn.Module):
    def __init__(self,in_channels=3,features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels,features,4,2,1,padding_mode='reflect'),
            nn.LeakyReLU(0.2),
        )
        self.down1 = Block(features,features*2,down=True,act='leaky',use_dropout=False)
        self.down2 = Block(features*2,features*4,down=True,act='leaky',use_dropout=False)
        self.down3 = Block(features*4,features*8,down=True,act='leaky',use_dropout=False)
        self.down4 = Block(features*8,features*8,down=True,act='leaky',use_dropout=False)
        self.down5 = Block(features*8,features*8,down=True,act='leaky',use_dropout=False)
        self.down6 = Block(features*8,features*8,down=True,act='leaky',use_dropout=False)
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features*8,features*8,4,2,1,padding_mode='reflect'),
            nn.ReLU()
        )

        self.up1 = Block(features*8,features*8,down=False,act='relu',use_dropout=True)
        self.up2 = Block(features*8*2,features*8,down=False,act='relu',use_dropout=True)
        self.up3 = Block(features*8*2,features*8,down=False,act='relu',use_dropout=True)
        self.up4 = Block(features*8*2,features*8,down=False,act='relu',use_dropout=False)
        self.up5 = Block(features*8*2,features*4,down=False,act='relu',use_dropout=False)
        self.up6 = Block(features*4*2,features*2,down=False,act='relu',use_dropout=False)
        self.up7 = Block(features*2*2,features,down=False,act='relu',use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features*2,in_channels,kernel_size=4,stride=2,padding=1),
            nn.Tanh(),
        )
    
    def forward(self,x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1,d7],dim=1))
        up3 = self.up3(torch.cat([up2,d6],dim=1))
        up4 = self.up4(torch.cat([up3,d5],dim=1))
        up5 = self.up5(torch.cat([up4,d4],dim=1))
        up6 = self.up6(torch.cat([up5,d3],dim=1))
        up7 = self.up7(torch.cat([up6,d2],dim=1))
        return self.final_up(torch.cat([up7,d1],dim=1))

def test():
    x = torch.randn((1, 3, 256, 256))
    model = Generator(in_channels=3, features=64)
    preds = model(x)
    print(preds.shape)


# Discriminator

In [None]:
class CNNBlock(nn.Module):
    def __init__(self,in_channels, out_channels, stride = 2):
        super(CNNBlock,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,4,stride,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    
    def forward(self,x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self,in_channels = 3, features = [64,128,256,512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels*2,features[0],kernel_size=4,stride=2,padding=1,padding_mode='reflect'),
            nn.LeakyReLU(0.2)
        ) # according to paper 64 channel doesn't contain BatchNorm2d
        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(CNNBlock(in_channels,feature,stride=1 if feature==features[-1] else 2 ))
            in_channels = feature
        
        layers.append(
            nn.Conv2d(in_channels,1,kernel_size=4,stride=1,padding=1,padding_mode='reflect')
        )
        self.model = nn.Sequential(*layers)
    
    def forward(self,x,y):
        x = torch.cat([x,y],dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x

def test():
    x = torch.randn((1, 3, 256, 256))
    y = torch.randn((1, 3, 256, 256))
    model = Discriminator(in_channels=3)
    preds = model(x, y)
    print(model)
    print(preds.shape)

# Function for viewing images

In [None]:
def view_image(image_tensor):
    # Convert tensor to NumPy array and detach if required
    image = image_tensor.detach().permute(1, 2, 0).cpu().numpy()
    
    # Denormalize the image
    image = image * 0.5 + 0.5
    
    # Plot the image
    plt.imshow(image)
    plt.axis('off')
    plt.show()

In [None]:
view_image(clean_imgs[1])

# Function for showing a set of images

In [None]:
def show_images(hazy_imgs, clean_imgs, generated_imgs, num_images=5):
    fig, axes = plt.subplots(3, num_images, figsize=(15, 10))
    for i in range(num_images):
        clean_image = clean_imgs[i].detach().permute(1, 2, 0).cpu().numpy()
        hazy_image = hazy_imgs[i].detach().permute(1, 2, 0).cpu().numpy()
        generated_image = generated_imgs[i].detach().permute(1, 2, 0).cpu().numpy()
    

        clean_image = clean_image * 0.5 + 0.5
        hazy_image = hazy_image * 0.5 + 0.5
        generated_image = generated_image * 0.5 + 0.5

        
        # Plot hazy images
        axes[0, i].imshow(hazy_image)
        axes[0, i].axis('off')
        axes[0, i].set_title("Hazy Image")

        # Plot clean images
        axes[1, i].imshow(clean_image)
        axes[1, i].axis('off')
        axes[1, i].set_title("Clean Image")

        # Plot generated images
        axes[2, i].imshow(generated_image)
        axes[2, i].axis('off')
        axes[2, i].set_title("Generated Image")

    plt.tight_layout()
    plt.show()

In [None]:
show_images(hazy_imgs, clean_imgs, fake_imgs)

# Create dataloaders

In [None]:
# USED ON THE GITHUB REPO
transform_only_input = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    transforms.ToTensor(),
])

# Later use -> Different trnsforms for clean and hazy

In [None]:
root_dir = 'Task2Dataset'
train_dir = os.path.join(root_dir, 'train')
val_dir = os.path.join(root_dir, 'val')


transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  # Approximation of ColorJitter
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Assuming input image range is [0, 1]
])


transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Assuming input image range is [0, 1]
])

train_dataset = DehazingDataset(train_dir, transform_train)
val_dataset = DehazingDataset(val_dir, transform_val)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

# Standard -> Comparison with L1loss & scheduler one

In [None]:
root_dir = 'Task2Dataset'
train_dir = os.path.join(root_dir, 'train')
val_dir = os.path.join(root_dir, 'val')


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.5, 0.5, 0.5],std = [0.5, 0.5, 0.5])
])

train_dataset = DehazingDataset(train_dir, transform)
val_dataset = DehazingDataset(val_dir, transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

# Initialise

In [None]:
# Define the generator and discriminator
generator = Generator()
discriminator = Discriminator()

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# USED/REFERENCED CODE ON REPO: https://github.com/AquibPy/Pix2Pix-Conditional-GANs/tree/main

In [None]:
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    # keep track of batch number
    batch_no = 0
    
    # Training the generator and discriminator
    for hazy_imgs, clean_imgs in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        # Training the discriminator
        discriminator.zero_grad()
        real_imgs = clean_imgs

        # GENERATOR TAKES HAZY IMAGES AS INPUT

        # X -> HAZY IMAGES
        fake_imgs = generator(hazy_imgs)

        # PREDICTIONS OF DISCRIMINATOR FOR REAL IMAGES
        real_outputs = discriminator(hazy_imgs, real_imgs)

        # PREDICTIONS OF DISCRIMINATOR FOR FAKE IMAGES
        fake_outputs = discriminator(hazy_imgs, fake_imgs.detach())

        # CREATE LABELS FOR LOSS CALCULATION
        real_labels = torch.ones_like(real_outputs)
        fake_labels = torch.zeros_like(fake_outputs)

        d_loss_real = criterion(real_outputs, real_labels)
        d_loss_fake = criterion(fake_outputs, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # Training the generator
        generator.zero_grad()
        fake_imgs = generator(hazy_imgs)
        fake_outputs = discriminator(hazy_imgs, fake_imgs)
        g_loss = criterion(fake_outputs, real_labels)
        g_loss.backward()
        optimizer_G.step()


        # increment batch number
        batch_no += 1

        if batch_no % 20 == 0:
            show_images(hazy_imgs, clean_imgs, fake_imgs, num_images=5)
            

    # Print losses
    print(f"Epoch [{epoch + 1}/{num_epochs}], Generator Loss: {g_loss.item():.4f}, Discriminator Loss: {d_loss.item():.4f}")
     
# Save the trained models
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')