In [1]:
import os
import time
import datetime
import time
import numpy as np
import matplotlib.pyplot as plt
import random

from skimage.metrics import structural_similarity
from skimage.metrics import peak_signal_noise_ratio
from sklearn.metrics import mean_squared_error

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as T
import torchvision
from torch.optim import Adam
from torch.utils.data import DataLoader, Subset
from torchvision.models import vgg16, vgg19
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure

from utils.Create_Dataset import PairedImageDataset
from utils.Pairing_Images import PairFinder

In [3]:
IMG_SHAPE = (256,256,3)
TARGET_SHAPE = (256,256,3)
BATCH_SIZE = 1
base_dir = "/home/nvlabs/.cache/kagglehub/datasets/requiemonk/sentinel12-image-pairs-segregated-by-terrain/versions/1/v_2"
# Dataset Hyper Parameters
subset = "agri"
save_dataframe = "True"
s1_image_path = os.path.join(base_dir,subset + "/s1")
s2_image_path = os.path.join(base_dir,subset + "/s2")

In [4]:
image_dataset = PairedImageDataset(s1_dir=s1_image_path, s2_dir=s2_image_path, subset_name=subset, save_dataframe=save_dataframe)
dataloader = DataLoader(image_dataset, batch_size=BATCH_SIZE, shuffle=True)
print(f"Total Instances = {len(image_dataset)}")

subset_indices = list(range(min(1000, len(image_dataset))))
subset_dataset = Subset(image_dataset, subset_indices)
subset_loader = DataLoader(subset_dataset, batch_size=1, shuffle=False)  # <--- Add this line

print(f"Sample Instances = {len(subset_dataset)}")


Total Instances = 4000
Sample Instances = 1000


In [5]:
for i,j in subset_dataset:
    print(i.shape,j.shape)
    break

torch.Size([3, 256, 256]) torch.Size([3, 256, 256])


# Model Instances

Provide to GPU

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = UNetGenerator(in_channels=3, out_channels=2).to(device)
discriminator = PatchDiscriminator(in_channels=3).to(device)

# Metrics

In [1]:
def compute_metrics(generated_image, target_image):
    """
    Compute SSIM, PSNR, MSE, RMSE between generated and target images.
    Both inputs are PyTorch tensors in the range [-1, 1].
    """
    # Move to CPU, detach, and convert to numpy
    generated_image = generated_image.squeeze().detach().cpu().numpy()
    target_image = target_image.squeeze().detach().cpu().numpy()

    # Rescale from [-1, 1] to [0, 1]
    generated_image = (generated_image + 1) / 2.0
    target_image = (target_image + 1) / 2.0

    # Transpose from (C, H, W) to (H, W, C)
    generated_image = np.transpose(generated_image, (1, 2, 0))
    target_image = np.transpose(target_image, (1, 2, 0))

    # Clip values
    generated_image = np.clip(generated_image, 0, 1)
    target_image = np.clip(target_image, 0, 1)

    # Determine win_size
    min_height = min(generated_image.shape[0], target_image.shape[0])
    min_width = min(generated_image.shape[1], target_image.shape[1])
    win_size = min(min_height, min_width)
    win_size = win_size if win_size % 2 == 1 else win_size - 1  # make it odd

    # Ensure win_size is at least 3
    win_size = max(3, win_size)

    # Compute metrics
    ssim = structural_similarity(
        target_image, generated_image,
        channel_axis=-1, data_range=1.0,
        win_size=win_size
    )
    psnr = peak_signal_noise_ratio(target_image, generated_image, data_range=1.0)
    mse = mean_squared_error(target_image.flatten(), generated_image.flatten())
    rmse = np.sqrt(mse)

    return {
        "SSIM": ssim,
        "PSNR": psnr,
        "MSE": mse,
        "RMSE": rmse
    }

