# Satellite Image Prediction with Diffusion
## Model Trainign

In [None]:
%load_ext autoreload
%autoreload 2

from satellite_dataset import SatelliteDataset
from satellite_diffusion_model import train_satellitediffusionmodel

In [None]:
dataset = SatelliteDataset("satellite", preload_to_ram=True)

model = train_satellitediffusionmodel(
    "load",
    dataset,
    batch_size=16,
    num_epochs=1,
    subset_fraction=0.05,
    rgb_loss_weight=1.0,
)

## GEE Data Postprocessing

In [None]:
import glob
import json
import os
import rasterio

def truncate_tile(input_path, size=[256, 256], output_path=None):
    if output_path is None:
        output_path = input_path

    # check output folder exists
    output_dir = os.path.dirname(output_path)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created output folder {output_dir}")
    # check input file exists
    if not os.path.exists(input_path):
        print(f"File {input_path} not found")
        return

    with rasterio.open(input_path, "r") as input_img:
        img = input_img.read()

        # Truncate or crop to 256x256 from top-left
        try:
            img_cropped = img[:, : size[0], : size[1]]
        except Exception as e:
            print(img.shape)
            raise ValueError(
                f"Error cropping image {input_path} to {size[0]}x{size[1]}: {e}"
            )

        # Update dataset dimensions in place
    with rasterio.open(
        output_path,
        "w",
        driver="GTiff",
        height=size[0],
        width=size[1],
        count=input_img.count,
        dtype=img.dtype,
        crs=input_img.crs,
        transform=input_img.transform,
    ) as output_img:
        output_img.write(img_cropped)
        print(f"Truncated {os.path.basename(input_path)} to {size[0]}x{size[1]}")


def process_from_json(json_path, input_dir, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created output folder {output_dir}")

    with open(json_path, "r") as f:
        meta = json.load(f)

    updated = False
    for key in ["tile_rgb", "tile_dem", "tile_lct"]:
        input_file = meta.get(key)
        input_path = os.path.join(input_dir, input_file)
        output_path = os.path.join(output_dir, input_file)

        if input_path and os.path.exists(input_path):
            truncate_tile(input_path, output_path=output_path)
        else:
            print(f"File {input_path} not found")

    if updated:
        with open(json_path, "w") as f:
            json.dump(meta, f, indent=2)
            print(f"Updated {json_path} with new file names")


json_files = glob.glob("training-data/satellite-raw/tile_*_met.json")

print(f"Found {len(json_files)} JSON files to process")
for jp in json_files:
    process_from_json(jp, "training-data/satellite-raw/raw", "training-data/satellite")

In [None]:
import torch
import numpy as np
import torch.nn.functional as F
from satellite_diffusion_model import cosine_beta_schedule
import matplotlib.pyplot as plt
import random


@torch.no_grad()
def denoise_samples(model, dataset, num_samples=4, num_timesteps=50, device="cuda"):
    model.eval()
    samples = []

    # Load the alpha schedule
    betas = cosine_beta_schedule(100)  # full schedule, not truncated
    alphas = np.insert(np.cumprod(1.0 - betas), 0, 1.0)  # ᾱ_t, 0-based offset
    alphas = torch.tensor(alphas, dtype=torch.float32, device=device)

    for _i in range(num_samples):
        iRandom = random.randint(0, len(dataset) - 1)
        rgb = dataset[iRandom]["target_image"].unsqueeze(0).to(device)
        geoinfo_spatial = dataset[iRandom]["geoinfo_spatial"].unsqueeze(0).to(device)
        geoinfo_vector = dataset[iRandom]["geoinfo_vector"].unsqueeze(0).to(device)

        # STEP 1: Add noise up to chosen timestep
        alpha = alphas[num_timesteps].view(1, 1, 1, 1)
        noise = torch.randn_like(rgb)  # Same shape as rgb
        rgb_noisy = alpha.sqrt() * rgb + (1 - alpha).sqrt() * noise  # Same shape as rgb

        # STEP 2: Predict noise
        t = torch.full((1,), num_timesteps / 100, device=device)  # normalised t
        noise_pred = model(
            torch.cat([rgb_noisy, geoinfo_spatial], dim=1), t, geoinfo_vector
        )

        # STEP 3: Reverse step (denoise)
        rgb_pred = (rgb_noisy - (1 - alpha).sqrt() * noise_pred) / alpha.sqrt()

        # Remove mean and clamp to [-1, 1]
        rgb_pred = torch.clamp(rgb_pred, 0, 1.0)
        alpha_ref = 0.5
        rgb_noisy_ref = alpha_ref * rgb + (1 - alpha_ref) * noise
        rgb_pred_ref = (rgb_noisy_ref - np.sqrt(1 - alpha_ref) * noise_pred) / np.sqrt(alpha_ref)
        rgb_pred_ref = torch.clamp(rgb_pred_ref, 0, 1.0)

        # Append to samples for visualization
        samples.append(
            (
                rgb.squeeze().cpu().numpy().transpose(1, 2, 0),
                rgb_noisy.squeeze().cpu().numpy().transpose(1, 2, 0),
                rgb_pred.squeeze().cpu().numpy().transpose(1, 2, 0),
            )
        )

        print(f"loss of noise: {F.mse_loss(noise_pred, noise)}")
        print(f"loss of image: {F.mse_loss(rgb_pred_ref, rgb)}")

    return samples


def plot_denoising_results(samples):
    num = len(samples)
    fig, axes = plt.subplots(num, 4, figsize=(20, 4 * num))

    if num == 1:
        axes = [axes]

    for i, (real, noisy, generated) in enumerate(samples):
        im = axes[i][0].imshow(real)
        axes[i][0].set_title(f"Original image #{i+1}")
        #
        im = axes[i][1].imshow(noisy)
        axes[i][1].set_title("Noisy image")
        #
        im = axes[i][2].imshow(generated)
        axes[i][2].set_title("Denoised image")
        #
        im = axes[i][3].imshow(np.abs(generated - real))
        axes[i][3].set_title("Residual image")

        for ax in axes[i]:
            ax.axis("off")
            ax.set_aspect("equal")

    plt.tight_layout()
    plt.show()

In [None]:
from satellite_dataset import SatelliteDataset
from satellite_diffusion_model import SatelliteDiffusionUNet, load_model
import torch

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

satellite_src = "satellite"
tile_size = 256

dataset = SatelliteDataset(satellite_src, preload_to_ram=False)

model, _, _ = load_model(False, "load", dataset, 1e-4, device)

samples = denoise_samples(
    model, dataset, num_samples=1, num_timesteps=20, device=device
)
plot_denoising_results(samples)