In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataloader import load_train, load_val
import matplotlib.pyplot as plt
from tqdm import tqdm
import training
import numpy as np
from model_autoregressive import Seq2SeqLSTM

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
device = training.determine_device()

In [None]:
# Hyperparameters
num_epochs = 5
lr = 1e-3
hidden_size = 64
num_layers = 1

In [None]:
#Test different hyperparameter configurations
defaults = {
    "optimizer": torch.optim.Adam,
    "optimizer_args": {
        "lr": 1e-3,
        "weight_decay": 0,
    }
}

configs = [
    {   
        "name": "mini_autoreg_lstm",
        "epochs": 10,
        "model_kwargs": {
            "hidden_size": 64,
            "num_layers": 2,
        },
    },
    {
        "name": "small_autoreg_lstm",
        "epochs": 40,
        "model_kwargs": {
            "hidden_size": 128,
            "num_layers": 3,
        },
    },
    {
        "name": "medium_autoreg_lstm",
        "epochs": 40,
        "model_kwargs": {
            "hidden_size": 512,
            "num_layers": 3,
        },
        "optimizer_args": {
            "weight_decay": 1e-4,
        },
    },
    {
        "name": "deeper_autoreg_lstm_2",
        "epochs": 90,
        "model_kwargs": {
            "hidden_size": 512,
            "num_layers": 4,
        },
        "optimizer_args": {
            "weight_decay": 1e-4,
        },
    },
    {
        "name": "deeper_autoreg_lstm",
        "epochs": 90,
        "model_kwargs": {
            "hidden_size": 512,
            "num_layers": 5,
        },
        "optimizer_args": {
            "lr": 5e-4,
        },
    },
]




In [None]:
batch_size = 2048

#load data
train_ds, scaler = load_train()
val_ds = load_val(scaler)
train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=4, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=4, shuffle=False)


In [None]:
#Test different hyperparameter configurations
defaults = {
    "optimizer": torch.optim.Adam,
    "optimizer_args": {
        "lr": 1e-3,
        "weight_decay": 0,
    }
}



In [None]:
# pass DataLoader objects (not raw datasets) so batches have the expected feature dimension
results = training.train_all(Seq2SeqLSTM, configs, train=train_ds, val=train_ds, defaults=defaults)

In [None]:
for result in results.values():
    # plot training and validation loss
    training_loss = result['history']['train_mse']
    validation_loss = result['history']['val_mse']
    num_epochs = result['config']['epochs']
    plt.figure()
    plt.plot(range(1, num_epochs + 1), training_loss, 'r', label='Training loss')
    plt.plot(range(1, num_epochs + 1), validation_loss, 'b', label='Validation loss')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    plt.show()

In [None]:
# the best model is the one with the lowest validation loss of any epoch
best_model = min(results.values(), key=lambda x: min(x['history']['val_mse']))
best_net = Seq2SeqLSTM(**best_model['config']['model_kwargs']).to(device)
best_net.load_state_dict(torch.load(training.checkpoint_model_path(best_model['config']['name'])))

In [None]:
from plot_trajectory import plot_paths
num_samples_to_plot = 5

# load model
net = best_net
net.eval()

# plot
for idx, (x, y) in enumerate(val_loader):
    if idx >= num_samples_to_plot:
        break
    x, y = x.to(device), y.to(device)
    with torch.no_grad():
        y_pred = net(x)
    
    # pick the first element in the batch
    x_np = x[0].cpu().numpy()             # (30,5)
    y_np = y[0].cpu().numpy()             # (10,5)
    y_pred_np = y_pred[0].cpu().numpy().reshape(10, 2)  # reshape flat 50 -> (10,5)

    plot_paths(x_np, y_np, y_pred_np, idx)
