In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torchvision.datasets as dset
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import time
from torchvision.transforms import Lambda
from torch.optim import lr_scheduler
import cv2

In [None]:

from google.colab import drive
drive.mount('/content/drive')


# U-Net architecture

In [None]:
def down_conv(in_channels, out_channels, kernel_size, stride, padding):
  conv = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
      nn.BatchNorm2d(out_channels),
      nn.LeakyReLU(0.2, inplace=True)
  )
  return conv

def up_conv(in_channels, out_channels, kernel_size, stride, padding):
  conv = nn.Sequential(
      nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True)
  )
  return conv

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Encoding layers
        self.down_conv_1 = down_conv(3,64,5,1,2)
        self.down_conv_2 = down_conv(64,128,4,2,1)
        self.down_conv_3 = down_conv(128,256,4,2,1)
        self.down_conv_4 = down_conv(256,512,4,2,1)

        self.up_trans_3 = up_conv(256, 128, 4, 2, 1)
        self.up_conv_3 = down_conv(256, 128, 3, 1, 1)
        self.up_trans_4 = up_conv(128, 64, 4, 2, 1)
        self.up_conv_4 = down_conv(128, 64, 3, 1, 1)

        self.out = nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0)

    def forward(self, image):
      #encoder
                                   # image = [32, 3, 256, 256]
      x1 = self.down_conv_1(image) # x1 = [32, 64, 256, 256]
      x2 = self.down_conv_2(x1)    # x2 = [32, 128, 128, 128]
      x3 = self.down_conv_3(x2)    # x3 = [32, 256, 64, 64]

      #decoder
      y = self.up_trans_3(x3)                   # y3 = [32, 128, 128, 128]
      y = self.up_conv_3(torch.cat([y,x2],1))  # y3 = [32, 128, 128, 128]
      y = self.up_trans_4(y)                   # y4 = [32, 64, 256, 256]
      y = self.up_conv_4(torch.cat([y,x1],1))  # y4 = [32, 64, 256, 256]
      y = self.out(y)

      return y

In [None]:
image = torch.rand((32,3,256,256))
model = Generator()
output = model(image)
print(output.size())

# Discriminator from repo

