# Model 2: LSTM sequence model
## Read data

In [None]:
from pathlib import Path
import numpy as np
import torch 

In [None]:
device = ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

In [None]:
# Import local modules from 'src/utils' as package 'utils'
import sys; sys.path.insert(0, '../')

In [None]:
import utils

## Create Trajectory dataset from dataframe

In [None]:
from utils.file_io import read_trajectory_datasets

In [None]:
data_folder = Path("../../data/")
train_set, test_set, validation_set, visualization_set = read_trajectory_datasets(data_folder, 0.8, 0.15, 0.045, 0.005, 64, standardize_features=True)

In [None]:
# FIXME: Total loaded size correct?
input_shape, output_shape = 8, 3
print(f"Data shape {input_shape} / {output_shape} of total {len(train_set) + len(test_set) + len(validation_set) + len(visualization_set)} data rows!")

## Defining the LSTM model

In [None]:
from torch import nn, Tensor

In [None]:
class DecoderLSTM(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, out_dim: int, dropout_lstm: float = 0.25, dropout_final: float = 0.25,
                 num_lstm_layers: int = 1, bidirectional: bool = False) -> None:
        super().__init__()
        self.total_epochs = 0
        self.hidden_dim = hidden_dim
        self.d = 2 if bidirectional else 1
        self.num_lstm_layers = num_lstm_layers

        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_lstm_layers, dropout=dropout_lstm, bidirectional=bidirectional)
        self.final_dropout = nn.Dropout(dropout_final)
        self.out = nn.Linear(hidden_dim * self.d, out_dim)
        
    def forward(self, x: Tensor) -> Tensor:
        batch_size = x.shape[1]
        # expect x to be of shape (sequence_length, batch_size, input_dim)
        h0 = torch.randn(self.d * self.num_lstm_layers, batch_size, self.hidden_dim).to(x.device)
        c0 = torch.randn(self.d * self.num_lstm_layers, batch_size, self.hidden_dim).to(x.device)
        # output shape is (sequence_length, batch_size, d * hidden_dim)
        output, (hn, cn) = self.lstm(x, (h0, c0))
        output = self.final_dropout(output)
        return self.out(output)

## Load parameter, functions and dataloader

In [None]:
import os

from torch.utils.data import DataLoader
from dotenv import load_dotenv

from utils.file_io import save_model
from utils.file_io import define_dataloader_from_subset
from utils.evaluation import compute_loss_on
from utils.loss_functions import maximum_squared_error

In [None]:
model_path = Path("../../models/lstm/").absolute()

In [None]:
dotenv_path = model_path / ".env"
load_dotenv(dotenv_path=dotenv_path)

learning_rate = float(os.getenv("LEARNING_RATE"))
batch_size = int(os.getenv("BATCH_SIZE"))
num_epochs = int(os.getenv("NUM_EPOCHS"))
hidden_layers = int(os.getenv("HIDDEN_LAYERS"))

In [None]:
def get_optimizer_function(model: nn.Module, learning_rate: float) -> torch.optim:
    return torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
def get_loss_function() -> nn.Module:
    return torch.nn.MSELoss()

In [None]:
train_dataloader, validation_dataloader, test_dataloader = define_dataloader_from_subset(train_set, validation_set, test_set, batch_size=batch_size, shuffle=True)

## Define train methods

In [None]:
from ray import train as ray_train
from ray.train import Checkpoint

In [None]:
def train_epoch(train_dataloader: DataLoader, model, loss_function, optimizer,
                device: torch.device, report_interval: int = 128):
    
    running_loss = 0
    last_loss = 0

    for i, (inputs, true_values) in enumerate(train_dataloader):

        inputs = inputs.to(device)
        true_values = true_values.to(device)
        
        inputs_shape, true_values_shape = inputs.size(), true_values.size()
        inputs = inputs.view(inputs_shape[1], inputs_shape[0], inputs_shape[2])
        true_values = true_values.view(true_values_shape[1], true_values_shape[0], true_values_shape[2])
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, true_values)
        running_loss += loss
        loss.backward()
        optimizer.step()
    
        if i % report_interval == report_interval - 1:
            last_loss = running_loss / report_interval
            print(f"batch {i + 1}, Mean Squared Error: {last_loss}")
            running_loss = 0
    
    return last_loss

In [None]:
def train(epochs: int, train_dataloader: DataLoader, validation_dataloader: DataLoader, model: nn.Module,
           loss_function, optimizer, checkpoint_path: Path, device: torch.device = 'cpu', report_interval: int = 1000, tune: bool = False) -> nn.Module:
    
    best_val_loss = float("inf")

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    model.to(device)

    if tune:
        checkpoint = ray_train.get_checkpoint()

        if checkpoint:
            with checkpoint.as_directory() as checkpoint_dir:
                model_state = torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
                model.load_state_dict(model_state)

    for epoch in range(model.total_epochs, epochs):
        print(f"Epoch: {epoch + 1}")

        model.train(True)
        avg_loss = train_epoch(train_dataloader, model, loss_function, optimizer, device, report_interval)
        model.eval()

        with torch.no_grad():
            avg_val_loss = compute_loss_on(validation_dataloader, model, loss_function, device=device)

        print(f"Loss on train: {avg_loss}, loss on validation: {avg_val_loss}")

        model.total_epochs += 1

        if avg_val_loss < best_val_loss or tune:
            best_val_loss = avg_val_loss            
            
            torch.save(model.state_dict(), checkpoint_path / "checkpoint.pt")

        if tune:
            ray_train.report(metrics={ "loss": float(avg_val_loss) }, checkpoint=Checkpoint.from_directory(checkpoint_path))
    
    return model

