Skip to content

Commit

Permalink
Merge pull request neuraloperator#185 from btolooshams/main
Browse files Browse the repository at this point in the history
adding flag option to only pad the last dim
  • Loading branch information
JeanKossaifi committed Aug 11, 2023
2 parents 89c13fe + be8fdd0 commit ddfb10d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 23 deletions.
44 changes: 33 additions & 11 deletions neuralop/layers/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ class DomainPadding(nn.Module):
Parameters
----------
domain_padding : float
domain_padding : float or list
typically, between zero and one, percentage of padding to use
if a list, make sure if matches the dim of (d1, ..., dN)
padding_mode : {'symmetric', 'one-sided'}, optional
whether to pad on both sides, by default 'one-sided'
Notes
-----
This class works for any input resolution, as long as it is in the form
Expand Down Expand Up @@ -39,8 +40,12 @@ def pad(self, x):
"""
resolution = x.shape[2:]

# if domain_padding is list, then to pass on
if isinstance(self.domain_padding, (float, int)):
self.domain_padding = [float(self.domain_padding)]*len(resolution)

assert len(self.domain_padding) == len(resolution), "domain_padding length must match the number of spatial/time dimensions (excluding batch, ch)"

if self.output_scaling_factor is None:
self.output_scaling_factor = [1]*len(resolution)
elif isinstance(self.output_scaling_factor, (float, int)):
Expand All @@ -52,33 +57,50 @@ def pad(self, x):

except KeyError:
padding = [int(round(p*r)) for (p, r) in zip(self.domain_padding, resolution)]

print(f'Padding inputs of {resolution=} with {padding=}, {self.padding_mode}')



# padding is being applied in reverse order (so we must reverse the padding list)
padding = padding[::-1]

output_pad = padding

output_pad = [int(round(i*j)) for (i,j) in zip(self.output_scaling_factor,output_pad)]



# the F.pad(x, padding) funtion pads the tensor 'x' in reverse order of the "padding" list i.e. the last axis of tensor 'x' will be
# padded by the amount mention at the first position of the 'padding' vector.
# The details about F.pad can be found here : https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html

if self.padding_mode == 'symmetric':
# Pad both sides
unpad_indices = (Ellipsis, ) + tuple([slice(p, -p, None) for p in output_pad[::-1] ])
unpad_list = list()
for p in output_pad[::-1]:
if p == 0:
padding_end = None
padding_start = None
else:
padding_end = p
padding_start = -p
unpad_list.append(slice(padding_end, padding_start, None))
unpad_indices = (Ellipsis, ) + tuple(unpad_list)

padding = [i for p in padding for i in (p, p)]

elif self.padding_mode == 'one-sided':
# One-side padding
unpad_indices = (Ellipsis, ) + tuple([slice(None, -p, None) for p in output_pad[::-1]])
unpad_list = list()
for p in output_pad[::-1]:
if p == 0:
padding_start = None
else:
padding_start = -p
unpad_list.append(slice(None, padding_start, None))
unpad_indices = (Ellipsis, ) + tuple(unpad_list)
padding = [i for p in padding for i in (0, p)]
else:
raise ValueError(f'Got {self.padding_mode=}')

self._padding[f'{resolution}'] = padding


padded = F.pad(x, padding, mode='constant')

Expand Down
38 changes: 34 additions & 4 deletions neuralop/layers/tests/test_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,46 @@
import pytest

@pytest.mark.parametrize('mode', ['one-sided', 'symmetric'])
def test_DomainPadding(mode):
out_size = {'one-sided': 12, 'symmetric': 14}
@pytest.mark.parametrize('padding', [0.2, [0.1, 0.2]])
def test_DomainPadding_2d(mode, padding):
if isinstance(padding, float):
out_size = {'one-sided': [12, 12], 'symmetric': [14, 14]}
else:
out_size = {'one-sided': [11, 12], 'symmetric': [12, 14]}

data = torch.randn((2, 3, 10, 10))
padder = DomainPadding(0.2, mode)
padder = DomainPadding(padding, mode)
padded = padder.pad(data)

target_shape = list(padded.shape)
target_shape[-1] = target_shape[-2] = out_size[mode]
# create the target shape from hardcoded out_size
for pad_dim in range(1,3):
target_shape[-pad_dim] = out_size[mode][-pad_dim]
assert list(padded.shape) == target_shape

unpadded = padder.unpad(padded)
assert unpadded.shape == data.shape


@pytest.mark.parametrize('mode', ['one-sided', 'symmetric'])
@pytest.mark.parametrize('padding', [0.2, [0.1, 0, 0.2]])
def test_DomainPadding_3d(mode, padding):
if isinstance(padding, float):
out_size = {'one-sided': [12, 12, 12], 'symmetric': [14, 14, 14]}
else:
out_size = {'one-sided': [11, 10, 12], 'symmetric': [12, 10, 14]}

data = torch.randn((2, 3, 10, 10, 10))
padder = DomainPadding(padding, mode)
padded = padder.pad(data)

target_shape = list(padded.shape)
# create the target shape from hardcoded out_size
for pad_dim in range(1,4):
target_shape[-pad_dim] = out_size[mode][-pad_dim]
assert list(padded.shape) == target_shape

unpadded = padder.unpad(padded)
assert unpadded.shape == data.shape


20 changes: 12 additions & 8 deletions neuralop/models/fno.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from functools import partialmethod

Expand Down Expand Up @@ -44,18 +45,20 @@ class FNO(nn.Module):
By default None, otherwise tanh is used before FFT in the FNO block
use_mlp : bool, optional
Whether to use an MLP layer after each FNO block, by default False
mlp_dropout : float
droupout parameter of MLP layer (default is 0)
mlp_expansion : float
expansion parameter of MLP layer (default is 0.5)
mlp_dropout : float , optional
droupout parameter of MLP layer, by default 0
mlp_expansion : float, optional
expansion parameter of MLP layer, by default 0.5
non_linearity : nn.Module, optional
Non-Linearity module to use, by default F.gelu
norm : F.module, optional
Normalization layer to use, by default None
preactivation : bool, default is False
if True, use resnet-style preactivation
skip : {'linear', 'identity', 'soft-gating'}, optional
Type of skip connection to use, by default 'soft-gating'
fno_skip : {'linear', 'identity', 'soft-gating'}, optional
Type of skip connection to use in fno, by default 'linear'
mlp_skip : {'linear', 'identity', 'soft-gating'}, optional
Type of skip connection to use in mlp, by default 'soft-gating'
separable : bool, default is False
if True, use a depthwise separable spectral convolution
factorization : str or None, {'tucker', 'cp', 'tt'}
Expand Down Expand Up @@ -135,10 +138,11 @@ def __init__(self, n_modes, hidden_channels,
# When updated, change should be reflected in fno blocks
self._incremental_n_modes = incremental_n_modes

if domain_padding is not None and domain_padding > 0:
self.domain_padding = DomainPadding(domain_padding=domain_padding, padding_mode=domain_padding_mode, output_scaling_factor=output_scaling_factor)
if domain_padding is not None and ((isinstance(domain_padding, list) and sum(domain_padding) > 0) or (isinstance(domain_padding, (float, int)) and domain_padding > 0)):
self.domain_padding = DomainPadding(domain_padding=domain_padding, padding_mode=domain_padding_mode, output_scaling_factor=output_scaling_factor)
else:
self.domain_padding = None

self.domain_padding_mode = domain_padding_mode

if output_scaling_factor is not None and not joint_factorization:
Expand Down

0 comments on commit ddfb10d

Please sign in to comment.