def evaluate_model(test_loader, generator, device):
    """
    Evaluate the model on the test set using image quality metrics.
    """
    ssim_scores, psnr_scores, mse_scores, rmse_scores = [], [], [], []

    generator.to(device)
    generator.eval()

    with torch.no_grad():
        for input_image, target_image in test_loader:
            input_image = input_image.to(device)
            target_image = target_image.to(device)

            # Ensure input is 4D (B, C, H, W)
            if input_image.dim() == 3:
                input_image = input_image.unsqueeze(0)
                target_image = target_image.unsqueeze(0)

            generated_image = generator(input_image)

            metrics = compute_metrics(generated_image, target_image)
            ssim_scores.append(metrics["SSIM"])
            psnr_scores.append(metrics["PSNR"])
            mse_scores.append(metrics["MSE"])
            rmse_scores.append(metrics["RMSE"])

    # Print averaged results
    print(f"SSIM: {np.mean(ssim_scores):.4f} | PSNR: {np.mean(psnr_scores):.2f} dB | MSE: {np.mean(mse_scores):.4f} | RMSE: {np.mean(rmse_scores):.4f}")



# Loss Functions

In [11]:
LAMBDA_L1 = 100
LAMBDA_PERC = 0.01
loss_object = nn.BCEWithLogitsLoss()
l1_loss_fn = nn.L1Loss()


#### Perceptual loss

In [12]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T

# ------------------> Build VGG Feature Extractor <------------------
class VGG19FeatureExtractor(nn.Module):
    def __init__(self):
        super(VGG19FeatureExtractor, self).__init__()
        vgg = models.vgg19(pretrained=True).features
        self.feature_extractor = nn.Sequential(*list(vgg.children())[:12])  # Up to relu3_3
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.feature_extractor(x)

# ------------------> Preprocess Function for VGG <------------------
def preprocess(img_tensor):
    """
    img_tensor: (B, 3, H, W), in [-1, 1]
    Output: normalized for VGG (mean-subtracted and scaled)
    """
    # Convert [-1, 1] to [0, 1]
    img_tensor = (img_tensor + 1) / 2.0
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    return normalize(img_tensor)

# ------------------> Perceptual Loss Class <------------------
class PerceptualLoss(nn.Module):
    def __init__(self, device):
        super(PerceptualLoss, self).__init__()
        self.vgg = VGG19FeatureExtractor().to(device)
        self.l1 = nn.L1Loss()

    def forward(self, y_true, y_pred):
        # Preprocess and move to same device as VGG
        y_true = preprocess(y_true).to(next(self.vgg.parameters()).device)
        y_pred = preprocess(y_pred).to(next(self.vgg.parameters()).device)

        # Extract VGG features
        features_true = self.vgg(y_true)
        features_pred = self.vgg(y_pred)

        # Compute L1 loss between features
        return self.l1(features_true, features_pred)

# ------------------> Instantiate on the Correct Device <------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
perceptual_loss = PerceptualLoss(device)




#### Generator Loss = l1_loss * lambda_l1 + perceptaul_loss * lambda

In [13]:
def generator_loss(disc_generated_output, gen_output, target, include_perceptual):
    real_labels = torch.ones_like(disc_generated_output)
    gan_loss = loss_object(disc_generated_output, real_labels)
    l1 = l1_loss_fn(gen_output, target)

    if include_perceptual:
        perc = perceptual_loss(target, gen_output)
        total_loss = gan_loss + (LAMBDA_L1 * l1) + (LAMBDA_PERC * perc)
        return total_loss, gan_loss, l1, perc
    else:
        total_loss = gan_loss + (LAMBDA_L1 * l1)
        return total_loss, gan_loss, l1

#### Discriminator Loss

In [14]:
def discriminator_loss(disc_real_output, disc_generated_output):
    real_labels = torch.ones_like(disc_real_output)
    fake_labels = torch.zeros_like(disc_generated_output)

    real_loss = loss_object(disc_real_output, real_labels)
    fake_loss = loss_object(disc_generated_output, fake_labels)

    total_disc_loss = real_loss + fake_loss
    return total_disc_loss


