In [30]:
import os
import sys
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import numpy as np

sys.path.append("../../")
from src.filepath import ABSOLUTE_PATH
from src.train.heatpipe import load_data, renormalize
from src.model.transolver import Transolver
from src.model.GeoFNO import GeoFNO2d as FNO
from src.utils.utils import relative_error, to_np, plot_scatter_compare, find_max_min
from src.model.diffusion import GaussianDiffusion
from src.inference.compose import compose_diffusion_multiE

In [None]:
device = "cuda"
diffusion_step = 250
model_type = "transformer"
train_loader, test_loader = load_data(ABSOLUTE_PATH, 64, 14000, model_type=model_type, device=device)
del train_loader
if model_type == "transformer":
    model = Transolver(
        space_dim=2,
        n_layers=5,
        n_hidden=64,
        dropout=0.0,
        n_head=8,
        Time_Input=True,
        act="gelu",
        mlp_ratio=1,
        fun_dim=13,
        out_dim=3,
        slice_num=16,
        ref=8,
        unified_pos=False,
    ).to(device)
elif model_type == "FNO":
    modes = [8, 8, 8]
    model = FNO(
        modes1=modes[0], modes2=modes[1], modes3=modes[2], width=32, in_channels=13, out_channels=3, time_input=True
    ).to(device)


diffusion = GaussianDiffusion(
    model,
    seq_length=tuple([804, 3]),
    timesteps=diffusion_step,
    auto_normalize=False,
).to(device)


diffusion.load_state_dict(torch.load("../../results/heatpipe/diffusion/" + model_type + "/model.pt")["model"])

In [58]:
def update(
    alpha,
    t,
    model,
    neighbors,
    cond_shape,
    boundary_emb,
    mult_e_noise,
    mult_e_estimate,
    mult_e_estimate_before,
    other_condition,
    normalize=nn.Identity(),
    renormalize=nn.Identity(),
):
    # boundary in neighbors should corrspond with boundary
    device = mult_e_estimate.device
    n_element, n_node = mult_e_estimate.shape[0], mult_e_estimate.shape[1]
    channel = mult_e_estimate.shape[-1]
    node_feature = torch.zeros((n_element, n_node) + (cond_shape,)).to(device)
    coord = other_condition[0]
    flux = other_condition[1]

    weight_field = mult_e_estimate_before * (1 - alpha) + mult_e_estimate * alpha

    for i in range(n_element):
        for j, neighbor in enumerate(neighbors[i + 1]):
            if not isinstance(neighbor, int):
                b_emb = torch.tensor(boundary_emb(neighbor)).to(device)
                node_feature[i, :, j * channel : j * channel + channel] = b_emb
            else:
                node_feature[i, :, j * channel : j * channel + channel] = weight_field[neighbor - 1]
        node_feature[i, :, (j + 1) * channel :] = flux[i]
    mult_e_noise_next, x0 = model.p_sample(mult_e_noise, t, (coord, node_feature))
    return mult_e_noise_next, x0


def boundary_emb_f(b_type, n_nods=804):
    free_emb = np.array([0, 0, 0])  # np.ones((1, 8)) * -1
    sym_emb = np.array([0, 1, 1])  # np.ones((1, 8)) * -2
    if b_type == "sym":
        return np.tile(sym_emb, (n_nods, 1))
    elif b_type == "free":
        return np.tile(free_emb, (n_nods, 1))

