In [1]:
import torch
import matplotlib.pyplot as plt
import sys
from xno.models import XNO
from xno.data.datasets import load_darcy_flow_small
from xno.utils import count_model_params
from xno.training import AdamW
from xno.training.incremental import IncrementalFNOTrainer
from xno.data.transforms.data_processors import IncrementalDataProcessor
from xno import LpLoss, H1Loss

In [None]:
train_loader, test_loaders, output_encoder = load_darcy_flow_small(
    n_train=100,
    batch_size=16,
    test_resolutions=[32],
    n_tests=[100, 50],
    test_batch_sizes=[32, 32],
)

In [None]:
batch = next(iter(train_loader))
type(train_loader), type(batch), batch['x'].shape, batch['y'].shape

In [None]:
batch = next(iter(test_loaders[32]))
type(test_loaders), type(batch), batch['x'].shape, batch['y'].shape

In [None]:
len(train_loader.dataset)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
incremental = True
if incremental:
    starting_modes = (2, 2)
else:
    starting_modes = (16, 16)

In [None]:
model = XNO(
    max_n_modes=(16, 16),
    n_modes=(2, 2),
    hidden_channels=32,
    in_channels=1,
    out_channels=1,
    transformation="wno",
    transformation_kwargs={"wavelet_level": 3, "wavelet_size": [16, 16]}
)
model = model.to(device)
n_params = count_model_params(model)

In [9]:
optimizer = AdamW(model.parameters(), lr=8e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

In [None]:
data_transform = IncrementalDataProcessor(
    in_normalizer=None,
    out_normalizer=None,
    device=device,
    subsampling_rates=[2, 1],
    dataset_resolution=16,
    dataset_indices=[2, 3],
    epoch_gap=10,
    verbose=True,
)

data_transform = data_transform.to(device)

In [None]:
l2loss = LpLoss(d=2, p=2)
h1loss = H1Loss(d=2)
train_loss = h1loss
eval_losses = {"h1": h1loss, "l2": l2loss}
print("\n### N PARAMS ###\n", n_params)
print("\n### OPTIMIZER ###\n", optimizer)
print("\n### SCHEDULER ###\n", scheduler)
print("\n### LOSSES ###")
print("\n### INCREMENTAL RESOLUTION + GRADIENT EXPLAINED ###")
print(f"\n * Train: {train_loss}")
print(f"\n * Test: {eval_losses}")
sys.stdout.flush()

In [12]:
# Finally pass all of these to the Trainer
trainer = IncrementalFNOTrainer(
    model=model,
    n_epochs=10,
    data_processor=data_transform,
    device=device,
    verbose=True,
    incremental_loss_gap=False,
    incremental_grad=True,
    incremental_grad_eps=0.9999,
    incremental_loss_eps = 0.001,
    incremental_buffer=5,
    incremental_max_iter=1,
    incremental_grad_max_iter=2,
)

In [None]:
trainer.train(
    train_loader,
    test_loaders,
    optimizer,
    scheduler,
    regularizer=False,
    training_loss=train_loss,
    eval_losses=eval_losses,
)

In [None]:
# FNO
{'train_err': 5.598510350499835,
 'avg_loss': 0.3918957245349884,
 'avg_lasso_loss': None,
 'epoch_train_time': 0.46678220800095005,
 '32_h1': tensor(0.9681),
 '32_l2': tensor(0.3976)}

In [None]:
# HNO
{'train_err': 6.751643555504935,
 'avg_loss': 0.47261504888534545,
 'avg_lasso_loss': None,
 'epoch_train_time': 0.41804891700303415,
 '32_h1': tensor(0.8712),
 '32_l2': tensor(0.3615)}

In [None]:
# LNO
{'train_err': 8.302332741873604,
 'avg_loss': 0.5811632919311523,
 'avg_lasso_loss': None,
 'epoch_train_time': 0.932207040998037,
 '32_h1': tensor(0.9699),
 '32_l2': tensor(0.4257)}

In [None]:
# WNO
{'train_err': 6.326124395642962,
 'avg_loss': 0.4428287076950073,
 'avg_lasso_loss': None,
 'epoch_train_time': 5.432985792002,
 '32_h1': tensor(1.4008),
 '32_l2': tensor(0.7668)}