In [1]:
# Reaction-Diffusion Prediction with FNO3D
# 1. Imports
from neuralop.models import FNO  # FNO handles N-D input
from neuralop.training import Trainer
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import h5py
import numpy as np


In [None]:
# 2. Dataset class 
class ReactionDiffusionDataset3D(Dataset):
    def __init__(self, file_path, initial_steps=10, future_steps=1):
        self.file_path = file_path
        self.initial_steps = initial_steps
        self.future_steps = future_steps

        # Open HDF5 in read-only mode
        self.h5_file = h5py.File(file_path, "r")
        self.keys = sorted(self.h5_file.keys())
        self.samples_per_key = self.h5_file[self.keys[0]]["data"].shape[0] - initial_steps - future_steps + 1

    def __len__(self):
        return len(self.keys) * self.samples_per_key

    def __getitem__(self, idx):
        key_idx = idx // self.samples_per_key
        local_idx = idx % self.samples_per_key
        data = self.h5_file[self.keys[key_idx]]["data"][:]

        x = data[local_idx:local_idx+self.initial_steps]
        y = data[local_idx+self.initial_steps:local_idx+self.initial_steps+self.future_steps]

        # Permute to [channels, nx, ny, time]
        x = np.transpose(x, (3, 1, 2, 0))
        y = np.transpose(y, (3, 1, 2, 0))

        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)


In [None]:
# 3. Hyperparameters
file_path = "--ENTER FILE NAME HERE --.h5"
initial_steps = 10
future_steps = 10
batch_size = 4
n_epochs = 50

In [None]:
# Wrap dataset to return dict instead of tuple
class DatasetWrapper(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, y = self.dataset[idx]   
        return {
            "x": x.detach().clone().float(),
            "y": y.detach().clone().float()
        }

In [None]:
# 4. Dataset and DataLoader
from torch.utils.data import Subset

dataset = ReactionDiffusionDataset3D(file_path, initial_steps=initial_steps, future_steps=future_steps)

n_train = 500
n_test = 50

# Random permutation of indices
perm = torch.randperm(len(dataset))

train_indices = perm[:n_train]
test_indices  = perm[n_train:n_train + n_test]

train_dataset = Subset(dataset, train_indices)
test_dataset  = Subset(dataset, test_indices)

# Wrap train and test datasets
train_dataset_wrapped = DatasetWrapper(train_dataset)
test_dataset_wrapped = DatasetWrapper(test_dataset)

# DataLoaders
train_loader = DataLoader(train_dataset_wrapped, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset_wrapped, batch_size=batch_size)

In [None]:
# 5. Initialize FNO3D model
operator = FNO(
    n_modes=(16,16,5),       # 3D Fourier modes: nx, ny, time
    hidden_channels=32,
    in_channels=dataset[0][0].shape[0],  # number of input channels
    out_channels=dataset[0][1].shape[0]  # number of output channels
)


In [None]:
# 6. Train
import torch.nn.functional as F

# Optimizer, scheduler, loss
optimizer = optim.Adam(operator.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
def mse_loss_dict(out, **kwargs):
    y = kwargs["y"]
    return F.mse_loss(out, y, reduction="mean")

trainer = Trainer(model=operator, n_epochs=n_epochs, verbose=True)

trainer.train(
    train_loader=train_loader,
    test_loaders={"test": test_loader},
    optimizer=optimizer,
    scheduler=scheduler,
    regularizer=False,
    training_loss=mse_loss_dict,
    eval_losses={"mse": mse_loss_dict}
)


Training on 500 samples
Testing on [50] samples         on resolutions ['test'].


  return forward_call(*args, **kwargs)


Raw outputs of shape torch.Size([4, 2, 128, 128, 10])
[0] time=663.93, avg_loss=0.0059, train_err=0.0238
Eval: test_mse=0.0019
[1] time=657.45, avg_loss=0.0012, train_err=0.0049
Eval: test_mse=0.0010
[2] time=655.83, avg_loss=0.0006, train_err=0.0026
Eval: test_mse=0.0006
[3] time=657.76, avg_loss=0.0004, train_err=0.0017
Eval: test_mse=0.0004
[4] time=657.42, avg_loss=0.0003, train_err=0.0012
Eval: test_mse=0.0003
[5] time=668.47, avg_loss=0.0002, train_err=0.0010
Eval: test_mse=0.0003
[6] time=770.51, avg_loss=0.0002, train_err=0.0008
Eval: test_mse=0.0002
[7] time=723.86, avg_loss=0.0002, train_err=0.0007
Eval: test_mse=0.0002
[8] time=665.19, avg_loss=0.0002, train_err=0.0007
Eval: test_mse=0.0002
[9] time=630.12, avg_loss=0.0002, train_err=0.0006
Eval: test_mse=0.0002
[10] time=663.06, avg_loss=0.0001, train_err=0.0005
Eval: test_mse=0.0002
[11] time=640.90, avg_loss=0.0001, train_err=0.0005
Eval: test_mse=0.0002
[12] time=677.24, avg_loss=0.0001, train_err=0.0005
Eval: test_mse=0

{'train_err': 0.00032004316663369534,
 'avg_loss': 8.001079165842383e-05,
 'avg_lasso_loss': None,
 'epoch_train_time': 607.3449381000028,
 'test_mse': tensor(9.9853e-05)}

In [None]:
# 7. Save model
operator.save_checkpoint(save_folder="./checkpoints", save_name="fno3d_reactiondiff_bc")

In [None]:
# 8. Test on unseen data
import torch
import numpy as np
from torch.utils.data import DataLoader, Subset

# Select last 150 samples as test set
test_dataset = dataset
test_indices = np.arange(len(test_dataset) - 150, len(test_dataset))
test_subset = Subset(test_dataset, test_indices)
test_loader = DataLoader(test_subset, batch_size=1, shuffle=False)

# Parameters
device = "cpu"
operator.to(device)
operator.eval()

mse_list = []
mae_list = []
r2_list = []

def r2_score(y_true, y_pred):
    ss_res = torch.sum((y_true - y_pred) ** 2)
    ss_tot = torch.sum((y_true - torch.mean(y_true)) ** 2)
    return 1 - ss_res / ss_tot

with torch.no_grad():
    for x, y_true in test_loader:

        x = x.to(device)              # Shape: [1, C, Nx, Ny, 10]
        y_true = y_true.to(device)    # Shape: [1, C, Nx, Ny, 10]

        y_pred = operator(x)                   # → same shape as y_true

        # Compute metrics on tensor (no flatten needed)
        mse_val = torch.mean((y_true - y_pred) ** 2).item()
        mae_val = torch.mean(torch.abs(y_true - y_pred)).item()
        r2_val = r2_score(y_true, y_pred).item()

        mse_list.append(mse_val)
        mae_list.append(mae_val)
        r2_list.append(r2_val)

print("Evaluation on last 150 unseen samples:")
print(f"MSE : {np.mean(mse_list):.6f}")
print(f"MAE : {np.mean(mae_list):.6f}")
print(f"R²  : {np.mean(r2_list):.6f}")


Evaluation on last 150 samples:
MSE : 0.000208
MAE : 0.008181
R²  : 0.995152
