In [1]:
import torch
from torch.utils.data import DataLoader, TensorDataset, Dataset, default_collate
import torch.nn.functional as F
import matplotlib.pyplot as plt
import sys
from utils import MatReader
from pathlib import Path
import numpy as np

In [2]:
import sys
import os
sys.path.append(os.path.abspath(".."))

from xno.models import XNO
from xno.data.datasets import Burgers1dTimeDataset
from xno.utils import count_model_params
from xno.training import AdamW
from xno.training.incremental import IncrementalXNOTrainer
from xno.data.transforms.data_processors import IncrementalDataProcessor
from xno import LpLoss, H1Loss

In [3]:
# Define the custom Dataset
class DictDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return {'x': self.x[idx], 'y': self.y[idx]}

# # Loading Burgers 1D dataset

# ## Settings

# ### Data Settings

In [4]:
ntrain = 495000
ntest = 5000

In [5]:
data_path = 'data/lorenz96_ashesh6810.mat'
data_name = '1d_lorenz96_ashesh6810'
time_step_mode = "s"

batch_size = 1024
dataset_resolution = 8

# XNO (model) 
max_modes = (4, )
n_modes = (4, )
in_channels = 1
out_channels = 1
n_layers = 4
hidden_channels = 32
transformation = "fno"
kwargs = {
    "wavelet_level": 3, 
    "wavelet_size": [dataset_resolution], "wavelet_filter": ['db6']
} if transformation.lower() == "wno" else {}

conv_non_linearity = None
mlp_non_linearity = None

match transformation.lower():
    case "fno" | "hno":
        conv_non_linearity = F.gelu
        mlp_non_linearity = F.gelu
    case "wno":
        conv_non_linearity = F.gelu
        mlp_non_linearity = F.gelu
    case "lno":
        conv_non_linearity = torch.sin
        mlp_non_linearity = torch.tanh

# AdamW (optimizer) 
learning_rate = 1e-3
weight_decay = 1e-4
# CosineAnnealingLR (scheduler) 
step_size = 100 if transformation.lower() == "lno" else 50
gamma = 0.5

# IncrementalDataProcessor (data_transform) 
dataset_resolution = dataset_resolution
dataset_indices = [2]

# IncrementalXNOTrainer (trainer) 
n_epochs = 250 # 500
save_every = 50
save_testing = True
save_dir = f"save/{data_name}/{transformation.lower()}/"


# Open the file at the start of the script
# output_file = open(f"{data_name}_{transformation.lower()}.txt", "w")
# sys.stdout = output_file  # Redirect stdout to the file

In [6]:
# 1) Load dataset
reader = MatReader(data_path)

X = reader.read_field("data")  
shift_k = 0
train_length = 500000

# 2) Shape X & Y, Test & Train
train = X[shift_k:shift_k+train_length,:]    
label = X[1+shift_k:1+shift_k+train_length,:]      

x_ns = train
y_ns = label - train # y is the delta_x

x_train, x_test = x_ns[:ntrain], x_ns[ntrain:ntrain+ntest]
y_train, y_test = y_ns[:ntrain], y_ns[ntrain:ntrain+ntest]

# 3) Print shapes to verify
print("Sshapes:", x_train.shape, y_train.shape, x_test.shape, y_test.shape)

Sshapes: torch.Size([495000, 8]) torch.Size([495000, 8]) torch.Size([5000, 8]) torch.Size([5000, 8])


In [7]:
print("\n=== Data shape after importing from raw dataset ===\n")
print(f"X_Train Shape: {x_train.shape}")
print(f"Y_Train Shape: {y_train.shape}")
print(f"X_Test Shape: {x_test.shape}")
print(f"Y_Test Shape: {y_test.shape}")


=== Data shape after importing from raw dataset ===

X_Train Shape: torch.Size([495000, 8])
Y_Train Shape: torch.Size([495000, 8])
X_Test Shape: torch.Size([5000, 8])
Y_Test Shape: torch.Size([5000, 8])


In [8]:
# x_train = x_train.permute(0, 2, 1)
x_train = x_train.unsqueeze(1)
y_train = y_train.unsqueeze(1)
# x_test = x_test.permute(0, 2, 1)
x_test = x_test.unsqueeze(1)
y_test = y_test.unsqueeze(1)

In [9]:


print("\n=== Data shape after reshaping based on [Batch, Channel, D1, D2, ...] ===\n")
print(f"X_Train Shape: {x_train.shape}")
print(f"Y_Train Shape: {y_train.shape}")
print(f"X_Test Shape: {x_test.shape}")
print(f"Y_Test Shape: {y_test.shape}")


=== Data shape after reshaping based on [Batch, Channel, D1, D2, ...] ===

X_Train Shape: torch.Size([495000, 1, 8])
Y_Train Shape: torch.Size([495000, 1, 8])
X_Test Shape: torch.Size([5000, 1, 8])
Y_Test Shape: torch.Size([5000, 1, 8])


In [10]:
train_loader = DictDataset(x_train, y_train)
test_loader = DictDataset(x_test, y_test)

In [11]:
train_loader = DataLoader(train_loader, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_loader, batch_size=batch_size, shuffle=True)
test_loader = {
    dataset_resolution: test_loader
}

