In [None]:
# Training an FNO for wave propagation

In [None]:
pip install torch

In [1]:
import torch
import numpy as np
import scipy.io
import h5py
import torch.nn as nn
from sklearn.model_selection import train_test_split

import operator
from functools import reduce
from functools import partial

import torch.nn.functional as F

import matplotlib.pyplot as plt

from timeit import default_timer

from torch.optim import Adam

from FNO2D import *
from utilities3 import *

torch.manual_seed(0)
np.random.seed(0)
torch.set_printoptions(precision=8)

In [5]:
# hyperparameters # 

batch_size = 20
learning_rate = 0.0025

epochs = 100
step_size = 100
gamma = 0.5

modes = 64
width = 5

s = 64
T = 530
ntrain = 80
nval = 20

In [3]:
# load the data
reader = MatReader('w_test.mat')
vel = reader.read_field('vel')
sr = reader.read_field('sr')


In [4]:
# split into train/ validation : 80 / 20
x_train, x_val, y_train, y_val = train_test_split(vel, sr, random_state=0, test_size=0.20)

In [7]:
# data normalisation
x_normalizer = UnitGaussianNormalizer(x_train)
x_train = x_normalizer.encode(x_train)
x_val = x_normalizer.encode(x_val)

y_normalizer = UnitGaussianNormalizer(y_train)
y_train = y_normalizer.encode(y_train)

x_train = x_train.reshape(ntrain,s,s,1)
x_val = x_val.reshape(nval,s,s,1)

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_val, y_val), batch_size=batch_size, shuffle=False)


In [8]:
# training the FNO
# Unable to train 2D FNO with both spatial and time dimensions
# uncomment .cuda() sections if cuda is available to improve performance time

model = FNO2d(modes, modes, width) #.cuda()

optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

myloss = LpLoss(size_average=False)
t1 = default_timer()

# y_normalizer.cuda()
for ep in range(epochs):
    model.train()
    train_l2 = 0
    
    for x, y in train_loader:
        # x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model(x).reshape(batch_size, s, s, T)
        out = y_normalizer.decode(out)
        y = y_normalizer.decode(y)

        loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1))
        
        loss.backward()

        optimizer.step()
        train_l2 += loss.item()

    scheduler.step()

    model.eval()
    val_l2 = 0.0
    with torch.no_grad():
        for x, y in val_loader:
            # x, y = x.cuda(), y.cuda()

            out = model(x).reshape(batch_size, s, s, T)
            out = y_normalizer.decode(out)

            val_l2 += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item()

    train_l2/= ntrain
    val_l2 /= nval

    t2 = default_timer()
    print("Ep:",ep,"Time:", t2-t1, "Train L:", train_l2, "Val L:", val_l2)

RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [20, 5, 64, 37]->[20, 1, 64, 37, 5] [5, 5, 64, 64]->[1, 5, 64, 64, 5]