In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, SubsetRandomSampler
from tqdm import tqdm
import random
import os, sys
import numpy as np

import dem_diffusion_model as demm
# from dem_diffusion_model import DEMDiffusionModel
from training_dataset import DEMDataset

In [None]:
tile_size = 256
coarsen_factor = 2

batch_size = 16
num_epochs = 30
dem_src = "test"

In [None]:
dataset = DEMDataset(dem_src, tile_size=tile_size, coarsen=2, rotate=True)

model = demm.train_demdiffusionmodel(
    "load",
    dataset,
    batch_size=batch_size,
    num_epochs=num_epochs,
    subset_fraction=0.1e-2,
    learning_rate=1e-3,
    cond_dim=4,
)

In [None]:
def cosine_beta_schedule(timesteps, s=0.004):
    steps = timesteps + 1
    x = np.linspace(0, timesteps, steps)
    alphas_cumprod = np.cos(((x / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return np.clip(betas, 0, 0.999)


def load_model(save_model, model, dataset, learning_rate, device):
    if isinstance(save_model, str) and os.path.exists(save_model):
        model_path = save_model
        save_model = True
    else:
        os.makedirs("models", exist_ok=True)
        model_path = os.path.join(
            "models", f"dem_ddpm-log{dataset.log_transform}-{dataset.tile_size}.pth"
        )
    if save_model:
        print(f"Model will be saved to {model_path}")

    if isinstance(model, nn.Module):
        model.to(device)
    elif isinstance(model, str) and os.path.exists(model):
        model = UNet()
        model.load_state_dict(torch.load(model, map_location=device))
        print(f"Model loaded successfully from {model_path}")
    elif isinstance(model, str) and model == "load":
        model = UNet()
        if os.path.exists(model_path):
            print(f"Loading model from {model_path}")
            model.load_state_dict(torch.load(model_path, map_location=device))
        else:
            print(f"Best model not found at {model_path}, starting from scratch")
    elif isinstance(model, str) and model == "new":
        model = UNet()

    model.to(device)
    optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate)

    return model, optimiser, model_path


def train_demdiffusionmodel(
    model,
    dataset,
    batch_size=16,
    learning_rate=1e-4,
    num_epochs=20,
    subset_fraction=1.0,
    num_timesteps=100,
    save_model=True,
    device=(
        "cuda"
        if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available() else "cpu"
    ),
    l2loss_weight=1.0,  # (slope-weighted) L2 loss weight
    tvloss_weight=0.01,  # total variation loss weight
    slope_scale=5.0,  # slope scaling factor
):
    if subset_fraction < 1.0:
        num_epochs = round(round(num_epochs / subset_fraction))
        print(
            f"Using {subset_fraction:.0%} of the dataset, training extended to {num_epochs} epochs"
        )

    print(f"Using device: {device}")

    # Load model
    model, optimiser, model_path = load_model(
        save_model, model, dataset, learning_rate, device
    )

    # Actual training starts here
    loss_prev = np.inf

    betas = cosine_beta_schedule(num_timesteps)
    alphas_cumprod = np.cumprod(1.0 - betas)

    for epoch in range(num_epochs):
        num_samples = int(len(dataset) * subset_fraction)
        indices = random.sample(range(len(dataset)), num_samples)
        sampler = SubsetRandomSampler(indices)
        dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
        model.train()
        epoch_loss = 0
        for x0 in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            x0 = x0.to(device)

            # Sample time step
            t = torch.randint(1, num_timesteps, (x0.size(0),), device=device).long()

            # Retrieve alpha_cumprod for each t
            alpha = torch.tensor(
                alphas_cumprod[t.cpu().numpy()], dtype=torch.float32, device=device
            ).view(-1, 1, 1, 1)

            # Add noise to input
            noise = torch.randn_like(x0)
            xt = alpha.sqrt() * x0 + (1 - alpha).sqrt() * noise

            # Predict the noise
            pred_noise = model(xt, t.float() / num_timesteps)
            mse_loss = F.mse_loss(pred_noise, noise)

            # Optional: slope-weighted reconstruction loss (L_x0)
            with torch.no_grad():
                # Reconstruct x0 estimate
                x0_pred = (xt - (1 - alpha).sqrt() * pred_noise) / alpha.sqrt()
                slope = DDPMUNet.compute_slope_map(x0)
                slope_weight = 1.0 + slope_scale * slope  # e.g. slope_scale = 5.0
                weighted_l2 = ((x0_pred - x0) ** 2 * slope_weight).mean()

            # TV regularisation on x0_pred
            tv_reg = DDPMUNet.total_variation(x0_pred)

            # Final combined loss
            # loss = mse_loss
            loss = mse_loss + l2loss_weight * weighted_l2 + tvloss_weight * tv_reg

            # Optimise
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()

            epoch_loss += loss.item()

        if loss > loss_prev * 100:
            # learning_rate = learning_rate * 0.5
            print(
                f"Epoch {round((epoch+1)*subset_fraction):04d}/{round(num_epochs*subset_fraction):04d}  Loss: {loss:.3f}. Loss diverged, back tracking."
            )
            model = UNet()
            model.load_state_dict(torch.load(model_path, map_location=device))
            model.to(device)
        else:
            loss_prev = loss
            if save_model:
                torch.save(
                    model.state_dict(),
                    model_path,
                )
                print(
                    f"Epoch {round((epoch+1)*subset_fraction):04d}/{round(num_epochs*subset_fraction):04d}  Loss: {loss:.3f}. Model saved"
                )
            else:
                print(
                    f"Epoch {round((epoch+1)*subset_fraction):04d}/{round(num_epochs*subset_fraction):04d}  Loss: {loss:.3f}"
                )
    return model


model = train_demdiffusionmodel(
    "load",
    dataset,
    batch_size=batch_size,
    num_epochs=num_epochs,
    subset_fraction=0.5e-2,
    learning_rate=1e-3,
)

In [None]:
import matplotlib.pyplot as plt

model = UNet()
model.load_state_dict(torch.load(f"models/dem_ddpm-logTrue-{tile_size}.pth", map_location=device))
model.to(device)


# --- 4. Visualisation (Denoising one sample) ---
def denoise_sample(model, shape, steps=100, device="mps"):
    model.eval()
    x = torch.randn(shape, device=device)
    with torch.no_grad():
        for i in reversed(range(1, steps + 1)):
            t = torch.full((shape[0],), i / steps, device=device)
            noise_pred = model(x, t)
            alpha = 1 - 0.01 * t[:, None, None, None]
            x = (x - (1 - alpha).sqrt() * noise_pred) / alpha.sqrt()
            x = x.clamp(min=0.0, max=1.0)
    return x.cpu()


samples = denoise_sample(model, (4, 1, 128, 128))
for i in range(4):
    dem = samples[i][0].numpy()
    # undo base-10 log transform
    dem = 10 ** (dem * 4) - 1
    plt.imshow(dem, cmap="terrain")
    plt.colorbar()
    plt.title(f"Generated DEM #{i+1}")
    plt.show()

In [None]:
from DEMDataset import DEMDataset
import numpy as np
import matplotlib.pyplot as plt

dataset = DEMDataset(
    "training-data/unet-input-dem.tif",
    tile_size=256,
    rotate=False,
    normalise_factor=1e4,
    log_transform=True,
)

# Visualise a sample from the dataset
iSample = np.random.randint(0, len(dataset))
sample = dataset[iSample][0].numpy()
plt.imshow(sample, cmap="terrain")
plt.clim(0, 1)
plt.colorbar()
plt.title(f"DEM #{iSample}")
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np


def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = np.linspace(0, timesteps, steps)
    alphas_cumprod = np.cos(((x / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return np.clip(betas, 0, 0.999)


betas = cosine_beta_schedule(100)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas)


def visualize_denoising(
    model, dataset, alphas_cumprod, device="mps", num_samples=3, timestep=None
):
    model.eval()
    indices = random.sample(range(len(dataset)), num_samples)
    alpha_array = np.array(alphas_cumprod, dtype=np.float32)

    for idx in indices:
        x0 = dataset[idx].unsqueeze(0).to(device)  # [1, 1, H, W]

        # Choose timestep t
        T = len(alpha_array)
        if timestep is None:
            t_val = random.randint(1, T - 1)
        else:
            t_val = timestep
        t = torch.tensor([t_val / T], dtype=torch.float32, device=device)  # [1]

        # Add noise
        alpha_t = torch.tensor(
            alpha_array[t_val], dtype=torch.float32, device=device
        ).view(1, 1, 1, 1)
        noise = torch.randn_like(x0)
        xt = alpha_t.sqrt() * x0 + (1 - alpha_t).sqrt() * noise

        # Predict and denoise
        with torch.no_grad():
            pred_noise = model(xt, t)
            x0_pred = (xt - (1 - alpha_t).sqrt() * pred_noise) / alpha_t.sqrt()

        # Convert to numpy for plotting
        x0_np = x0.squeeze().cpu().numpy()
        xt_np = xt.squeeze().cpu().numpy()
        x0_pred_np = x0_pred.squeeze().cpu().numpy()
        diff_np = x0_np - x0_pred_np

        # Plot
        fig, axes = plt.subplots(1, 4, figsize=(20, 4))
        for ax in axes:
            ax.axis("off")
        im0 = axes[0].imshow(x0_np, cmap="terrain")
        im0.set_clim(0, 1)
        axes[0].set_title("Original DEM $x_0$")
        im1 = axes[1].imshow(xt_np, cmap="terrain")
        im1.set_clim(0, 1)
        axes[1].set_title(f"Noisy Input $x_t$ (t={t_val})")
        im2 = axes[2].imshow(x0_pred_np, cmap="terrain")
        im2.set_clim(0, 1)
        axes[2].set_title("Predicted DEM ${x_0}'$")
        im = axes[3].imshow(diff_np, cmap="coolwarm")
        im.set_clim(-0.5, 0.5)
        axes[3].set_title("Residual Error $x_0 - {x_0}'$")
        fig.colorbar(im, ax=axes[3], fraction=0.046, pad=0.04)
        plt.tight_layout()
        plt.show()


dataset = DEMDataset(
    "training-data/unet-input-dem.tif",
    tile_size=256,
    rotate=False,
    normalise_factor=1e4,
    log_transform=True,
)

device=(
        "cuda"
        if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available() else "cpu"
    )

model = UNet()
model.load_state_dict(
    torch.load(f"models/dem_ddpm-log{dataset.log_transform}-{dataset.tile_size}.pth", map_location=device)
)
model.to(device)

visualize_denoising(model, dataset, alphas_cumprod, device="mps", timestep=75)