In [2]:
import torch
import os
import sys
import matplotlib.pyplot as plt
import argparse
import numpy as np
from tqdm.auto import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

sys.path.append("../../")
from src.model.video_diffusion_pytorch_conv3d import Unet3D_with_Conv3D
from src.model.fno import FNO3D
from src.model.diffusion import GaussianDiffusion
from src.train.nuclear_thermal_coupling import load_nt_dataset_emb, cond_emb, normalize, renormalize
from src.utils.utils import L2_norm, get_parameter_net, plot_compare_2d, relative_error

In [3]:
device = "cuda"
iter = "iter1"
diffusion_step = 250
model_type = "Unet"

In [None]:
train_which = "neutron"
dim = 8
emb = cond_emb(train_which, device=device)
cond, data = load_nt_dataset_emb(field=train_which, dataset=iter, device=device)
if model_type == "Unet":
    model = Unet3D_with_Conv3D(
        dim=dim,
        cond_dim=len(cond),
        out_dim=data.shape[1],
        cond_emb=emb,
        dim_mults=(1, 2, 4),
        use_sparse_linear_attn=False,
        attn_dim_head=16,
    ).to(device)
elif model_type == "ViT":
    model = ViT(
        image_size=data.shape[-2:],
        image_patch_size=(8, 2),
        frames=data.shape[2],
        frame_patch_size=2,
        dim=128,
        depth=2,
        heads=8,
        mlp_dim=256,
        cond_emb=emb,
        Time_Input=True,
        channels=len(emb) + data.shape[1],
        out_channels=data.shape[1],
        dropout=0.0,
        emb_dropout=0.0,
    )
elif model_type == "FNO":
    model = FNO3D(
        in_channels=len(emb) + data.shape[1],
        out_channels=data.shape[1],
        nr_fno_layers=3,
        fno_layer_size=8,
        fno_modes=[6, 16, 8],
        cond_emb=emb,
        time_input=True,
    )
diffusion_neu = GaussianDiffusion(
    model, seq_length=tuple(data.shape[1:]), timesteps=diffusion_step, auto_normalize=False
).to(device)

if model_type == "Unet":
    diffusion_neu.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/diffusionUnetneutron/" + iter + "_5000/model.pt")["model"],
    )
elif model_type == "ViT":
    diffusion_neu.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/diffusionViTneutron/" + iter + "_5000/model.pt")["model"],
    )
else:
    diffusion_neu.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/diffusionFNOneutron/" + iter + "_5000/model.pt")["model"],
    )

In [None]:
train_which = "solid"
dim = 8
emb = cond_emb(train_which, device=device)
cond, data = load_nt_dataset_emb(field=train_which, dataset=iter, device=device)
if model_type == "Unet":
    model = Unet3D_with_Conv3D(
        dim=dim,
        cond_dim=len(cond),
        out_dim=data.shape[1],
        cond_emb=emb,
        dim_mults=(1, 2, 4),
        use_sparse_linear_attn=False,
        attn_dim_head=16,
    ).to(device)
elif model_type == "ViT":
    model = ViT(
        image_size=data.shape[-2:],
        image_patch_size=(8, 2),
        frames=data.shape[2],
        frame_patch_size=2,
        dim=128,
        depth=2,
        heads=8,
        mlp_dim=256,
        cond_emb=emb,
        Time_Input=True,
        channels=len(emb) + data.shape[1],
        out_channels=data.shape[1],
        dropout=0.0,
        emb_dropout=0.0,
    )
elif model_type == "FNO":
    model = FNO3D(
        in_channels=len(emb) + data.shape[1],
        out_channels=data.shape[1],
        nr_fno_layers=3,
        fno_layer_size=8,
        fno_modes=[6, 16, 4],
        cond_emb=emb,
        time_input=True,
    )
diffusion_fuel = GaussianDiffusion(
    model,
    seq_length=tuple(data.shape[1:]),
    timesteps=diffusion_step,
    auto_normalize=False,
).to(device)

if model_type == "Unet":
    diffusion_fuel.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/diffusionUnetsolid/" + iter + "_5000/model.pt")["model"],
        strict=False,
    )
elif model_type == "ViT":
    diffusion_fuel.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/diffusionViTsolid/" + iter + "_5000/model.pt")["model"],
    )
