# Model 2: LSTM sequence model
## Read data

In [None]:
# FIXME: Only for local usage!
import sys
sys.path.insert(0, "/home/pandavid/uni/WiSe2324/ProjectWork/machine-boom-project/src/utils")

from pathlib import Path

import torch
import numpy as np
import pandas as pd

from file_io import read_all_data_dumps_in
from preprocessing import reshape_dataframe_for_learning

In [None]:
device = ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

In [None]:
data_folder = Path("../../data/").resolve()
raw_data = read_all_data_dumps_in(data_folder)
data = reshape_dataframe_for_learning(raw_data)
print(data["features"][0].shape)
print(data.head())

## Create Trajectory dataset from dataframe

In [None]:
from typing import Tuple, Any

from torch.utils.data import Dataset, random_split

In [None]:
class TrajectoryDataset(Dataset):
    def __init__(self, dataframe: pd.DataFrame, trajectory_length: int):
        super().__init__()
        self.dataframe = dataframe
        self.trajectory_length = trajectory_length
    
    def __len__(self) -> int:
        dataframe_len = len(self.dataframe.index)
        return int(dataframe_len / self.trajectory_length)
    
    def __getitem__(self, index) -> Any:
        dataframe_length = len(self.dataframe.index)
        start_index = index * self.trajectory_length
        end_index = start_index + self.trajectory_length if start_index + self.trajectory_length < dataframe_length else dataframe_length
        slice = self.dataframe.iloc[list(range(start_index, end_index))] 
        features, true_lowpoints = torch.from_numpy(np.stack(slice.iloc[:, 0].to_numpy())), torch.from_numpy(np.stack(slice.iloc[:, 1].to_numpy()))
        return features, true_lowpoints


def read_trajectory_datasets(data_folder: Path, train_split: float, test_split: float, validation_split: float, trajectory_length: int = 10) -> Tuple[TrajectoryDataset, TrajectoryDataset]:
    data = read_all_data_dumps_in(data_folder)
    preprocessed = reshape_dataframe_for_learning(data)
    complete_dataset = TrajectoryDataset(preprocessed, trajectory_length)
    return random_split(complete_dataset, [train_split, test_split, validation_split])


In [None]:
train_set, test_set, validation_set = read_trajectory_datasets(data_folder, 0.8, 0.15, 0.05)
print(len(train_set), len(test_set), len(validation_set))

## Defining the LSTM model

In [None]:
from torch import nn

In [None]:
class DecoderLSTM(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, out_dim: int, dropout_lstm: float = 0.25, dropout_final: float = 0.25, num_lstm_layers: int = 1, bidirectional: bool = False) -> None:
        super().__init__()
        self.total_epochs = 0
        self.hidden_dim = hidden_dim
        self.d = 2 if bidirectional else 1
        self.num_lstm_layers = num_lstm_layers

        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_lstm_layers, dropout=dropout_lstm, bidirectional=bidirectional)
        self.final_dropout = nn.Dropout(dropout_final)
        self.out = nn.Linear(hidden_dim * self.d, out_dim)
        

    def forward(self, x):
        batch_size = x.shape[1]
        # expect x to be of shape (sequence_length, batch_size, input_dim)
        h0 = torch.randn(self.d * self.num_lstm_layers, batch_size, self.hidden_dim)
        c0 = torch.randn(self.d * self.num_lstm_layers, batch_size, self.hidden_dim)
        # output shape is (sequence_length, batch_size, d * hidden_dim)
        output, (hn, cn) = self.lstm(x, (h0, c0))
        output = self.final_dropout(output)
        return self.out(output)
        


In [None]:
model = DecoderLSTM(8, 6, 3).to(device)

## Train the model

In [None]:
import os

from torch.utils.data import DataLoader
from dotenv import load_dotenv

from file_io import save_model
from evaluation import compute_loss_on

In [None]:
dotenv_path = Path("../../models/lstm/.env").resolve()
print(dotenv_path)
load_dotenv(dotenv_path=dotenv_path)
learning_rate = float(os.getenv("LEARNING_RATE"))
batch_size = int(os.getenv("BATCH_SIZE"))

# inputs for training loop
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_function = torch.nn.MSELoss()
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
validation_dataloader = DataLoader(validation_set, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

In [None]:
def train(epochs: int, train_dataloader: DataLoader, validation_dataloader: DataLoader, model, loss_function, optimizer, checkpoint_path: Path, report_interval: int = 1000):
    best_val_loss = float("inf")
    for epoch in range(model.total_epochs, epochs):
        print(f"Epoch: {epoch + 1}")

        model.train(True)
        avg_loss = train_epoch(train_dataloader, model, loss_function, optimizer, report_interval)
        model.eval()

        with torch.no_grad():
            avg_val_loss = compute_loss_on(validation_dataloader, model, loss_function)

        print(f"Loss on train: {avg_loss}, loss on validation: {avg_val_loss}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            model_path = checkpoint_path / f"{checkpoint_path.name}_{epoch}.pt"
            save_model(model, model_path)

        model.total_epochs += 1
    
    return model            


def train_epoch(train_dataloader: DataLoader, model, loss_function, optimizer, report_interval: int = 1000):
    running_loss = 0
    last_loss = 0
    
    for i, (inputs, true_values) in enumerate(train_dataloader):
        inputs_shape, true_values_shape = inputs.size(), true_values.size()
        inputs = inputs.view(inputs_shape[1], inputs_shape[0], inputs_shape[2])
        true_values = true_values.view(true_values_shape[1], true_values_shape[0], true_values_shape[2])
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, true_values)
        running_loss += loss
        loss.backward()
        optimizer.step()
    
    if i % report_interval == report_interval - 1 or i + 1 == len(train_dataloader):
        last_loss = running_loss / report_interval
        print(f"batch {i + 1}, Mean Squared Error: {last_loss}")
        running_loss = 0
    
    return last_loss

In [None]:
checkpointing_path = Path("../../models/lstm/")
last_model = train(int(os.getenv("NUM_EPOCHS")), train_dataloader, validation_dataloader, model, loss_function, optimizer, checkpointing_path)

## Evaluation
### Compute mean squared error

In [None]:
from evaluation import compute_predictions, compute_losses_from

In [None]:
y, y_true = compute_predictions(test_dataloader, last_model)
test_losses = compute_losses_from(y, y_true, loss_function)
print(f"The mean squared error on test is: {test_losses.mean()}")

### Draw prediction/truth traces

In [None]:
%matplotlib notebook

from matplotlib import pyplot as plt

plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  

from IPython.display import HTML

from visualization import create_trace_animation

In [None]:
y, y_true = y.numpy(), y_true.numpy()
y = y.reshape(y.shape[0] * y.shape[1], y.shape[2])
y_true = y_true.reshape(y_true.shape[0] * y_true.shape[1], y_true.shape[2])
animation = create_trace_animation(y, y_true)
HTML(animation.to_jshtml())

## Loading the best model

In [None]:
from file_io import load_model

In [None]:
loaded_model = DecoderLSTM(input_shape, [32, 16], 3).to(device)
model_state_dict = load_model(checkpointing_path)
loaded_model.load_state_dict(model_state_dict)
loaded_model.eval()
y, y_true = compute_predictions(test_dataloader, loaded_model)
test_losses = compute_losses_from(y, y_true, loss_function)
print(f"The mean squared error of the loaded model on test is: {test_losses.mean()}")
animation = create_trace_animation(y.numpy(), y_true.numpy())
HTML(animation.to_jshtml())