# Generative trajectory interpolation

Wan, Z., & Dodge, S. (2023, November). A Generative Trajectory Interpolation Method for Imputing Gaps in Wildlife Movement Data. In *Proceedings of the 1st ACM SIGSPATIAL International Workshop on AI-driven Spatio-temporal Data Analysis for Wildlife Conservation* (pp. 1-8).

## Preparation

In [None]:
# Import libraries
import os
import copy
import pickle

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from itertools import chain

from tqdm import tqdm, trange

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

In [None]:
# Check available GPU
if torch.cuda.is_available():
    print("GPU is available:", torch.cuda.get_device_name())
    device = torch.device("cuda:0")
else:
    print("GPU is unavailable")
    device = torch.device("cuda:1")

In [None]:
# Project and sample direction
proj_dir = "..."
sample_dir = "..."

## LSTM-GAN-based trajectory interpolation model

### Encoder

In [None]:
class EncoderLstm(nn.Module):
    '''
    LSTM path encoding module
    '''
    def __init__(self, input_size, hidden_size, n_layers=1):
        self.input_size = input_size
        self.hidden_size = hidden_size
        super(EncoderLstm, self).__init__()
        # The LSTM cell.
        # Input dimension (observations mapped through embedding) is the same as the output
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=n_layers, batch_first=True)

    def init_lstm(self, h, c):
        # Initialization of the LSTM: hidden state and cell state
        self.lstm_h = (h, c)

    def forward(self, x, batch_size=None):
        # batch size
        if batch_size == None:
            batch_size = x.shape[0]

        # Reshape and applies LSTM over a whole sequence or over one single step
        y, self.lstm_h = self.lstm(x, self.lstm_h)

        return y, self.lstm_h

### Decoder

