<a href="https://colab.research.google.com/github/yashpandey474/CSF425-Deep-Learning-Project-Task-2/blob/master/TASK2_TRAIN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from PIL import Image
import os
from google.colab import drive

In [3]:
drive.mount('/content/drive')

Mounted at /content/drive


In [20]:


# Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Discriminator
# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 1, 4, stride=2, padding=1)
        )

    def forward(self, x):
        x = self.model(x)
        # Perform global average pooling
        x = torch.mean(x, dim=(2, 3))
        return torch.sigmoid(x)



In [11]:


class DehazingDataset(data.Dataset):
    def __init__(self, root_dir, transform=None):
        #Get the images
        self.root_dir = root_dir
        hazy_images_path = os.path.join(root_dir, 'hazy')
        clean_images_path = os.path.join(root_dir, 'GT')


        self.hazy_images = [os.path.join(hazy_images_path,f) for f in os.listdir(hazy_images_path) if  f.endswith('.jpg') or f.endswith('.png') or f.endswith('.jpeg')]
        self.clean_images = [os.path.join(clean_images_path, f) for f in os.listdir(clean_images_path) if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.jpeg')]

        #Filter the images to ensure they are counterparts of the same scene
        self.size = len(self.hazy_images)
        self.transform=transform

    def __getitem__(self, index):
        hazy_img = self.rgb_loader(self.hazy_images[index])
        clean_img = self.rgb_loader(self.clean_images[index])
        hazy_img = self.transform(hazy_img)
        clean_img = self.transform(clean_img)
        return hazy_img, clean_img

    def rgb_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def __len__(self):
        return self.size



In [12]:
root_dir = '/content/drive/MyDrive/Task2/Dataset'
train_dir = os.path.join(root_dir, 'train')
val_dir = os.path.join(root_dir, 'val')
transform = transforms.Compose([
                                #  transforms.Resize((224, 224)), # ASSUMING NO NEED FOR RESIZING AS ALL IMAGES ARE ALREADY 256*256
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
                                 ])

train_dataset = DehazingDataset(train_dir, transform)
val_dataset = DehazingDataset(val_dir, transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

In [21]:
# Define the generator and discriminator
generator = Generator()
discriminator = Discriminator()


In [22]:

# Define the loss function and optimizer
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))



In [None]:

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    # Training the generator and discriminator
    for hazy_imgs, clean_imgs in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        # Training the discriminator
        discriminator.zero_grad()
        real_imgs = clean_imgs

        # GENERATOR TAKES HAZY IMAGES AS INPUT
        fake_imgs = generator(hazy_imgs)

        real_labels = torch.ones(real_imgs.size(0), 1)
        fake_labels = torch.zeros(fake_imgs.size(0), 1)

        # PREDICTIONS OF DISCRIMINATOR FOR REAL IMAGES
        real_outputs = discriminator(real_imgs)

        # PREDICTIONS OF DISCRIMINATOR FOR FAKE IMAGES
        fake_outputs = discriminator(fake_imgs.detach())


        d_loss_real = criterion(real_outputs, real_labels)
        d_loss_fake = criterion(fake_outputs, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # Training the generator
        generator.zero_grad()
        fake_imgs = generator(hazy_imgs)
        fake_outputs = discriminator(fake_imgs)
        g_loss = criterion(fake_outputs, real_labels)
        g_loss.backward()
        optimizer_G.step()

    # Print losses
    print(f"Epoch [{epoch + 1}/{num_epochs}], Generator Loss: {g_loss.item():.4f}, Discriminator Loss: {d_loss.item():.4f}")

# Save the trained models
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')

  self.pid = os.fork()
Epoch 1/10:   3%|▎         | 8/239 [06:04<2:44:39, 42.77s/it]

In [17]:
real_outputs.shape

torch.Size([8192, 1])

In [18]:
real_labels.shape

torch.Size([32, 1])

In [19]:
fake_outputs.shape

torch.Size([8192, 1])

In [None]:
#output is 256*32 instead of 32 outputs