# Optimizers

In [15]:
GEN_LR = 0.0002
DISC_LR = 0.0002
BETA_1 = 0.5
BETA_2 = 0.999


In [16]:
generator_optimizer = Adam(generator.parameters(), lr=GEN_LR, betas=(BETA_1, BETA_2))
discriminator_optimizer = Adam(discriminator.parameters(), lr=DISC_LR, betas=(BETA_1, BETA_2))

# Train Step

In [17]:
def train_step(grayscale, color, include_perceptual=False):
    grayscale = grayscale.to(device)
    color = color.to(device)

    # Forward pass
    fake_color = generator(grayscale)

    disc_real_output = discriminator(grayscale, color)
    disc_fake_output = discriminator(grayscale, fake_color.detach())

    if include_perceptual:
        gen_loss, adv_loss, l1_loss, perc_loss = generator_loss(disc_fake_output, fake_color, color, include_perceptual)
    else:
        gen_loss, adv_loss, l1_loss = generator_loss(disc_fake_output, fake_color, color, include_perceptual)

    disc_loss = discriminator_loss(disc_real_output, disc_fake_output)

    # Backward pass and optimization
    generator_optimizer.zero_grad()
    gen_loss.backward(retain_graph=True)
    generator_optimizer.step()

    discriminator_optimizer.zero_grad()
    disc_loss.backward()
    discriminator_optimizer.step()

    if include_perceptual:
        return gen_loss, adv_loss, l1_loss, perc_loss, disc_loss
    else:
        return gen_loss, adv_loss, l1_loss, disc_loss


In [18]:
CHECKPOINT_DIR = 'checkpoints_Pix2PixPerceptual'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
loss_history = []

In [19]:
EPOCHS = 100
loss_history = []
include_metrics = True
include_perceptual = True

# Train Module 

In [24]:
loss_history = []
for epoch in range(1, EPOCHS + 1):
    start_time = time.time()

    # Accumulators
    gen_loss_total = adv_loss_total = l1_loss_total = perc_loss_total = disc_loss_total = 0
    num_batches = len(subset_loader)

    for grayscale, color in subset_loader:
        grayscale, color = grayscale.to(device), color.to(device)

        if include_perceptual:
            gen_loss, adv_loss, l1_loss, perc_loss, disc_loss = train_step(grayscale, color, include_perceptual)
            perc_loss_total += perc_loss.item()
        else:
            gen_loss, adv_loss, l1_loss, disc_loss = train_step(grayscale, color, include_perceptual)
            perc_loss = torch.tensor(0.0)  # Placeholder if not used

        # Accumulate all losses
        gen_loss_total += gen_loss.item()
        adv_loss_total += adv_loss.item()
        l1_loss_total += l1_loss.item()
        disc_loss_total += disc_loss.item()

    # Epoch time
    elapsed_time = time.time() - start_time
    minutes, seconds = divmod(elapsed_time, 60)

    # Store loss values (averaged per batch)
    loss_history.append([
        gen_loss_total / num_batches,
        adv_loss_total / num_batches,
        l1_loss_total / num_batches,
        perc_loss_total / num_batches if include_perceptual else 0.0,
        disc_loss_total / num_batches
    ])

    # Print loss summary
    if include_perceptual:
        print(f"Epoch {epoch}/{EPOCHS} | Gen: {gen_loss_total / num_batches:.4f} | "
              f"Adv: {adv_loss_total / num_batches:.4f} | L1: {l1_loss_total / num_batches:.4f} | "
              f"Perc: {perc_loss_total / num_batches:.4f} | Disc: {disc_loss_total / num_batches:.4f} "
              f"| Time: {int(minutes)}m {int(seconds)}s")
    else:
        print(f"Epoch {epoch}/{EPOCHS} | Gen: {gen_loss_total / num_batches:.4f} | "
              f"Adv: {adv_loss_total / num_batches:.4f} | L1: {l1_loss_total / num_batches:.4f} | "
              f"Disc: {disc_loss_total / num_batches:.4f} | Time: {int(minutes)}m {int(seconds)}s")

    # Evaluate model every 10 epochs
    if epoch % 10 == 0 and include_metrics:
        print("-------------------------------------")
        evaluate_model(subset_loader, generator, device)
        print("-------------------------------------")

        # Save models
        torch.save(generator.state_dict(), f"Models/generator_epoch_{epoch}.pth")
        torch.save(discriminator.state_dict(), f"Models/discriminator_epoch_{epoch}.pth")