In [None]:
class DecoderLstm(nn.Module):
    '''
    LSTM path decoding module
    '''
    def __init__(self, input_dim, output_dim, hidden_dim, noise_len, n_latent_codes, slope=0.2):
        super(DecoderLstm, self).__init__()
        # Decoding LSTM
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.noise_len = noise_len
        self.n_latent_codes = n_latent_codes

        self.emd = nn.Linear(input_dim, hidden_dim)
        self.lstm = torch.nn.LSTM(hidden_dim + noise_len + n_latent_codes, hidden_dim, num_layers=1, batch_first=True)
        # Fully connected sub-network. Input is hidden_size, output is 2.
        self.fc = nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(slope),
                                torch.nn.Linear(hidden_dim, hidden_dim // 2), nn.LeakyReLU(slope),
                                torch.nn.Linear(hidden_dim // 2, hidden_dim // 4),
                                torch.nn.Linear(hidden_dim // 4, output_dim))

        # init_weights(self)
        self.lstm_h = []

    def init_lstm(self, h, c):
        # Initialization of the LSTM: hidden state and cell state
        self.lstm_h = (h, c)

    def forward(self, dec_inp, z, latent_code):
        # batch size
        batch_size = z.shape[0]

        h = self.emd(dec_inp).view(batch_size, self.hidden_dim)
        # For each sample in the batch, concatenate h (hidden state), z (noise), and latent_code
        inp = torch.cat([h, z, latent_code], dim=1)
        # Applies a forward step.
        out, self.lstm_h = self.lstm(inp.unsqueeze(1), self.lstm_h)
        # Applies the fully connected layer to the LSTM output
        out = self.fc(out.squeeze())
        return out

### Discriminator

In [None]:
class LSTMDiscriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_latent_code, bidirectional=False, slope=0.2):
        super(LSTMDiscriminator, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.lstm_encoder = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=bidirectional)
        self.n_latent_code = n_latent_code

        if bidirectional == True:
            self.bd = 2
        else:
            self.bd = 1

        self.hidden2label = nn.Sequential(
            nn.Linear(hidden_dim * self.bd, hidden_dim), nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2), nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 2, 1)
        )

        # latent code inference: input is hidden_dim (concatenated encodings of observed and predicted trajectories), output is n_latent_code (distribution of latent codes)
        self.latent_decoder = nn.Sequential(nn.Linear(hidden_dim * self.bd, hidden_dim // 2), nn.LeakyReLU(slope),
                                            nn.Linear(hidden_dim // 2, n_latent_code))

    def forward(self, x1, x2, y):
        batch_size = x1.shape[0]
        lstm_h_c0 = (torch.zeros(1, batch_size, self.hidden_dim).cuda(),
                     torch.zeros(1, batch_size, self.hidden_dim).cuda())

        # construct whole seqs
        whole_seqs = torch.cat([x1, y, x2], dim=1)  # [batch_size, seq_len, input_dim]

        # use lstm hidden states for classification and code inference
        _, (hn, _) = self.lstm_encoder(whole_seqs, lstm_h_c0)
        hidden = hn.view(batch_size, self.hidden_dim * self.bd)
        label = self.hidden2label(hidden)
        code_hat = self.latent_decoder(hidden)

        return label, code_hat

    def load(self, backup):
        for m_from, m_to in zip(backup.modules(), self.modules()):
            if isinstance(m_to, nn.Linear):
                m_to.weight.data = m_from.weight.data.clone()
                if m_to.bias is not None:
                    m_to.bias.data = m_from.bias.data.clone()

### Model

In [None]:
class LSTMGANIntp(nn.Module):
    def __init__(
            self, input_size=5, dec_output_dim=5, hidden_size=128, samp_int=300,
            n_lstm_layers=1, noise_len=128,
            y_max_len=100, disc_bd=False, n_latent_codes=2,
            n_unrolling_steps=10, slope=0.2
    ):
        super(LSTMGANIntp, self).__init__()
        self.input_size = input_size
        self.dec_output_dim = dec_output_dim
        self.hidden_size = hidden_size
        self.n_lstm_layers = n_lstm_layers

        self.samp_int = samp_int
        self.y_max_len = y_max_len
        self.noise_len = noise_len
        self.n_latent_codes = n_latent_codes
        self.disc_bd = disc_bd
        self.n_unrolling_steps = n_unrolling_steps

        self.encoder1 = EncoderLstm(input_size, hidden_size, n_lstm_layers).cuda()
        self.encoder2 = EncoderLstm(input_size, hidden_size, n_lstm_layers).cuda()
        self.decoder = DecoderLstm(input_size, dec_output_dim, hidden_size, noise_len, n_latent_codes, slope).cuda()
        self.D = LSTMDiscriminator(input_size, hidden_size, n_latent_codes, disc_bd, slope).cuda()

        # concatenate x1 and x2 outputs
        self.h_fc = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size * 2), nn.LeakyReLU(slope),
            nn.Linear(hidden_size * 2, hidden_size)
        )
        self.c_fc = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size * 2), nn.LeakyReLU(slope),
            nn.Linear(hidden_size * 2, hidden_size)
        )


    def predict(self, x1, x2, noise, latent_code):
        '''
        x1: 1st part of the observed traj
        x2: 2nd part of the observed traj
        '''
        batch_size = x1.shape[0]

        # Initial values for the hidden and cell states (zero)
        lstm_h_c = (torch.zeros(self.n_lstm_layers, batch_size, self.hidden_size).cuda(),
                    torch.zeros(self.n_lstm_layers, batch_size, self.hidden_size).cuda())
        self.encoder1.init_lstm(lstm_h_c[0], lstm_h_c[1])
        self.encoder2.init_lstm(lstm_h_c[0], lstm_h_c[1])

        # encode the observed sequence
        _, x1_h = self.encoder1(x1, batch_size)
        _, x2_h = self.encoder2(x2, batch_size)

        # concatenation
        enc_h_cat = (
            self.h_fc(torch.cat([x1_h[0], x2_h[0]], dim=2)),
            self.c_fc(torch.cat([x1_h[1], x2_h[1]], dim=2))
        )

        pred_ls = []

        # last pt before gap
        last_obsv = x1[:, -1, :]  # [batch_size, 5]
        dec_inp = last_obsv

        # init decoder hidden to the same as encoder
        self.decoder.init_lstm(h=enc_h_cat[0], c=enc_h_cat[1])

        # For all the steps to predict, applies a step of the decoder
        for ii in range(self.y_max_len):
            dec_out = self.decoder(dec_inp, noise, latent_code).view(batch_size, self.dec_output_dim)
            # Keeps all the predictions
            pred_ls.append(dec_out)
            dec_inp = dec_out

        pred = torch.stack(pred_ls, 1)

        return pred


    def train_model(
            self, train_loader, epoch, train_batch_size, predictor_optimizer, D_optimizer,
            scaler_relxy_vxy, criterion=nn.MSELoss(), loss_info_w=0.5, loss_dist_w=1
    ):
        # training
        self.encoder1.train()
        self.encoder2.train()
        self.decoder.train()
        self.D.train()

        # loss
        d_loss_batchsum = 0.0
        g_loss_batchsum = 0.0
        ADE_batchsum = 0.0
        FDE_batchsum = 0.0
        batch_skip = 0

        for x1, x2, y, start_pts, x1_lens, x2_lens, y_lens in tqdm(train_loader):
            if x1.shape[0] < train_batch_size:  # not enough data
                batch_skip += 1  # skip the last batch
                # print(f"Batch skip. x1 shape: {x1.shape[0]}")
                continue

                # data sent to gpu
            x1 = x1.to(device)
            x1_lens = x1_lens.to(device)
            x2 = x2.to(device)
            y = y.to(device)
            y_lens = y_lens.to(device)

            # zero the gradient for each optimizer
            predictor_optimizer.zero_grad()
            D_optimizer.zero_grad()

            zeros = Variable(torch.zeros(train_batch_size, 1) + np.random.uniform(0, 0.1), requires_grad=False).cuda()
            ones = Variable(torch.ones(train_batch_size, 1) * np.random.uniform(0.9, 1.0), requires_grad=False).cuda()
            noise = torch.FloatTensor(torch.rand(train_batch_size, self.noise_len)).cuda()
            latent_code = torch.FloatTensor(torch.rand(train_batch_size, self.n_latent_codes)).cuda()

            # ============== Train Discriminator ================
            for u in range(self.n_unrolling_steps + 1):
                # Zero the gradient buffers of all parameters
                self.D.zero_grad()
                with torch.no_grad():
                    y_pred = self.predict(x1, x2, noise, latent_code)

                fake_labels, code_hat = self.D(x1, x2, y_pred)  # classify fake samples
                # Evaluate the MSE loss: the fake_labels should be close to zero
                d_loss_fake = criterion(fake_labels, zeros)
                d_loss_info = criterion(code_hat.view(train_batch_size, self.n_latent_codes), latent_code)
                # Evaluate the MSE loss: the real should be close to one
                real_labels, code_hat = self.D(x1, x2, y)  # classify real samples
                d_loss_real = criterion(real_labels, ones)

                #  loss functinos to use for D?
                d_loss = d_loss_fake + d_loss_real
                d_loss += loss_info_w * d_loss_info
                d_loss.backward()  # update D
                D_optimizer.step()

                d_loss_batchsum += d_loss.item()  # loss.item() returns the loss as a float

                if u == 0 and self.n_unrolling_steps > 0:
                    backup = copy.deepcopy(self.D)

            # =============== Train Generator =================
            # Zero the gradient buffers of all the discriminator parameters
            self.D.zero_grad()
            # Zero the gradient buffers of all the generator parameters
            predictor_optimizer.zero_grad()
            # Applies a forward step of prediction
            y_pred = self.predict(x1, x2, noise, latent_code)

            # Classify the generated fake sample
            gen_labels, code_hat = self.D(x1, x2, y_pred)
            # distance loss between the predicted paths and the true ones
            g_loss_dist = criterion(y_pred[:, :, :2], y[:, :, :2])
            # attribute loss
            g_loss_attr = criterion(y_pred[:, :, 2:], y[:, :, 2:])
            # Adversarial loss (classification labels should be close to one)
            g_loss_fooling = criterion(gen_labels, ones)
            # Information loss
            g_loss_info = criterion(code_hat.view(train_batch_size, self.n_latent_codes), latent_code)

            # generator loss
            g_loss = g_loss_fooling + loss_dist_w * g_loss_dist + g_loss_attr + loss_info_w * g_loss_info

            g_loss.backward()
            predictor_optimizer.step()

            g_loss_batchsum += g_loss.item()

            if self.n_unrolling_steps > 0:
                self.D.load(backup)
                del backup

            # calculate error
            with torch.no_grad():
                ADE, FDE = calc_error(y_pred, y, scaler=scaler_relxy_vxy)

            ADE_batchsum += ADE
            FDE_batchsum += FDE

        d_loss_batchavg = d_loss_batchsum / (len(train_loader) - batch_skip) / (self.n_unrolling_steps + 1)
        g_loss_batchavg = g_loss_batchsum / (len(train_loader) - batch_skip)
        ADE_batchavg = ADE_batchsum / (len(train_loader) - batch_skip)
        FDE_batchavg = FDE_batchsum / (len(train_loader) - batch_skip)

        # progress bar
        print(f"Epoch {epoch+1} | Training | d_loss {d_loss_batchavg:.4f} | g_loss {g_loss_batchavg:.4f}  \
              | ADE: {ADE_batchavg:.4f} | | FDE: {FDE_batchavg:.4f}")

        return d_loss_batchavg, g_loss_batchavg, ADE_batchavg, FDE_batchavg


    def eval_model(
            self, val_loader, epoch, val_batch_size, scaler_relxy_vxy, criterion=nn.MSELoss(),
            loss_info_w=0.5, loss_dist_w=1
    ):
        # validating
        self.encoder1.eval()
        self.encoder2.eval()
        self.decoder.eval()
        self.D.eval()

        # loss
        d_loss_batchsum = 0.0
        g_loss_batchsum = 0.0
        ADE_batchsum = 0.0
        FDE_batchsum = 0.0
        batch_skip = 0

        for x1, x2, y, start_pts, x1_lens, x2_lens, y_lens in tqdm(val_loader):
            if x1.shape[0] < val_batch_size:  # not enough data
                batch_skip += 1  # skip the last batch
                # print(f"Batch skip. x1 shape: {x1.shape[0]}")
                continue

                # data sent to gpu
            x1 = x1.to(device)
            x1_lens = x1_lens.to(device)
            x2 = x2.to(device)
            y = y.to(device)
            y_lens = y_lens.to(device)

            zeros = Variable(torch.zeros(val_batch_size, 1) + np.random.uniform(0, 0.1), requires_grad=False).cuda()
            ones = Variable(torch.ones(val_batch_size, 1) * np.random.uniform(0.9, 1.0), requires_grad=False).cuda()
            noise = torch.FloatTensor(torch.rand(val_batch_size, self.noise_len)).cuda()
            latent_code = torch.FloatTensor(torch.rand(val_batch_size, self.n_latent_codes)).cuda()

            # ============== Validate Discriminator ================
            # Zero the gradient buffers of all parameters
            self.D.zero_grad()
            with torch.no_grad():
                y_pred = self.predict(x1, x2, noise, latent_code)

                fake_labels, code_hat = self.D(x1, x2, y_pred)  # classify fake samples
                # Evaluate the MSE loss: the fake_labels should be close to zero
                d_loss_fake = criterion(fake_labels, zeros)
                d_loss_info = criterion(code_hat.view(val_batch_size, self.n_latent_codes), latent_code)
                # Evaluate the MSE loss: the real should be close to one
                real_labels, code_hat = self.D(x1, x2, y)  # classify real samples
                d_loss_real = criterion(real_labels, ones)

                # descriminator loss
                d_loss = d_loss_fake + d_loss_real
                d_loss += loss_info_w * d_loss_info

                d_loss_batchsum += d_loss.item()  # loss.item() returns the loss as a float


            # =============== Validate Generator =================
            # Zero the gradient buffers of all the discriminator parameters
            self.D.zero_grad()
            # Applies a forward step of prediction
            with torch.no_grad():
                y_pred = self.predict(x1, x2, noise, latent_code)

                # Classify the generated fake sample
                gen_labels, code_hat = self.D(x1, x2, y_pred)
                # L2 loss between the predicted paths and the true ones
                g_loss_dist = criterion(y_pred[:, :, :2], y[:, :, :2])
                # attribute loss
                g_loss_attr = criterion(y_pred[:, :, 2:], y[:, :, 2:])
                # Adversarial loss (classification labels should be close to one)
                g_loss_fooling = criterion(gen_labels, ones)
                # Information loss
                g_loss_info = criterion(code_hat.view(val_batch_size, self.n_latent_codes), latent_code)

                # generator loss
                g_loss = g_loss_fooling + loss_dist_w * g_loss_dist + g_loss_attr + loss_info_w * g_loss_info

                g_loss_batchsum += g_loss.item()


            # calculate error
            with torch.no_grad():
                ADE, FDE = calc_error(y_pred, y, scaler=scaler_relxy_vxy)

            ADE_batchsum += ADE
            FDE_batchsum += FDE

        d_loss_batchavg = d_loss_batchsum / (len(val_loader) - batch_skip)
        g_loss_batchavg = g_loss_batchsum / (len(val_loader) - batch_skip)
        ADE_batchavg = ADE_batchsum / (len(val_loader) - batch_skip)
        FDE_batchavg = FDE_batchsum / (len(val_loader) - batch_skip)

        # progress bar
        print(f"Validating | d_loss {d_loss_batchavg:.4f} | g_loss {g_loss_batchavg:.4f}  \
              | ADE: {ADE_batchavg:.4f} | | FDE: {FDE_batchavg:.4f}")

        return d_loss_batchavg, g_loss_batchavg, ADE_batchavg, FDE_batchavg


    def test_model(self, test_loader, epoch, test_batch_size, scaler_relxy_vxy):
        # validating
        self.encoder1.eval()
        self.encoder2.eval()
        self.decoder.eval()
        self.D.eval()

        # loss
        ADE_batchsum = 0.0
        FDE_batchsum = 0.0
        batch_skip = 0

        for x1, x2, y, start_pts, x1_lens, x2_lens, y_lens in tqdm(test_loader):
            if x1.shape[0] < test_batch_size:  # not enough data
                batch_skip += 1  # skip the last batch
                print(f"Batch skip. x1 shape: {x1.shape[0]}")
                continue

                # data sent to gpu
            x1 = x1.to(device)
            x1_lens = x1_lens.to(device)
            x2 = x2.to(device)
            y = y.to(device)
            y_lens = y_lens.to(device)

            noise = torch.FloatTensor(torch.rand(test_batch_size, self.noise_len)).cuda()
            latent_code = torch.FloatTensor(torch.rand(test_batch_size, self.n_latent_codes)).cuda()

            # =============== Validate Generator =================
            # Zero the gradient buffers of all the discriminator parameters
            self.D.zero_grad()
            # Applies a forward step of prediction
            with torch.no_grad():
                y_pred = self.predict(x1, x2, noise, latent_code)

                # calculate error
                ADE, FDE = calc_error(y_pred, y, scaler=scaler_relxy_vxy)
                ADE_batchsum += ADE
                FDE_batchsum += FDE

        ADE_batchavg = ADE_batchsum / (len(test_loader) - batch_skip)
        FDE_batchavg = FDE_batchsum / (len(test_loader) - batch_skip)

        # progress bar
        print(f"Testing | Epoch {epoch+1} | ADE: {ADE_batchavg:.4f} | | FDE: {FDE_batchavg:.4f}")

        return ADE_batchavg, FDE_batchavg


    def model_predict(self, x1, x2, latent_code=None, noise=None, y=None):
        # predicting
        self.encoder1.eval()
        self.encoder2.eval()
        self.decoder.eval()
        self.D.eval()

        # data sent to gpu
        x1 = x1.to(device)
        x2 = x2.to(device)

        pred_batch_size = x1.shape[0]
        if noise is None:
            noise = torch.FloatTensor(torch.rand(pred_batch_size, self.noise_len)).cuda()
        if latent_code is None:
            latent_code = torch.FloatTensor(torch.rand(pred_batch_size, self.n_latent_codes)).cuda()

        with torch.no_grad():
            y_pred = self.predict(x1, x2, noise, latent_code)

        if y is not None:
            y = y.to(device)
            ADE, FDE = calc_error(y_pred, y, scaler=scaler_relxy_vxy)
            print(f"ADE: {ADE}, FDE: {FDE}")

        return y_pred

