In [None]:
import torch
import os
import sys
import matplotlib.pyplot as plt
import argparse
import numpy as np
from tqdm.auto import tqdm
import torch.nn.functional as F

sys.path.append("../../")
from src.filepath import ABSOLUTE_PATH
from src.model.UNet1d import Unet1D
from src.utils.utils import plot_compare_2d, relative_error
from src.train.reaction_diffusion import cond_emb, renormalize
from src.train.reaction_diffusion import normalize_to_neg_one_to_one as normalize

In [None]:
channel, out_dim, dim = 20, 9, 24
model1 = Unet1D(dim=dim, cond_emb=cond_emb(), out_dim=out_dim, dim_mults=(1, 2), channels=channel - out_dim).to("cuda")
model2 = Unet1D(dim=dim, cond_emb=cond_emb(), out_dim=out_dim, dim_mults=(1, 2), channels=channel - out_dim).to("cuda")
model1.load_state_dict(torch.load("../../results/reaction_diffusion_couple_data/surrogateUnetu5000/model.pt")["model"])
model2.load_state_dict(torch.load("../../results/reaction_diffusion_couple_data/surrogateUnetv5000/model.pt")["model"])

In [3]:
device = "cuda"

In [None]:
uv = np.load(ABSOLUTE_PATH + "/data/reaction_diffusion/reaction_diffusion_uv.npy").transpose(0, 2, 1)
# uv = uv
data = uv[..., :20]
cond = uv[..., 20:]
# u0 is cond
cond = np.concatenate((cond, data[:, :1]), axis=1)
data = data[:, 1:]

data, cond = torch.tensor(data).to(device).float(), torch.tensor(cond).to(device).float()


u_pred = model1(normalize(cond))


rmse = relative_error(data, renormalize(u_pred))
rmse, F.mse_loss(normalize(data), u_pred)

In [None]:
uv = np.load(ABSOLUTE_PATH + "/data/reaction_diffusion/reaction_diffusion_uv.npy").transpose(0, 2, 1)
# uv = uv
data = uv[..., 20:]
cond = uv[..., :20]
# u0 is cond
cond = np.concatenate((cond, data[:, :1]), axis=1)
data = data[:, 1:]

data, cond = torch.tensor(data).to(device).float(), torch.tensor(cond).to(device).float()


u_pred = model2(normalize(cond))


rmse = relative_error(data, renormalize(u_pred))
rmse, F.mse_loss(normalize(data), u_pred)

compose

In [None]:
data = torch.tensor(np.load("../../data/reaction_diffusion/reaction_diffusion_uv.npy")).float().to("cuda")
data = normalize(data)
# data = (data + 5) / 10
data = data.permute(0, 2, 1)
# data1 = np.load('../../data/reaction_diffusion_u_from_v_u.npy')
u, v = data[..., :20], data[..., 20:]
u_intial, v_intial = u[:, :1], v[:, :1]
u.shape, u_intial.shape

In [None]:
with torch.no_grad():
    i = 0
    eps = 1
    v_iter = torch.ones_like(v[:, 1:])
    u_iter = torch.ones_like(u[:, 1:])
    while i < 100 and eps > 2e-5:
        i = i + 1
        condv = torch.concat((v_intial, v_iter, u_intial), dim=1)
        u_iter_new = model1(condv)
        condu = torch.concat((u_intial, u_iter, v_intial), dim=1)
        v_iter_new = model2(condu)
        eps = F.l1_loss(u_iter, u_iter_new) + F.l1_loss(v_iter, v_iter_new)
        u_iter = u_iter_new
        v_iter = v_iter_new
        print("iteration: ", i, " eps: ", eps)

F.mse_loss(u_iter, u[:, 1:]), F.mse_loss(v_iter, v[:, 1:])

In [None]:
mult_p_true = renormalize(torch.concat((u[:, 1:].unsqueeze(1), v[:, 1:].unsqueeze(1)), dim=1))
mult_p_pred = renormalize(torch.concat((u_iter.unsqueeze(1), v_iter.unsqueeze(1)), dim=1))
relative_error(mult_p_pred[:, 0], mult_p_true[:, 0]), relative_error(mult_p_pred[:, 1], mult_p_true[:, 1])

In [None]:
random_n = np.random.randint(0, data.shape[0])
random_n = -1


plot_compare_2d(
    true_d=mult_p_true[random_n, 0],
    pred_d=mult_p_pred[random_n, 0],
    savep=ABSOLUTE_PATH + "/results/reaction_diffusion/surrogate.pdf",
)