# Satellite Image Prediction with Diffusion
## Model Trainign

In [None]:
%load_ext autoreload
%autoreload 2

from satellite_dataset import SatelliteDataset
from satellite_diffusion_model import train_satellitediffusionmodel

In [None]:
geoinfo_keys = [
    "dem_mean",
    "dem_min",
    "dem_max",
    "tile_scale",
    "month",
    "solar_zenith",
    "solar_azimuth",
    "north_dir",
]

dataset = SatelliteDataset("satellite", geoinfo_keys=geoinfo_keys, preload_to_ram=True)

In [None]:
model = train_satellitediffusionmodel(
    "load",
    dataset,
    batch_size=16,
    num_epochs=100,
    subset_fraction=0.005,
    learning_rate=1e-4,
    rgb_loss_weight=2.0,
)

## GEE Data Postprocessing
See `gee_tiles.ipynb`

## Plotting

In [None]:
import torch
import numpy as np
import torch.nn.functional as F
from satellite_diffusion_model import cosine_beta_schedule
import matplotlib.pyplot as plt
import random


@torch.no_grad()
def denoise_samples(model, dataset, num_samples=4, num_timesteps=50, device="cuda"):
    model.eval()
    samples = []

    # Load the alpha schedule
    betas = cosine_beta_schedule(100)  # full schedule, not truncated
    alphas = np.insert(np.cumprod(1.0 - betas), 0, 1.0)  # ᾱ_t, 0-based offset
    alphas = torch.tensor(alphas, dtype=torch.float32, device=device)

    for _i in range(num_samples):
        iRandom = random.randint(0, len(dataset) - 1)
        rgb = dataset[iRandom]["target_image"].unsqueeze(0).to(device)
        geoinfo_spatial = dataset[iRandom]["geoinfo_spatial"].unsqueeze(0).to(device)
        geoinfo_vector = dataset[iRandom]["geoinfo_vector"].unsqueeze(0).to(device)

        # STEP 1: Add noise up to chosen timestep
        alpha = alphas[num_timesteps].view(1, 1, 1, 1)
        noise = torch.randn_like(rgb)  # Same shape as rgb
        rgb_noisy = alpha.sqrt() * rgb + (1 - alpha).sqrt() * noise  # Same shape as rgb

        # STEP 2: Predict noise
        t = torch.full((1,), num_timesteps / 100, device=device)  # normalised t
        noise_pred = model(
            torch.cat([rgb_noisy, geoinfo_spatial], dim=1), t, geoinfo_vector
        )

        # STEP 3: Reverse step (denoise)
        rgb_pred = (rgb_noisy - (1 - alpha).sqrt() * noise_pred) / alpha.sqrt()

        # Remove mean and clamp to [-1, 1]
        oob_loss = F.mse_loss(rgb_pred, torch.clamp(rgb_pred, 0, 1.0))
        rgb_pred = torch.clamp(rgb_pred, 0, 1.0)

        # Denoise with reference alpha
        alpha_ref = 0.5
        rgb_noisy_ref = alpha_ref * rgb + (1 - alpha_ref) * noise
        rgb_pred_ref = (rgb_noisy_ref - np.sqrt(1 - alpha_ref) * noise_pred) / np.sqrt(
            alpha_ref
        )
        rgb_pred_ref = torch.clamp(rgb_pred_ref, 0, 1.0)

        # Append to samples for visualization
        samples.append(
            (
                rgb.squeeze().cpu().numpy().transpose(1, 2, 0),
                rgb_noisy.squeeze().cpu().numpy().transpose(1, 2, 0),
                rgb_pred.squeeze().cpu().numpy().transpose(1, 2, 0),
            )
        )

        print(f"loss of noise: {F.mse_loss(noise_pred, noise):.4f}")
        print(f"loss of image: {F.mse_loss(rgb_pred_ref, rgb):.4f}")
        print(f"loss of OOB:   {oob_loss:.4f}")

    return samples


def plot_denoising_results(samples):
    num = len(samples)
    fig, axes = plt.subplots(num, 4, figsize=(20, 4 * num))

    if num == 1:
        axes = [axes]

    for i, (real, noisy, generated) in enumerate(samples):
        im = axes[i][0].imshow(real)
        axes[i][0].set_title(f"Original image #{i+1}")
        #
        im = axes[i][1].imshow(noisy)
        axes[i][1].set_title("Noisy image")
        #
        im = axes[i][2].imshow(generated)
        axes[i][2].set_title("Denoised image")
        #
        im = axes[i][3].imshow(np.abs(generated - real))
        axes[i][3].set_title("Residual noise")

        for ax in axes[i]:
            ax.axis("off")
            ax.set_aspect("equal")

    plt.tight_layout()
    plt.show()

