In [1]:
import numpy as np
from sklearn.pipeline import Pipeline
from swimnetworks_local import dense, linear
from swimnetworks_local import deeponet

### Read the data and split it into training and test sets

In [2]:
u0 = np.load("u0.npy").T # initial conditions
u1 = np.load("u1.npy").T # solutions

# Split data
train_size = 12000

u0_train, u0_test = u0[:, :train_size], u0[:, train_size:]
u1_train, u1_test = u1[:, :train_size], u1[:, train_size:]
epsilon = np.linspace(0, 2*np.pi, 256).reshape(1, -1) # grid

# Print shapes
print(f"u0_train shape: {u0_train.shape}")
print(f"u0_test shape: {u0_test.shape}")
print(f"u1_train shape: {u1_train.shape}")
print(f"u1_test shape: {u1_test.shape}")
print(f"grid shape: {epsilon.shape}")

u0_train shape: (256, 12000)
u0_test shape: (256, 3000)
u1_train shape: (256, 12000)
u1_test shape: (256, 3000)
grid shape: (1, 256)


### Define the branch net and the trunk net

In [3]:
# Branch net
branch_steps = [
    ("dense", dense.Dense(layer_width=1024, activation="tanh",
                    parameter_sampler="tanh",
                    random_seed=42)),
    ("linear", linear.Linear(regularization_scale=1e-10))
]
branch_net = Pipeline(branch_steps)

# Trunk net
trunk_steps = [
    ("dense", dense.Dense(layer_width=1024, activation="tanh",
                    parameter_sampler="tanh",
                    random_seed=43)), 
    ("linear", linear.Linear(regularization_scale=1e-10))
]
trunk_net = Pipeline(trunk_steps)


model = deeponet.DeepONet(branch_pipeline=branch_net, trunk_pipeline=trunk_net)
model.fit(u0_train, u1_train, epsilon)


Iteration 0 | Loss: 0.9494296607789089
Iteration 1 | Loss: 21.55759843813135
Iteration 2 | Loss: 59.202861974266085
Iteration 3 | Loss: 348.2756716000173
Iteration 4 | Loss: 0.3027977794428526
Iteration 5 | Loss: 28.456023661824617
Iteration 6 | Loss: 77.56888875604922
Iteration 7 | Loss: 948.4305507478466
Iteration 8 | Loss: 111.41598943835525
Iteration 9 | Loss: 2.04661466567235


In [4]:
predictions = model.transform(u0_test, epsilon)

# MSE
mse = np.mean((predictions - u1_test)**2)
print(f"Mean Squared Error on Test Data: {mse}")

Mean Squared Error on Test Data: 2398.7133245974214
