In [None]:
# Training an FNO for wave propagation

In [None]:
pip install torch

In [1]:
import torch
import numpy as np
import scipy.io
import h5py
import torch.nn as nn
from sklearn.model_selection import train_test_split

import operator
from functools import reduce
from functools import partial

import torch.nn.functional as F

import matplotlib.pyplot as plt

from timeit import default_timer

from torch.optim import Adam

from FNO2D import *
from utilities3 import *

torch.manual_seed(0)
np.random.seed(0)
torch.set_printoptions(precision=8)

In [2]:
# hyperparameters # 

batch_size = 2
learning_rate = 0.0025

epochs = 100
step_size = 100
gamma = 0.5

modes = 20
width = 5

s = 64
T = 10
ntrain = 80
nval = 20

In [3]:
# load the data
reader = MatReader('w_test.mat')

In [4]:
vel = reader.read_field('vel')
sr = reader.read_field('sr')

In [5]:
# split into train/ validation : 80 / 20
x_train, x_val, y_train, y_val = train_test_split(vel, sr, random_state=0, test_size=0.20)

In [6]:
x_train = x_train.reshape(ntrain,s,s,1)
x_val = x_val.reshape(nval,s,s,1)

# use one in every 53 timesteps 
y_train = y_train[:,0::53,:,:]
y_val = y_val[:,0::53,:,:]
print(y_train.shape)
print(y_val.shape)

y_train = np.transpose(y_train, (0, 2, 3,1))
y_val = np.transpose(y_val, (0, 2, 3,1))

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_val, y_val), batch_size=batch_size, shuffle=False)


torch.Size([80, 10, 64, 64])
torch.Size([20, 10, 64, 64])


In [7]:
# training the FNO

step = 1

model = FNO2d(modes, modes, width) #.cuda()

optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

myloss = LpLoss(size_average=False)
for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_l2_step = 0
    train_l2_full = 0
    for xx, yy in train_loader:
        loss = 0
        xx = xx.to(device)
        yy = yy.to(device)

        for t in range(0, T, 1):
            y = yy[..., t:t + 1]
            im = model(xx)
            loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1))

            if t == 0:
                pred = im
            else:
                pred = torch.cat((pred, im), -1)

            xx = torch.cat((xx[..., 1:], im), dim=-1)

        train_l2_step += loss.item()
        l2_full = myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1))
        train_l2_full += l2_full.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    test_l2_step = 0
    test_l2_full = 0
    with torch.no_grad():
        for xx, yy in val_loader:
            loss = 0
            xx = xx.to(device)
            yy = yy.to(device)

            for t in range(0, T, 1):
                y = yy[..., t:t + 1]
                im = model(xx)
                loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1))

                if t == 0:
                    pred = im
                else:
                    pred = torch.cat((pred, im), -1)

                xx = torch.cat((xx[..., 1:], im), dim=-1)

            test_l2_step += loss.item()
            test_l2_full += myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item()

    t2 = default_timer()
    scheduler.step()
    print(ep, t2 - t1, train_l2_step / ntrain / (T / step), train_l2_full / ntrain, test_l2_step / nval / (T / step),
          test_l2_full / nval)

0 21.928659291006625 nan nan nan nan
1 21.668858940713108 nan nan nan nan
2 21.143709857016802 nan nan nan nan
3 21.301720654591918 nan nan nan nan
4 21.435412214137614 nan nan nan nan


KeyboardInterrupt: 