In [12]:
print("\n=== One batch of the Train Loader ===\n")
batch = next(iter(train_loader))
print(f"Loader Type: {type(train_loader)}\nBatch Type: { type(batch)}\nBatch['x'].shape: {batch['x'].shape}\nBatch['y'].shape: {batch['y'].shape}")


=== One batch of the Train Loader ===

Loader Type: <class 'torch.utils.data.dataloader.DataLoader'>
Batch Type: <class 'dict'>
Batch['x'].shape: torch.Size([20, 1, 8])
Batch['y'].shape: torch.Size([20, 1, 8])


In [13]:
print("\n=== One batch of the Test Loader ===\n")
batch = next(iter(test_loader[dataset_resolution]))
print(f"Loader Type: {type(test_loader[dataset_resolution])}\nBatch Type: { type(batch)}\nBatch['x'].shape: {batch['x'].shape}\nBatch['y'].shape: {batch['y'].shape}")


=== One batch of the Test Loader ===

Loader Type: <class 'torch.utils.data.dataloader.DataLoader'>
Batch Type: <class 'dict'>
Batch['x'].shape: torch.Size([20, 1, 8])
Batch['y'].shape: torch.Size([20, 1, 8])


In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n=== Device: {device} ===\n")


=== Device: cpu ===



In [15]:
model = XNO(
    max_n_modes=max_modes,
    n_modes=n_modes,
    hidden_channels=hidden_channels,
    in_channels=in_channels,
    out_channels=out_channels,
    transformation=transformation,
    transformation_kwargs=kwargs,
    conv_non_linearity=conv_non_linearity, 
    mlp_non_linearity=mlp_non_linearity,
    n_layers=n_layers
)
model = model.to(device)
n_params = count_model_params(model)

Dimentionality: 1D
Transformation: [ Fourier Neural Operator (FNO) Kernel ]
>>> Overview:
The FNO leverages Fourier Transform to map input data into the spectral domain, where
convolutional operations are performed by truncating high-frequency modes.

>>> Key Features:
- Effective for parameterized Partial Differential Equations (PDEs).
- Reduces computational complexity by retaining only significant modes.

>>> Reference:
Li, Z. et al. 'Fourier Neural Operator for Parametric Partial Differential Equations' (ICLR 2021).
Link: https://arxiv.org/pdf/2010.08895

>>> Normaliztion: None
>>> Activation Function: 



In [16]:
optimizer = AdamW(
    model.parameters(), 
    lr=learning_rate, 
    weight_decay=weight_decay
)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, 
    step_size=step_size, # default=30
    gamma=gamma # default=0.1
)

In [17]:
data_transform = IncrementalDataProcessor(
    in_normalizer=None,
    out_normalizer=None,
    device=device,
    subsampling_rates=[2, 1],
    dataset_resolution=dataset_resolution,
    dataset_indices=dataset_indices,
    verbose=True,
)

data_transform = data_transform.to(device)

Original Incre Res: change index to 0
Original Incre Res: change sub to 2
Original Incre Res: change res to 4


In [18]:
l2loss = LpLoss(d=2, p=2,)
h1loss = H1Loss(d=2)
train_loss = h1loss
eval_losses = {"h1": h1loss, "l2": l2loss}
print("\n### N PARAMS ###\n", n_params)
print("\n### OPTIMIZER ###\n", optimizer)
print("\n### SCHEDULER ###\n", scheduler)
print("\n### LOSSES ###")
print("\n### INCREMENTAL RESOLUTION + GRADIENT EXPLAINED ###")
print(f"\n * Train: {train_loss}")
print(f"\n * Test: {eval_losses}")
sys.stdout.flush()


### N PARAMS ###
 45857

### OPTIMIZER ###
 AdamW (
Parameter Group 0
    betas: (0.9, 0.999)
    correct_bias: True
    eps: 1e-06
    initial_lr: 0.001
    lr: 0.001
    weight_decay: 0.0001
)

### SCHEDULER ###
 <torch.optim.lr_scheduler.StepLR object at 0x7f19c9800fe0>

### LOSSES ###

### INCREMENTAL RESOLUTION + GRADIENT EXPLAINED ###

 * Train: <xno.losses.data_losses.H1Loss object at 0x7f19c95aa690>

 * Test: {'h1': <xno.losses.data_losses.H1Loss object at 0x7f19c95aa690>, 'l2': <xno.losses.data_losses.LpLoss object at 0x7f19d2aafda0>}


In [19]:
# Finally pass all of these to the Trainer
trainer = IncrementalXNOTrainer(
    model=model,
    n_epochs=n_epochs,
    data_processor=data_transform,
    device=device,
    verbose=True,
    incremental_loss_gap=False,
    incremental_grad=True,
    incremental_grad_eps=0.9999,
    incremental_loss_eps = 0.001,
    incremental_buffer=5,
    incremental_max_iter=1,
    incremental_grad_max_iter=2,
)

In [None]:
mess = trainer.train(
    train_loader,
    test_loader,
    optimizer,
    scheduler,
    regularizer=False,
    training_loss=train_loss,
    eval_losses=eval_losses,
    # save_every=save_every,
    # save_testing=save_testing, 
    # save_dir=save_dir
)

print(mess)

Training on 495000 samples
Testing on [5000] samples         on resolutions [8].
Raw outputs of shape torch.Size([20, 1, 4])


In [None]:

# At the end of the script
sys.stdout = sys.__stdout__  # Restore original stdout
output_file.close()  # Close the file