In [None]:
class CNNBlock(nn.Module):
    def __init__(self,in_channels, out_channels, stride = 2):
        super(CNNBlock,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,4,stride,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self,x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self,in_channels = 3, features = [64,128,256,512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels*2,features[0],kernel_size=4,stride=2,padding=1,padding_mode='reflect'),
            nn.LeakyReLU(0.2)
        ) # according to paper 64 channel doesn't contain BatchNorm2d
        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(CNNBlock(in_channels,feature,stride=1 if feature==features[-1] else 2 ))
            in_channels = feature

        layers.append(
            nn.Conv2d(in_channels,1,kernel_size=4,stride=1,padding=1,padding_mode='reflect')
        )
        self.model = nn.Sequential(*layers)

    def forward(self,x,y):
        x = torch.cat([x,y],dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x



In [None]:
def test():
    x = torch.randn((1, 3, 256, 256))
    y = torch.randn((1, 3, 256, 256))
    model = Discriminator(in_channels=3)
    preds = model(x, y)
    print(model)
    print(preds.shape)

test()

In [None]:
class DehazingDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        #Get the images
        self.root_dir = root_dir
        hazy_images_path = os.path.join(root_dir, 'hazy')
        clean_images_path = os.path.join(root_dir, 'GT')


        self.hazy_images = sorted([os.path.join(hazy_images_path,f) for f in os.listdir(hazy_images_path) if  f.endswith('.jpg') or f.endswith('.png') or f.endswith('.jpeg')])
        self.clean_images = sorted([os.path.join(clean_images_path, f) for f in os.listdir(clean_images_path) if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.jpeg')])

        #Filter the images to ensure they are counterparts of the same scene
        self.size = len(self.hazy_images)
        self.transform = transform

    def __getitem__(self, index):
        hazy_img = self.rgb_loader(self.hazy_images[index])
        clean_img = self.rgb_loader(self.clean_images[index])
        hazy_img = self.transform(hazy_img)
        clean_img = self.transform(clean_img)
        return hazy_img, clean_img

    def rgb_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def __len__(self):
        return self.size

# Load the dataset

In [None]:
root_dir = train_dir = "/content/drive/My Drive/Task2Dataset"
train_dir = os.path.join(root_dir, 'train')
val_dir = os.path.join(root_dir, 'val')


transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize images
    transforms.ToTensor(),           # Convert to tensor
    Lambda(lambda x: x / torch.max(x.abs()))
])


In [None]:
train_dataset = DehazingDataset(train_dir, transform)
# val_dataset = DehazingDataset(val_dir, transform)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
# val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

# Function to show images

In [None]:
def show_images(hazy_imgs, clean_imgs, generated_imgs, num_images=5):
    fig, axes = plt.subplots(3, num_images, figsize=(15, 10))
    for i in range(num_images):
        clean_image = clean_imgs[i].detach().permute(1, 2, 0).cpu().numpy()
        hazy_image = hazy_imgs[i].detach().permute(1, 2, 0).cpu().numpy()
        generated_image = generated_imgs[i].detach().permute(1, 2, 0).cpu().numpy()


        # Plot hazy images
        axes[0, i].imshow(hazy_image)
        axes[0, i].axis('off')
        axes[0, i].set_title("Hazy Image")

        # Plot clean images
        axes[1, i].imshow(clean_image)
        axes[1, i].axis('off')
        axes[1, i].set_title("Clean Image")

        # Plot generated images
        axes[2, i].imshow(generated_image)
        axes[2, i].axis('off')
        axes[2, i].set_title("Generated Image")

    plt.tight_layout()
    plt.show()

# Initialise

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print('GPU:', torch.cuda.get_device_name(0))  # 0 is the GPU index, change if you have multiple GPUs
else:
    device = torch.device("cpu")
    print('Using CPU for computations')

In [None]:
# Define the generator and discriminator
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [None]:
# Define the loss function and optimizer
bce_loss = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
lr_step_size = 2
lr_gamma = 0.5

schedulerG = lr_scheduler.StepLR(optimizer_G, lr_step_size, lr_gamma)
schedulerD = lr_scheduler.StepLR(optimizer_D, lr_step_size, lr_gamma)

# Training loop

In [None]:
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    batch_no = 0
    schedulerG.step()
    schedulerD.step()

    # Initialize total losses for the epoch
    g_total_loss = 0
    d_total_loss = 0

    # Training the generator and discriminator
    for hazy_imgs, clean_imgs in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}'):

        hazy_imgs = hazy_imgs.to(device)
        clean_imgs = clean_imgs.to(device)

        # Training the discriminator
        optimizer_D.zero_grad()
        real_imgs = clean_imgs
        real_imgs = real_imgs.to(device)

        # GENERATOR TAKES HAZY IMAGES AS INPUT
        fake_imgs = generator(hazy_imgs)
        fake_imgs = fake_imgs.to(device)

        # PREDICTIONS OF DISCRIMINATOR FOR REAL IMAGES
        real_outputs = discriminator(hazy_imgs, real_imgs)

        # PREDICTIONS OF DISCRIMINATOR FOR FAKE IMAGES
        fake_outputs = discriminator(hazy_imgs, fake_imgs.detach())

        # CREATE LABELS FOR LOSS CALCULATION
        real_labels = torch.ones_like(real_outputs).to(device)
        fake_labels = torch.zeros_like(fake_outputs).to(device)

        d_loss_real = bce_loss(real_outputs, real_labels)
        d_loss_fake = bce_loss(fake_outputs, fake_labels)
        d_loss = (d_loss_real + d_loss_fake)/2

        # Update discriminator
        d_loss.backward()
        optimizer_D.step()

        # Accumulate discriminator loss
        d_total_loss += d_loss.item()

        # Training the generator
        optimizer_G.zero_grad()
        fake_imgs = generator(hazy_imgs)
        fake_imgs = fake_imgs.to(device)
        fake_outputs = discriminator(hazy_imgs, fake_imgs)
        g_loss = bce_loss(fake_outputs, real_labels)

        # Compute reconstruction loss
        g_res_loss = l1_loss(fake_imgs, clean_imgs)

        # Update generator
        g_complete_loss = g_loss + g_res_loss
        g_complete_loss.backward()
        optimizer_G.step()

        # Accumulate generator loss
        g_total_loss += g_complete_loss.item()

        batch_no += 1

        # Clear GPU memory
        torch.cuda.empty_cache()

        # Display images every 20 batches
        if batch_no % 20 == 0:
            show_images(hazy_imgs, clean_imgs, fake_imgs, num_images=5)

    # Print average losses for the epoch
    print(f"Epoch [{epoch + 1}/{num_epochs}], Generator Loss: {g_total_loss / len(train_dataloader):.4f}, Discriminator Loss: {d_total_loss / len(train_dataloader):.4f}")

    # Save the trained models after each epoch
    torch.save(generator.state_dict(), f'generator_epoch_{epoch + 1}.pth')
    torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch + 1}.pth')

# Save the trained models
torch.save(generator.state_dict(), 'generator_l1_cgan.pth')
torch.save(discriminator.state_dict(), 'discriminator_l1_cgan.pth')
