In [1]:
import os
os.environ["DDE_BACKEND"] = "pytorch"
import torch
import json
from io import BytesIO
import numpy as np
import torch.nn as nn
import imageio.v2 as imageio
from tqdm import tqdm, trange
from utils import DeepONetDataset
from deepxde.nn import DeepONetCartesianProd
import matplotlib.pyplot as plt
import deepxde as dde

Using backend: pytorch
Other supported backends: tensorflow.compat.v1, tensorflow, jax, paddle.
paddle supports more examples now and is recommended.


In [2]:
device = torch.device('cpu')

In [3]:
data_path = '/Users/reza/Career/DMLab/SURROGATE/Data/psi_web_sample/test'
result_path = '/Users/reza/Career/DMLab/SURROGATE/results/laplace/deeponet/exp_16'
cfg_path = os.path.join(result_path, 'cfg.json')
with open(cfg_path, 'r') as f:
    cfg = json.load(f)

In [4]:
len(cfg['val_files'])

227

In [5]:
instruments = [
        "kpo_mas_mas_std_0101",
        "mdi_mas_mas_std_0101",
        "hmi_mast_mas_std_0101",
        "hmi_mast_mas_std_0201",
        "hmi_masp_mas_std_0201",
        "mdi_mas_mas_std_0201",
    ]
subdir_paths = sorted(os.listdir(data_path))
cr_paths = [os.path.join(data_path, p) for p in subdir_paths if p.startswith("cr")]
sim_paths = []
for cr_path in cr_paths:
    for instrument in instruments:
        instrument_path = os.path.join(cr_path, instrument)
        if os.path.exists(instrument_path):
            sim_paths.append(instrument_path)

In [6]:
dataset = DeepONetDataset(
    sim_paths,
    b_min=cfg["train_min"],
    b_max=cfg["train_max"],
)

100%|██████████| 21/21 [00:00<00:00, 101.07it/s]


In [7]:
cfg['train_min'], cfg['train_max']

(-2.6867804527282715, 2.7947838306427)

In [8]:
model = DeepONetCartesianProd(
    # Branch net: from input size=14080 => hidden layers => final layer=128
    layer_sizes_branch=[14080, 512, 512, 256, 128],
    # Trunk net: from coords=2 => hidden layers => final layer=17920 (128 * 140)
    layer_sizes_trunk=[2, 512, 512, 512, 256, 17920],
    activation="tanh",
    kernel_initializer="Glorot uniform",
    num_outputs=140,  # produce 140 output channels
    multi_output_strategy="split_trunk",
).to(device)
# dde.Model.load(os.path.join(result_path, "199.pth"))
state = torch.load(os.path.join(result_path, "199.pth"), map_location=device, weights_only=False)
model.load_state_dict(state)

<All keys matched successfully>

In [9]:
instance = dataset[0]

In [45]:
I = 10
instance = dataset[I]
x, y = instance
print(x.shape, y.shape)

model.eval()
with torch.no_grad():
    yhats = model([x.to(device).unsqueeze(0), dataset.trunk_input.to(device)])

torch.Size([14080]) torch.Size([14080, 140])


In [46]:
y.shape, yhats.shape

(torch.Size([14080, 140]), torch.Size([1, 14080, 140]))

In [47]:
y_unflattened = y.reshape(140, 110, 128)
yhats_unflattened = yhats.reshape(140, 110, 128)

In [48]:
frames = []
for i in trange(140):
    y = y_unflattened[i]
    # y = y.reshape(110, 128)
    yhat = yhats_unflattened[i]
    # yhat = yhat.reshape(110, 128)
    # print(cube.shape)
    # y = y.reshape((128, 110))
    # # print(yhat.shape)
    # yhat = yhat.reshape((128, 110))
    # y = y.transpose(1, 0)
    # yhat = yhat.transpose(1, 0)

    error = np.abs(y - yhat)
    
    error = (error - error.min()) / (error.max() - error.min())
    
    vmin = min(y.min(), yhat.min())
    vmax = max(y.max(), yhat.max())
    fig, axes = plt.subplots(1, 3, figsize=(12, 6))
    cmap = "coolwarm"

    # # Plot first subplot
    im1 = axes[0].imshow(y, cmap=cmap, vmin=vmin, vmax=vmax)
    axes[0].set_title(f"gt: br002 at {i+1}")

    # Plot second subplot
    im3 = axes[1].imshow(yhat, cmap=cmap, vmin=vmin, vmax=vmax)
    axes[1].set_title(f"pred at {i+1} ")

    cbar = fig.colorbar(im1, ax=axes, orientation="horizontal", fraction=0.1, pad=0.02)

    # Plot second subplot
    im4 = axes[2].imshow(error, cmap="gray")
    axes[2].set_title(f"|pred-gt|")

    # Create a single colorbar
    cbar = fig.colorbar(im4, ax=axes, orientation="vertical", fraction=0.05, pad=0.02)

    # plt.tight_layout()
    buf = BytesIO()
    plt.savefig(buf, format="png")
    frames.append(imageio.imread(buf))
    plt.close()

100%|██████████| 140/140 [00:24<00:00,  5.79it/s]


In [49]:
output_filename = f"b{I}.gif"
fps = 10  # Adjust frames per second as needed

# Create video directly from frames in memory
with imageio.get_writer(output_filename, fps=fps, loop=0) as writer:
    for frame in frames:
        writer.append_data(frame)