# Setup

In [None]:
from pathlib import Path
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from topodiff.unet import UNetModel, EncoderUNetModel
from topodiff import gaussian_diffusion as gd
from topodiff.respace import SpacedDiffusion, space_timesteps
from typing import List
import math

%matplotlib inline

# Models

## Initialization

In [None]:
def alpha_bar(t):
    return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2


steps = 1000
betas = np.zeros(steps)
for i in range(steps):
    t1 = i / steps
    t2 = (i + 1) / steps
    betas[i] = min(1 - alpha_bar(t2) / alpha_bar(t1), 0.999)

diffusion = SpacedDiffusion(
    use_timesteps=space_timesteps(steps, [100]),
    betas=betas,
    model_mean_type=gd.ModelMeanType.EPSILON,
    model_var_type=gd.ModelVarType.LEARNED_RANGE,
    loss_type=gd.LossType.MSE,
    rescale_timesteps=False,
)

image_size = 64
channel_multiplier = (1, 2, 3, 4)


def get_attention_resolutions(attention_resolutions: List):
    return tuple(image_size // resolution for resolution in attention_resolutions)


mean_variance = UNetModel(
    image_size=image_size,
    in_channels=6,
    model_channels=128,
    out_channels=2,
    num_res_blocks=3,
    attention_resolutions=get_attention_resolutions([16, 8]),
    dropout=0.3,
    channel_mult=channel_multiplier,
    use_fp16=True,
    num_heads=4,
    use_scale_shift_norm=True,
)

regressor = EncoderUNetModel(
    image_size=image_size,
    in_channels=8,
    model_channels=128,
    out_channels=1,
    num_res_blocks=4,
    attention_resolutions=get_attention_resolutions([32, 16, 8]),
    channel_mult=channel_multiplier,
    use_fp16=False,
    num_head_channels=64,
    use_scale_shift_norm=True,
    resblock_updown=True,
    pool="spatial",
)

classifier = EncoderUNetModel(
    image_size=image_size,
    in_channels=1,
    model_channels=128,
    out_channels=2,
    num_res_blocks=2,
    attention_resolutions=get_attention_resolutions([32, 16, 8]),
    channel_mult=channel_multiplier,
    use_fp16=False,
    num_head_channels=64,
    use_scale_shift_norm=True,
    resblock_updown=True,
    pool="attention",
)

## Loading Checkpoints

In [None]:
checkpoints_path = Path(r".\checkpoints")
cpu = torch.cpu.current_device()
gpu = torch.cuda.current_device()
def get_state_dict(path: Path):
    return torch.load(path, map_location=cpu, weights_only=True)

mean_variance_path = checkpoints_path / "diff_checkpoint" / "model_180000.pt"
mean_variance.load_state_dict(get_state_dict(mean_variance_path))
mean_variance.to(gpu)
mean_variance.convert_to_fp16()
mean_variance.eval()

regressor_path = checkpoints_path / "reg_checkpoint" / "model_350000.pt"
regressor.load_state_dict(get_state_dict(regressor_path))
regressor.to(gpu)
regressor.eval()

classifier_path = checkpoints_path / "class_checkpoint" / "model_299999.pt"
classifier.load_state_dict(get_state_dict(classifier_path))
classifier.to(gpu)
classifier.eval()

# Sampling

In [8]:
batch_size = 1
channel_count = 1
shape = (batch_size, channel_count, image_size, image_size)

def cond_fn_1(x: torch.Tensor, time_steps: torch.Tensor):
    with torch.enable_grad():
        x_in = x.detach().requires_grad_()
        logits = regressor(x_in, time_steps)
        grad = torch.autograd.grad(logits.sum(), x_in)[0]
        return (-1) * grad[:,0,:,:].reshape(shape) * 4.0

def cond_fn_2(x: torch.Tensor, time_steps: torch.Tensor):
    with torch.enable_grad():
        x_in = x.detach().requires_grad_()
        logits: torch.Tensor = classifier(x_in, time_steps)
        log_probs = F.log_softmax(logits, dim=-1)
        selected = log_probs[range(len(logits)), 1]
        grad = torch.autograd.grad(selected.sum(), x_in)[0]
        return grad[:,0,:,:].reshape(shape) * 3.0


def get_boundary_condition(condition_name: str):
    folder = Path(r".\data\dataset_1_diff\test_data_level_1")
    path = folder / f"cons_{condition_name}_array_200.npy"
    ndarray = np.transpose(np.load(path), [2, 0, 1]).astype(np.float32)
    tensor = torch.unsqueeze(torch.as_tensor(ndarray), 0) # Add batch size dimension
    return tensor.to(gpu)

sample = diffusion.p_sample_loop(
    model=mean_variance,
    shape=shape,
    cons=get_boundary_condition("pf"),
    loads=get_boundary_condition("load"),
    BCs=get_boundary_condition("bc"),
    noise=None,
    clip_denoised=True,
    denoised_fn=None,
    cond_fn_1=cond_fn_1,
    cond_fn_2=cond_fn_2,
    model_kwargs={},
    device=gpu,
    progress=False,
)

In [None]:
plt.imshow(sample[0,0,:,:].cpu().detach().numpy(), cmap="gray")