In [None]:
# Training an FNO for wave propagation

In [None]:
pip install torch

In [1]:
import torch
import numpy as np
import scipy.io
import h5py
import torch.nn as nn
from sklearn.model_selection import train_test_split

import operator
from functools import reduce
from functools import partial

import torch.nn.functional as F

import matplotlib.pyplot as plt

from timeit import default_timer

from torch.optim import Adam

from FNO3D import *
from utilities3 import *

torch.manual_seed(0)
np.random.seed(0)
torch.set_printoptions(precision=8)

In [2]:
# hyperparameters # 

batch_size = 2
learning_rate = 0.0025

epochs = 100
step_size = 100
gamma = 0.5

modes = 8
width = 5

s = 64
T = 10
ntrain = 80
nval = 20

In [3]:
# load the data
reader = MatReader('w_test.mat')
vel = reader.read_field('vel')
sr = reader.read_field('sr')

In [4]:
# split into train/ validation : 80 / 20
x_train, x_val, y_train, y_val = train_test_split(vel, sr, random_state=0, test_size=0.20)

In [5]:
# data normalisation
x_normalizer = UnitGaussianNormalizer(x_train)
x_train = x_normalizer.encode(x_train)
x_val = x_normalizer.encode(x_val)

y_normalizer = UnitGaussianNormalizer(y_train)
y_train = y_normalizer.encode(y_train)

x_train = x_train.reshape(ntrain,s,s,1,1).repeat([1,1,1,T,1])
x_val = x_val.reshape(nval,s,s,1,1).repeat([1,1,1,T,1])

# use one in every 53 timesteps 
y_train = y_train[:,0::53,:,:]
y_val = y_val[:,0::53,:,:]
print(y_train.shape)
print(y_val.shape)

y_train = np.transpose(y_train, (0, 2, 3,1))
y_val = np.transpose(y_val, (0, 2, 3,1))

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_val, y_val), batch_size=batch_size, shuffle=False)


torch.Size([80, 10, 64, 64])
torch.Size([20, 10, 64, 64])


In [6]:
# training the FNO

step = 1

model = FNO3d(modes, modes, modes, width)#.cuda()
# model = torch.load('model/ns_fourier_V100_N1000_ep100_m8_w20')


optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)


myloss = LpLoss(size_average=False)
#y_normalizer.cuda()
for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader:
        #x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model(x).view(batch_size, s, s, T)

        #mse = F.mse_loss(out, y, reduction='mean')
        # mse.backward()

        y = y_normalizer.decode(y)
        out = y_normalizer.decode(out)
        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        l2.backward()

        optimizer.step()
        #train_mse += mse.item()
        train_l2 += l2.item()

    scheduler.step()

    model.eval()
    val_l2 = 0.0
    with torch.no_grad():
        for x, y in val_loader:
            #x, y = x.cuda(), y.cuda()

            out = model(x).view(batch_size, s, s, T)
            out = y_normalizer.decode(out)
            val_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()

    #train_mse /= len(train_loader)
    train_l2 /= ntrain
    val_l2 /= nval

    t2 = default_timer()
    print(ep, t2-t1, train_l2, val_l2)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (81920x4 and 5x5)