## Loss/Error

In [None]:
def calc_error(y_pred, y_gt, scaler=None):
    '''
    Average displacement error (ADE)
    Final displacement error (FDE)

    Input:
        y_pred: torch.tensor
            prediction
        y_gt: torch.tensor
        ground truth
    '''
    batch_size = y_pred.shape[0]
    seq_len = y_pred.shape[1]
    attr_dim = y_pred.shape[2]

    # last pts
    y_pred_last_pts = y_pred[:, -1, :].data.cpu().numpy()
    y_gt_last_pts = y_gt[:, -1, :].data.cpu().numpy()

    # convert 3d tensor to 2d array
    y_pred = y_pred.view(batch_size * seq_len, attr_dim)  # 2d tensor
    y_gt = y_gt.view(batch_size * seq_len, attr_dim)
    y_pred = y_pred.data.cpu().numpy()
    y_gt = y_gt.data.cpu().numpy()

    if scaler is not None:
        # inverse MinMaxScaler
        y_gt_last_pts = scaler.inverse_transform(y_gt_last_pts)
        y_pred_last_pts = scaler.inverse_transform(y_pred_last_pts)
        y_pred = scaler.inverse_transform(y_pred)
        y_gt = scaler.inverse_transform(y_gt)

    fde = mean_squared_error(y_gt_last_pts[:, :2], y_pred_last_pts[:, :2], squared=False)
    ade = mean_squared_error(y_gt[:, :2], y_pred[:, :2], squared=False)

    return ade, fde

