Skip to content

Commit

Permalink
Merge pull request neuraloperator#147 from JeanKossaifi/superres
Browse files Browse the repository at this point in the history
Adds super-resolution to FNO
  • Loading branch information
JeanKossaifi committed Jun 1, 2023
2 parents b98e89f + a2edc7f commit ad06a4d
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 9 deletions.
10 changes: 7 additions & 3 deletions neuralop/models/fno_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def __init__(self, in_channels, out_channels, n_modes,

if output_scaling_factor is not None:
if isinstance(output_scaling_factor, (float, int)):
output_scaling_factor = [float(output_scaling_factor)]*len(self.n_modes)
output_scaling_factor = [[float(output_scaling_factor)]*len(self.n_modes)]*n_layers
elif isinstance(output_scaling_factor[0], (float, int)):
output_scaling_factor = [[s]*len(self.n_modes) for s in output_scaling_factor]
self.output_scaling_factor = output_scaling_factor

self._incremental_n_modes = incremental_n_modes
Expand Down Expand Up @@ -107,12 +109,14 @@ def forward(self, x, index=0):

x_skip_fno = self.fno_skips[index](x)
if self.convs.output_scaling_factor is not None:
x_skip_fno = resample(x_skip_fno, self.convs.output_scaling_factor, list(range(-len(self.convs.output_scaling_factor), 0)))
# x_skip_fno = resample(x_skip_fno, self.convs.output_scaling_factor[index], list(range(-len(self.convs.output_scaling_factor[index]), 0)))
x_skip_fno = resample(x_skip_fno, self.output_scaling_factor[index], list(range(-len(self.output_scaling_factor[index]), 0)))

if self.mlp is not None:
x_skip_mlp = self.mlp_skips[index](x)
if self.convs.output_scaling_factor is not None:
x_skip_mlp = resample(x_skip_mlp, self.convs.output_scaling_factor, list(range(-len(self.convs.output_scaling_factor), 0)))
# x_skip_mlp = resample(x_skip_mlp, self.convs.output_scaling_factor[index], list(range(-len(self.convs.output_scaling_factor[index]), 0)))
x_skip_mlp = resample(x_skip_mlp, self.output_scaling_factor[index], list(range(-len(self.output_scaling_factor[index]), 0)))

x_fno = self.convs(x, index)

Expand Down
8 changes: 5 additions & 3 deletions neuralop/models/spectral_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def __init__(self, in_channels, out_channels, n_modes, incremental_n_modes=None,
half_total_n_modes = [m//2 for m in n_modes]
self.half_total_n_modes = half_total_n_modes

# WE use half_total_n_modes to build the full weights
# We use half_total_n_modes to build the full weights
# During training we can adjust incremental_n_modes which will also
# update half_n_modes
# So that we can train on a smaller part of the Fourier modes and total weights
Expand All @@ -215,7 +215,9 @@ def __init__(self, in_channels, out_channels, n_modes, incremental_n_modes=None,

if output_scaling_factor is not None:
if isinstance(output_scaling_factor, (float, int)):
output_scaling_factor = [float(output_scaling_factor)]*len(self.n_modes)
output_scaling_factor = [[float(output_scaling_factor)]*len(self.n_modes)]*n_layers
elif isinstance(output_scaling_factor[0], (float, int)):
output_scaling_factor = [[s]*len(self.n_modes) for s in output_scaling_factor]
self.output_scaling_factor = output_scaling_factor

if init_std == 'auto':
Expand Down Expand Up @@ -335,7 +337,7 @@ def forward(self, x, indices=0):
out_fft[idx_tuple] = self._contract(x[idx_tuple], self._get_weight(self.n_weights_per_layer*indices + i), separable=self.separable)

if self.output_scaling_factor is not None:
mode_sizes = tuple([int(round(s*r)) for (s, r) in zip(mode_sizes, self.output_scaling_factor)])
mode_sizes = tuple([int(round(s*r)) for (s, r) in zip(mode_sizes, self.output_scaling_factor[indices])])

x = torch.fft.irfftn(out_fft, s=(mode_sizes), norm=self.fft_norm)

Expand Down
42 changes: 39 additions & 3 deletions neuralop/models/tests/test_tfno.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@

import torch
from neuralop import TFNO3d, TFNO2d, TFNO1d, TFNO
from neuralop.models import FNO
import pytest
from tensorly import tenalg
from math import prod
from configmypy import Bunch
tenalg.set_backend('einsum')

Expand Down Expand Up @@ -35,8 +37,8 @@ def test_tfno(factorization, implementation, n_dim):

rank = 0.2
size = (s, )*n_dim
m_modes = (modes,)*n_dim
model = TFNO(hidden_channels=width, n_modes=m_modes,
n_modes = (modes,)*n_dim
model = TFNO(hidden_channels=width, n_modes=n_modes,
factorization=factorization,
implementation=implementation,
rank=rank,
Expand All @@ -62,4 +64,38 @@ def test_tfno(factorization, implementation, n_dim):
if param.grad is None:
n_unused_params += 1
assert n_unused_params == 0, f'{n_unused_params} parameters were unused!'


@pytest.mark.parametrize('output_scaling_factor',
[[2, 1, 1], [1, 2, 1], [1, 1, 2], [1, 2, 2], [1, 0.5, 1]])
def test_fno_superresolution(output_scaling_factor):
device = 'cpu'
s = 16
modes = 5
hidden_channels = 15
fc_channels = 32
batch_size = 3
n_layers = 3
use_mlp = True
n_dim = 2
rank = 0.2
size = (s, )*n_dim
n_modes = (modes,)*n_dim

model = FNO(n_modes, hidden_channels,
in_channels=3,
out_channels=1,
factorization='cp',
implementation='reconstructed',
rank=rank,
output_scaling_factor=output_scaling_factor,
n_layers=n_layers,
use_mlp=use_mlp,
fc_channels=fc_channels).to(device)

in_data = torch.randn(batch_size, 3, *size).to(device)
# Test forward pass
out = model(in_data)

# Check output size
factor = prod(output_scaling_factor)
assert list(out.shape) == [batch_size, 1] + [int(round(factor*s)) for s in size]
7 changes: 7 additions & 0 deletions neuralop/models/tfno.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(self, n_modes, hidden_channels,
lifting_channels=256,
projection_channels=256,
n_layers=4,
output_scaling_factor=None,
incremental_n_modes=None,
use_mlp=False, mlp_dropout=0, mlp_expansion=0.5,
non_linearity=F.gelu,
Expand Down Expand Up @@ -159,10 +160,16 @@ def __init__(self, n_modes, hidden_channels,
self.domain_padding = None
self.domain_padding_mode = domain_padding_mode

if output_scaling_factor is not None and not joint_factorization:
if isinstance(output_scaling_factor, (float, int)):
output_scaling_factor = [output_scaling_factor]*self.n_layers
self.output_scaling_factor = output_scaling_factor

self.fno_blocks = FNOBlocks(
in_channels=hidden_channels,
out_channels=hidden_channels,
n_modes=self.n_modes,
output_scaling_factor=output_scaling_factor,
use_mlp=use_mlp, mlp_dropout=mlp_dropout, mlp_expansion=mlp_expansion,
non_linearity=non_linearity,
norm=norm, preactivation=preactivation,
Expand Down

0 comments on commit ad06a4d

Please sign in to comment.