Epoch 1/100 | Gen: 29.3811 | Adv: 9.9509 | L1: 0.1942 | Perc: 1.4057 | Disc: 0.0118 | Time: 0m 24s
Epoch 2/100 | Gen: 29.4556 | Adv: 10.7915 | L1: 0.1865 | Perc: 1.4008 | Disc: 0.0006 | Time: 0m 24s
Epoch 3/100 | Gen: 29.4244 | Adv: 11.3540 | L1: 0.1806 | Perc: 1.3965 | Disc: 0.0003 | Time: 0m 24s
Epoch 4/100 | Gen: 27.9414 | Adv: 10.4167 | L1: 0.1751 | Perc: 1.3914 | Disc: 0.0585 | Time: 0m 25s
Epoch 5/100 | Gen: 26.6102 | Adv: 9.6184 | L1: 0.1698 | Perc: 1.3866 | Disc: 0.0075 | Time: 0m 23s
Epoch 6/100 | Gen: 27.0140 | Adv: 10.3690 | L1: 0.1663 | Perc: 1.3836 | Disc: 0.0012 | Time: 0m 23s
Epoch 7/100 | Gen: 27.0042 | Adv: 10.8702 | L1: 0.1612 | Perc: 1.3786 | Disc: 0.0007 | Time: 0m 21s
Epoch 8/100 | Gen: 25.9858 | Adv: 10.2474 | L1: 0.1572 | Perc: 1.3748 | Disc: 0.0268 | Time: 0m 26s
Epoch 9/100 | Gen: 25.9290 | Adv: 10.5534 | L1: 0.1536 | Perc: 1.3709 | Disc: 0.0027 | Time: 0m 23s
Epoch 10/100 | Gen: 26.2037 | Adv: 11.1695 | L1: 0.1502 | Perc: 1.3673 | Disc: 0.0011 | Time: 0m 25s
-

In [25]:
for epoch in range(101, 151):
    start_time = time.time()

    # Accumulators
    gen_loss_total = adv_loss_total = l1_loss_total = perc_loss_total = disc_loss_total = 0
    num_batches = len(subset_loader)

    for grayscale, color in subset_loader:
        grayscale, color = grayscale.to(device), color.to(device)

        if include_perceptual:
            gen_loss, adv_loss, l1_loss, perc_loss, disc_loss = train_step(grayscale, color, include_perceptual)
            perc_loss_total += perc_loss.item()
        else:
            gen_loss, adv_loss, l1_loss, disc_loss = train_step(grayscale, color, include_perceptual)
            perc_loss = torch.tensor(0.0)  # Placeholder if not used

        # Accumulate all losses
        gen_loss_total += gen_loss.item()
        adv_loss_total += adv_loss.item()
        l1_loss_total += l1_loss.item()
        disc_loss_total += disc_loss.item()

    # Epoch time
    elapsed_time = time.time() - start_time
    minutes, seconds = divmod(elapsed_time, 60)

    # Store loss values (averaged per batch)
    loss_history.append([
        gen_loss_total / num_batches,
        adv_loss_total / num_batches,
        l1_loss_total / num_batches,
        perc_loss_total / num_batches if include_perceptual else 0.0,
        disc_loss_total / num_batches
    ])

    # Print loss summary
    if include_perceptual:
        print(f"Epoch {epoch}/{EPOCHS} | Gen: {gen_loss_total / num_batches:.4f} | "
              f"Adv: {adv_loss_total / num_batches:.4f} | L1: {l1_loss_total / num_batches:.4f} | "
              f"Perc: {perc_loss_total / num_batches:.4f} | Disc: {disc_loss_total / num_batches:.4f} "
              f"| Time: {int(minutes)}m {int(seconds)}s")
    else:
        print(f"Epoch {epoch}/{EPOCHS} | Gen: {gen_loss_total / num_batches:.4f} | "
              f"Adv: {adv_loss_total / num_batches:.4f} | L1: {l1_loss_total / num_batches:.4f} | "
              f"Disc: {disc_loss_total / num_batches:.4f} | Time: {int(minutes)}m {int(seconds)}s")

    # Evaluate model every 10 epochs
    if epoch % 10 == 0 and include_metrics:
        print("-------------------------------------")
        evaluate_model(subset_loader, generator, device)
        print("-------------------------------------")

        # Save models
        torch.save(generator.state_dict(), f"Models/generator_epoch_{epoch}.pth")
        torch.save(discriminator.state_dict(), f"Models/discriminator_epoch_{epoch}.pth")

