forked from neuraloperator/neuraloperator
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request neuraloperator#150 from JeanKossaifi/sfno
Adds SFNO + example
- Loading branch information
Showing
9 changed files
with
660 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
""" | ||
Training a SFNO on the spherical Shallow Water equations | ||
============================= | ||
In this example, we demonstrate how to use the small Spherical Shallow Water Equations example we ship with the package | ||
to train a Spherical Fourier-Neural Operator | ||
""" | ||
|
||
# %% | ||
# | ||
|
||
|
||
import torch | ||
import matplotlib.pyplot as plt | ||
import sys | ||
from neuralop.models import SFNO | ||
from neuralop import Trainer | ||
from neuralop.datasets import load_spherical_swe | ||
from neuralop.utils import count_params | ||
from neuralop import LpLoss, H1Loss | ||
|
||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
|
||
|
||
# %% | ||
# Loading the Navier-Stokes dataset in 128x128 resolution | ||
train_loader, test_loaders = load_spherical_swe(n_train=128, batch_size=4, test_resolutions=[(128, 256), (256, 512)], n_tests=[10, 10], test_batch_sizes=[4, 4],) | ||
|
||
|
||
# %% | ||
# We create a tensorized FNO model | ||
|
||
model = SFNO(n_modes=(64, 128), in_channels=3, out_channels=3, hidden_channels=32, projection_channels=64, factorization='dense') | ||
model = model.to(device) | ||
|
||
n_params = count_params(model) | ||
print(f'\nOur model has {n_params} parameters.') | ||
sys.stdout.flush() | ||
|
||
|
||
# %% | ||
#Create the optimizer | ||
optimizer = torch.optim.Adam(model.parameters(), | ||
lr=8e-4, | ||
weight_decay=0.0) | ||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30) | ||
|
||
|
||
# %% | ||
# Creating the losses | ||
l2loss = LpLoss(d=2, p=2, reduce_dims=(0,1)) | ||
h1loss = H1Loss(d=2, reduce_dims=(0,1)) | ||
|
||
train_loss = h1loss | ||
eval_losses={'h1': h1loss, 'l2': l2loss} | ||
|
||
|
||
# %% | ||
|
||
|
||
print('\n### MODEL ###\n', model) | ||
print('\n### OPTIMIZER ###\n', optimizer) | ||
print('\n### SCHEDULER ###\n', scheduler) | ||
print('\n### LOSSES ###') | ||
print(f'\n * Train: {train_loss}') | ||
print(f'\n * Test: {eval_losses}') | ||
sys.stdout.flush() | ||
|
||
|
||
# %% | ||
# Create the trainer | ||
trainer = Trainer(model, n_epochs=20, | ||
device=device, | ||
mg_patching_levels=0, | ||
wandb_log=False, | ||
log_test_interval=3, | ||
use_distributed=False, | ||
verbose=True) | ||
|
||
|
||
# %% | ||
# Actually train the model on our small Darcy-Flow dataset | ||
|
||
trainer.train(train_loader, test_loaders, | ||
None, | ||
model, | ||
optimizer, | ||
scheduler, | ||
regularizer=False, | ||
training_loss=train_loss, | ||
eval_losses=eval_losses) | ||
|
||
|
||
# %% | ||
# Plot the prediction, and compare with the ground-truth | ||
# Note that we trained on a very small resolution for | ||
# a very small number of epochs | ||
# In practice, we would train at larger resolution, on many more samples. | ||
# | ||
# However, for practicity, we created a minimal example that | ||
# i) fits in just a few Mb of memory | ||
# ii) can be trained quickly on CPU | ||
# | ||
# In practice we would train a Neural Operator on one or multiple GPUs | ||
|
||
test_samples = test_loaders[32].dataset | ||
|
||
fig = plt.figure(figsize=(7, 7)) | ||
for index in range(3): | ||
data = test_samples[index] | ||
# Input x | ||
x = data['x'] | ||
# Ground-truth | ||
y = data['y'] | ||
# Model prediction | ||
out = model(x.unsqueeze(0)) | ||
|
||
ax = fig.add_subplot(3, 3, index*3 + 1) | ||
ax.imshow(x[0], cmap='gray') | ||
if index == 0: | ||
ax.set_title('Input x') | ||
plt.xticks([], []) | ||
plt.yticks([], []) | ||
|
||
ax = fig.add_subplot(3, 3, index*3 + 2) | ||
ax.imshow(y.squeeze()) | ||
if index == 0: | ||
ax.set_title('Ground-truth y') | ||
plt.xticks([], []) | ||
plt.yticks([], []) | ||
|
||
ax = fig.add_subplot(3, 3, index*3 + 3) | ||
ax.imshow(out.squeeze().detach().numpy()) | ||
if index == 0: | ||
ax.set_title('Model prediction') | ||
plt.xticks([], []) | ||
plt.yticks([], []) | ||
|
||
fig.suptitle('Inputs, ground-truth output and prediction.', y=0.98) | ||
plt.tight_layout() | ||
fig.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
# from .burgers import load_burgers | ||
from .darcy import load_darcy_pt, load_darcy_flow_small | ||
from .spherical_swe import load_spherical_swe | ||
from .navier_stokes import load_navier_stokes_pt | ||
#, load_navier_stokes_zarr | ||
# from .navier_stokes import load_navier_stokes_hdf5 | ||
# from .burgers import load_burgers | ||
from .darcy import load_darcy_pt, load_darcy_flow_small | ||
|
||
# from .positional_encoding import append_2d_grid_positional_encoding, get_grid_positional_encoding | ||
from .pt_dataset import load_pt_traintestsplit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from math import ceil, floor | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
from torch_harmonics.examples import ShallowWaterSolver | ||
|
||
def load_spherical_swe(n_train, n_tests, batch_size, test_batch_sizes, | ||
train_resolution=(256, 512), test_resolutions=[(256, 512)], | ||
device=torch.device('cpu')): | ||
"""Load the Spherical Shallow Water equations Dataloader""" | ||
|
||
print(f'Loading train dataloader at resolution {train_resolution} with {n_train} samples and batch-size={batch_size}') | ||
train_dataset = SphericalSWEDataset(dims=train_resolution, num_examples=n_train, device=device) | ||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, persistent_workers=False) | ||
|
||
test_loaders = dict() | ||
for (res, n_test, test_batch_size) in zip(test_resolutions, n_tests, test_batch_sizes): | ||
print(f'Loading test dataloader at resolution {res} with {n_test} samples and batch-size={test_batch_size}') | ||
|
||
test_dataset = SphericalSWEDataset(dims=res, num_examples=n_test, device=device) | ||
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=0, persistent_workers=False) | ||
test_loaders[res] = test_loader | ||
|
||
return train_loader, test_loader | ||
|
||
|
||
class SphericalSWEDataset(torch.utils.data.Dataset): | ||
"""Custom Dataset class for PDE training data""" | ||
def __init__(self, dt=3600, dims=(256, 512), initial_condition='random', num_examples=32, | ||
device=torch.device('cpu'), normalize=True, stream=None): | ||
# Caution: this is a heuristic which can break and lead to diverging results | ||
dt_min = 256 / dims[0] * 150 | ||
nsteps = int(floor(dt / dt_min)) | ||
|
||
self.num_examples = num_examples | ||
self.device = device | ||
self.stream = stream | ||
|
||
self.nlat = dims[0] | ||
self.nlon = dims[1] | ||
|
||
# number of solver steps used to compute the target | ||
self.nsteps = nsteps | ||
self.normalize = normalize | ||
|
||
lmax = ceil(self.nlat/3) | ||
mmax = lmax | ||
dt_solver = dt / float(self.nsteps) | ||
self.solver = ShallowWaterSolver(self.nlat, self.nlon, dt_solver, lmax=lmax, mmax=mmax, grid='equiangular').to(self.device).float() | ||
|
||
self.set_initial_condition(ictype=initial_condition) | ||
|
||
if self.normalize: | ||
inp0, _ = self._get_sample() | ||
self.inp_mean = torch.mean(inp0, dim=(-1, -2)).reshape(-1, 1, 1) | ||
self.inp_var = torch.var(inp0, dim=(-1, -2)).reshape(-1, 1, 1) | ||
|
||
def __len__(self): | ||
length = self.num_examples if self.ictype == 'random' else 1 | ||
return length | ||
|
||
def set_initial_condition(self, ictype='random'): | ||
self.ictype = ictype | ||
|
||
def set_num_examples(self, num_examples=32): | ||
self.num_examples = num_examples | ||
|
||
def _get_sample(self): | ||
|
||
if self.ictype == 'random': | ||
inp = self.solver.random_initial_condition(mach=0.2) | ||
elif self.ictype == 'galewsky': | ||
inp = self.solver.galewsky_initial_condition() | ||
|
||
# solve pde for n steps to return the target | ||
tar = self.solver.timestep(inp, self.nsteps) | ||
inp = self.solver.spec2grid(inp) | ||
tar = self.solver.spec2grid(tar) | ||
|
||
return inp, tar | ||
|
||
def __getitem__(self, index): | ||
|
||
with torch.inference_mode(): | ||
with torch.no_grad(): | ||
inp, tar = self._get_sample() | ||
|
||
if self.normalize: | ||
inp = (inp - self.inp_mean) / torch.sqrt(self.inp_var) | ||
tar = (tar - self.inp_mean) / torch.sqrt(self.inp_var) | ||
|
||
return {'x': inp.clone(), 'y': tar.clone()} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .tfno import TFNO, TFNO1d, TFNO2d, TFNO3d | ||
from .tfno import FNO, FNO1d, FNO2d, FNO3d | ||
from .tfno import SFNO | ||
from .uno import UNO | ||
from .model_dispatcher import get_model |
Oops, something went wrong.