## Trainer

In [None]:
class Trainer():
    def __init__(
            self, model, train_loader, val_loader, scaler_relxy_vxy, save_dir, loss_dist_w, criterion=nn.MSELoss(),
            n_epochs=200, patience=1, lr_g=1e-3, lr_d=1e-4, train_batch_size=32, val_batch_size=256
    ):
        self.scaler_relxy_vxy = scaler_relxy_vxy
        self.loss_dist_w = loss_dist_w

        self.n_epochs = n_epochs
        self.patience = patience

        self.save_dir = save_dir

        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size

        self.criterion = criterion.to(device)
        self.predictor_optimizer = optim.Adam(
            chain(
                model.encoder1.parameters(), model.encoder2.parameters(), model.decoder.parameters(),
                model.h_fc.parameters(), model.c_fc.parameters()
            )
            , lr=lr_g, betas=(0.9, 0.999)
        )
        self.D_optimizer = optim.Adam(model.D.parameters(), lr=lr_d, betas=(0.9, 0.999))
        self.lr_g = lr_g
        self.lr_d = lr_d


    def fit(self, model, model_save_prefix, min_val_loss=None):
        print("Training model...")
        model = model.to(device)

        # initialize loss logger
        if min_val_loss is None:
            min_val_loss = float('inf')
        else:
            min_val_loss = min_val_loss
        best_models = []
        loss_log = []
        model_save_name_log = []

        for epoch in range(self.n_epochs):
            train_d_loss, train_g_loss, train_ADE, train_FDE = model.train_model(
                self.train_loader, epoch, self.train_batch_size, self.predictor_optimizer, self.D_optimizer,
                self.scaler_relxy_vxy, criterion=nn.MSELoss(), loss_info_w=0.5, loss_dist_w=self.loss_dist_w
            )
            val_d_loss, val_g_loss, val_ADE, val_FDE = model.eval_model(
                self.val_loader, epoch, self.val_batch_size, self.scaler_relxy_vxy, criterion=nn.MSELoss(),
                loss_info_w=0.5, loss_dist_w=self.loss_dist_w
            )

            # save to log
            loss_log.append(
                [train_d_loss, train_g_loss, train_ADE, train_FDE,
                 val_d_loss, val_g_loss, val_ADE, val_FDE]
            )

            # compare to min_val_loss
            if val_ADE < min_val_loss:
                min_val_loss = val_ADE
                best_models.append([epoch, val_d_loss, val_g_loss, val_ADE, val_FDE])

                if epoch > self.patience-1:
                    model_save_name = model_save_prefix + f"_epoch{epoch+1}_trainADE{train_ADE:.4f}_valADE{val_ADE:.4f}.pt"

                    # save model
                    model_save_path = os.path.join(
                        proj_dir, self.save_dir, model_save_name
                    )
                    torch.save(model, model_save_path)

                    print(f"val_ADE decreased to {val_ADE:.4f} at epoch {epoch+1}. Model saved.")
                    model_save_name_log.append(model_save_name)

                else:
                    print(f"val_ADE decreased to {val_ADE:.4f} at epoch {epoch+1}. No save.")
                    model_save_name_log.append("no_save")
            else:
                model_save_name_log.append("no_save")

        loss_log_arr = np.array(loss_log)
        loss_log_df = pd.DataFrame(loss_log_arr, columns=[
            "train_d_loss", "train_g_loss", "train_ADE", "train_FDE",
            'val_d_loss', 'val_g_loss', 'val_ADE', 'val_FDE'
        ])
        loss_log_df["save_name"] = model_save_name_log
        self.loss_log_df = loss_log_df

        return loss_log_df

    def plot_train_log(self, loss_log_df=None):
        if loss_log_df == None:
            loss_log_df = self.loss_log_df

        f, ax = plt.subplots(1, 2, figsize=(12, 8))
        x = np.array(range(len(self.loss_log_df)))

        # loss
        ax[0].plot(x, loss_log_df['train_g_loss'], color='dodgerblue', label='Generator loss')
        ax[0].plot(x, loss_log_df['train_d_loss'], color='coral', label='Discriminator loss')
        ax[0].plot(x, loss_log_df['val_g_loss'], color='blue', label='Generator loss')
        ax[0].plot(x, loss_log_df['val_d_loss'], color='red', label='Discriminator loss')


        # ade, fde
        ax[1].plot(x, loss_log_df['train_ADE'], color='dodgerblue', label='ADE')
        ax[1].plot(x, loss_log_df['train_FDE'], color='coral', label='FDE')
        ax[1].plot(x, loss_log_df['val_ADE'], color='blue', label='ADE')
        ax[1].plot(x, loss_log_df['val_FDE'], color='red', label='FDE')

        ax[0].legend(loc="best")
        ax[1].legend(loc="best")

        ax[0].set_xlabel("Epoch")
        ax[0].set_ylabel("Loss")
        ax[1].set_xlabel("Epoch")
        ax[1].set_ylabel("Displacement error (m)")

        ax[0].set_title("Loss")
        ax[1].set_title("Displacement error")