Epoch 101/100 | Gen: 23.5989 | Adv: 16.2576 | L1: 0.0733 | Perc: 1.1129 | Disc: 0.0211 | Time: 0m 25s
Epoch 102/100 | Gen: 22.0563 | Adv: 14.7718 | L1: 0.0727 | Perc: 1.1106 | Disc: 0.0070 | Time: 0m 25s
Epoch 103/100 | Gen: 22.1233 | Adv: 14.8447 | L1: 0.0727 | Perc: 1.1094 | Disc: 0.0005 | Time: 0m 23s
Epoch 104/100 | Gen: 22.4289 | Adv: 15.1738 | L1: 0.0724 | Perc: 1.1075 | Disc: 0.0010 | Time: 0m 26s
Epoch 105/100 | Gen: 23.0538 | Adv: 15.8274 | L1: 0.0722 | Perc: 1.1057 | Disc: 0.0003 | Time: 0m 22s
Epoch 106/100 | Gen: 23.6436 | Adv: 16.4542 | L1: 0.0718 | Perc: 1.1038 | Disc: 0.0001 | Time: 0m 21s
Epoch 107/100 | Gen: 23.8929 | Adv: 16.6963 | L1: 0.0719 | Perc: 1.1022 | Disc: 0.0001 | Time: 0m 22s
Epoch 108/100 | Gen: 24.2676 | Adv: 17.0931 | L1: 0.0716 | Perc: 1.1007 | Disc: 0.0000 | Time: 0m 23s
Epoch 109/100 | Gen: 24.5428 | Adv: 17.4048 | L1: 0.0713 | Perc: 1.0988 | Disc: 0.0000 | Time: 0m 22s
Epoch 110/100 | Gen: 23.7970 | Adv: 16.6597 | L1: 0.0713 | Perc: 1.0977 | Disc: 0.

