In [None]:
%env CUDA_VISIBLE_DEVICES=0
%load_ext autoreload
%autoreload 2
%matplotlib inline
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as L
import torch.nn.functional as F
torch.set_float32_matmul_precision('medium')

# Exercise 3

Use POU-Net with neural operators to fit the following functional data. Working in the unit circle domain, $\Omega = \{x \in \mathbb{R}^2 : x\leq 1\}$, the true operator maps functions on $\Omega$ to functions on $\Omega$ where the both the input and output functions vanish at the boundaries of the domain. This situation is common when trying to model physical systems, e.g., the flow at the blade of a wind turbine can be assumed to be zero, but the flow elsewhere needs to be modeled. We call the details of the field behavior at the boundaries the boundary conditions.

We'll be working with an operator learning method we developed called MOR-Physics. See https://arxiv.org/pdf/1810.08552 and https://www.sciencedirect.com/science/article/pii/S004578252030685X. It's very similar to FNO. Try implementing it as described in the papers. The formula for the action of the parameterized operator is,
$$
\mathcal{N}(u) = \mathcal{F^{-1}}\left(g(\mathbf{\kappa}) \mathcal{F} (h(u)) \right)
$$
where $\mathcal{F}$ is the fourier transform, $g$ is a complex valued function of the wave vector $\kappa$, and $h$ is a point-wise nonlinearity.

Since the method is Fourier based, it only works for periodic domains, while the domain for the data is the unit circle. We can still work on the periodic domain and use MOR-Physics by embedding the circle inside a periodic domain and and use mixture of experts to fit the operator in the domain while mapping the region outside the domain to zero. See this paper for more details and generalizations of this approach in the context of numerical methods for solving PDEs, https://www.sciencedirect.com/science/article/pii/S0021999114000151

Try using MOR-physics to fit the operator in the unit circle while having it vanish outside. There's a few different levels of complexity you could try here. You could start off by fixing the POU's to be the unit circle and the region outside the circle and also fix the experts to be MOR-Physics inside the circle and the zero operator outside. Next you could try letting the POU's adopt to the domain and/or choosing between several MOR-Physics experts.

In [None]:
import numpy as np
u_train = np.load('data/u_train.npy')
v_train = np.load('data/v_train.npy')

In [None]:
# this plots a sample of the input functions. Note the domain boundary in red
fig,ax = plt.subplots(1,1)
plt.imshow(u_train[5],extent=[-1.25,1.25,-1.25,1.25])
ax.add_patch(plt.Circle((0, 0), 1, ec='r',fc=[0,0,0,0]))

In [None]:
# this plots a sample of the output functions. Note the domain boundary in red
fig,ax = plt.subplots(1,1)
plt.imshow(v_train[5],extent=[-1.25,1.25,-1.25,1.25])
plt.colorbar()
ax.add_patch(plt.Circle((0, 0), 1, ec='r',fc=[0,0,0,0]))

### Wave function g hermitian symmetry:
In order for IFFT to give real results we need $ g(-\kappa)=\overline {g(\kappa)}$ \
Or you can just take real part after IFFT...

In [None]:
from lightning_utils import *
from MOR_Operator import MOR_Operator
from POU_net import POU_net

In [None]:
for i in range(5):
    print(torch.fft.fftfreq(i+1))

In [None]:
batch_size = 32

u_train = torch.as_tensor(u_train).float()
v_train = torch.as_tensor(v_train).float()
dataset = torch.utils.data.TensorDataset(u_train[:,None], v_train[:,None])
#train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, drop_last=True)
train, val = torch.utils.data.random_split(dataset, [0.8, 0.2])
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, drop_last=True)
val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size)

In [None]:
import random
X, y = next(iter(train_loader))
print(f'{X.shape=}, {y.shape=}')

i = random.randrange(len(X))
print(f'id={i}')

plt.figure(1)
plt.imshow(X[i].squeeze())
plt.colorbar()
plt.figure(2)
plt.imshow(y[i].squeeze())
plt.colorbar()
plt.show()

In [None]:
%pdb

In [None]:
for i in range(1,6):
    print(torch.fft.fftshift(torch.fft.fftfreq(i)))

In [None]:
len(val_loader)

In [None]:
%pdb
#import torch
#torch.multiprocessing.set_start_method('spawn') # good solution !!!!
torch.use_deterministic_algorithms(False)
torch.manual_seed(0)

#Expert = lambda **kwd_args: MOR_Operator(n_layers=1, **kwd_args) # works b/c only 1 layer
#Expert = MOR_Operator # works (with 32 modes)
#Expert = lambda *args, **kwd_args: MOR_Operator(*args, **kwd_args, k_modes=16, mlp_second=True) # only kind of works?
#Expert = CNN # works

from POU_net import FieldGatingNet

# The gating net seems to need to have full modes to make the MoE work much better
# but it still results in some compute savings.
gating_net = lambda *args, **kwd_args: FieldGatingNet(*args, **(kwd_args | {'k':5, 'k_modes':32, 'n_layers':12, 'noise_sd': 0.0}))

# train model
model = POU_net(1, 1, 100, lr=0.001, T_max=10, make_gating_net=gating_net,
                k_modes=16, mlp_second=True)
trainer = L.Trainer(max_epochs=1000, accelerator='gpu', devices=1)
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [None]:
import torchmetrics
x = torchmetrics.ExplainedVariance()
x.compute()

In [None]:
x = trainer.progress_bar_metrics
print(x)
print(type(x))

$$NLL_{scalar_i}=(\mu_i-y_i)^2/(2\sigma_i^2)+ln(\sigma_i)$$

### The 2nd Phase is Very Important!
Despite the exponential decay, I'm not entirely sure why (because I've verified again that L2 decay works).
But at least part of the reason is because it restarts the learning rate schedulers and the Adam adaptive learning rates too.

In [None]:
trainer = L.Trainer(max_epochs=1000, accelerator='gpu', devices=1)
trainer.fit(model=model, train_dataloaders=train_loader)

In [None]:
%pdb

In [None]:
trainer.validate(model, dataloaders=train_loader)

In [None]:
# This display loop, Verified to work 7/19/24
shuffle_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True)
model.eval()

for i, datum in enumerate(shuffle_loader):
    if i>10: break
    X, y = datum
    plt.figure(1+i*3)
    plt.imshow(X.squeeze())
    plt.colorbar()
    plt.title('Input')
    
    plt.figure(2+i*3)
    plt.imshow(y.squeeze())
    plt.colorbar()
    plt.title('Truth')
    
    plt.figure(3+i*3)
    plt.imshow(model(X.cuda()).cpu().detach().squeeze())
    plt.colorbar()
    plt.title('Pred')
    plt.show()
model.train()

In [None]:
import torch
print(torch.cuda.memory_summary())