## Dataset in pytorch format
### Trajectory dataset class

dtype has to be float (float32) instead of double (float64) in order to avoid data/parameter mismatch

In [None]:
def tensorize_input(X1_ls, X2_ls, Y_gap_ls):
    X1_ts_ls = []
    X2_ts_ls = []
    Y_ts_ls = []

    # lengths
    X1_len_ls = []
    X2_len_ls = []
    Y_len_ls = []

    n_samples = len(X1_ls)

    for spi in trange(n_samples):
        X1_ts_ls.append(torch.tensor(X1_ls[spi], dtype=torch.float))
        X2_ts_ls.append(torch.tensor(X2_ls[spi], dtype=torch.float))
        Y_ts_ls.append(torch.tensor(Y_gap_ls[spi], dtype=torch.float))

        X1_len_ls.append(len(X1_ls[spi]))
        X2_len_ls.append(len(X2_ls[spi]))
        Y_len_ls.append(len(Y_gap_ls[spi]))

    return X1_ts_ls, X2_ts_ls, Y_ts_ls, X1_len_ls, X2_len_ls, Y_len_ls

In [None]:
class IntpTrajDataset(Dataset):
    def __init__(self, X1_ls, X2_ls, Y_gap_ls, start_xy_arr):
        X1_ts_ls, X2_ts_ls, Y_ts_ls, X1_len_ls, X2_len_ls, Y_len_ls = tensorize_input(X1_ls, X2_ls, Y_gap_ls)

        # max len
        x_max_len = 0
        y_max_len = 0
        for sp_i in range(len(X1_ts_ls)):
            # x_max_len
            if len(X1_ts_ls[sp_i]) > x_max_len:
                x_max_len = len(X1_ts_ls[sp_i])
            if len(X2_ts_ls[sp_i]) > x_max_len:
                x_max_len = len(X2_ts_ls[sp_i])
            # y_max_len
            if len(Y_ts_ls[sp_i]) > y_max_len:
                y_max_len = len(Y_ts_ls[sp_i])

        # pad first seq to desired length
        X1_ts_ls[0] = nn.ConstantPad2d((0, 0, 0, x_max_len - X1_ts_ls[0].shape[0]), 0)(X1_ts_ls[0])
        X2_ts_ls[0] = nn.ConstantPad2d((0, 0, 0, x_max_len - X2_ts_ls[0].shape[0]), 0)(X2_ts_ls[0])
        Y_ts_ls[0] = nn.ConstantPad2d((0, 0, 0, y_max_len - Y_ts_ls[0].shape[0]), 0)(Y_ts_ls[0])

        # pad all seqs to desired length
        self.X1 = pad_sequence(X1_ts_ls, batch_first=True)
        self.X2 = pad_sequence(X2_ts_ls, batch_first=True)
        self.Y = pad_sequence(Y_ts_ls, batch_first=True)
        print(f"padded shape: X1: {self.X1.shape}, X2: {self.X2.shape}, Y: {self.Y.shape}")

        # start_xy_arr: np.object -> float
        self.start_pts = torch.tensor(start_xy_arr.astype(float), dtype=torch.float)

        # lengths
        self.X1_lens = torch.tensor(X1_len_ls, dtype=torch.int)
        self.X2_lens = torch.tensor(X2_len_ls, dtype=torch.int)
        self.Y_lens = torch.tensor(Y_len_ls, dtype=torch.int)

    def __len__(self):
        return self.X1.shape[0]

    def __getitem__(self, i):
        return self.X1[i,:,:], self.X2[i,:,:], self.Y[i,:,:], self.start_pts[i, :], self.X1_lens[i], self.X2_lens[i], self.Y_lens[i]

