In [None]:
from dataclasses import dataclass, field
from torch import Tensor
from tqdm.auto import tqdm
from typing import *

import itertools
import mat73
import matplotlib.pyplot as plt
import numpy as np
import os
import scipy.io
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
@dataclass
class FNOConfig:
    nx: int
    ny: int
    nt_in: int
    nt_out: int
    mx: int
    my: int
    mt: int
    nc_in: int = 2
    nc_lift: int = 20
    nc_out: int = 1
    n_blocks: int = 4
    act_fn: Callable = F.gelu
    device: torch.device = torch.device('cuda')
    dtype: torch.dtype = torch.float32
    dtype_complex: torch.dtype = field(init=False)
        
    def __post_init__(self) -> None:
        self.dtype_complex = torch.complex64 if self.dtype == torch.float32 else torch.complex128

In [None]:
class SpectralConv(nn.Module):
    
    def __init__(self, config: FNOConfig) -> None:
        super().__init__()
        self.slices = []
        self.weights = nn.ParameterList([])
        
        for tup in itertools.product((0, 1), (0, 1), (0,)):
            sl = [slice(None, None, 1), slice(None, None, 1)]
            shape = [config.nx, config.ny, config.nt_out]
            modes = [config.mx, config.my, config.mt]
            
            for i, t in enumerate(tup):
                if t == 0:
                    sl.append(slice(None, modes[i], 1))
                else:
                    sl.append(slice(shape[i]-modes[i], shape[i], 1))
            
            self.slices.append(tuple(sl))
            weight = nn.Parameter(torch.empty(
                config.nc_lift,
                config.nc_lift,
                config.mx,
                config.my,
                config.mt,
                device=config.device,
                dtype=config.dtype_complex
            ))
            nn.init.kaiming_normal_(weight)
            self.weights.append(weight)
            
    def forward(self, x: Tensor) -> Tensor:
        x_ft = torch.fft.rfftn(x, dim=(2, 3, 4))
        y_ft = torch.zeros_like(x_ft)
        for sl, w in zip(self.slices, self.weights):
            y_ft[sl] = torch.einsum('dcxyt,bcxyt->bdxyt', w, x_ft[sl])
        y = torch.fft.irfftn(y_ft, s=x.shape[2:])
        return y

In [None]:
class FNOBlock(nn.Module):
    
    def __init__(self, config: FNOConfig) -> None:
        super().__init__()
        self.sconv = SpectralConv(config)
        self.weight = nn.Parameter(torch.empty(config.nc_lift, config.nc_lift, device=config.device, dtype=config.dtype))
        nn.init.kaiming_normal_(self.weight)
        self.act_fn = config.act_fn
        
    def forward(self, x: Tensor) -> Tensor:
        x0 = self.sconv(x)
        x1 = torch.einsum('dc,bcxyt->bdxyt', self.weight, x)
        y = self.act_fn(x0 + x1)
        return y

