# Training the model

This notebook explains how to train the end-to-end model of the master's thesis “Fast, Accurate, and Scalable Numerical Wave Propagation: Enhancement by Deep Learning” by Luis Kaiser, supervised by Prof. Tsai (University of Texas Austin) and Prof. Klingenberg (University of Wuerzburg), in practice.

First, make sure that you have install all necessary libraries specified in `requirements.txt` using `pip` or `pip3` depending on your setup by running the command below.

In [None]:
!pip3 install --upgrade pip
!pip3 install -r requirements.txt

We made this notebook easier to read with less complex dependencies in the code compared to the original implementation. We first load the data, then set up the model with optimizer and lastly train and evaluate the model. We use `datagen_test.npz` as our train and test data. Please note that the model does not train well for small datasets.

In [None]:
import torch
from model import save_model, Model_end_to_end
from config import get_params
from generate_data import fetch_data_end_to_end
import random
import numpy as np


def train_model(
    model_name = "test",
    lr = .001,
    batch_size = 1,
    n_epochs = 10,
    downsampling_model = "Interpolation",
    upsampling_model = "UNet3",
    data_paths = "data/datagen_test.npz",
    val_paths = "data/datagen_test2.npz"
):
    '''
    Parameters
    ----------
    model_name : (string) name of model, used as name for output-file containing model parameters
    lr : (float) learning rate of model
    batch_size : (int) batch size
    n_epochs : (int) number of iterations model sees all data; i.e. amount of epochs model is trained
    downsampling_model : (string) name of downsampling model
    upsampling_model : (string) name of upsampling model
    data_paths : (string)
    val_paths : (string)

    Returns
    -------
    trained model parameters in ".pt"-file
    '''

    # model setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    param_dict = get_params()  # get dict of params contains all model specifications such as numerical params
    model = Model_end_to_end(param_dict, downsampling_model, upsampling_model)
    model = torch.nn.DataParallel(model).to(device)  # multi-GPU use

    # data setup
    train_loader, val_loader, _ = fetch_data_end_to_end([data_paths], batch_size, [val_paths])

    # deep learning setup
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)  # initialize optimizer
    loss_f = torch.nn.MSELoss()  # initialize loss function


    for epoch in range(n_epochs):

        # training
        model.train()

        train_loss_list = []  # initialize list to store loss values of training

        for i, data in enumerate(train_loader):  # iterate over datapoints in train_loader
            loss_list = []  # create tmp loss list for backpropagating multiple losses at once
            n_snaps = data[0].shape[1]  # number of snapshots defined by data input
            data = data[0].to(device)  # use GPUs if available for faster training

            # choose n_snaps start indices, possible points are in [0, ..., n_snaps - 2] and chosen randomly
            # i.e. (input_idx * dt_star) point in time is when we start using the data
            for input_idx in random.choices(range(n_snaps - 2), k=n_snaps):
                input_tensor = data[:, input_idx].detach()  # detach because if not computation graph would go too far back

                # one-step loss function (this loop is used to show that we can use something else for
                # "input_idx + 2" to get a multi-step-loss approach (cf. paper)
                for label_idx in range(input_idx + 1, input_idx + 2):
                    label = data[:, label_idx, :3]
                    output = model(input_tensor, data[:,input_idx, 4].detach())  # apply end-to-end model
                    loss_list.append(loss_f(output, label))  # save loss

                    # IMPORTANT: save current result to use for next iteration if multi-step-loss used
                    input_tensor = torch.cat((output, input_tensor[:, 3].unsqueeze(dim=1)), dim=1)

            # optimizer stepping
            optimizer.zero_grad()
            sum(loss_list).backward()
            optimizer.step()

            # save loss to later print out
            train_loss_list.append(np.array([l.cpu().detach().numpy() for l in loss_list]).mean())


        # validation
        model.eval()
        with torch.no_grad():

            # initialize list to save losses
            val_loss_list = []

            for i, data in enumerate(val_loader):  # iterate over datapoints in val_loader
                n_snaps = data[0].shape[1]  # number of snapshots defined by data input
                data = data[0].to(device)  # use GPUs if available for faster execution

                input_tensor = data[:, 0].detach()  # detach because if not computation graph would go too far back
                vel = input_tensor[:, 3].unsqueeze(dim=1)  # access velocity profile in input_tensor

                for label_idx in range(1, n_snaps):  # advance a wave field for (n_snaps - 1) time steps
                    label = data[:, label_idx, :3]
                    output = model(input_tensor, data[:, 0, 4].detach())  # apply end-to-end model

                    # get and save loss (this could be optimized by using a metric)
                    val_loss_list.append(loss_f(output, label).item())

                    # IMPORTANT: save current result to use for next iteration if multi-step-loss used
                    input_tensor = torch.cat((output, vel), dim=1)

        print(f'epoch %d, train loss: %.5f, test loss: %.5f'
              %(epoch + 1, np.array(train_loss_list).mean(), np.array(val_loss_list).mean()))

    # save model parameters as a ".pt"-file, use empty string for directory path
    save_model(model, model_name, "results/")

In [None]:
train_model()