### Training, validation, and test

In [None]:
def train_val_test_set(X1_ls, X2_ls, Y_gap_ls, start_xy_arr, train_pct=0.7):
    n_samples = len(X1_ls)

    # train, val, test split
    val_test_pct = 1 - train_pct
    test_pct = 0.33  # test/(val+test)
    X1_train, X1_test, X2_train, X2_test, Y_train, Y_test, start_xy_arr_train, start_xy_arr_test = train_test_split(
        X1_ls, X2_ls, Y_gap_ls, start_xy_arr, test_size=val_test_pct
    )
    X1_val, X1_test, X2_val, X2_test, Y_val, Y_test, start_xy_arr_val, start_xy_arr_test = train_test_split(
        X1_test, X2_test, Y_test, start_xy_arr_test, test_size=test_pct
    )

    print("Train, validation, test split done! Next: create PyTorch datasets")
    # PyTorch dataset
    train_set = IntpTrajDataset(X1_train, X2_train, Y_train, start_xy_arr_train)
    val_set = IntpTrajDataset(X1_val, X2_val, Y_val, start_xy_arr_val)
    test_set = IntpTrajDataset(X1_test, X2_test, Y_test, start_xy_arr_test)

    return train_set, val_set, test_set

## Experiments

In [None]:
# read pickle
pk_name = "...stork.pickle"  # preprocessed strok tracking dataset in the pickle format
with open(os.path.join(proj_dir, pk_name), 'rb') as my_file_obj:
    pts_df, scaler_relxy_vxy = pickle.load(my_file_obj)