In [33]:
def coord_transform(coord):

    move_dis = 1 * np.array(
        [
            [0, 0],
            [-0.013856407, -0.024],
            [0.013856406, -0.024],
            [-0.027712807, -0.048],
            [0.0, -0.048],
            [0.027712813, -0.048],
            [-0.041569220, -0.072],
            [-0.013856407, -0.072],
            [0.013856406, -0.0720],
            [0.0415692190, -0.072],
            [0, 0],
            [-0.013856407, -0.024],
            [0.013856406, -0.024],
            [-0.027712807, -0.048],
            [0.0, -0.048],
            [0.027712813, -0.048],
        ]
    )

    sym = [-1] * 10 + [0.072, 0.072, 0.072, 0.072, 0.072, 0.072]

    def sym_block(y_coords, sym_axis):
        if sym_axis < 0:
            return y_coords
        else:
            sym_y_coords = 2 * sym_axis - y_coords
            return sym_y_coords

    coord_new = torch.zeros_like(coord)
    if model_type == "FNO":
        coord_new[:, :, 0] = (coord[:, :, 0]) * (0.065345 - 0.0455) + 0.0455
        coord_new[:, :, 1] = (coord[:, :, 1]) * (0.08918 - 0.072) + 0.072
    else:
        coord_new[:, :, 0] = (coord[:, :, 0] + 1) / 2 * (0.065345 - 0.0455) + 0.0455
        coord_new[:, :, 1] = (coord[:, :, 1] + 1) / 2 * (0.08918 - 0.072) + 0.072
    for i in range(move_dis.shape[0]):
        coord_new[i, :, 0] = coord_new[i, :, 0] + move_dis[i][0]
        coord_new[i, :, 1] = sym_block(coord_new[i, :, 1], sym[i])
        coord_new[i, :, 1] = coord_new[i, :, 1] + move_dis[i][1]
    # coord_new = to_np(coord_new.reshape(-1, 2))
    # coord_new_x = coord_new[:, 0]
    # coord_new_y = coord_new[:, 1]
    return coord_new


left = "sym"
right = "free"
bottom = "sym"
# left = "free"
# right = "free"
# bottom = (renomalize_disp(-0.0832), renomalize_disp(-0.2587))  # "fix"
neighbors = {
    # 1
    1: (left, right, 11),
    2: (left, 11, 12),
    3: (11, right, 13),
    4: (left, 12, 14),
    5: (12, 13, 15),
    6: (13, right, 16),
    7: (left, 14, 26),
    8: (14, 15, 25),
    9: (15, 16, 24),
    10: (16, right, 23),
    11: (3, 2, 1),
    12: (5, 4, 2),
    13: (6, 5, 3),
    14: (8, 7, 4),
    15: (9, 8, 5),
    16: (10, 9, 6),
    17: (55, 42, 27),
    18: (52, 27, 28),
    19: (27, 38, 29),
    20: (50, 28, 30),
    21: (28, 29, 31),
    22: (29, 35, 32),
    23: (49, 30, 10),
    24: (30, 31, 9),
    25: (31, 32, 8),
    26: (32, 33, 7),
    27: (19, 18, 17),
    28: (21, 20, 18),
    29: (22, 21, 19),
    30: (24, 23, 20),
    31: (25, 24, 21),
    32: (26, 25, 22),
    33: (left, 26, 43),
    34: (left, 43, 44),
    35: (43, 22, 45),
    36: (left, 44, 46),
    37: (44, 45, 47),
    38: (45, 19, 48),
    39: (left, 46, bottom),
    40: (46, 47, bottom),
    41: (47, 48, bottom),
    42: (48, 17, bottom),
    43: (35, 34, 33),
    44: (37, 36, 34),
    45: (38, 37, 35),
    46: (40, 39, 36),
    47: (41, 40, 37),
    48: (42, 41, 38),
    49: (23, right, 59),
    50: (20, 59, 60),
    51: (59, right, 61),
    52: (18, 60, 62),
    53: (60, 61, 63),
    54: (61, right, 64),
    55: (17, 62, bottom),
    56: (62, 63, bottom),
    57: (63, 64, bottom),
    58: (64, right, bottom),
    59: (51, 50, 49),
    60: (53, 52, 50),
    61: (54, 53, 51),
    62: (56, 55, 52),
    63: (57, 56, 53),
    64: (58, 57, 54),
}

In [93]:
import math
from tqdm.auto import tqdm



