Skip to content

Commit

Permalink
Merge pull request neuraloperator#150 from JeanKossaifi/sfno
Browse files Browse the repository at this point in the history
Adds SFNO + example
  • Loading branch information
JeanKossaifi committed Jun 9, 2023
2 parents 835707c + 7bc044f commit 22f60bd
Show file tree
Hide file tree
Showing 9 changed files with 660 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
python -m pip install -r requirements.txt
- name: Install package
run: |
python -m pip install -e .
Expand Down
141 changes: 141 additions & 0 deletions examples/plot_SFNO_swe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""
Training a SFNO on the spherical Shallow Water equations
=============================
In this example, we demonstrate how to use the small Spherical Shallow Water Equations example we ship with the package
to train a Spherical Fourier-Neural Operator
"""

# %%
#


import torch
import matplotlib.pyplot as plt
import sys
from neuralop.models import SFNO
from neuralop import Trainer
from neuralop.datasets import load_spherical_swe
from neuralop.utils import count_params
from neuralop import LpLoss, H1Loss

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


# %%
# Loading the Navier-Stokes dataset in 128x128 resolution
train_loader, test_loaders = load_spherical_swe(n_train=128, batch_size=4, test_resolutions=[(128, 256), (256, 512)], n_tests=[10, 10], test_batch_sizes=[4, 4],)


# %%
# We create a tensorized FNO model

model = SFNO(n_modes=(64, 128), in_channels=3, out_channels=3, hidden_channels=32, projection_channels=64, factorization='dense')
model = model.to(device)

n_params = count_params(model)
print(f'\nOur model has {n_params} parameters.')
sys.stdout.flush()


# %%
#Create the optimizer
optimizer = torch.optim.Adam(model.parameters(),
lr=8e-4,
weight_decay=0.0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)


# %%
# Creating the losses
l2loss = LpLoss(d=2, p=2, reduce_dims=(0,1))
h1loss = H1Loss(d=2, reduce_dims=(0,1))

train_loss = h1loss
eval_losses={'h1': h1loss, 'l2': l2loss}


# %%


print('\n### MODEL ###\n', model)
print('\n### OPTIMIZER ###\n', optimizer)
print('\n### SCHEDULER ###\n', scheduler)
print('\n### LOSSES ###')
print(f'\n * Train: {train_loss}')
print(f'\n * Test: {eval_losses}')
sys.stdout.flush()


# %%
# Create the trainer
trainer = Trainer(model, n_epochs=20,
device=device,
mg_patching_levels=0,
wandb_log=False,
log_test_interval=3,
use_distributed=False,
verbose=True)


# %%
# Actually train the model on our small Darcy-Flow dataset

trainer.train(train_loader, test_loaders,
None,
model,
optimizer,
scheduler,
regularizer=False,
training_loss=train_loss,
eval_losses=eval_losses)


# %%
# Plot the prediction, and compare with the ground-truth
# Note that we trained on a very small resolution for
# a very small number of epochs
# In practice, we would train at larger resolution, on many more samples.
#
# However, for practicity, we created a minimal example that
# i) fits in just a few Mb of memory
# ii) can be trained quickly on CPU
#
# In practice we would train a Neural Operator on one or multiple GPUs

test_samples = test_loaders[32].dataset

fig = plt.figure(figsize=(7, 7))
for index in range(3):
data = test_samples[index]
# Input x
x = data['x']
# Ground-truth
y = data['y']
# Model prediction
out = model(x.unsqueeze(0))

ax = fig.add_subplot(3, 3, index*3 + 1)
ax.imshow(x[0], cmap='gray')
if index == 0:
ax.set_title('Input x')
plt.xticks([], [])
plt.yticks([], [])

ax = fig.add_subplot(3, 3, index*3 + 2)
ax.imshow(y.squeeze())
if index == 0:
ax.set_title('Ground-truth y')
plt.xticks([], [])
plt.yticks([], [])

ax = fig.add_subplot(3, 3, index*3 + 3)
ax.imshow(out.squeeze().detach().numpy())
if index == 0:
ax.set_title('Model prediction')
plt.xticks([], [])
plt.yticks([], [])

fig.suptitle('Inputs, ground-truth output and prediction.', y=0.98)
plt.tight_layout()
fig.show()
6 changes: 4 additions & 2 deletions neuralop/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# from .burgers import load_burgers
from .darcy import load_darcy_pt, load_darcy_flow_small
from .spherical_swe import load_spherical_swe
from .navier_stokes import load_navier_stokes_pt
#, load_navier_stokes_zarr
# from .navier_stokes import load_navier_stokes_hdf5
# from .burgers import load_burgers
from .darcy import load_darcy_pt, load_darcy_flow_small

# from .positional_encoding import append_2d_grid_positional_encoding, get_grid_positional_encoding
from .pt_dataset import load_pt_traintestsplit
92 changes: 92 additions & 0 deletions neuralop/datasets/spherical_swe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from math import ceil, floor

import torch
from torch.utils.data import DataLoader
from torch_harmonics.examples import ShallowWaterSolver

def load_spherical_swe(n_train, n_tests, batch_size, test_batch_sizes,
train_resolution=(256, 512), test_resolutions=[(256, 512)],
device=torch.device('cpu')):
"""Load the Spherical Shallow Water equations Dataloader"""

print(f'Loading train dataloader at resolution {train_resolution} with {n_train} samples and batch-size={batch_size}')
train_dataset = SphericalSWEDataset(dims=train_resolution, num_examples=n_train, device=device)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, persistent_workers=False)

test_loaders = dict()
for (res, n_test, test_batch_size) in zip(test_resolutions, n_tests, test_batch_sizes):
print(f'Loading test dataloader at resolution {res} with {n_test} samples and batch-size={test_batch_size}')

test_dataset = SphericalSWEDataset(dims=res, num_examples=n_test, device=device)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=0, persistent_workers=False)
test_loaders[res] = test_loader

return train_loader, test_loader


class SphericalSWEDataset(torch.utils.data.Dataset):
"""Custom Dataset class for PDE training data"""
def __init__(self, dt=3600, dims=(256, 512), initial_condition='random', num_examples=32,
device=torch.device('cpu'), normalize=True, stream=None):
# Caution: this is a heuristic which can break and lead to diverging results
dt_min = 256 / dims[0] * 150
nsteps = int(floor(dt / dt_min))

self.num_examples = num_examples
self.device = device
self.stream = stream

self.nlat = dims[0]
self.nlon = dims[1]

# number of solver steps used to compute the target
self.nsteps = nsteps
self.normalize = normalize

lmax = ceil(self.nlat/3)
mmax = lmax
dt_solver = dt / float(self.nsteps)
self.solver = ShallowWaterSolver(self.nlat, self.nlon, dt_solver, lmax=lmax, mmax=mmax, grid='equiangular').to(self.device).float()

self.set_initial_condition(ictype=initial_condition)

if self.normalize:
inp0, _ = self._get_sample()
self.inp_mean = torch.mean(inp0, dim=(-1, -2)).reshape(-1, 1, 1)
self.inp_var = torch.var(inp0, dim=(-1, -2)).reshape(-1, 1, 1)

def __len__(self):
length = self.num_examples if self.ictype == 'random' else 1
return length

def set_initial_condition(self, ictype='random'):
self.ictype = ictype

def set_num_examples(self, num_examples=32):
self.num_examples = num_examples

def _get_sample(self):

if self.ictype == 'random':
inp = self.solver.random_initial_condition(mach=0.2)
elif self.ictype == 'galewsky':
inp = self.solver.galewsky_initial_condition()

# solve pde for n steps to return the target
tar = self.solver.timestep(inp, self.nsteps)
inp = self.solver.spec2grid(inp)
tar = self.solver.spec2grid(tar)

return inp, tar

def __getitem__(self, index):

with torch.inference_mode():
with torch.no_grad():
inp, tar = self._get_sample()

if self.normalize:
inp = (inp - self.inp_mean) / torch.sqrt(self.inp_var)
tar = (tar - self.inp_mean) / torch.sqrt(self.inp_var)

return {'x': inp.clone(), 'y': tar.clone()}
1 change: 1 addition & 0 deletions neuralop/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .tfno import TFNO, TFNO1d, TFNO2d, TFNO3d
from .tfno import FNO, FNO1d, FNO2d, FNO3d
from .tfno import SFNO
from .uno import UNO
from .model_dispatcher import get_model
Loading

0 comments on commit 22f60bd

Please sign in to comment.