In [8]:
import torch
import os
import sys
import matplotlib.pyplot as plt
import argparse
import numpy as np
from tqdm.auto import tqdm
import torch.nn.functional as F

sys.path.append("../../")
from src.filepath import ABSOLUTE_PATH
from src.model.diffusion import GaussianDiffusion
from src.model.UNet2d import Unet2D
from src.inference.compose import compose_diffusion
from src.utils.utils import plot_compare_2d, relative_error, find_max_min, to_np
from src.train.reaction_diffusion_couple import cond_emb, renormalize
from src.train.reaction_diffusion_couple import normalize_to_neg_one_to_one as normalize

## load model

In [None]:
dim = 24
out_dim = 2
channel = 4
nx = 20
diffusion_step = 250
device = "cuda"
model = Unet2D(
    dim=dim,
    cond_emb=cond_emb(),
    out_dim=out_dim,
    dim_mults=(1, 2),
    channels=channel,
)
diffusion = GaussianDiffusion(model, seq_length=(out_dim, 10, nx), timesteps=diffusion_step, auto_normalize=False).to(
    device
)
diffusion.load_state_dict(
    torch.load("../../results/reaction_diffusion_couple_model/diffusionUnet10000/model.pt")["model"]
)

In [None]:
uv = torch.tensor(np.load(ABSOLUTE_PATH + "/data/reaction_diffusion/reaction_diffusion_uv.npy").transpose(0, 2, 1))
# uv = uv
u = uv[..., :20].unsqueeze(1)[9000:].clone()
v = uv[..., 20:].unsqueeze(1)[9000:].clone()
data = torch.concat((u, v), dim=1)
cond = torch.concat((u[:, 0:1].clone(), v[:, 0:1].clone()), axis=1).expand(-1, -1, data.shape[2], -1)
# u0 is cond
cond = torch.tensor(cond).float().to(device)


uv_pred = diffusion.sample(cond.shape[0], cond=[normalize(cond)])

In [12]:
data = to_np(data)
uv_pred = to_np(renormalize(uv_pred))

In [None]:
relative_error((data)[:, 0], uv_pred[:, 0]), relative_error((data)[:, 1], uv_pred[:, 1])

In [None]:
random_n = np.random.randint(0, data.shape[0])
plot_compare_2d(
    true_d=data[random_n, 0],
    pred_d=uv_pred[random_n, 0],
)
# plot_compare_2d(data, un(u_pred))