# Initialisation

In [None]:
import os, sys
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms
import numpy as np
from PIL import Image
import random

sys.path.append("models")
from UNetGenerator import UNetGenerator
from SatelliteDataset import SatelliteDataset

print(
    f"PyTorch version: {torch.__version__}, MPS available: {torch.backends.mps.is_available()}"
)

import matplotlib.pyplot as plt
from landcovervis import landcover, landcovernorm

# Data Preparation

In [None]:
tile_size = 256  # Size of each tile (256x256)
in_channels = 2  # DEM, land cover
# in_channels = 3  # DEM, land cover, hillshade
batch_size = 16
learning_rate = 1e-3

input_path = "/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/training-data/unet-input.tif"
target_path = "/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/training-data/unet-target.tif"

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [None]:
genmodel = UNetGenerator(in_channels=in_channels).to(device)
genmodel.load_state_dict(
    torch.load(
        f"/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/models/unet-T{tile_size}C{in_channels}_best.pth",
        map_location=device,
    )
)

optimiser = torch.optim.Adam(genmodel.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

dataset = SatelliteDataset(
    input_path,
    target_path,
    tile_size=tile_size,
    in_channels=in_channels,
    rotate=True,
    # forest_gamma=1.2,
)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
epochs = 200
subset_fraction = 0.25

best_loss = float("inf")

for epoch in range(epochs):
    num_samples = int(len(dataset) * subset_fraction)
    indices = random.sample(range(len(dataset)), num_samples)
    sampler = SubsetRandomSampler(indices)
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)

    genmodel.train()
    epoch_loss = 0
    for batch in train_loader:
        x, y = batch
        x, y = x.to(device), y.to(device)
        pred = genmodel(x)
        loss = criterion(pred, y)
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        epoch_loss += loss.item()

    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(
            genmodel.state_dict(),
            f"/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/models/unet-T{tile_size}C{in_channels}_best.pth",
        )
        print(
            f"Epoch {epoch+1:2d}/{epochs:2d}  Loss: {epoch_loss/len(train_loader):.4f}  Saved best model"
        )
    else:
        print(
            f"Epoch {epoch+1:2d}/{epochs:2d}  Loss: {epoch_loss/len(train_loader):.4f}"
        )

    nTests = 2
    with torch.no_grad():
        x, y = next(iter(train_loader))
        x, y = x.to(device), y.to(device)
        pred = genmodel(x).cpu().numpy()
        fig, axes = plt.subplots(
            nTests, 2 + in_channels, figsize=(5 * (2 + in_channels), nTests * 3 + 1)
        )
        for i in range(nTests):
            axes[i, 0].imshow(x[i][0].cpu(), cmap="terrain", vmin=0, vmax=4000)
            axes[i, 1].imshow(x[i][1].cpu(), cmap=landcover, norm=landcovernorm)
            if in_channels == 3:
                axes[i, 2].imshow(
                    x[i][2].cpu(), cmap="gray", vmin=0, vmax=1
                )  # Hillshade
            axes[i, in_channels].imshow(np.transpose(pred[i], (1, 2, 0)))
            axes[i, in_channels + 1].imshow(y[i].cpu().numpy().transpose(1, 2, 0))

            # Cosmetic
            for j in range(4):
                axes[i, j].set_xticks([])
                axes[i, j].set_yticks([])
                axes[i, 0].set_axis_off()

            if i == 0:
                axes[i, 0].set_title("DEM")
                axes[i, 1].set_title("Land Cover")
                if in_channels == 3:
                    axes[i, 2].set_title("Hillshade")
                axes[i, in_channels].set_title("Predicted RGB")
                axes[i, in_channels + 1].set_title("True RGB")
        plt.tight_layout()
        plt.show()

torch.save(
    genmodel.state_dict(),
    f"/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/models/unet-T{tile_size}C{in_channels}.pth",
)

In [None]:
dataset = SatelliteDataset(
    input_path,
    target_path,
    tile_size=tile_size,
    in_channels=2,
    rotate=False,
)
genmodel = UNetGenerator(in_channels=2).to(device)
genmodel.load_state_dict(
    torch.load(
        f"/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/models/unet-T{tile_size}C2_best.pth",
        map_location=device,
    )
)
genmodel.eval()

dataset_hs = SatelliteDataset(
    input_path,
    target_path,
    tile_size=tile_size,
    in_channels=3,
    rotate=False,
)
genmodel_hs = UNetGenerator(in_channels=3).to(device)
genmodel_hs.load_state_dict(
    torch.load(
        f"/Users/williameclee/Documents/college/MATH/2025_1-MATH496T/satellite-image-predictor/models/unet-T{tile_size}C3_best.pth",
        map_location=device,
    )
)
genmodel_hs.eval()

testId = [30, 800, 64]

with torch.no_grad():
    x_list, x_list_hs, y_list = [], [], []
    for iTest in testId:  # Choose specific sample indices here
        x_item, y_item = dataset[iTest]
        x_item_hs, _ = dataset_hs[iTest]
        x_list.append(x_item)
        x_list_hs.append(x_item_hs)
        y_list.append(y_item)

    x = torch.stack(x_list).to(device)
    x_hs = torch.stack(x_list_hs).to(device)
    y = torch.stack(y_list).to(device)

    pred = genmodel(x).cpu().numpy()
    pred_hs = genmodel_hs(x_hs).cpu().numpy()

    for iTest in range(len(testId)):
        plt.rcParams.update({"font.size": 6})
        fig, axes = plt.subplots(2, 3, figsize=(5, 3))
        axes[0, 0].imshow(x_hs[iTest][0].cpu(), cmap="terrain", vmin=0, vmax=4000)
        axes[0, 1].imshow(x_hs[iTest][1].cpu(), cmap=landcover, norm=landcovernorm)
        axes[0, 2].imshow(
            x_hs[iTest][2].cpu(), cmap="gray", vmin=0, vmax=1
        )  # Hillshade
        axes[1, 0].imshow(np.transpose(pred[iTest], (1, 2, 0)))
        axes[1, 1].imshow(np.transpose(pred_hs[iTest], (1, 2, 0)))
        axes[1, 2].imshow(y[iTest].cpu().numpy().transpose(1, 2, 0))

        # Cosmetic
        for i in range(2):
            for j in range(3):
                axes[i, j].set_xticks([])
                axes[i, j].set_yticks([])
                axes[i, 0].set_axis_off()

        axes[0, 0].set_title("DEM")
        axes[0, 1].set_title("Land Cover")
        axes[0, 2].set_title("HS")
        axes[1, 0].set_title("Prediction, w/o HS")
        axes[1, 1].set_title("Prediction, w/ HS")
        axes[1, 2].set_title("Ground truth")

        plt.tight_layout()
        plt.savefig(
            f"figures/unet_test-{testId[iTest]}.png", dpi=600, bbox_inches="tight"
        )
        plt.show()