# Initialisation

In [None]:
import os, sys
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, random_split
from torchvision import transforms
import numpy as np
from PIL import Image
from pytorch_msssim import ssim
import random

sys.path.append("models")
from UNetGenerator import UNetGenerator
from UNetDiscriminator import UNetDiscriminator
from SatelliteDataset import SatelliteDataset

print(
    f"PyTorch version: {torch.__version__}, MPS available: {torch.backends.mps.is_available()}"
)

import matplotlib.pyplot as plt
from landcovervis import landcover, landcovernorm

In [None]:
tile_size = 256  # Size of each tile (256x256)
# in_channels = 2  # DEM, land cover
in_channels = 3  # DEM, land cover, hillshade

input_path = "/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/training-data/unet-input.tif"
target_path = "/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/training-data/unet-target.tif"

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
batch_size = 16

In [None]:
genmodel = UNetGenerator(in_channels=in_channels).to(device)
discmodel = UNetDiscriminator(in_channels=in_channels + 3).to(device)

genmodel.load_state_dict(
    torch.load(
        f"/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/models/unet-T{tile_size}C{in_channels}_best.pth",
        map_location=device,
    )
)
# genmodel.load_state_dict(
#     torch.load(
#         f"/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/models/pix2pix-T{tile_size}C{in_channels}_best.pth",
#         map_location=device,
#     )
# )

# discmodel.load_state_dict(
#     torch.load(
#         f"/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/models/pix2pix_disc-T{tile_size}C{in_channels}.pth",
#         map_location=device,
#     )
# )

g_opt = torch.optim.Adam(genmodel.parameters(), lr=2e-4, betas=(0.5, 0.999))
d_opt = torch.optim.Adam(discmodel.parameters(), lr=2e-4, betas=(0.5, 0.999))

criterion_gan = nn.BCELoss()
criterion_l1 = nn.L1Loss()

# Original dataset
dataset = SatelliteDataset(
    input_path,
    target_path,
    tile_size=tile_size,
    in_channels=in_channels,
    rotate=True,
    # forest_gamma=1.4,
)

# Split
val_size = int(0.1 * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1)

best_psnr = 0.0

In [None]:
epochs = 40
subset_fraction = 0.1

for epoch in range(epochs):
    num_samples = int(len(dataset) * subset_fraction)
    indices = random.sample(range(len(dataset)), num_samples)
    sampler = SubsetRandomSampler(indices)
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    
    genmodel.train()
    for i, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)

        # Discriminator
        discmodel.zero_grad()
        real_pair = torch.cat([x, y], dim=1)
        fake_y = genmodel(x).detach()
        fake_pair = torch.cat([x, fake_y], dim=1)
        real_output = discmodel(real_pair)
        fake_output = discmodel(fake_pair)
        real_label = torch.ones_like(real_output)
        fake_label = torch.zeros_like(fake_output)
        d_real_loss = criterion_gan(real_output, real_label)
        d_fake_loss = criterion_gan(fake_output, fake_label)
        d_loss = (d_real_loss + d_fake_loss) * 0.5
        d_loss.backward()
        d_opt.step()

        # Generator
        genmodel.zero_grad()
        fake_y = genmodel(x)
        fake_pair = torch.cat([x, fake_y], dim=1)
        g_gan_loss = criterion_gan(discmodel(fake_pair), real_label)
        g_l1_loss = criterion_l1(fake_y, y)
        g_loss = 0.1 * g_gan_loss + 100 * g_l1_loss
        g_loss.backward()
        g_opt.step()

    # Validation Metrics
    genmodel.eval()
    total_psnr = 0.0
    total_ssim = 0.0
    with torch.no_grad():
        for x_val, y_val in val_loader:
            x_val, y_val = x_val.to(device), y_val.to(device)
            pred_val = genmodel(x_val)
            total_psnr += 20 * torch.log10(
                1.0 / torch.sqrt(nn.functional.mse_loss(pred_val, y_val))
            )
            total_ssim += ssim(pred_val, y_val, data_range=1.0, size_average=True)

    avg_psnr = total_psnr / len(val_loader)
    avg_ssim = total_ssim / len(val_loader)
    print(
        f"Epoch {epoch+1}/{epochs} | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f} | PSNR: {avg_psnr:.2f} | SSIM: {avg_ssim:.4f}"
    )

    # Save best generator
    if avg_psnr > best_psnr:
        best_psnr = avg_psnr
        torch.save(
            genmodel.state_dict(),
            f"/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/models/pix2pix-T{tile_size}C{in_channels}_best.pth",
        )
        print(
            f"New best model saved with PSNR: {avg_psnr:.2f} at epoch {epoch+1}/{epochs}"
        )

    nTests = 2
    with torch.no_grad():
        x, y = next(iter(train_loader))
        x, y = x.to(device), y.to(device)
        pred = genmodel(x).cpu().numpy()
        fig, axes = plt.subplots(
            nTests, 2 + in_channels, figsize=(5 * (2 + in_channels), nTests * 3 + 1)
        )
        for i in range(nTests):
            axes[i, 0].imshow(x[i][0].cpu(), cmap="terrain", vmin=0, vmax=4000)
            axes[i, 1].imshow(x[i][1].cpu(), cmap=landcover, norm=landcovernorm)
            if in_channels == 3:
                axes[i, 2].imshow(
                    x[i][2].cpu(), cmap="gray", vmin=0, vmax=1
                )  # Hillshade
            axes[i, in_channels].imshow(np.transpose(pred[i], (1, 2, 0)))
            axes[i, in_channels + 1].imshow(y[i].cpu().numpy().transpose(1, 2, 0))

            # Cosmetic
            for j in range(4):
                axes[i, j].set_xticks([])
                axes[i, j].set_yticks([])
                axes[i, 0].set_axis_off()

            if i == 0:
                axes[i, 0].set_title("DEM")
                axes[i, 1].set_title("Land Cover")
                if in_channels == 3:
                    axes[i, 2].set_title("Hillshade")
                axes[i, in_channels].set_title("Predicted RGB")
                axes[i, in_channels + 1].set_title("True RGB")
        plt.tight_layout()
        plt.show()