else:
    diffusion_fuel.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/diffusionFNOsolid/" + iter + "_5000/model.pt")["model"],
    )

In [None]:
train_which = "fluid"
dim = 16
emb = cond_emb(train_which, device=device)
cond, data = load_nt_dataset_emb(field=train_which, dataset=iter, device=device)
if model_type == "Unet":

    model = Unet3D_with_Conv3D(
        dim=dim,
        cond_dim=len(cond),
        out_dim=data.shape[1],
        cond_emb=emb,
        dim_mults=(1, 2, 4),
        use_sparse_linear_attn=False,
        attn_dim_head=16,
    ).to(device)
elif model_type == "ViT":
    model = ViT(
        image_size=data.shape[-2:],
        image_patch_size=(8, 2),
        frames=data.shape[2],
        frame_patch_size=2,
        dim=256,
        depth=2,
        heads=8,
        mlp_dim=256,
        cond_emb=emb,
        Time_Input=True,
        channels=len(emb) + data.shape[1],
        out_channels=data.shape[1],
        dropout=0.0,
        emb_dropout=0.0,
    )
elif model_type == "FNO":
    model = FNO3D(
        in_channels=len(emb) + data.shape[1],
        out_channels=data.shape[1],
        nr_fno_layers=3,
        fno_layer_size=16,
        fno_modes=[6, 16, 6],
        cond_emb=emb,
        time_input=True,
    )

diffusion_fluid = GaussianDiffusion(
    model,
    seq_length=tuple(data.shape[1:]),
    timesteps=diffusion_step,
    auto_normalize=False,
).to(device)


if model_type == "Unet":
    diffusion_fluid.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/diffusionUnetfluid/" + iter + "_5000/model.pt")["model"],
        strict=False,
    )
elif model_type == "ViT":
    diffusion_fluid.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/diffusionViTfluid/" + iter + "_5000/model.pt")["model"],
    )
else:
    diffusion_fluid.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/diffusionFNOfluid/" + iter + "_5000/model.pt")["model"],
    )

In [7]:
def update_neu(
    alpha, t, model, field_noise, mult_p_estimate, mult_p_estimate_before, other_condition, normalize, renormalize
):
    weight_field = []
    for i in range(len(mult_p_estimate)):
        weight_field.append(alpha * mult_p_estimate[i] + (1 - alpha) * mult_p_estimate_before[i])
    T_n = torch.concat((weight_field[1], weight_field[2][:, 0:1]), dim=-1)
    (phi_bc,) = other_condition
    cond = [phi_bc, T_n]
    field_noise_next, x0 = model.p_sample(field_noise, t, cond)
    return field_noise_next, x0


def update_fuel(
    alpha, t, model, field_noise, mult_p_estimate, mult_p_estimate_before, other_condition, normalize, renormalize
):
    weight_field = []
    for i in range(len(mult_p_estimate)):
        weight_field.append(alpha * mult_p_estimate[i] + (1 - alpha) * mult_p_estimate_before[i])
    neu = weight_field[0][..., :8]
    fluid_b = weight_field[2][:, 0:1, :, :, 0:1]
    cond = [neu, fluid_b]
    field_noise_next, x0 = model.p_sample(field_noise, t, cond)
    return field_noise_next, x0


def k(t):
    return 17.5 * (1 - 0.223) / (1 + 0.161) + 1.54e-2 * (1 + 0.0061) / (1 + 0.161) * t + 9.38e-6 * t * t


def update_fluid(
    alpha, t, model, field_noise, mult_p_estimate, mult_p_estimate_before, other_condition, normalize, renormalize
):
    weight_field = []
    for i in range(len(mult_p_estimate)):
        weight_field.append(alpha * mult_p_estimate[i] + (1 - alpha) * mult_p_estimate_before[i])
    fuel = renormalize(weight_field[1], field="solid")
    flux = normalize((fuel[..., -2:-1] - fuel[..., -1:None]) * k(fuel[..., -1:None]), field="flux")
    cond = [flux]
    field_noise_next, x0 = model.p_sample(field_noise, t, cond)
    return field_noise_next, x0

