In [None]:
import torch
import os
import sys
import matplotlib.pyplot as plt
import argparse
import numpy as np
from tqdm.auto import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

sys.path.append("../../")
from src.model.video_diffusion_pytorch_conv3d import Unet3D_with_Conv3D
from src.model.fno import FNO3D
from src.model.diffusion import GaussianDiffusion
from src.train.nuclear_thermal_coupling_couple import load_nt_dataset_emb, cond_emb, normalize, renormalize
from src.inference.compose import compose_diffusion
from src.utils.utils import L2_norm, get_parameter_net, plot_compare_2d, relative_error

In [2]:
device = "cuda"
iter = "iter1"
diffusion_step = 250
model_type = "Unet"

In [None]:
train_which = "neutron"
dim = 8
emb = cond_emb()
cond, data = load_nt_dataset_emb(field=train_which, device=device)
model = Unet3D_with_Conv3D(
    dim=dim,
    cond_dim=len(cond),
    out_dim=data.shape[1],
    cond_emb=emb,
    dim_mults=(1, 2, 4),
    use_sparse_linear_attn=False,
    attn_dim_head=16,
).to(device)
diffusion_neu = GaussianDiffusion(
    model, seq_length=tuple(data.shape[1:]), timesteps=diffusion_step, auto_normalize=False
).to(device)
diffusion_neu.load_state_dict(
    torch.load("../../results/nuclear_thermal_coupling_couple/diffusionUnetneutron/" + iter + "_5000/model-50.pt")[
        "model"
    ],
)

In [None]:
b = -32
for i in range(len(cond)):
    cond[i] = cond[i][b:]
data = data[b:]
with torch.no_grad():
    pred = diffusion_neu.sample(cond[0].shape[0], cond)
    pred = renormalize(pred, "neutron")
    data = renormalize(data, "neutron")
    rmse = relative_error(data, pred)
    mse = F.mse_loss(pred, data)
    # print(get_relative_error(model_neu, cond, data, batchsize=32))
rmse, mse

In [8]:
train_which = "solid"
dim = 8
emb = cond_emb()

In [None]:
cond, data = load_nt_dataset_emb(field=train_which, device=device)
if model_type == "Unet":
    model = Unet3D_with_Conv3D(
        dim=dim,
        cond_dim=len(cond),
        out_dim=data.shape[1],
        cond_emb=emb,
        dim_mults=(1, 2, 4),
        use_sparse_linear_attn=False,
        attn_dim_head=16,
    ).to(device)
elif model_type == "ViT":
    model = ViT(
        image_size=data.shape[-2:],
        image_patch_size=(8, 2),
        frames=data.shape[2],
        frame_patch_size=2,
        dim=128,
        depth=2,
        heads=8,
        mlp_dim=256,
        cond_emb=emb,
        Time_Input=True,
        channels=len(emb) + data.shape[1],
        out_channels=data.shape[1],
        dropout=0.0,
        emb_dropout=0.0,
    )
elif model_type == "FNO":
    model = FNO3D(
        in_channels=len(emb) + data.shape[1],
        out_channels=data.shape[1],
        nr_fno_layers=3,
        fno_layer_size=8,
        fno_modes=[6, 16, 4],
        cond_emb=emb,
        time_input=True,
    )
diffusion_fuel = GaussianDiffusion(
    model,
    seq_length=tuple(data.shape[1:]),
    timesteps=diffusion_step,
    auto_normalize=False,
).to(device)

if model_type == "Unet":
    diffusion_fuel.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling_couple/diffusionUnetsolid/" + iter + "_5000/model-50.pt")[
            "model"
        ],
        strict=False,
    )
elif model_type == "ViT":
    diffusion_fuel.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/diffusionViTsolid/" + iter + "_5000/model.pt")["model"],
    )
else:
    diffusion_fuel.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/diffusionFNOsolid/" + iter + "_5000/model.pt")["model"],
    )

In [None]:
b = -32
for i in range(len(cond)):
    cond[i] = cond[i][b:]
data = data[b:]
with torch.no_grad():
    pred = diffusion_fuel.sample(cond[0].shape[0], cond)
    pred = renormalize(pred, "solid")
    data = renormalize(data, "solid")
    rmse = relative_error(data, pred)
    mse = F.mse_loss(pred, data)
    # print(get_relative_error(model_neu, cond, data, batchsize=32))
rmse, mse

In [12]:
train_which = "fluid"
dim = 16
emb = cond_emb()

In [None]:
cond, data = load_nt_dataset_emb(field=train_which, device=device)
if model_type == "Unet":

    model = Unet3D_with_Conv3D(
        dim=dim,
        cond_dim=len(cond),
        out_dim=data.shape[1],
        cond_emb=emb,
        dim_mults=(1, 2, 4),
        use_sparse_linear_attn=False,
        attn_dim_head=16,
    ).to(device)
elif model_type == "ViT":
    model = ViT(
        image_size=data.shape[-2:],
        image_patch_size=(8, 2),
        frames=data.shape[2],
        frame_patch_size=2,
        dim=256,
        depth=2,
        heads=8,
        mlp_dim=256,
        cond_emb=emb,
        Time_Input=True,
        channels=len(emb) + data.shape[1],
        out_channels=data.shape[1],
        dropout=0.0,
        emb_dropout=0.0,
    )
elif model_type == "FNO":
    model = FNO3D(
        in_channels=len(emb) + data.shape[1],
        out_channels=data.shape[1],
        nr_fno_layers=3,
        fno_layer_size=16,
        fno_modes=[6, 16, 6],
        cond_emb=emb,
        time_input=True,
    )

diffusion_fluid = GaussianDiffusion(
    model,
    seq_length=tuple(data.shape[1:]),
    timesteps=diffusion_step,
    auto_normalize=False,
).to(device)


if model_type == "Unet":
    diffusion_fluid.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling_couple/diffusionUnetfluid/" + iter + "_5000/model-15.pt")[
            "model"
        ],
        strict=False,
    )
elif model_type == "ViT":
    diffusion_fluid.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/diffusionViTfluid/" + iter + "_5000/model.pt")["model"],
    )
else:
    diffusion_fluid.load_state_dict(
        torch.load("../../results/nuclear_thermal_coupling/diffusionFNOfluid/" + iter + "_5000/model.pt")["model"],
    )

In [None]:
b = -32
for i in range(len(cond)):
    cond[i] = cond[i][b:]
data = data[b:]
with torch.no_grad():
    pred = diffusion_fluid.sample(cond[0].shape[0], cond)
    pred = renormalize(pred, field="fluid")
    data = renormalize(data, field="fluid")

In [None]:
loss_fluid = 0
for i in range(4):
    cu_loss = relative_error(data[:, i], pred[:, i])
    print(cu_loss)
    loss_fluid += cu_loss
relative_error(data, pred), loss_fluid / 4