<a href="https://colab.research.google.com/github/yashpandey474/CSF425-Deep-Learning-Project-Task-2/blob/master/TASK2_TRAIN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from PIL import Image
import os
from database2 import DehazingDataset
import matplotlib.pyplot as plt

In [None]:
# Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


#MODIFIED DISCRIMINATOR FOR W-LOSS; NO SIGMOID AT THE END
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 1, 4, stride=2, padding=1)
        )

    def forward(self, x):
        x = self.model(x)
        # Perform global average pooling
        x = torch.mean(x, dim=(2, 3))
        return x




In [None]:
root_dir = 'Task2Dataset'
train_dir = os.path.join(root_dir, 'train')
val_dir = os.path.join(root_dir, 'val')
transform = transforms.Compose([
                                #  transforms.Resize((224, 224)), # ASSUMING NO NEED FOR RESIZING AS ALL IMAGES ARE ALREADY 256*256
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
                                 ])

train_dataset = DehazingDataset(train_dir, transform)
val_dataset = DehazingDataset(val_dir, transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

In [None]:
# Define the generator and discriminator
generator = Generator()
discriminator = Discriminator()

In [None]:

# Define the loss function and optimizer [USING THE W-LOSS INSTEAD OF MIN-MAX LOSS]
criterion = nn.BCEWithLogitsLoss()

# W GANS RECOMMEND RMSPROP
optimizer_D = optim.RMSprop(discriminator.parameters(), lr=0.00005)
optimizer_G = optim.RMSprop(generator.parameters(), lr=0.00005)


# THE CRITIC IS UPDATED MORE TIMES THAN GENERATOR FOR W-GANS
n_critic = 1

In [None]:
num_epochs = 10
epochs = []
g_losses = []
d_losses = []
num_samples = 5

for epoch in range(num_epochs):
    # Training the generator and discriminator
    batch_no = 0
    
    for hazy_imgs, clean_imgs in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        # Training the discriminator
        real_imgs = clean_imgs
        fake_imgs = generator(hazy_imgs)
        
        for discr_train in range(n_critic):
            real_outputs = discriminator(real_imgs)
            fake_outputs = discriminator(fake_imgs.detach())
            
            # UPDATE THE DISCRIMINATOR [CRITIC]
            optimizer_D.zero_grad()
    
            # WGAN utility, we ascend on this hence the loss will be the negative.
            d_loss = -torch.mean(real_outputs - fake_outputs)
    
            d_loss.backward()
            optimizer_D.step()
    
            # CLIPPING OF THE DISCRIMINATOR WEIGHTS
            for param in discriminator.parameters():
                param.data.clamp_(-0.01, 0.01)

        # UPDATE THE GENERATOR
        optimizer_G.zero_grad()

        # REGENERATE IMAGES AND GET OUTPUTS FROM DISCRIMINATOR
        fake_imgs = generator(hazy_imgs)
        fake_outputs = discriminator(fake_imgs)

        #  W-LOSS FOR GENERATOR
        g_loss = -torch.mean(fake_outputs)
        g_loss.backward()
        
        optimizer_G.step()

        epochs.append(epoch + batch_no/len(train_dataloader))
        g_losses.append(-g_loss.item()) # Negative because the loss is actually maximized in WGAN.
        d_losses.append(-d_loss.item())

        if batch_no % 30 == 0:
            # Generate and display a few images
            generated_images = generator(hazy_imgs[:num_samples]).detach().cpu()

            plt.figure(figsize=(10, 4))
            for i in range(num_samples):
                plt.subplot(2, num_samples, i + 1)
                plt.imshow(hazy_imgs[i].permute(1, 2, 0))  # Assuming images are in CHW format
                plt.title('Hazy Image')
                plt.axis('off')

                plt.subplot(2, num_samples, num_samples + i + 1)
                plt.imshow(generated_images[i].permute(1, 2, 0))  # Assuming images are in CHW format
                plt.title('Generated Image')
                plt.axis('off')

            plt.tight_layout()
            plt.show()
            
        batch_no += 1

# Save the trained models
torch.save(generator.state_dict(), 'generator_wloss.pth')
torch.save(discriminator.state_dict(), 'discriminator_wloss.pth')


In [None]:
real_outputs.shape

In [None]:
real_labels.shape

In [None]:
fake_outputs.shape

In [None]:
#output is 256*32 instead of 32 outputs