In [1]:
import os
os.environ["DDE_BACKEND"] = "pytorch"

import numpy as np
import torch
import deepxde as dde
from pyhdf.SD import SD, SDC
from deepxde.nn import DeepONetCartesianProd


Using backend: pytorch
Other supported backends: tensorflow.compat.v1, tensorflow, jax, paddle.
paddle supports more examples now and is recommended.


In [2]:
def read_hdf(hdf_path, dataset_names):
    f = SD(hdf_path, SDC.READ)
    datasets = []
    for dataset_name in dataset_names:
        datasets.append(f.select(dataset_name).get())
    return datasets

In [3]:
br, phi, theta, rho = read_hdf('/Users/reza/Career/DMLab/SURROGATE/Data/psi_web_sample/train/cr1732/kpo_mas_mas_std_0101/br002.hdf', ["Data-Set-2", "fakeDim0", "fakeDim1", "fakeDim2"])

In [4]:
br.shape, phi.shape, theta.shape, rho.shape

((128, 110, 141), (128,), (110,), (141,))

In [6]:
br = br.transpose(2, 1, 0)
br.shape
br = np.zeros((2, 110, 128))

In [7]:
theta = torch.as_tensor(theta, dtype=torch.float32)  # [111]
phi   = torch.as_tensor(phi,   dtype=torch.float32)  # [128]
br    = torch.as_tensor(br,    dtype=torch.float32)  # [1, 111, 128]

In [8]:
Theta, Phi = torch.meshgrid(theta, phi, indexing="ij")
# Flatten => shape [111 * 128, 2] = [14208, 2]
coords = torch.stack([Theta.flatten(), Phi.flatten()], dim=-1)
print("coords shape:", coords.shape)  # [14208, 2]

coords shape: torch.Size([14080, 2])


In [12]:
# br has shape [1, 111, 128]
# Flatten => [1, 14208]
branch_input = br.reshape(2, -1)
print("branch_input shape:", branch_input.shape)  # [1, 14208]


branch_input shape: torch.Size([2, 14080])


In [13]:
model = DeepONetCartesianProd(
    layer_sizes_branch=[14080, 256, 128],   # Branch final layer = 128
    layer_sizes_trunk=[2, 256, 17920],      # Trunk final layer = 17920 = 128 * 140
    activation="tanh",
    kernel_initializer="Glorot uniform",
    num_outputs=140,                        # produce 140 output channels
    multi_output_strategy="split_trunk",    
)


In [14]:
y_pred = model([branch_input, coords])
print("Raw output:", y_pred.shape)  # [1, 14208, 140]


Raw output: torch.Size([2, 14080, 140])


In [29]:
# Permute to [1, 140, 14208]
y_pred = y_pred.permute(0, 2, 1)

# Reshape to [1, 140, 111, 128]
y_pred = y_pred.reshape(1, 140, 110, 128)

# Drop the batch dim => [140, 111, 128]
y_pred = y_pred.squeeze(0)
print("Final output:", y_pred.shape)  # [140, 111, 128]


Final output: torch.Size([140, 110, 128])