In [8]:
b = None
fuel = torch.tensor(np.load("../../data/NTcouple/val/fuel.npy")).float().to(device)[:b]
fluid = torch.tensor(np.load("../../data/NTcouple/val/fluid.npy")).float().to(device)[:b]
neu = torch.tensor(np.load("../../data/NTcouple/val/neu.npy")).float().to(device)[:b]
bc = torch.tensor(np.load("../../data/NTcouple/val/bc.npy")).float().to(device)[:b]
bc = normalize(bc, "neutron")
b = bc.shape[0]

In [9]:
def mean_stddev(data):
    if not data:
        return None, None

    mean = sum(data) / len(data)

    stddev = (sum((x - mean) ** 2 for x in data) / (len(data) - 1)) ** 0.5

    return mean, stddev

In [27]:
def compose_diffusion(
    model_list,
    shape: list,
    update_f: list,
    normalize_f,
    unnormalize_f,
    other_condition=[],
    num_iter=2,
    device="cuda",
):
    """compose diffusion model

    Args:
        model_list (_type_):conditional diffusion model for each physics field
        shape (_type_): shape of field: b, c, *
        update_f (list): update function for each physics field
        normalize_f (_type_, optional): normalization function for each physics field.
        unnormalize_f (_type_, optional): unnormalization function for each physics field.
        other_condition (list): other_condition such as initial state, source term.
        num_iter: (int, optional): outer iteration. Defaults to 2.
        device (str, optional): _description_. Defaults to 'cuda'.
    Returns:
        list: a list contains each field
    """
    with torch.no_grad():

        n_compose = len(model_list)

        timestep = model_list[0].num_timesteps

        # initial field
        mult_p_estimate = []
        for s in shape:
            mult_p_estimate.append(torch.randn(s, device=device))

        for k in range(num_iter):
            mult_p_estimate_before = mult_p_estimate.copy()
            mult_p_estimate = []
            mult_p = []
            for s in shape:
                mult_p_estimate.append(torch.randn(s, device=device))
                mult_p.append(torch.randn(s, device=device))
            for t in tqdm(reversed(range(0, timestep)), desc="sampling loop time step", total=timestep):
                alpha = 1 - t / (timestep - 1) if k > 0 else 1
                # linear: 1 - t / (timestep - 1), 0->1
                # cos: cos(math.pi/2(t / (timestep - 1))), 0->1
                # power1: 1 - (t / (timestep - 1))**2
                # power2: (t / (timestep - 1)-1)**2
                for i in range(n_compose):
                    # condition
                    model = model_list[i]
                    update = update_f[i]
                    single_p, x0 = update(
                        alpha,
                        t,
                        model,
                        mult_p[i].clone(),
                        # mult_p.copy(),
                        mult_p_estimate.copy(),
                        mult_p_estimate_before.copy(),
                        other_condition,
                        normalize_f,
                        unnormalize_f,
                    )
                    mult_p[i] = single_p

                    # update estimated physics field

                    mult_p_estimate[i] = model.unnormalize(x0)
    return mult_p

In [34]:
def run():
    mult_p = compose_diffusion(
        [diffusion_neu, diffusion_fuel, diffusion_fluid],
        [neu.shape, fuel.shape, fluid.shape],
        [update_neu, update_fuel, update_fluid],
        normalize,
        renormalize,
        [bc],
        3,
    )
    mult_p[0] = renormalize(mult_p[0], "neutron")
    mult_p[1] = renormalize(mult_p[1], "solid")
    mult_p[2] = renormalize(mult_p[2], "fluid")
    loss_fluid = 0
    fluid_p = mult_p[2]
    for i in range(4):
        cu_loss = relative_error(fluid[:, i], fluid_p[:, i])
        print(cu_loss)
        loss_fluid += cu_loss
    loss_fluid = loss_fluid / 4
    return relative_error(neu, mult_p[0]), relative_error(fuel, mult_p[1]), relative_error(fluid, mult_p[2])

In [None]:
num = 1
e1_l, e2_l, e3_l = [], [], []


for i in range(num):

    e1, e2, e3 = run()

    e1_l.append(e1)
    e2_l.append(e2)
    e3_l.append(e3)

In [None]:
e1_l

In [None]:
mean_stddev(e1_l), mean_stddev(e2_l), mean_stddev(e3_l)