In [1]:
import os
import time
import datetime
import time
import numpy as np
import matplotlib.pyplot as plt
import random

from skimage.metrics import structural_similarity
from skimage.metrics import peak_signal_noise_ratio
from sklearn.metrics import mean_squared_error

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as T
import torchvision
from torch.optim import Adam
from torch.utils.data import DataLoader, Subset
from torchvision.models import vgg16, vgg19
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure

from LabDataset import PairedImageDataset 
from Pairing_Images import PairFinder

In [2]:
IMG_SHAPE = (256,256,3)
TARGET_SHAPE = (256,256,3)
BATCH_SIZE = 1
# Dataset Hyper Parameters
subset = "agri"
save_dataframe = "True"
s1_image_path = "Dataset/agri/s1/"
s2_image_path = "Dataset/agri/s2/"

In [3]:
image_dataset = PairedImageDataset(s1_dir=s1_image_path,s2_dir=s2_image_path,subset_name=subset,save_dataframe=save_dataframe,image_size=IMG_SHAPE[0])

# 1. Full Dataset Loader
dataloader = DataLoader(image_dataset, batch_size=BATCH_SIZE, shuffle=True)
print(f"Total Instances = {len(image_dataset)}")

# 2. Subset Dataset (first 1000 samples)
subset_indices = list(range(min(1000, len(image_dataset))))
subset_dataset = Subset(image_dataset, subset_indices)
subset_loader = DataLoader(subset_dataset, batch_size=1, shuffle=False)
print(f"Subset Instances (1000 max) = {len(subset_dataset)}")

# 3. Plot Dataset (first 10 samples)
plot_indices = list(range(min(10, len(image_dataset))))
plot_dataset = Subset(image_dataset, plot_indices)
plot_loader = DataLoader(plot_dataset, batch_size=1, shuffle=False)
print(f"Plot Instances (10 max) = {len(plot_dataset)}")


Total Instances = 4000
Subset Instances (1000 max) = 1000
Plot Instances (10 max) = 10


In [4]:
for i,j in subset_dataset:
    print(i.shape,j.shape)
    break

torch.Size([1, 256, 256]) torch.Size([2, 256, 256])


In [5]:
import torch
import torch.nn as nn

class UNetBlockDown(nn.Module):
    def __init__(self, in_channels, out_channels, apply_batchnorm=True):
        super(UNetBlockDown, self).__init__()
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)]
        if apply_batchnorm:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class UNetBlockUp(nn.Module):
    def __init__(self, in_channels, out_channels, apply_dropout=False):
        super(UNetBlockUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        if apply_dropout:
            layers.append(nn.Dropout(0.5))
        self.block = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.block(x)
        x = torch.cat([x, skip_input], dim=1)
        return x

class UNetGeneratorLAB(nn.Module):
    def __init__(self, in_channels=1, out_channels=2):  # L -> ab
        super(UNetGeneratorLAB, self).__init__()
        self.down1 = UNetBlockDown(in_channels, 64, apply_batchnorm=False)
        self.down2 = UNetBlockDown(64, 128)
        self.down3 = UNetBlockDown(128, 256)
        self.down4 = UNetBlockDown(256, 512)
        self.down5 = UNetBlockDown(512, 512)
        self.down6 = UNetBlockDown(512, 512)
        self.down7 = UNetBlockDown(512, 512)
        self.down8 = UNetBlockDown(512, 512, apply_batchnorm=False)

        self.up1 = UNetBlockUp(512, 512, apply_dropout=True)
        self.up2 = UNetBlockUp(1024, 512, apply_dropout=True)
        self.up3 = UNetBlockUp(1024, 512, apply_dropout=True)
        self.up4 = UNetBlockUp(1024, 512)
        self.up5 = UNetBlockUp(1024, 256)
        self.up6 = UNetBlockUp(512, 128)
        self.up7 = UNetBlockUp(256, 64)

        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # Output ab channels in [-1, 1]
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)

        return self.final(u7)


In [6]:
import torch
import torch.nn as nn

class PatchDiscriminator(nn.Module):
    """
    PatchGAN Discriminator for LAB image colorization.
    - input_image: L channel (1 channel)
    - target_image: ab channels (2 channels)
    """

    def __init__(self, in_channels=3):  # L + ab = 1 + 2
        super(PatchDiscriminator, self).__init__()
        self.model = nn.Sequential(
            self._block(in_channels, 64, norm=False),  # [L|ab] = 3 channels
            self._block(64, 128),
            self._block(128, 256),
            nn.ZeroPad2d(1),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ZeroPad2d(1),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0)
        )

    def _block(self, in_channels, out_channels, norm=True):
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)]
        if norm:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return nn.Sequential(*layers)

    def forward(self, input_l, target_ab):
        """
        input_l: Tensor [B, 1, H, W] — grayscale input (L channel)
        target_ab: Tensor [B, 2, H, W] — real or generated color (ab channels)
        """
        x = torch.cat([input_l, target_ab], dim=1)  # concat on channel dimension → [B, 3, H, W]
        return self.model(x)


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = UNetGeneratorLAB(in_channels=1, out_channels=2).to(device)
discriminator = PatchDiscriminator(in_channels=3).to(device)

In [17]:
LAMBDA_L1 = 100
LAMBDA_PERC = 0.01
loss_object = nn.BCEWithLogitsLoss()
l1_loss_fn = nn.L1Loss()


In [8]:
def generator_loss(disc_generated_output, gen_output, target, include_perceptual):
    real_labels = torch.ones_like(disc_generated_output)
    gan_loss = loss_object(disc_generated_output, real_labels)
    l1 = l1_loss_fn(gen_output, target)

    if include_perceptual:
        perc = perceptual_loss(target, gen_output)
        total_loss = gan_loss + (LAMBDA_L1 * l1) + (LAMBDA_PERC * perc)
        return total_loss, gan_loss, l1, perc
    else:
        total_loss = gan_loss + (LAMBDA_L1 * l1)
        return total_loss, gan_loss, l1

In [9]:
def discriminator_loss(disc_real_output, disc_generated_output):
    real_labels = torch.ones_like(disc_real_output)
    fake_labels = torch.zeros_like(disc_generated_output)

    real_loss = loss_object(disc_real_output, real_labels)
    fake_loss = loss_object(disc_generated_output, fake_labels)

    total_disc_loss = real_loss + fake_loss
    return total_disc_loss


In [10]:
GEN_LR = 0.0002
DISC_LR = 0.0002
BETA_1 = 0.5
BETA_2 = 0.999


In [11]:
generator_optimizer = Adam(generator.parameters(), lr=GEN_LR, betas=(BETA_1, BETA_2))
discriminator_optimizer = Adam(discriminator.parameters(), lr=DISC_LR, betas=(BETA_1, BETA_2))

# Perceptual Loss