In [1]:
# %%
from timeit import default_timer
import torch
import numpy as np
import scipy.io

from einops import rearrange
from wavebench.nn.fno import FNO2d
from wavebench.nn.lploss import LpLoss


# reading data
class MatReader(object):
  def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True):
    super(MatReader, self).__init__()

    self.to_torch = to_torch
    self.to_cuda = to_cuda
    self.to_float = to_float

    self.file_path = file_path

    self.data = None
    self.old_mat = None
    self._load_file()

  def _load_file(self):
    self.data = scipy.io.loadmat(self.file_path)
    self.old_mat = True

  def load_file(self, file_path):
    self.file_path = file_path
    self._load_file()

  def read_field(self, field):
    x = self.data[field]
    if not self.old_mat:
      x = x[()]
      x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1))

    if self.to_float:
      x = x.astype(np.float32)

    if self.to_torch:
      x = torch.from_numpy(x)

      if self.to_cuda:
        x = x.cuda()

    return x

  def set_cuda(self, to_cuda):
    self.to_cuda = to_cuda

  def set_torch(self, to_torch):
    self.to_torch = to_torch

  def set_float(self, to_float):
    self.to_float = to_float

# normalization, pointwise gaussian
class UnitGaussianNormalizer(object):
  def __init__(self, x, eps=0.00001):
    super(UnitGaussianNormalizer, self).__init__()

    # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T
    self.mean = torch.mean(x, 0)
    self.std = torch.std(x, 0)
    self.eps = eps

  def encode(self, x):
    x = (x - self.mean) / (self.std + self.eps)
    return x

  def decode(self, x, sample_idx=None):
    if sample_idx is None:
      std = self.std + self.eps # n
      mean = self.mean
    else:
      if len(self.mean.shape) == len(sample_idx[0].shape):
        std = self.std[sample_idx] + self.eps  # batch*n
        mean = self.mean[sample_idx]
      if len(self.mean.shape) > len(sample_idx[0].shape):
        std = self.std[:,sample_idx]+ self.eps # T*batch*n
        mean = self.mean[:,sample_idx]

    # x is in shape of batch*n or T*batch*n
    x = (x * std.to(x.device)) + mean.to(x.device)
    return x

  def cuda(self, device = torch.device('cuda:0')):
    self.mean = self.mean.to(device)
    self.std = self.std.to(device)

  def cpu(self):
    self.mean = self.mean.cpu()
    self.std = self.std.cpu()


In [2]:
import ml_collections

config = ml_collections.ConfigDict()
config.ntrain = 1000
config.ntest = 100
config.batch_size = 20

config.fno_modes = 12
config.fno_width = 32
config.learning_rate = 0.001
config.num_epochs = 200

## Dataset

Download the darcy flow data from the Google Drive [here](https://drive.google.com/drive/folders/1UnbQh2WWc6knEHbLn-ZaXrKUZhp7pjt-).

See also https://github.com/neuraloperator/neuraloperator/issues/7.

In [3]:
dataset_dir = '/home/liu0003/Desktop/datasets'
TRAIN_PATH = dataset_dir + '/pde_data/Darcy_421' + '/piececonst_r421_N1024_smooth1.mat'
TEST_PATH = dataset_dir + '/pde_data/Darcy_421' +  '/piececonst_r421_N1024_smooth2.mat'


ntrain = config.ntrain
ntest = config.ntest

r = 5
h = int(((421 - 1)/r) + 1)
s = h

reader = MatReader(TRAIN_PATH)
x_train = reader.read_field('coeff')[:ntrain,::r,::r][:,:s,:s]
y_train = reader.read_field('sol')[:ntrain,::r,::r][:,:s,:s]

reader.load_file(TEST_PATH)
x_test = reader.read_field('coeff')[:ntest,::r,::r][:,:s,:s]
y_test = reader.read_field('sol')[:ntest,::r,::r][:,:s,:s]

x_normalizer = UnitGaussianNormalizer(x_train)
x_train = x_normalizer.encode(x_train)
x_test = x_normalizer.encode(x_test)

y_normalizer = UnitGaussianNormalizer(y_train)
y_train = y_normalizer.encode(y_train)

x_train = x_train.reshape(ntrain,s,s,1)
x_test = x_test.reshape(ntest,s,s,1)

# batch_size = 20
train_loader = torch.utils.data.DataLoader(
  torch.utils.data.TensorDataset(x_train, y_train),
  batch_size=config.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
  torch.utils.data.TensorDataset(x_test, y_test),
  batch_size=config.batch_size, shuffle=False)

## Training

In [4]:
model = FNO2d(
  modes1=config.fno_modes,
  modes2=config.fno_modes,
  hidden_width=config.fno_width,
  ).cuda()

# epochs = 500
# iterations = epochs*(ntrain//batch_size)

optimizer = torch.optim.AdamW(
    model.parameters(), lr=config.learning_rate
    )

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config.num_epochs*len(train_loader),
    eta_min=1e-6)

myloss = LpLoss(reduction='sum')
y_normalizer.cuda()
for ep in range(config.num_epochs):
  model.train()
  t1 = default_timer()
  train_l2 = 0
  for x, y in train_loader:
    x, y = x.cuda(), y.cuda()

    optimizer.zero_grad()
    out = model(rearrange(x, 'b h w c -> b c h w')).reshape(
      config.batch_size, s, s)
    # out = model(x).reshape(batch_size, s, s)

    out = y_normalizer.decode(out)
    y = y_normalizer.decode(y)

    loss = myloss(out.view(config.batch_size,-1), y.view(config.batch_size,-1))
    loss.backward()

    optimizer.step()
    scheduler.step()
    train_l2 += loss.item()

  model.eval()
  test_l2 = 0.0
  with torch.no_grad():
    for x, y in test_loader:
      x, y = x.cuda(), y.cuda()

      out = model(rearrange(x, 'b h w c -> b c h w')).reshape(
        config.batch_size, s, s)
      # out = model(x).reshape(batch_size, s, s)
      out = y_normalizer.decode(out)

      test_l2 += myloss(
        out.view(config.batch_size,-1), y.view(config.batch_size,-1)).item()

  train_l2/= ntrain
  test_l2 /= ntest

  t2 = default_timer()
  print(ep, t2-t1, train_l2, test_l2)





0 2.870572882995475 0.11266509556770325 0.06562536358833312
1 1.6448596470290795 0.04840609937906265 0.03930075824260712
2 1.6392666759784333 0.03413216185569763 0.036280732750892636
3 1.651758732041344 0.028835083723068238 0.02738577425479889
4 1.6398576999781653 0.02435221701860428 0.026867963075637817
5 1.6487251510261558 0.021698871195316313 0.022876895368099212
6 1.6583535440149717 0.02101210716366768 0.02220281183719635
7 1.6562338539515622 0.017889699578285217 0.02010775238275528
8 1.6718751059961505 0.018747543662786484 0.022063536942005156
9 1.6607170030474663 0.01778507339954376 0.020209083259105684
10 1.6657434810185805 0.01498659098148346 0.017445169389247894
11 1.6649010769906454 0.013804593101143837 0.017885723114013673
12 1.6676555869635195 0.012834921434521675 0.016397310495376585
13 1.6696301529882476 0.012135858342051506 0.01565384805202484
14 1.6715576499700546 0.013415244534611702 0.01781418889760971
15 1.664296688977629 0.011317055001854896 0.015269584655761718
16 