In [3]:
import torch
import os
import sys
import matplotlib.pyplot as plt
import argparse
import numpy as np
from tqdm.auto import tqdm
import torch.nn.functional as F

sys.path.append("../../")
from src.filepath import ABSOLUTE_PATH
from src.model.diffusion import GaussianDiffusion
from src.model.UNet2d import Unet2D
from src.model.fno import FNO2D
from src.utils.utils import plot_compare_2d, relative_error
from src.train.reaction_diffusion import cond_emb, renormalize
from src.train.reaction_diffusion import normalize_to_neg_one_to_one as normalize

In [4]:
model_type = "Unet"

In [5]:
dim = 24
out_dim = 1
channel = 3
nx = 20
diffusion_step = 250
device = "cuda"
if model_type == "Unet":
    model1 = Unet2D(dim=dim, cond_emb=cond_emb(), out_dim=out_dim, dim_mults=(1, 2), channels=channel)
    model2 = Unet2D(dim=dim, cond_emb=cond_emb(), out_dim=out_dim, dim_mults=(1, 2), channels=channel)
elif model_type == "ViT":
    model1 = ViT(
        seq_len=20,
        patch_size=2,
        dim=64,
        depth=2,
        heads=8,
        mlp_dim=128,
        cond_emb=cond_emb(),
        Time_Input=True,
        dropout=0.0,
        emb_dropout=0.0,
        channels=20,
        out_channels=9,
    ).to("cuda")
    model2 = ViT(
        seq_len=20,
        patch_size=2,
        dim=64,
        depth=2,
        heads=8,
        mlp_dim=128,
        cond_emb=cond_emb(),
        Time_Input=True,
        dropout=0.0,
        emb_dropout=0.0,
        channels=20,
        out_channels=9,
    ).to("cuda")
elif model_type == "FNO":
    model1 = FNO2D(
        in_channels=channel,
        out_channels=out_dim,
        nr_fno_layers=4,
        fno_layer_size=24,
        fno_modes=[6, 12],
        time_input=True,
        cond_emb=cond_emb(),
    )
    model2 = FNO2D(
        in_channels=channel,
        out_channels=out_dim,
        nr_fno_layers=4,
        fno_layer_size=24,
        fno_modes=[6, 12],
        time_input=True,
        cond_emb=cond_emb(),
    )
diffusion1 = GaussianDiffusion(model1, seq_length=(out_dim, 10, nx), timesteps=diffusion_step, auto_normalize=False).to(
    device
)
diffusion2 = GaussianDiffusion(model2, seq_length=(out_dim, 10, nx), timesteps=diffusion_step, auto_normalize=False).to(
    device
)

In [None]:
if model_type == "Unet":
    diffusion1.load_state_dict(torch.load("../../results/reaction_diffusion/diffusionUnetu10000/model.pt")["model"])
    diffusion2.load_state_dict(torch.load("../../results/reaction_diffusion/diffusionUnetv10000/model.pt")["model"])
elif model_type == "ViT":
    diffusion1.load_state_dict(torch.load("../../results/reaction_diffusion/diffusionViTu10000/model.pt")["model"])
    diffusion2.load_state_dict(torch.load("../../results/reaction_diffusion/diffusionViTv10000/model.pt")["model"])
elif model_type == "FNO":
    diffusion1.load_state_dict(torch.load("../../results/reaction_diffusion/diffusionFNOu10000/model.pt")["model"])
    diffusion2.load_state_dict(torch.load("../../results/reaction_diffusion/diffusionFNOv10000/model.pt")["model"])

In [None]:
data = (
    torch.tensor(np.load(ABSOLUTE_PATH + "/data/reaction_diffusion/reaction_diffusion_uv.npy"))
    .float()
    .to("cuda")[:2000]
)
# data = (data + 5) / 10

data = data.permute(0, 2, 1)

# data1 = np.load('../../data/reaction_diffusion_u_from_v_u.npy')

u, v = data[..., :20].unsqueeze(1), data[..., 20:].unsqueeze(1)

u_intial, v_intial = u[:, :, 0:1].expand(-1, -1, 10, -1), v[:, :, 0:1].expand(-1, -1, 10, -1)

u.shape, u_intial.shape

