In [1]:
import numpy as np
from sklearn.pipeline import Pipeline
from swimnetworks_local import dense, linear
from swimnetworks_local import deeponet, pod_deeponet

### Read the data and split it into training and test sets

In [2]:
u0 = np.load("u0.npy") # initial conditions
u1 = np.load("u1.npy") # solutions

# Split data
train_size = 12000

u0_train, u0_test = u0[:train_size, :], u0[train_size:, :]
u1_train, u1_test = u1[:train_size, :], u1[train_size:, :]
epsilon = np.linspace(0, 2*np.pi, 256).reshape(-1, 1)


# Print shapes
print(f"u0_train shape: {u0_train.shape}")
print(f"u0_test shape: {u0_test.shape}")
print(f"u1_train shape: {u1_train.shape}")
print(f"u1_test shape: {u1_test.shape}")
print(f"grid shape: {epsilon.shape}")

u0_train shape: (12000, 256)
u0_test shape: (3000, 256)
u1_train shape: (12000, 256)
u1_test shape: (3000, 256)
grid shape: (256, 1)


### Define the branch net and the trunk net

In [3]:
# Branch net
branch_steps = [
    ("dense", dense.Dense(layer_width=1024, activation="tanh",
                    parameter_sampler="tanh",
                    random_seed=42)),
    ("linear", linear.Linear(regularization_scale=1e-10))
]
branch_net = Pipeline(branch_steps)

# Trunk net
trunk_steps = [
    ("dense", dense.Dense(layer_width=1024, activation="tanh",
                    parameter_sampler="tanh",
                    random_seed=43)), 
    ("linear", linear.Linear(regularization_scale=1e-10))
]
trunk_net = Pipeline(trunk_steps)


model = deeponet.DeepONet(branch_pipeline=branch_net, trunk_pipeline=trunk_net)
model.fit(u0_train, u1_train, epsilon)


POD modes shape: (256, 32)
T_tilda shape: (256, 32)
After T shape: (256, 32)
Iteration 0 | Relative L2 Loss: 0.41769499941868676


IndexError: index 7164 is out of bounds for axis 0 with size 256

In [None]:
predictions = model.transform_branch(u0_test)

# Mean Relative L2 Loss
relative_L2_loss = np.sum(np.linalg.norm(predictions - u1_test, axis=1) / np.linalg.norm(u1_test, axis=1)) / u1_test.shape[0]
print(f"Mean Relative L2 Loss on Test Data: {relative_L2_loss}")