print(f"Number of tracking points: {len(pts_df)}")
pts_df.head()

In [None]:
seq_len = 200

y_max_len = seq_len // 2
if y_max_len % 2 != 0:
    y_max_len += 1

# read pickle
pk_name = f"...{seq_len}.pickle"  # samples containing lists of X1, X2, Y, and gap start pt
with open(os.path.join(sample_dir, pk_name), 'rb') as my_file_obj:
    X1_ls, X2_ls, Y_gap_ls, start_xy_arr = pickle.load(my_file_obj)

print(f"Number of samples: {len(X1_ls)}")

In [None]:
# create training datasets
train_set, val_set, test_set = create_toy_dataset(
    X1_ls, X2_ls, Y_gap_ls, start_xy_arr, size=20000, train_pct=0.7
)

print(f"Train set size: {len(train_set)}")
print(f"Val set size: {len(val_set)}")
print(f"Test set size: {len(test_set)}")

# save to pickle
pk_name = f"....pickle"
with open(os.path.join(proj_dir, "...", pk_name), 'wb') as my_file_obj:
    pickle.dump([train_set, val_set, test_set], my_file_obj)

In [None]:
# read pickle
pk_name = f"....pickle"
with open(os.path.join(sample_dir, pk_name), 'rb') as my_file_obj:
    train_set, val_set, test_set = pickle.load(my_file_obj)