torch.save(
    genmodel.state_dict(),
    f"/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/models/pix2pix-T{tile_size}C{in_channels}.pth",
)
torch.save(
    discmodel.state_dict(),
    f"/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/models/pix2pix_disc-T{tile_size}C{in_channels}.pth",
)

In [None]:
in_channels = 3
tile_size = 256

dataset = SatelliteDataset(
    input_path,
    target_path,
    tile_size=tile_size,
    in_channels=in_channels,
    rotate=False,
)

genmodel_unet = UNetGenerator(in_channels=in_channels).to(device)
genmodel_unet.load_state_dict(
    torch.load(
        f"/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/models/unet-T{tile_size}C{in_channels}_best.pth",
        map_location=device,
    )
)
genmodel_unet.eval()

genmodel_pix2pix = UNetGenerator(in_channels=3).to(device)
genmodel_pix2pix.load_state_dict(
    torch.load(
        f"/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/models/pix2pix-T{tile_size}C{in_channels}_best.pth",
        map_location=device,
    )
)
genmodel_pix2pix.eval()

testId = [1158, 1171, 64]

with torch.no_grad():
    x_list, y_list = [], []
    for iTest in testId:  # Choose specific sample indices here
        x_item, y_item = dataset[iTest]
        x_list.append(x_item)
        y_list.append(y_item)

    x = torch.stack(x_list).to(device)
    y = torch.stack(y_list).to(device)

    pred_unet = genmodel_unet(x).cpu().numpy()
    pred_pix2pix = genmodel_pix2pix(x).cpu().numpy()

    for iTest in range(len(testId)):
        plt.rcParams.update({"font.size": 6})
        fig, axes = plt.subplots(2, 3, figsize=(5, 3))
        axes[0, 0].imshow(x[iTest][0].cpu(), cmap="terrain", vmin=0, vmax=4000)
        axes[0, 1].imshow(x[iTest][1].cpu(), cmap=landcover, norm=landcovernorm)
        axes[0, 2].imshow(x[iTest][2].cpu(), cmap="gray", vmin=0, vmax=1)  # Hillshade
        axes[1, 0].imshow(np.transpose(pred_unet[iTest], (1, 2, 0)))
        axes[1, 1].imshow(np.transpose(pred_pix2pix[iTest], (1, 2, 0)))
        axes[1, 2].imshow(y[iTest].cpu().numpy().transpose(1, 2, 0))

        # Cosmetic
        for i in range(2):
            for j in range(3):
                axes[i, j].set_xticks([])
                axes[i, j].set_yticks([])
                axes[i, 0].set_axis_off()

        axes[0, 0].set_title("DEM")
        axes[0, 1].set_title("Land Cover")
        axes[0, 2].set_title("HS")
        axes[1, 0].set_title("Prediction, U-Net")
        axes[1, 1].set_title("Prediction, Pix2Pix")
        axes[1, 2].set_title("Ground truth")

        plt.tight_layout()
        plt.savefig(
            f"figures/pix2pix_test-{testId[iTest]}.png", dpi=600, bbox_inches="tight"
        )
        plt.show()