def compose_diffusion_multiE(
    model,
    shape,
    cond_shape,
    update_f,

    adj,
    boundary_emb,
    normalize_f=nn.Identity(),
    unnormalize_f=nn.Identity(),
    other_condition=[],
    num_iter=2,
    device="cuda",
):
    """compose diffusion model for multi element.



    Args:


        model: conditional diffusion model.


        shape: shape of field.


        update_f: update function physics field.


        adj (dict): neighbor for each element.


        normalize_f (_type_, optional): normalization function for each physics field.


        unnormalize_f (_type_, optional): unnormalization function for each physics field.


        boundary_emb: emb function for boundary.


        other_condition (list): other_condition such as initial state, source term. The shape of list element is b, *


        unnormalize (_type_, optional): unnormalization function for different physics field. Defaults to identity.


        num_iter: (int, optional): outer iteration. Defaults to 2.


        device (str, optional): _description_. Defaults to 'cuda'.


    Returns:


        Tensor: a tensor of multiphysics field


    """
    with torch.no_grad():


        n_compose = len(adj)


        timestep = model.num_timesteps


        # initial field
        mult_e_estimate = torch.randn((n_compose,) + shape).to(device)
        # for i in range(n_compose):
        #     mult_p_estimate.append(torch.randn(shape, device=device))


        for k in range(num_iter):
            mult_e_estimate_before = mult_e_estimate.clone()
            mult_e_estimate = torch.randn((n_compose,) + shape).to(device)
            mult_e = torch.randn((n_compose,) + shape).to(device)


            for t in tqdm(reversed(range(0, timestep)), desc="sampling loop time step", total=timestep):


                alpha = math.cos(math.pi / 2 * (t / (timestep - 1))) if k > 0 else 1
                # linear: 1 - t / (timestep - 1), 0->1
                # cos: math.cos(math.pi/2*(t / (timestep - 1))), 0->1
                # power1: 1 - (t / (timestep - 1))**2, 0->1
                # power2: (t / (timestep - 1)-1)**2, 0->1
                single_p, x0 = update_f(
                    alpha,
                    t,
                    model,

                    adj,
                    cond_shape,

                    boundary_emb,
                    mult_e.clone(),
                    # mult_e.clone(),
                    mult_e_estimate.clone(),
                    mult_e_estimate_before.clone(),
                    other_condition,

                    normalize_f,
                    unnormalize_f,
                )


                mult_e = single_p


                mult_e_estimate = model.unnormalize(x0)
    return mult_e

In [None]:
if model_type == "FNO":
    coord = torch.zeros(804, 3).to(device)
    coordxy = torch.tensor(np.load(ABSOLUTE_PATH + "/data/heatpipe/coord.npy")).to(device).float()
    coordxy[:, 0] = (coordxy[:, 0] - 0.0455) / (0.065345 - 0.0455)
    coordxy[:, 1] = (coordxy[:, 1] - 0.072) / (0.08918 - 0.072)
    coord[:, :-1] = coordxy
    coord = coord.expand(64, -1, -1)
else:
    coord = torch.tensor(np.load(ABSOLUTE_PATH + "/data/heatpipe/coord.npy")).to(device).float()
    coord[:, 0] = (coord[:, 0] - 0.0455) / (0.065345 - 0.0455) * 2 - 1
    coord[:, 1] = (coord[:, 1] - 0.072) / (0.08918 - 0.072) * 2 - 1
    coord = coord.expand(64, -1, -1)
flux = torch.tensor(np.load(ABSOLUTE_PATH + "/data/heatpipe/val_flux.npy")).to(device)
flux = (flux - 1e5) / 9e5 * 2 - 1
coord.shape, flux.shape

In [36]:
def coord_transform2(coord):
    base_coord = coord

    coord1 = np.zeros_like(base_coord)
    coord1[..., 0] = 2 * 0.0554256215 - base_coord[..., 0]
    coord1[..., 1] = -base_coord[..., 1]

    coord2 = np.zeros_like(base_coord)
    coord2[..., 0] = base_coord[..., 0] - 0.055425618
    coord2[..., 1] = base_coord[..., 1] - 0.096

    coord3 = np.zeros_like(base_coord)
    coord3[..., 0] = base_coord[..., 0] + 0.055425618
    coord3[..., 1] = base_coord[..., 1] - 0.096
    return np.concatenate((base_coord, coord1, coord2, coord3), axis=0)