## Train the model with optuna hyperparameter tuning

In [None]:
import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.optuna import OptunaSearch
from typing import Dict
from utils.evaluation import compute_predictions, compute_losses_from

In [None]:
def parameter_train(parameter: Dict, train_epochs: int, train_data: DataLoader, validation_data: DataLoader, model_input_shape: int,
                    model_output_shape: int, checkpoint_path: Path, device: torch.device) -> None:

    run_id = ray_train.get_context().get_trial_id()
    run_checkpoint = checkpoint_path / run_id
    run_checkpoint.mkdir(parents=True)

    model = DecoderLSTM(model_input_shape, parameter["hidden_layers"], model_output_shape, parameter["lstm_dropout"], parameter["final_dropout"], parameter["lstm_layers"], parameter["bidirectional"])

    optimizer = get_optimizer_function(model, parameter["lr"])
    loss_function = get_loss_function()

    _ = train(train_epochs, train_data, validation_data, model, loss_function, optimizer, run_checkpoint, device, report_interval=50, tune=True)

In [None]:
learning_rate_radius = 1e-3
batch_size_radius = 10
hidden_layers_radius = 4
num_samples = 100

In [None]:
parameter_space = {
    "lr": tune.loguniform(learning_rate_radius - learning_rate, learning_rate + learning_rate_radius),
    "batch_size": tune.choice(list(range(batch_size - batch_size_radius, batch_size + batch_size_radius, 4))),
    "hidden_layers": tune.choice(list(range(hidden_layers - hidden_layers_radius, hidden_layers + hidden_layers_radius, 1))),
    "bidirectional": tune.choice([True, False]),
    "lstm_layers": tune.choice([1, 2, 3]),
    "lstm_dropout": tune.uniform(0.1, 0.5),
    "final_dropout": tune.uniform(0.1, 0.5)
}

In [None]:
scheduler = ASHAScheduler(
    metric="loss",
    mode="min",
    max_t=num_epochs
)

In [None]:
search_alg = OptunaSearch(
    metric="loss",
    mode="min"
) 

In [None]:
if ray.is_initialized():
    ray.shutdown()

ray.init(runtime_env={ "py_modules": [utils] })

In [None]:
ray_resources_manager = tune.with_resources(
    trainable=lambda param: parameter_train(param, num_epochs, train_dataloader, validation_dataloader, input_shape, output_shape, model_path, device),
    # See: https://stackoverflow.com/questions/58967793/what-is-the-way-to-make-tune-run-parallel-trials-across-multiple-gpus
    resources={ "cpu": 3, "gpu": 0.25 }
)

tuner = tune.Tuner(
    ray_resources_manager,
    param_space=parameter_space,
    tune_config=tune.TuneConfig(
        scheduler=scheduler,
        search_alg=search_alg,
         num_samples=num_samples
    )
)

In [None]:
results = tuner.fit()

In [None]:
if ray.is_initialized():
    ray.shutdown()

In [None]:
# Save as csv file
results.get_dataframe().to_csv(model_path / "trail_grid.csv")

In [None]:
best_result = results.get_best_result("loss", "min")
best_checkpoint = best_result.get_best_checkpoint("loss", "min")

best_model = torch.load(f"{best_checkpoint.path}/checkpoint.pt")

In [None]:
print(f"Best trail by loss value {best_result.metrics['loss']}", "\n------")
for i in best_result.config:
    print(f"Best trail: {i} value {best_result.config[i]}")

## Evaluation

In [None]:
from utils.visualization import create_trace_animation
from matplotlib import pyplot as plt
from IPython.display import HTML

In [None]:
# Compute evaluation on the cpu
device = 'cpu'

In [None]:
%matplotlib notebook
 
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  

### Loading the best model

In [None]:
loss_function = get_loss_function()

In [None]:
model = DecoderLSTM(input_shape, best_result.config["hidden_layers"], output_shape, best_result.config["lstm_dropout"], 
                    best_result.config["final_dropout"], best_result.config["lstm_layers"], best_result.config["bidirectional"])
model.load_state_dict(best_model)
model.eval()

In [None]:
y, y_true = compute_predictions(test_dataloader, model, device)
test_losses = compute_losses_from(y, y_true, loss_function)
print(f"The mean squared error of the loaded model on test is: {test_losses.mean()}")

In [None]:
animation = create_trace_animation(y.numpy(), y_true.numpy())
HTML(animation.to_jshtml())