# Training

In [1]:
import torch
import scipy.io as sio
import glob
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.nn.utils import clip_grad_norm_
from neuralop.models import FNO
from neuralop import Trainer
from neuralop.training import AdamW
from neuralop import LpLoss, H1Loss
from neuralop.utils import count_model_params
from neuralop.data.transforms.normalizers import UnitGaussianNormalizer
from abc import abstractmethod
import time

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# --- 1. Find and Load All Data Files ---
data_path = './FNO_Dataset_PT/'
file_paths = glob.glob(f"{data_path}/FNO_dataset_run_*.pt")
file_paths.sort()

if not file_paths:
    raise FileNotFoundError(f"No .mat files found in {data_path}")

print(f"Found {len(file_paths)} data files.")

# --- 2. Split File Paths into Train and Test ---
train_split = int(0.2 * len(file_paths))
train_paths = file_paths[:train_split]
test_paths = file_paths[train_split:]

# --- 3. Load ALL Data into RAM for Normalization ---
# This is the workflow you prefer. It requires
# loading all training data into memory first.

def load_data_from_paths(paths): # <-- Removed data_key
    all_tensors = []
    for p in paths:
        try:
            # Load the .pt file directly as a tensor
            tensor_data = torch.load(p).float() # <-- Changed loading function
            all_tensors.append(tensor_data)
        except Exception as e:
            print(f"Warning: Error loading {p}: {e}")
    # Concatenate all runs along the time dimension (dim=0)
    return torch.cat(all_tensors, dim=0)

print("Loading training data into memory...")
train_data_sequence = load_data_from_paths(train_paths)
print(f"Full training sequence shape: {train_data_sequence.shape}")

print("Loading test data into memory...")
test_data_sequence = load_data_from_paths(test_paths)
print(f"Full test sequence shape: {test_data_sequence.shape}")

# --- 4. Fit and Transform (Your Method) ---
# Create the normalizer
normalizer = UnitGaussianNormalizer(dim=[0, 2, 3, 4]) 

# Fit ONLY on the training data
print("Fitting normalizer on training data...")
normalizer.fit(train_data_sequence)
print("Fit complete.")

# Transform both sets
print("Normalizing data...")
train_data = normalizer.transform(train_data_sequence)
test_data = normalizer.transform(test_data_sequence)

# --- ADD THIS SANITY CHECK ---
print(f"Normalized train data mean: {train_data.mean()}")
print(f"Normalized train data std: {train_data.std()}")
# -----------------------------

# Free up memory
del train_data_sequence
del test_data_sequence
print("Normalization complete. Raw data cleared from RAM.")

Found 1000 data files.
Loading training data into memory...
Full training sequence shape: torch.Size([20200, 3, 32, 32, 32])
Loading test data into memory...
Full test sequence shape: torch.Size([80800, 3, 32, 32, 32])
Fitting normalizer on training data...
Fit complete.
Normalizing data...
Normalized train data mean: -2.0463277738969055e-09
Normalized train data std: 0.9999977946281433
Normalization complete. Raw data cleared from RAM.


In [3]:
# --- 5. Define Simple Dataset Class ---
class TimeSteppingDataset(Dataset):
    """
    A simple dataset that just returns the (t, t+1) pairs
    from a pre-normalized data sequence.
    """
    def __init__(self, data_sequence):
        self.data = data_sequence

    def __len__(self):
        return self.data.shape[0] - 1

    def __getitem__(self, idx):
        return {'x': self.data[idx], 'y': self.data[idx + 1]}

In [4]:
# --- 6. Create Datasets and DataLoaders ---
# Create the datasets from your NEW normalized tensors
train_dataset = TimeSteppingDataset(train_data)
test_dataset = TimeSteppingDataset(test_data)

# Create the DataLoaders
# Try a small batch size first due to memory
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)

In [5]:
# --- 7. Define Model, Optimizer, Loss ---
model = FNO(
    n_modes=(16, 16, 16),
    hidden_channels=8,
    in_channels=3,
    out_channels=3,
    n_layers=2
).to(device) 

print(f"Model has {count_model_params(model)} parameters.")

optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4) # Using the lower 1e-4 lr
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
l2loss = LpLoss(d=3, p=2)
h1loss = H1Loss(d=3)

# --- 8. Create Trainer (No Processor) ---
trainer = Trainer(model=model, n_epochs=1000,
                  device=device,
                  wandb_log=False,
                  eval_interval=1,
                  use_distributed=False,
                  verbose=True)

Model has 590579 parameters.


In [6]:
# --- 9. Start Training ---
print("Starting training on full, normalized dataset...")
# Use the shape of one test sample as the key
test_key = test_data[0].shape[1]
trainer.train(train_loader=train_loader,
              test_loaders={test_key: test_loader},
              optimizer=optimizer,
              scheduler=scheduler,
              training_loss=h1loss,
              eval_losses={'h1': h1loss, 'l2': l2loss},
                save_every=50,
                save_dir='./checkpoints/')

Starting training on full, normalized dataset...
Training on 20199 samples
Testing on [80799] samples         on resolutions [32].


  return forward_call(*args, **kwargs)


Raw outputs of shape torch.Size([100, 3, 32, 32, 32])


  loss += training_loss(out, **sample)


[0] time=32.95, avg_loss=847389368.4559, train_err=84734741848.7129


  val_loss = loss(out, **sample)


Eval: 32_h1=301908352.0000, 32_l2=224846240.0000
[Rank 0]: saved training state to ./checkpoints/
[1] time=30.92, avg_loss=101753114.4251, train_err=10174807714.2178
Eval: 32_h1=44625072.0000, 32_l2=4055993.7500
[2] time=32.18, avg_loss=31942527.1160, train_err=3194094580.2772
Eval: 32_h1=22646602.0000, 32_l2=2233443.7500
[3] time=32.09, avg_loss=17798774.0256, train_err=1779789289.8218
Eval: 32_h1=14318814.0000, 32_l2=1452944.5000
[4] time=31.30, avg_loss=11858983.9790, train_err=1185839690.0594
Eval: 32_h1=10272504.0000, 32_l2=1047596.0000
[5] time=31.33, avg_loss=8838845.5048, train_err=883840793.8218
Eval: 32_h1=8060312.0000, 32_l2=813801.0000
[6] time=32.46, avg_loss=7091570.3375, train_err=709121926.9703
Eval: 32_h1=6752081.5000, 32_l2=939951.3750
[7] time=32.99, avg_loss=5985795.7812, train_err=598549945.4653
Eval: 32_h1=5791120.0000, 32_l2=729962.0625
[8] time=32.01, avg_loss=5240282.2409, train_err=524002282.0990
Eval: 32_h1=5051724.0000, 32_l2=530373.5000
[9] time=31.22, avg_

KeyboardInterrupt: 

# Testing