In [None]:
from satellite_dataset import SatelliteDataset
from satellite_diffusion_model import SatelliteDiffusionUNet, load_model
import torch

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

satellite_src = "satellite"
tile_size = 256

geoinfo_keys = [
    "dem_mean",
    "dem_min",
    "dem_max",
    "tile_scale",
    "month",
    "solar_zenith",
    "solar_azimuth",
    "north_dir",
]

dataset = SatelliteDataset(
    satellite_src, geoinfo_keys=geoinfo_keys, preload_to_ram=False
)

model, _, _ = load_model(
    False,
    "models/satellite_diffusion-T256-A7_best.pth",
    dataset,
    1e-4,
    device,
)

samples = denoise_samples(
    model, dataset, num_samples=1, num_timesteps=75, device=device
)
plot_denoising_results(samples)

In [None]:
import torch
import numpy as np
from dem_diffusion_model import cosine_beta_schedule
from satellite_dataset import SatelliteDataset
from satellite_diffusion_model import load_model
import torch


@torch.no_grad()
def generate_from_noise(
    model, dataset, num_samples=4, num_timesteps=100, device="cuda"
):
    model.eval()
    samples = []

    # Define beta schedule
    betas = cosine_beta_schedule(num_timesteps)
    alphas = 1.0 - betas
    alphas_cumprod = np.cumprod(alphas)
    alphas_cumprod = np.insert(
        alphas_cumprod, 0, 1.0
    )  # now alphas_cumprod[1] = \bar{alpha}_1

    alphas_cumprod = torch.tensor(alphas_cumprod, dtype=torch.float32, device=device)

    for _ in range(num_samples):
        # Pick a random conditioning input
        idx = np.random.randint(0, len(dataset))
        cond = dataset[idx]
        geo_spatial = cond["geoinfo_spatial"].unsqueeze(0).to(device)
        geo_vector = cond["geoinfo_vector"].unsqueeze(0).to(device)

        # Start from pure Gaussian noise
        x = torch.randn_like(cond["target_image"].unsqueeze(0)).to(device)

        # Denoise iteratively from t = T to 1 (every 5 steps for visualization)
        for t_step in reversed(range(1, num_timesteps + 1, 1)):
            t_tensor = torch.full((1,), t_step / num_timesteps, device=device)

            noise_pred = model(torch.cat([x, geo_spatial], dim=1), t_tensor, geo_vector)

            # Predict x0 from x_t and predicted noise
            alpha_bar_t = alphas_cumprod[t_step].view(1, 1, 1, 1)
            x0_pred = (x - (1 - alpha_bar_t).sqrt() * noise_pred) / alpha_bar_t.sqrt()
            x0_pred = torch.clamp(x0_pred, 0, 1.0)

            if t_step > 1:
                alpha_bar_prev = alphas_cumprod[t_step - 1].view(1, 1, 1, 1)
                z = torch.randn_like(x)  # add new noise

                x = alpha_bar_prev.sqrt() * x0_pred + (1 - alpha_bar_prev).sqrt() * z
            else:
                x = x0_pred  # final step: use x0 directly

            # Visualize intermediate result
            if t_step % 10 == 0:
                img = x0_pred[0].clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
                img_noisy = x[0].clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
                fig, ax = plt.subplots(1, 2, figsize=(10, 5))
                ax[0].imshow(img)
                ax[0].set_title(f"t_step: {t_step} (x0_pred)")
                ax[1].imshow(img_noisy)
                ax[1].set_title(f"t_step: {t_step} (x)")
                plt.show()
                # plt.imshow(img_noisy)
                # plt.title(f"t_step: {t_step}")
                # plt.axis("off")
                # plt.show()

            assert not torch.isnan(x).any(), f"NaN detected in x at t={t_step}"

        rgb_out = x.clamp(0, 1).squeeze().cpu().numpy().transpose(1, 2, 0)
        samples.append(rgb_out)

    return samples


import matplotlib.pyplot as plt


def plot_generated_samples(samples, titles=None, figsize=(10, 10)):
    n = len(samples)
    cols = min(n, 4)
    rows = (n + cols - 1) // cols

    plt.figure(figsize=figsize)
    for i, img in enumerate(samples):
        ax = plt.subplot(rows, cols, i + 1)
        ax.imshow(img)
        ax.axis("off")
        if titles:
            ax.set_title(titles[i])
    plt.tight_layout()
    plt.show()


device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

satellite_src = "satellite"
tile_size = 256

dataset = SatelliteDataset(satellite_src, preload_to_ram=False)

model, _, _ = load_model(
    False,
    "models/satellite_diffusion-T256-A7_best.pth",
    dataset,
    1e-4,
    device,
)


samples = generate_from_noise(model, dataset, num_samples=1, device=device)
plot_generated_samples(samples)