# This is a basic LSSTM model used to test dataset preperation and model evaluation

## Imports

In [1]:
import torch.nn as nn
import torch
import full_iri_dataset_generator as iri
from training_loop import train_model

## Constants

- `SEQUENCE_LENGTH` is the number of historical measurements before the target element to provide to the model
- `NUM_FREATURES_PER_SAMPLE` is how many details each measurement has. `IRI-only` has 3: left_iri, right_iri, and time_since_first_measurement
- `NUM_LAYERS` is the number of RNN layers to use

In [2]:
SEQUENCE_LENGTH = 10
NUM_FEATURES_PER_SAMPLE = 6
NUM_LAYERS = 5

## Dataset Preperation

Load train and test datasets

In [3]:
train, test = iri.load_iri_datasets(path="../training_data/final_data.parquet",
                                    construction_path="../training_data/construction_data.parquet",
                                    seq_length=SEQUENCE_LENGTH,
                                    one_hot=True)

                                                                          

## Model Definition

Here a basic RNN classifier model is defined.

1. Data is flattened
2. RNN layers process data and modify hidden state
3. final layer maps hidden state to 3 predicted probilities
4. outputs are scaled using a logsoftmax function

In [4]:
class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        self.rnn = nn.LSTM(input_size=SEQUENCE_LENGTH,
                          hidden_size=SEQUENCE_LENGTH * NUM_FEATURES_PER_SAMPLE,
                          num_layers=NUM_LAYERS,
                          batch_first=True)
        self.final = nn.Linear(SEQUENCE_LENGTH * NUM_FEATURES_PER_SAMPLE, 3)

    def forward(self, x):
        hidden = torch.zeros(NUM_LAYERS,
                             x.size(0),
                             SEQUENCE_LENGTH * NUM_FEATURES_PER_SAMPLE).to(x.device)
        cell = torch.zeros(NUM_LAYERS,
                            x.size(0),
                            SEQUENCE_LENGTH * NUM_FEATURES_PER_SAMPLE).to(x.device)
        out, _ = self.rnn(x, (hidden, cell))
        out = self.final(out[:, -1, :])
        out = nn.LogSoftmax(dim=1)(out)
        return out

## Training

In [11]:
model = RNN()
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.33)

train_model(model, train, test, loss, optimizer, epochs=100, test_every_n=10, batch_size=512, lr_scheduler=scheduler)

Training Epoch: 100%|██████████| 100/100 [02:32<00:00,  1.53s/it, Train Loss=0.0596, Test Loss=0.0989]


## Accuracy Computation

In [25]:
from torcheval.metrics import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall, MulticlassF1Score
from torch.utils.data import DataLoader

def compute_metric(dataset, metric):
    train_data = DataLoader(dataset, batch_size=512, shuffle=True)
    for _, data in enumerate(train_data):
        inputs, goal = data[0], data[1]
        _, goal = torch.max(goal, dim=1)
        outputs = model(inputs)
        metric.update(outputs, goal)
    return metric.compute()

def show_metrics_for(dataset, state):
    train_accuracy = compute_metric(dataset, metric=MulticlassAccuracy(average='none', num_classes=3))
    train_precision = compute_metric(dataset, metric=MulticlassPrecision(average='None', num_classes=3))
    train_recall = compute_metric(dataset, metric=MulticlassRecall(average=None, num_classes=3))
    train_f1 = compute_metric(dataset, metric=MulticlassF1Score(average=None, num_classes=3))
    print(f"{state}: GOOD | ACCEPTABLE | POOR")
    print("ACCURACY:", float(train_accuracy[0]), " | ", float(train_accuracy[1]), " | ", float(train_accuracy[2]))
    print("PRECISION:", float(train_precision[0]), " | ", float(train_precision[1]), " | ", float(train_precision[2]))
    print("RECALL:", float(train_recall[0]), " | ", float(train_recall[1]), " | ", float(train_recall[2]))
    print("F1:", float(train_f1[0]), " | ", float(train_f1[1]), " | ", float(train_f1[2]))

model.to("cpu")
model.eval()
with torch.no_grad():
    show_metrics_for(train, "TRAIN")
    print("")
    show_metrics_for(test, "TEST")

TRAIN: GOOD | ACCEPTABLE | POOR
ACCURACY: 1.0  |  0.7968936562538147  |  0.9908524751663208
PRECISION: 0.998706579208374  |  0.8917112350463867  |  0.9817655086517334
RECALL: 1.0  |  0.7968936562538147  |  0.9908524751663208
F1: 0.9993528723716736  |  0.8416403532028198  |  0.9862880110740662

TEST: GOOD | ACCEPTABLE | POOR
ACCURACY: 0.9963924884796143  |  0.6551724076271057  |  0.9790703058242798
PRECISION: 0.9985538721084595  |  0.7638190984725952  |  0.9644097089767456
RECALL: 0.9963924884796143  |  0.6551724076271057  |  0.9790703058242798
F1: 0.9974720478057861  |  0.7053363919258118  |  0.9716846942901611
