In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, SubsetRandomSampler
from tqdm import tqdm
import random
import os
import numpy as np

import dem_diffusion_model as demm
from dem_dataset import DEMDataset

In [None]:
tile_size = 256
coarsen_factor = 2

batch_size = 8
num_epochs = 30
dem_src = "dem_ea"

In [None]:
dataset = DEMDataset(
    dem_src, tile_size=tile_size, coarsen=2, rotate=False, preload_to_ram=True,
)

In [None]:
model = demm.train_demdiffusionmodel(
    "load",
    dataset,
    batch_size=batch_size,
    num_epochs=num_epochs,
    subset_fraction=0.5e-2,
    learning_rate=1e-4,
    stats_dim=4,
    dem_loss_weight=0.0,
    stats_loss_weight=1e-3,
)

In [None]:
import torch
import numpy as np
import torch.nn.functional as F
from dem_diffusion_model import cosine_beta_schedule, denormalise_batch
import matplotlib.pyplot as plt
import random


@torch.no_grad()
def denoise_samples(model, dataset, num_samples=4, num_timesteps=50, device="cuda"):
    model.eval()
    samples = []

    # Load the alpha schedule
    betas = cosine_beta_schedule(100)  # full schedule, not truncated
    alphas = np.insert(np.cumprod(1.0 - betas), 0, 1.0)  # ᾱ_t, 0-based offset
    alphas = torch.tensor(alphas, dtype=torch.float32, device=device)

    for i in range(num_samples):
        iRandom = random.randint(0, len(dataset) - 1)
        dem, stats = dataset[iRandom]

        dem = dem.to(device).unsqueeze(0)  # [1, 1, H, W]
        stats = torch.tensor(stats, dtype=torch.float32, device=device).unsqueeze(0)

        # STEP 1: Add noise up to chosen timestep
        alpha = alphas[num_timesteps].view(1, 1, 1, 1)
        noise = torch.randn_like(dem)
        dem_noisy = alpha.sqrt() * dem + (1 - alpha).sqrt() * noise

        # STEP 2: Predict noise
        t_tensor = torch.full((1,), num_timesteps / 100, device=device)  # normalized t
        noise_pred, stats_pred = model(dem_noisy, t_tensor, stats)

        # STEP 3: Reverse step (denoise)
        dem_pred = (dem_noisy - (1 - alpha).sqrt() * noise_pred) / alpha.sqrt()

        # Remove mean and clamp to [-1, 1]
        dem_pred = dem_pred - dem_pred.mean(dim=[1, 2, 3], keepdim=True)
        dem_pred = torch.clamp(dem_pred, -1.0, 1.0)
        
        min_pred, mean_pred, max_pred = (
            stats_pred[:, 1],
            stats_pred[:, 0],
            stats_pred[:, 2],
        )

        min_pred = torch.clamp(min_pred, min=0.0)
        mean_pred = torch.clamp(mean_pred, min=min_pred, max=max_pred)
        stats_pred = torch.stack([mean_pred, min_pred, max_pred], dim=1)

        # Append to samples for visualization
        samples.append(
            (
                dem.squeeze().cpu().numpy(),
                dem_noisy.squeeze().cpu().numpy(),
                dem_pred.squeeze().cpu().numpy(),
            )
        )

        print(f"true stats     : {stats.cpu().numpy()[:3]}")
        print(f"predicted stats: {stats_pred.cpu().numpy()}")
        print(f"loss of stats  : {F.mse_loss(stats_pred, stats[:, :3]).sqrt()}")
        print(f"loss of noise  : {F.mse_loss(noise_pred, noise).sqrt()}")
        print(f"loss of dem    : {F.mse_loss(dem_pred, dem).sqrt()}")

    return samples


def plot_denoising_results(samples):
    num = len(samples)
    fig, axes = plt.subplots(num, 4, figsize=(20, 4 * num))

    if num == 1:
        axes = [axes]

    for i, (real, noisy, generated) in enumerate(samples):
        im = axes[i][0].imshow(real, cmap="terrain")
        im.set_clim(-1, 1)
        plt.colorbar(im, ax=axes[i][0])
        axes[i][0].set_title(f"Original DEM #{i+1}")
        #
        im = axes[i][1].imshow(noisy, cmap="terrain")
        im.set_clim(-1, 1)
        plt.colorbar(im, ax=axes[i][1])
        axes[i][1].set_title("Noisy DEM")
        #
        im = axes[i][2].imshow(generated, cmap="terrain")
        im.set_clim(-1, 1)
        plt.colorbar(im, ax=axes[i][2])
        axes[i][2].set_title("Denoised DEM")
        #
        im = axes[i][3].imshow(generated - real, cmap="RdBu")
        im.set_clim(-0.5, 0.5)
        plt.colorbar(im, ax=axes[i][3])
        axes[i][3].set_title("Residual noise")

        for ax in axes[i]:
            ax.axis("off")
            ax.set_aspect("equal")

    plt.tight_layout()
    plt.show()

In [None]:
from dem_dataset import DEMDataset
from dem_diffusion_model import DEMDiffusionModel

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

dem_src = "dem_test"
tile_size = 256
coarsen = 2

dataset = DEMDataset(
    dem_src,
    tile_size=tile_size,
    coarsen=coarsen,
    rotate=False,
    preload_to_ram=False,
)

model = DEMDiffusionModel(cond_dim=4)
model.load_state_dict(
    torch.load("models/dem_diffusion-T256-A4.pth", map_location=device)
)
model.to(device)

samples = denoise_samples(
    model, dataset, num_samples=1, num_timesteps=75, device=device
)
plot_denoising_results(samples)