In [8]:
def update_u(
    alpha, t, model, field_noise, mult_p_estimate, mult_p_estimate_before, other_condition, normalize, renormalize
):
    weight_field = []
    for i in range(len(mult_p_estimate)):
        weight_field.append(alpha * mult_p_estimate[i] + (1 - alpha) * mult_p_estimate_before[i])
    intial_u, intial_v = other_condition[0], other_condition[1]
    cond = [torch.concat((weight_field[1], intial_u), dim=1)]
    field_noise_next, x0 = model.p_sample(field_noise, t, cond)
    return field_noise_next, x0


def update_v(
    alpha, t, model, field_noise, mult_p_estimate, mult_p_estimate_before, other_condition, normalize, renormalize
):
    weight_field = []
    for i in range(len(mult_p_estimate)):
        weight_field.append(alpha * mult_p_estimate[i] + (1 - alpha) * mult_p_estimate_before[i])
    intial_u, intial_v = other_condition[0], other_condition[1]
    cond = [torch.concat((weight_field[0], intial_v), dim=1)]
    field_noise_next, x0 = model.p_sample(field_noise, t, cond)
    return field_noise_next, x0

In [9]:
def compose_diffusion(
    model_list,
    shape: list,
    update_f: list,
    normalize_f,
    unnormalize_f,
    other_condition=[],
    num_iter=2,
    device="cuda",
):
    """compose diffusion model

    Args:
        model_list (_type_):conditional diffusion model for each physics field
        shape (_type_): shape of field: b, c, *
        update_f (list): update function for each physics field
        normalize_f (_type_, optional): normalization function for each physics field.
        unnormalize_f (_type_, optional): unnormalization function for each physics field.
        other_condition (list): other_condition such as initial state, source term.
        num_iter: (int, optional): outer iteration. Defaults to 2.
        device (str, optional): _description_. Defaults to 'cuda'.
    Returns:
        list: a list contains each field
    """
    with torch.no_grad():

        n_compose = len(model_list)

        timestep = model_list[0].num_timesteps

        # initial field
        mult_p_estimate = []
        for s in shape:
            mult_p_estimate.append(torch.randn(s, device=device))

        for k in range(num_iter):
            mult_p_estimate_before = mult_p_estimate.copy()
            mult_p_estimate = []
            mult_p = []
            for s in shape:
                mult_p_estimate.append(torch.randn(s, device=device))
                mult_p.append(torch.randn(s, device=device))
            for t in tqdm(reversed(range(0, timestep)), desc="sampling loop time step", total=timestep):
                alpha = 0 if k > 0 else 1
                # linear: 1 - t / (timestep - 1), 0->1
                # cos: cos(math.pi/2(t / (timestep - 1))), 0->1
                # power1: 1 - (t / (timestep - 1))**2
                # power2: (t / (timestep - 1)-1)**2
                for i in range(n_compose):
                    # condition
                    model = model_list[i]
                    update = update_f[i]
                    single_p, x0 = update(
                        alpha,
                        t,
                        model,
                        mult_p[i].clone(),
                        mult_p_estimate.copy(),
                        mult_p_estimate_before.copy(),
                        other_condition,
                        normalize_f,
                        unnormalize_f,
                    )
                    mult_p[i] = single_p

                    # update estimated physics field

                    mult_p_estimate[i] = model.unnormalize(x0)
    return mult_p

In [None]:
model_lis = [diffusion1, diffusion2]
mult_p = compose_diffusion(
    model_list=model_lis,
    shape=[(u.shape[0], 1, 10, 20), (u.shape[0], 1, 10, 20)],
    other_condition=[normalize(u_intial), normalize(v_intial)],
    update_f=[update_u, update_v],
    normalize_f=[normalize, normalize],
    unnormalize_f=[renormalize, renormalize],
    num_iter=2,
    device="cuda",
)

In [None]:
mult_p_true = torch.concat((u, v), dim=1)
mult_p_pred = renormalize(torch.concat((mult_p[0], mult_p[1]), dim=1))

In [None]:
print(relative_error(mult_p_pred[:, 0], mult_p_true[:, 0]), relative_error(mult_p_pred[:, 1], mult_p_true[:, 1]))