In [None]:
class FNO(nn.Module):
    
    def __init__(self, config: FNOConfig) -> None:
        super().__init__()
        
        # Lifting parameters
        self.c_lift = nn.Parameter(torch.empty(config.nc_lift, config.nc_in, device=config.device, dtype=config.dtype))
        self.t_lift = nn.Parameter(torch.empty(config.nt_out,  config.nt_in, device=config.device, dtype=config.dtype))    
        self.c_lift_bias = nn.Parameter(torch.zeros(1, config.nc_lift, 1, 1, 1))
        self.t_lift_bias = nn.Parameter(torch.zeros(1, 1, 1, 1, config.nt_out))
        
        # Projection parameters
        self.c_proj0 = nn.Parameter(torch.empty(128, config.nc_lift, device=config.device, dtype=config.dtype))
        self.c_proj1 = nn.Parameter(torch.empty(config.nc_out, 128,  device=config.device, dtype=config.dtype))
        self.c_proj0_bias = nn.Parameter(torch.zeros(1, 128, 1, 1, 1))
        self.c_proj1_bias = nn.Parameter(torch.zeros(1, config.nc_out, 1, 1, 1))
        
        nn.init.kaiming_normal_(self.c_lift)
        nn.init.kaiming_normal_(self.t_lift)
        nn.init.kaiming_normal_(self.c_proj0)
        nn.init.kaiming_normal_(self.c_proj1)
        
        self.blocks = nn.ModuleList([FNOBlock(config) for _ in range(config.n_blocks)])
        self.act_fn = config.act_fn
        
    def forward(self, x: Tensor) -> Tensor:

        # Lift channels
        x = torch.einsum('dc,bcxyt->bdxyt', self.c_lift, x)
        x = x + self.c_lift_bias
        x = self.act_fn(x)

        # Lift time
        x = torch.einsum('qt,bcxyt->bcxyq', self.t_lift, x)
        x = x + self.t_lift_bias
        x = self.act_fn(x)

        # Perform blocks
        for block in self.blocks:
            x = block(x)

        # Project channels
        x = torch.einsum('dc,bcxyt->bdxyt', self.c_proj0, x)
        x = x + self.c_proj0_bias
        x = self.act_fn(x)

        x = torch.einsum('dc,bcxyt->bdxyt', self.c_proj1, x)
        x = x + self.c_proj1_bias
        return x        

In [None]:
np.random.seed(123)
torch.manual_seed(123)
config = FNOConfig(nx=64, ny=64, nt_in=10, nt_out=40, mx=8, my=8, mt=8, device=torch.device('cpu'))
model = FNO(config)

In [None]:
# Your data path might be different! Download PDE dataset files from this drive:
# https://drive.google.com/drive/folders/1UnbQh2WWc6knEHbLn-ZaXrKUZhp7pjt-

data_path = os.path.expanduser('~/data/ns_V1e-3_N5000_T50_1200.mat')
#data = mat73.loadmat(data_path) # Uncomment this line if it is a MATLAB v7.3 file
data = scipy.io.loadmat(data_path)

In [None]:
# If your data is not in (batch, channels, x, y, t) format, write code
# here to reshape it as such

In [None]:
# Construct grid
grid = np.einsum('x,y,t->xyt',
                 np.linspace(0, 1, config.nx),
                 np.linspace(0, 1, config.ny),
                 np.linspace(0, 1, config.nt_in))

grid = torch.tensor(grid).to(device=config.device, dtype=config.dtype).view(1, 1, *grid.shape)

In [None]:
n_data = 1000
train_split = 0.8
n_train = int(train_split*n_data)
n_valid = n_data - n_train

u = data['u']
x_train = u[:n_train,:,:,:,:config.nt_in]
x_valid = u[n_train:,:,:,:,:config.nt_in]
y_train = u[:n_train,:,:,:,config.nt_in:]
y_valid = u[n_train:,:,:,:,config.nt_in:]

In [None]:
optim = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
optim.zero_grad(set_to_none=True)

In [None]:
batch_size = 10
n_epochs = 100

for e in range(n_epochs):
    
    batch_idxs = [(i,i+batch_size) for i in range(0,n_train,batch_size)]
    np.random.shuffle(batch_idxs)
    
    pbar = tqdm(enumerate(batch_idxs))
    for j, (a, b) in pbar:
        
        x = torch.tensor(x_train[a:b]).to(device=config.device, dtype=config.dtype)
        x = torch.cat((x, grid.repeat(x.shape[0], 1, 1, 1, 1)), dim=1)
        y = torch.tensor(y_train[a:b]).to(device=config.device, dtype=config.dtype)
        y_hat = model(x)
        
        loss = F.mse_loss(y, y_hat)
        pbar.set_description(f'epoch = {e:03d}, loss = {loss.item():.6f}')
        loss.backward()
        optim.zero_grad(set_to_none=True)