print(f"Train set size: {len(train_set)}")
print(f"Val set size: {len(val_set)}")
print(f"Test set size: {len(test_set)}")

In [None]:
# training datasets
train_batch_size = 32
val_batch_size = 64
test_batch_size = 32

train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=val_batch_size)
test_loader = DataLoader(test_set, batch_size=test_batch_size)

### Training example: with distance loss weight 1

In [None]:
y_max_len = seq_len // 2
if y_max_len % 2 != 0:
    y_max_len += 1

gan_intp_model = LSTMGANIntp(
    input_size=5, dec_output_dim=5, hidden_size=128, n_lstm_layers=1, noise_len=32,
    y_max_len=y_max_len, disc_bd=False, n_latent_codes=2, n_unrolling_steps=10, slope=0.2
)

trainer = Trainer(
    gan_intp_model, train_loader, val_loader, scaler_relxy_vxy, save_dir=f"...", loss_dist_w=1,
    criterion=nn.MSELoss(), n_epochs=100, patience=3, lr_g=1e-3, lr_d=1e-4, train_batch_size=train_batch_size, val_batch_size=val_batch_size
)

model_save_prefix = f"..."
loss_log_df = trainer.fit(gan_intp_model, model_save_prefix)

# save to pickle
pk_name = f"....pickle"
with open(os.path.join(proj_dir, f"...", pk_name), 'wb') as my_file_obj:
    pickle.dump(loss_log_df, my_file_obj)

# print the best model info
ade_min_idx = loss_log_df['val_ADE'].idxmin()
print(f"Min validation ADE {loss_log_df.loc[ade_min_idx, 'val_ADE']} at epoch {ade_min_idx+1}")
print(f"with training ADE {loss_log_df.loc[ade_min_idx, 'train_ADE']}")