In [13]:
from xno.data.datasets.hdf5_dataset import H5pyDataset
from utils import MatReader
from torch.utils.data import DataLoader, TensorDataset, Dataset, default_collate
import numpy as np

In [14]:
import torch
import matplotlib.pyplot as plt
import sys
from xno.models import XNO
from xno.utils import count_model_params
from xno.training import AdamW
from xno.training.incremental import IncrementalXNOTrainer
from xno.data.transforms.data_processors import IncrementalDataProcessor
from xno import LpLoss, H1Loss

In [32]:
data_path="data/1d_lorenz96.mat"

In [33]:
# 1) Load dataset
reader = MatReader(data_path)
X = reader.read_field("X")  # shape [T, D], e.g. [2001, 40]

step_mode = "s"
# 2) Next-step: (X[t], X[t+1]) pairs
x_ns = X[:-1]    # all but last
y_ns = X[1:]     # all but first

In [34]:
# 3) Multi-step example: window=5 in, horizon=10 out
window, horizon = 5, 10
X_in, X_out = [], []
for i in range(len(X) - window - horizon + 1):
    X_in.append(X[i : i + window])
    X_out.append(X[i + window : i + window + horizon])
X_in = np.array(X_in)   # shape [N, window, D]
X_out = np.array(X_out) # shape [N, horizon, D]

In [35]:
# 4) Simple train/test split (80/20) for both tasks
if step_mode == 's':
    print('Step mode is: Single')
    split = 1600
    x_train, x_test = x_ns[:split], x_ns[split:]
    y_train, y_test = y_ns[:split], y_ns[split:]
else:
    print('Step mode is: Multiple')
    split = int(0.8 * len(X_in))
    x_train, x_test = X_in[:split], X_in[split:]
    y_train, y_test = X_out[:split], X_out[split:]

# 5) Print shapes to verify
print("Sshapes:", x_train.shape, y_train.shape, x_test.shape, y_test.shape)


Step mode is: Single
Sshapes: torch.Size([1600, 40]) torch.Size([1600, 40]) torch.Size([400, 40]) torch.Size([400, 40])


In [36]:
# Define the custom Dataset
class DictDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return {'x': self.x[idx], 'y': self.y[idx]}

In [37]:
x_train.shape, y_train.shape

(torch.Size([1600, 40]), torch.Size([1600, 40]))

In [38]:
x_test.shape, y_test.shape

(torch.Size([400, 40]), torch.Size([400, 40]))

In [40]:
x_train = x_train.unsqueeze(1)
# x_train = x_train.permute(0, 2, 1)
y_train = y_train.unsqueeze(1)
x_test = x_test.unsqueeze(1)
# x_test = x_test.permute(0, 2, 1)
y_test = y_test.unsqueeze(1)

In [48]:
y_test.shape

torch.Size([400, 1, 40])

In [49]:
train_loader = DictDataset(x_train, y_train)
test_loader = DictDataset(y_test, y_test)

In [51]:
train_loader = DataLoader(train_loader, batch_size=20, shuffle=True)
test_loader = DataLoader(test_loader, batch_size=20, shuffle=True)
test_loader = {
    40: test_loader
}

In [52]:
batch = next(iter(train_loader))
type(train_loader), type(batch), batch['x'].shape, batch['y'].shape

(torch.utils.data.dataloader.DataLoader,
 dict,
 torch.Size([20, 1, 40]),
 torch.Size([20, 1, 40]))

In [53]:
batch = next(iter(test_loader[40]))
type(test_loader), type(batch), batch['x'].shape, batch['y'].shape

(dict, dict, torch.Size([20, 1, 40]), torch.Size([20, 1, 40]))

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = XNO(
    max_n_modes=(16, ),
    n_modes=(2, ),
    hidden_channels=32,
    in_channels=1,
    out_channels=1,
    transformation="hno",
    # transformation_kwargs={"wavelet_level": 2, "wavelet_size": [2048]}
)
model = model.to(device)
n_params = count_model_params(model)

In [16]:
optimizer = AdamW(model.parameters(), lr=8e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

In [None]:
data_transform = IncrementalDataProcessor(
    in_normalizer=None,
    out_normalizer=None,
    device=device,
    subsampling_rates=[2, 1],
    dataset_resolution=2048,
    dataset_indices=[2],
    epoch_gap=10,
    verbose=True,
)

data_transform = data_transform.to(device)

In [None]:
l2loss = LpLoss(d=2, p=2)
h1loss = H1Loss(d=2)
train_loss = h1loss
eval_losses = {"h1": h1loss, "l2": l2loss}
print("\n### N PARAMS ###\n", n_params)
print("\n### OPTIMIZER ###\n", optimizer)
print("\n### SCHEDULER ###\n", scheduler)
print("\n### LOSSES ###")
print("\n### INCREMENTAL RESOLUTION + GRADIENT EXPLAINED ###")
print(f"\n * Train: {train_loss}")
print(f"\n * Test: {eval_losses}")
sys.stdout.flush()

In [19]:
# Finally pass all of these to the Trainer
trainer = IncrementalFNOTrainer(
    model=model,
    n_epochs=10,
    data_processor=data_transform,
    device=device,
    verbose=True,
    incremental_loss_gap=False,
    incremental_grad=True,
    incremental_grad_eps=0.9999,
    incremental_loss_eps = 0.001,
    incremental_buffer=5,
    incremental_max_iter=1,
    incremental_grad_max_iter=2,
)

In [None]:
trainer.train(
    train_loader,
    test_loader,
    optimizer,
    scheduler,
    regularizer=False,
    training_loss=train_loss,
    eval_losses=eval_losses,
)

In [None]:
# FNO
{'train_err': 10.971330422621508,
 'avg_loss': 0.713136477470398,
 'avg_lasso_loss': None,
 'epoch_train_time': 4.013646583998707,
 '2048_h1': tensor(0.7238),
 '2048_l2': tensor(0.9066)}

In [None]:
# HNO
{'train_err': 11.142363548278809,
 'avg_loss': 0.7242536306381225,
 'avg_lasso_loss': None,
 'epoch_train_time': 5.140133208000407,
 '2048_h1': tensor(0.7229),
 '2048_l2': tensor(0.7630)}

In [None]:
# WNO
{'train_err': 6.489716823284443,
 'avg_loss': 0.42183159351348876,
 'avg_lasso_loss': None,
 'epoch_train_time': 5.868622500000129,
 '2048_h1': tensor(0.7528),
 '2048_l2': tensor(0.7417)}

In [None]:
# LNO
{'train_err': 11.74565157523522,
 'avg_loss': 0.7634673523902893,
 'avg_lasso_loss': None,
 'epoch_train_time': 9.407964875001198,
 '2048_h1': tensor(0.7606),
 '2048_l2': tensor(0.7453)}