In [26]:
for epoch in range(151, 201):
    start_time = time.time()

    # Accumulators
    gen_loss_total = adv_loss_total = l1_loss_total = perc_loss_total = disc_loss_total = 0
    num_batches = len(subset_loader)

    for grayscale, color in subset_loader:
        grayscale, color = grayscale.to(device), color.to(device)

        if include_perceptual:
            gen_loss, adv_loss, l1_loss, perc_loss, disc_loss = train_step(grayscale, color, include_perceptual)
            perc_loss_total += perc_loss.item()
        else:
            gen_loss, adv_loss, l1_loss, disc_loss = train_step(grayscale, color, include_perceptual)
            perc_loss = torch.tensor(0.0)  # Placeholder if not used

        # Accumulate all losses
        gen_loss_total += gen_loss.item()
        adv_loss_total += adv_loss.item()
        l1_loss_total += l1_loss.item()
        disc_loss_total += disc_loss.item()

    # Epoch time
    elapsed_time = time.time() - start_time
    minutes, seconds = divmod(elapsed_time, 60)

    # Store loss values (averaged per batch)
    loss_history.append([
        gen_loss_total / num_batches,
        adv_loss_total / num_batches,
        l1_loss_total / num_batches,
        perc_loss_total / num_batches if include_perceptual else 0.0,
        disc_loss_total / num_batches
    ])

    # Print loss summary
    if include_perceptual:
        print(f"Epoch {epoch}/{200} | Gen: {gen_loss_total / num_batches:.4f} | "
              f"Adv: {adv_loss_total / num_batches:.4f} | L1: {l1_loss_total / num_batches:.4f} | "
              f"Perc: {perc_loss_total / num_batches:.4f} | Disc: {disc_loss_total / num_batches:.4f} "
              f"| Time: {int(minutes)}m {int(seconds)}s")
    else:
        print(f"Epoch {epoch}/{200} | Gen: {gen_loss_total / num_batches:.4f} | "
              f"Adv: {adv_loss_total / num_batches:.4f} | L1: {l1_loss_total / num_batches:.4f} | "
              f"Disc: {disc_loss_total / num_batches:.4f} | Time: {int(minutes)}m {int(seconds)}s")

    # Evaluate model every 10 epochs
    if epoch % 10 == 0 and include_metrics:
        print("-------------------------------------")
        evaluate_model(subset_loader, generator, device)
        print("-------------------------------------")

        # Save models
        torch.save(generator.state_dict(), f"Models/generator_epoch_{epoch}.pth")
        torch.save(discriminator.state_dict(), f"Models/discriminator_epoch_{epoch}.pth")

Epoch 151/200 | Gen: 24.8637 | Adv: 18.4064 | L1: 0.0645 | Perc: 1.0393 | Disc: 0.0000 | Time: 0m 23s
Epoch 152/200 | Gen: 25.1130 | Adv: 18.6827 | L1: 0.0642 | Perc: 1.0382 | Disc: 0.0289 | Time: 0m 23s
Epoch 153/200 | Gen: 23.5009 | Adv: 17.0833 | L1: 0.0641 | Perc: 1.0370 | Disc: 0.0019 | Time: 0m 22s
Epoch 154/200 | Gen: 23.7227 | Adv: 17.3384 | L1: 0.0637 | Perc: 1.0352 | Disc: 0.0001 | Time: 0m 24s
Epoch 155/200 | Gen: 23.9649 | Adv: 17.6090 | L1: 0.0635 | Perc: 1.0337 | Disc: 0.0001 | Time: 0m 24s
Epoch 156/200 | Gen: 23.9786 | Adv: 17.6078 | L1: 0.0636 | Perc: 1.0333 | Disc: 0.0001 | Time: 0m 23s
Epoch 157/200 | Gen: 24.0169 | Adv: 17.6352 | L1: 0.0637 | Perc: 1.0327 | Disc: 0.0001 | Time: 0m 21s
Epoch 158/200 | Gen: 24.2833 | Adv: 17.9351 | L1: 0.0634 | Perc: 1.0309 | Disc: 0.0000 | Time: 0m 21s
Epoch 159/200 | Gen: 24.6292 | Adv: 18.3246 | L1: 0.0629 | Perc: 1.0290 | Disc: 0.0000 | Time: 0m 23s
Epoch 160/200 | Gen: 23.0824 | Adv: 16.7648 | L1: 0.0631 | Perc: 1.0283 | Disc: 0.

In [29]:
import numpy as np

# Convert to numpy array and save
np.save("Pizx2pix_Percpetual_loss_history.npy", np.array(loss_history))

# Load it later:
# loss_history = np.load("loss_history.npy", allow_pickle=True).tolist()
