# LSTM

In [None]:
from numpy import ones, unique
from sklearn.model_selection import StratifiedGroupKFold

from src.common.helpers import read_dataframe
from src.rnn.data import WindowGenerator


In [None]:
all_features = read_dataframe("data/df/rnn/cvs_features.pkl")

feature_placeholder = ones(shape=(all_features.shape[0]))
groups = all_features["group"]
labels = all_features["label"]

In [None]:
fold_num = 1 # range 1 - 10
sgkf1 = StratifiedGroupKFold(n_splits=10, shuffle=False)
#for i, (train_temp, test_index) in enumerate(sgkf1.split(feature_placeholder, labels, groups)):
train_temp, test_index = list(sgkf1.split(feature_placeholder, labels, groups))[fold_num-1]

sgkf2 = StratifiedGroupKFold(n_splits=10, shuffle=False)
train_index, val_index = list(sgkf2.split(train_temp, labels[train_temp], groups[train_temp]))[fold_num-1]
train_index = train_temp[train_index]
val_index = train_temp[val_index]

train_groups = unique(groups[train_index])
val_groups = unique(groups[val_index])
test_groups = unique(groups[test_index])

print(f"Fold {fold_num}:")
print(f"  Train: groups={train_groups}")
print(f"  Val:  groups={val_groups}")
print(f"  Test:  groups={test_groups}")

In [None]:
w1 = WindowGenerator(input_width=5, label_width=1, shift=1, data=all_features, 
    train_groups=train_groups, val_groups=val_groups, test_groups=test_groups)
print(w1)

In [None]:
w1.inspect_fold_split()

In [None]:
w1.plot()

In [None]:
from keras.api.models import Sequential, Model
from keras.api.layers import LSTM, Dense, Reshape
from keras.api.callbacks import EarlyStopping
from keras.api.optimizers import Adam
from keras.api.losses import CategoricalCrossentropy
from keras.api.metrics import CategoricalAccuracy

from src.labels import get_valid_label_count

In [None]:
lstm_model = Sequential([
    LSTM(128, return_sequences=False),
    Dense(units=get_valid_label_count(), activation="softmax"),
    Reshape((-1, get_valid_label_count()))
])

In [None]:
def compile_and_fit(model: Model, window: WindowGenerator, patience=3):
    early_stopping = EarlyStopping(monitor='val_categorical_accuracy', patience=patience)

    model.compile(loss=CategoricalCrossentropy(), optimizer=Adam(), metrics=[CategoricalAccuracy()])

    history = model.fit(window.train_ds, epochs=5, validation_data=window.val_ds, 
        callbacks=[early_stopping])
    return history

In [None]:
hist = compile_and_fit(lstm_model, w1)

In [None]:
lstm_model.summary()

In [None]:
w1.plot(lstm_model, plot_col="RIGHT_KNEE_x")

In [None]:
performance = lstm_model.evaluate(w1.test_ds, return_dict=True)

In [None]:
from enum import Enum
from typing import Optional, override

from src.common.model import ClassificationModel, ModelConstructorArgs, ModelInitializeArgs,\
    TrainArgs, MultiRunTrainArgs

In [None]:
class RnnArch(Enum):
    ARCH1 = 0

class RnnConstructorArgs(ModelConstructorArgs):

    @property
    def window_generator(self) -> WindowGenerator:
        return self._window_generator

    @override
    def model_arch(self) -> RnnArch:
        return self._model_arch
    
    @override
    def __init__(self, name: str, model_arch: RnnArch, window_generator: WindowGenerator,
            data_root_path = "data", dataset_name = "techniques"):
        super().__init__(name, model_arch, data_root_path, dataset_name)
        self._window_generator = window_generator

class RnnTrainArgs(TrainArgs):

    @override
    @property
    def balance(self) -> bool:
        """Rnn models are never trained on balanced data."""
        return False
    
    def __init__(self, epochs=10, additional_config={}):
        super().__init__(epochs, False, additional_config)

class RnnModelInitializeArgs(ModelInitializeArgs):
    pass

class RnnMultiRunTrainArgs(MultiRunTrainArgs):
    
    @override
    @property
    def model_initialize_args(self) -> RnnModelInitializeArgs:
        """Arguments for initializing the HPE DNN model."""
        return self._model_initialize_args
    
    @override
    def __init__(self, model_initialize_args: RnnModelInitializeArgs, 
            runs = 5, 
            train_args: RnnTrainArgs = RnnTrainArgs()):
        super().__init__(model_initialize_args, runs, train_args)

class Rnn(ClassificationModel):
    
    @override
    @property
    def model_arch(self) -> RnnArch:
        """Enum that is mapped to a factory function"""
        return self._model_arch
    
    @override
    def __init__(self, args: RnnConstructorArgs):
        super().__init__(args)

    @override
    def execute_train_runs(self, args: RnnMultiRunTrainArgs):
        return super().execute_train_runs(args)

    @override
    def initialize_model(self, args: RnnModelInitializeArgs):
        ClassificationModel.initialize_model(self, args)
