In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataloader import load_train, load_val
from model_lstm import LSTMModel
import training
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

In [2]:
device = training.determine_device()
print("Using device:", device)

Using device: mps


In [None]:
configs = [
    {   
        "name": "mini_lstm",
        "epochs": 10,
        "model_kwargs": {
            "hidden_size": 64,
            "num_layers": 2,
        },
    },
    {
        "name": "small_lstm",
        "epochs": 40,
        "model_kwargs": {
            "hidden_size": 128,
            "num_layers": 3,
        },
    },
    {
        "name": "medium_lstm",
        "epochs": 40,
        "model_kwargs": {
            "hidden_size": 512,
            "num_layers": 3,
        },
        "optimizer_args": {
            "weight_decay": 1e-4,
        },
    },
    {
        "name": "deeper_lstm_2",
        "epochs": 90,
        "model_kwargs": {
            "hidden_size": 512,
            "num_layers": 4,
        },
        "optimizer_args": {
            "weight_decay": 1e-4,
        },
    },
    {
        "name": "deeper_lstm",
        "epochs": 90,
        "model_kwargs": {
            "hidden_size": 512,
            "num_layers": 5,
        },
        "optimizer_args": {
            "lr": 5e-4,
        },
    },
]


In [None]:
batch_size = 2048

#load data
train_ds, scaler = load_train()
val_ds = load_val(scaler)
train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=4, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=4, shuffle=False)


In [None]:
defaults = dict(batch_size=batch_size)
results = training.train_all(LSTMModel, configs, train=train_ds, val=val_ds, defaults=defaults)

In [None]:
for result in results:
    model = result["model"]
    history = result["history"]
    print(f"Model: {result['config']['name']}")
    print(f"Final Training Loss: {history['train_rmse'][-1]:.6f}")
    print(f"Final Validation Loss: {history['val_rmse'][-1]:.6f}")
    print("-" * 30)
best_result = min(results, key=lambda x: x["history"]["val_rmse"][-1])
best_model = best_result["model"]
torch.save(best_model.state_dict(), "best_lstm_model.pt")

In [None]:
for result in results:
    num_epochs = result["config"]["epochs"]
    training_loss = result["history"]["train_rmse"]
    validation_loss = result["history"]["val_rmse"]
    # plot training and validation loss
    plt.figure()
    plt.plot(range(1, num_epochs + 1), training_loss, 'r', label='Training loss')
    plt.plot(range(1, num_epochs + 1), validation_loss, 'b', label='Validation loss')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    plt.show()

In [None]:
def plot_paths(x_sample, y_true, y_pred, idx):
    # Past  
    x_sample = x_sample.reshape(30, 2) # reshape back
    # True future
    y_true_sample = y_true
    # Predicted future
    y_pred_sample = y_pred

    plt.figure(figsize=(6,6))
    plt.plot(x_sample[:,0], x_sample[:,1], 'bo-', label='Past')         
    plt.plot(y_true_sample[:,0], y_true_sample[:,1], 'go-', label='True')   
    plt.plot(y_pred_sample[:,0], y_pred_sample[:,1], 'ro--', label='Predicted') 
    plt.xlabel('Latitude')
    plt.ylabel('Longitude')
    plt.title(f'Trajectory Sample {idx}')
    plt.legend()
    plt.show()

num_samples_to_plot = 50

# load best model
model = LSTMModel(hidden_size=512, num_layers=4)
net = model.to(device)
net.load_state_dict(torch.load("best_lstm_model.pt"))
net.eval()

# plot
for idx, (x, y) in enumerate(val_loader):
    if idx >= num_samples_to_plot:
        break
    x, y = x.to(device), y.to(device)
    with torch.no_grad():
        y_pred = net(x)
    
    # pick the first element in the batch
    x_np = x[0].cpu().numpy()             # (30,5)
    y_np = y[0].cpu().numpy()             # (10,5)
    y_pred_np = y_pred[0].cpu().numpy().reshape(10, 2)  # reshape flat 50 -> (10,5)

    plot_paths(x_np, y_np, y_pred_np, idx)


RuntimeError: Error(s) in loading state_dict for LSTMModel:
	Unexpected key(s) in state_dict: "lstm.weight_ih_l4", "lstm.weight_hh_l4", "lstm.bias_ih_l4", "lstm.bias_hh_l4". 
	size mismatch for lstm.weight_ih_l0: copying a param with shape torch.Size([2048, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
	size mismatch for lstm.weight_hh_l0: copying a param with shape torch.Size([2048, 512]) from checkpoint, the shape in current model is torch.Size([1024, 256]).
	size mismatch for lstm.bias_ih_l0: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for lstm.bias_hh_l0: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for lstm.weight_ih_l1: copying a param with shape torch.Size([2048, 512]) from checkpoint, the shape in current model is torch.Size([1024, 256]).
	size mismatch for lstm.weight_hh_l1: copying a param with shape torch.Size([2048, 512]) from checkpoint, the shape in current model is torch.Size([1024, 256]).
	size mismatch for lstm.bias_ih_l1: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for lstm.bias_hh_l1: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for lstm.weight_ih_l2: copying a param with shape torch.Size([2048, 512]) from checkpoint, the shape in current model is torch.Size([1024, 256]).
	size mismatch for lstm.weight_hh_l2: copying a param with shape torch.Size([2048, 512]) from checkpoint, the shape in current model is torch.Size([1024, 256]).
	size mismatch for lstm.bias_ih_l2: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for lstm.bias_hh_l2: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for lstm.weight_ih_l3: copying a param with shape torch.Size([2048, 512]) from checkpoint, the shape in current model is torch.Size([1024, 256]).
	size mismatch for lstm.weight_hh_l3: copying a param with shape torch.Size([2048, 512]) from checkpoint, the shape in current model is torch.Size([1024, 256]).
	size mismatch for lstm.bias_ih_l3: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for lstm.bias_hh_l3: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for l_out.weight: copying a param with shape torch.Size([20, 512]) from checkpoint, the shape in current model is torch.Size([20, 256]).