In [86]:
def run():
    if model_type == "FNO":
        coord = torch.zeros(804, 3).to(device)
        coordxy = torch.tensor(np.load(ABSOLUTE_PATH + "/data/heatpipe/coord.npy")).to(device).float()
        coordxy[:, 0] = (coordxy[:, 0] - 0.0455) / (0.065345 - 0.0455)
        coordxy[:, 1] = (coordxy[:, 1] - 0.072) / (0.08918 - 0.072)
        coord[:, :-1] = coordxy
        coord = coord.expand(64, -1, -1)
    else:
        coord = torch.tensor(np.load(ABSOLUTE_PATH + "/data/heatpipe/coord.npy")).to(device).float()
        coord[:, 0] = (coord[:, 0] - 0.0455) / (0.065345 - 0.0455) * 2 - 1
        coord[:, 1] = (coord[:, 1] - 0.072) / (0.08918 - 0.072) * 2 - 1
        coord = coord.expand(64, -1, -1)
    flux = torch.tensor(np.load(ABSOLUTE_PATH + "/data/heatpipe/val_flux.npy")).to(device)
    flux = (flux - 1e5) / 9e5 * 2 - 1
    # coord.shape, flux.shape
    mult_e = compose_diffusion_multiE(
        model=diffusion,
        shape=(804, 3),
        cond_shape=10,
        update_f=update,
        adj=neighbors,
        boundary_emb=boundary_emb_f,
        other_condition=[coord, flux],
        num_iter=2,
    )
    mult_e = to_np(renormalize(mult_e))

    coord = torch.tensor(np.load(ABSOLUTE_PATH + "/data/heatpipe/coord.npy")).to(device).float()
    if model_type == "FNO":
        coord[:, 0] = (coord[:, 0] - 0.0455) / (0.065345 - 0.0455)
        coord[:, 1] = (coord[:, 1] - 0.072) / (0.08918 - 0.072)
    else:


        coord[:, 0] = (coord[:, 0] - 0.0455) / (0.065345 - 0.0455) * 2 - 1
        coord[:, 1] = (coord[:, 1] - 0.072) / (0.08918 - 0.072) * 2 - 1
    coord = coord.expand(16, -1, -1)


    coord = to_np(coord_transform(coord))

    coord_val = np.load(ABSOLUTE_PATH + "/data/heatpipe/coord_val.npy")
    val_y = np.load(ABSOLUTE_PATH + "/data/heatpipe/val_y.npy")
    bound = [[948, 1500], [-4e-4, 1.2e-3], [-4e-4, 1.2e-3]]
    val_y = renormalize(val_y)
    # for i in range(3):
    #     val_y[..., i] = (val_y[..., i] - bound[i][0]) / (bound[i][1] - bound[i][0])
    tolerance = 5e-5

    coord_struture_big = coord_transform2(to_np(coord))
    x_new, y_new = coord_struture_big[..., 0].reshape(-1), coord_struture_big[..., 1].reshape(-1)
    val_y_sorted = np.empty_like(val_y)
    for i, (x_val, y_val) in enumerate(zip(x_new, y_new)):

        distances = np.sqrt((coord_val[:, 0] - x_val) ** 2 + (coord_val[:, 1] - y_val) ** 2)
        min_distance_index = np.argmin(distances)
        val_y_sorted[i] = val_y[min_distance_index]
        if distances[min_distance_index] > tolerance:
            print(

                f"Multiple matching coordinates found for ({i},{x_val}, {y_val}) within tolerance, the min distance is {distances[min_distance_index]}."
            )


    mult_e = mult_e.reshape(-1, 3)
    r_e1, r_e2 = relative_error(mult_e[..., :1], val_y_sorted[..., :1]), relative_error(
        mult_e[..., 1:], val_y_sorted[..., 1:]
    )


    return r_e1, r_e2

In [73]:
def mean_stddev(data):
    if not data:
        return None, None
    mean = sum(data) / len(data)
    stddev = (sum((x - mean) ** 2 for x in data) / (len(data) - 1)) ** 0.5

    return mean, stddev

In [None]:
num = 5
e1_l, e2_l = [], []


for i in range(num):

    e1, e2 = run()

    e1_l.append(e1)

    e2_l.append(e2)
mean_stddev(e1_l), mean_stddev(e2_l)

pow1
((0.008363432815531436, 0.00024148665810661744),
 (0.022048669798078714, 0.0007914473439841704))
pow2
((0.008530787912456263, 0.00020660809828848567),
 (0.022533556806864113, 0.0005989850890467279))
linear
((0.008163951702652005, 3.950417688465311e-05),
 (0.021729155752877934, 0.00034757450324333016))
 Cos
 ((0.008022428352570717, 5.0478009322086946e-05),
 (0.02119709851787716, 0.00018905868963163024))

1((0.008832847685884425, nan), (0.022201415749382905, nan))
3((0.008095351677258247, nan), (0.021685304319997444, nan))
4 0.007847557604063202, nan, 0.020908780046219617,
5 0.007542297709852646, nan), (0.02024037768972978