# LSTM

In [None]:
from pandas import DataFrame, Series
from sklearn.model_selection import StratifiedGroupKFold
from numpy import (
    ones,
    unique,
    ndarray,
    save, 
    load,
    average
)
from typing import Tuple, Iterator, List, override
from ctypes import ArgumentError
from os import makedirs
from os.path import join, exists
from copy import deepcopy
import matplotlib.pyplot as plt

from src.common.helpers import read_dataframe
from src.common.model import MultiRunTrainArgs, TrainArgs
from src.rnn.architecture import RnnArch
from src.rnn.data import WindowGenerator
from src.rnn.model import Rnn, RnnConstructorArgs, RnnModelInitializeArgs, RnnTrainArgs,\
    RnnMultiRunTrainArgs, RnnTestArgs

In [None]:
class ExtendedStratifiedGroupKFold:

    def __init__(self):
        self._n_splits = 10
        self._shuffle = False

        self._splitter = StratifiedGroupKFold(
            n_splits=self._n_splits, shuffle=self._shuffle
        )

        self._splits: List[Tuple[ndarray, ndarray, ndarray]] | None = None

    def split(
        self, X: DataFrame, y: Series, groups: Series
    ) -> Iterator[Tuple[ndarray, ndarray, ndarray]]:

        for n in range(self._n_splits):
            train_temp, test_index = list(self._splitter.split(X, y, groups))[n]
            train_index, val_index = list(
                self._splitter.split(train_temp, y[train_temp], groups[train_temp])
            )[n]

            train_index = train_temp[train_index]
            val_index = train_temp[val_index]

            yield train_index, val_index, test_index


In [None]:


class RnnFoldCrossValidation:

    @property
    def model_constructor_args(self) -> RnnConstructorArgs:
        return self._model_args

    def __init__(
        self,
        model_args: RnnConstructorArgs
    ):
        self._model_args = model_args
        self._splitter = ExtendedStratifiedGroupKFold()

    def get_full_data_list(self) -> DataFrame:
        path_to_all = join(self.model_constructor_args.data_root_path, "df", "rnn", 
            "cvs_features.pkl")
        return read_dataframe(path_to_all)

    def train_folds(self, train_run_args: MultiRunTrainArgs, verbose: bool = False):
        full_data = self.get_full_data_list()

        feature_placeholder = ones(shape=(full_data.shape[0]))
        labels = full_data["label"]
        groups = full_data["group"]

        split_iterator = self._splitter.split(feature_placeholder, labels, groups)

        for fold_index, split in enumerate(split_iterator):
            train_index, val_index, test_index = split
            train_groups = unique(groups[train_index])
            val_groups = unique(groups[val_index])
            test_groups = unique(groups[test_index])

                    
            if verbose:
                print(f"Fold {fold_index+1}:")
                print(f"  Train: groups={train_groups}")
                print(f"  Val:  groups={val_groups}")
                print(f"  Test:  groups={test_groups}")

            fold_num = fold_index + 1
            self.__train_fold(train_run_args, fold_num, full_data, train_groups, val_groups,
                test_groups, verbose)

        self.print_box_plot()
    
    def print_box_plot(self):
        metrics = self.get_test_accuracy_metrics()

        print(f"Average Top 1 categorical accuracy: {average(metrics)}")

        plt.figure()
        plt.boxplot(metrics)
        plt.show()

    def get_test_accuracy_metrics(self) -> List[float]:
        def get_metric(model: Rnn) -> float:
            return model.get_test_accuracy_metric()

        models = list(map(self.__init_fold_model, range(1, 11)))
        return list(map(get_metric, models))

    def test_folds(self):
        full_data = self.get_full_data_list()

        for fold_index in range(self._splitter._n_splits):
            fold_num = fold_index + 1

            model = self.__init_fold_model(fold_num)
            model_dir = model._get_model_dir()

            if self.__split_files_exist(model_dir):
                (train_groups, val_groups, test_groups) = self.__load_split(model_dir)
            else:
                raise Exception(
                    "split files should already exist when just testing the models"
                )

            wg = self.__build_fold(fold_num, full_data, train_groups, val_groups, test_groups)

            additional_config = self.__get_additional_config(
                context_config={"fold": fold_num}
            )
            model.test_model(
                args=RnnTestArgs(
                    window_generator=wg,
                    write_to_wandb=True, additional_config=additional_config)
            )

        self.print_box_plot()

    def __get_additional_config(self, context_config: dict = {}) -> dict:
        return context_config | {
            # add values from model_initialize_args
        }
            
    def __train_fold(self, train_run_args: MultiRunTrainArgs, fold_num: int, data: DataFrame, 
            train_groups: list, val_groups: list, test_groups: list, verbose: bool):
        model = self.__init_fold_model(fold_num)
        model_dir = model._get_model_dir()

        if self.__split_files_exist(model_dir):
            (train_groups, val_groups, test_groups) = self.__load_split(model_dir)
        else:
            self.__save_split(model_dir, (train_groups, val_groups, test_groups))
        
        wg = self.__build_fold(fold_num, data, train_groups, val_groups, test_groups, verbose)

        additional_config = self.__get_additional_config(
            context_config={"fold": fold_num}
        )
        
        rnn_train_run_args = RnnMultiRunTrainArgs(
            train_args=RnnTrainArgs(
                window_generator = wg,
                epochs = train_run_args.train_args.epochs,
                additional_config = additional_config | train_run_args.train_args.additional_config
            ),
            runs=train_run_args.runs,
        )
        model.execute_train_runs(rnn_train_run_args)

        model.test_model(
            args=RnnTestArgs(window_generator=wg, write_to_wandb=True, 
                additional_config=additional_config)
        )
        
    def __init_fold_model(self, fold_num: int) -> Rnn:
        adapted_args = self._model_args.copy_with(
            name=f"{self._model_args.name}-fold{fold_num}"
        )
        return Rnn(adapted_args)

    def __split_files_exist(self, model_dir):
        return (
            exists(join(model_dir, "split", "train.npy"))
            and exists(join(model_dir, "split", "val.npy"))
            and exists(join(model_dir, "split", "test.npy"))
        )

    def __load_split(self, model_dir) -> Tuple[ndarray, ndarray, ndarray]:
        return (
            load(join(model_dir, "split", "train.npy")),
            load(join(model_dir, "split", "val.npy")),
            load(join(model_dir, "split", "test.npy")),
        )

    def __save_split(self, model_dir, split: Tuple[ndarray, ndarray, ndarray]):
        (train, val, test) = split

        makedirs(join(model_dir, "split"), exist_ok=True)
        save(join(model_dir, "split", "train.npy"), train)
        save(join(model_dir, "split", "val.npy"), val)
        save(join(model_dir, "split", "test.npy"), test)

    def __build_fold(self, fold_num: int, data: DataFrame, train_groups: list, val_groups: list, 
            test_groups: list, verbose: bool) -> WindowGenerator:
        if verbose: print(f"Building fold {fold_num} ...")
        
        wg = WindowGenerator(5, 1, 1, data, train_groups, val_groups, test_groups)
        if verbose:
            print(wg)
            wg.inspect_fold_split()
        return wg
    


In [None]:
rnn_folds = RnnFoldCrossValidation(
    model_args=RnnConstructorArgs(
        name="arch1",
        model_initialize_args=RnnModelInitializeArgs(
            model_arch=RnnArch.ARCH1
        )
    )
)

In [None]:
rnn_folds.train_folds(train_run_args=MultiRunTrainArgs(
    runs=5,
    train_args=TrainArgs(
        epochs=10
    )
))

In [None]:
# def get_group_split(
#     csv_data: DataFrame, split_idx: int, verbose: bool = False
# ) -> Tuple[list, list, list]:

#     n_splits = 10
#     if not (0 <= split_idx and split_idx < n_splits):
#         raise Exception(f"split_idx must be in [0, {n_splits-1}]")

#     feature_placeholder = ones(shape=(csv_data.shape[0]))
#     groups = csv_data["group"]
#     labels = csv_data["label"]

#     sgkf1 = StratifiedGroupKFold(n_splits=n_splits, shuffle=False)
#     train_temp, test_index = list(sgkf1.split(feature_placeholder, labels, groups))[
#         split_idx
#     ]

#     train_index, val_index = list(
#         sgkf1.split(train_temp, labels[train_temp], groups[train_temp])
#     )[split_idx]
#     train_index = train_temp[train_index]
#     val_index = train_temp[val_index]

#     train_groups = unique(groups[train_index])
#     val_groups = unique(groups[val_index])
#     test_groups = unique(groups[test_index])


#     return train_groups, val_groups, test_groups

In [None]:
from pandas import DataFrame, Series, concat
from numpy import arange
from numpy.random import rand, choice

from src.labels import iterate_valid_labels
from src.rnn.data import get_group_split

test_features = DataFrame(rand(100, 20), columns=[f"feat{n}" for n in arange(20)])
test_groups = Series(arange(0, 100), name="group")
test_labels = Series(choice(list(iterate_valid_labels()), 100), name="label")
test_all_features = concat([test_features, test_groups, test_labels], axis=1)

groups1 = get_group_split(test_all_features, 5)
groups2 = get_group_split(test_all_features, 5)

print(all(groups1[0]==groups2[0]))
print(all(groups1[1]==groups2[1]))
print(all(groups1[2]==groups2[2]))

In [None]:
all_features = read_dataframe("data/df/rnn/cvs_features.pkl")

feature_placeholder = ones(shape=(all_features.shape[0]))
groups = all_features["group"]
labels = all_features["label"]

In [None]:
fold_num = 1 # range 1 - 10
sgkf1 = StratifiedGroupKFold(n_splits=10, shuffle=False)
#for i, (train_temp, test_index) in enumerate(sgkf1.split(feature_placeholder, labels, groups)):
train_temp, test_index = list(sgkf1.split(feature_placeholder, labels, groups))[fold_num-1]

sgkf2 = StratifiedGroupKFold(n_splits=10, shuffle=False)
train_index, val_index = list(sgkf2.split(train_temp, labels[train_temp], groups[train_temp]))[fold_num-1]
train_index = train_temp[train_index]
val_index = train_temp[val_index]

train_groups = unique(groups[train_index])
val_groups = unique(groups[val_index])
test_groups = unique(groups[test_index])

print(f"Fold {fold_num}:")
print(f"  Train: groups={train_groups}")
print(f"  Val:  groups={val_groups}")
print(f"  Test:  groups={test_groups}")

In [None]:
from pandas import DataFrame, Series, concat
from numpy import arange
from numpy.random import rand, choice

from src.labels import iterate_valid_labels


In [None]:
w1 = WindowGenerator(input_width=5, label_width=1, shift=1, data=all_features, 
    train_groups=train_groups, val_groups=val_groups, test_groups=test_groups)
print(w1)

In [None]:
w1.inspect_fold_split()

In [None]:
w1.plot()

In [None]:
rnn = Rnn(args=RnnConstructorArgs(
    name="arch1",
    model_initialize_args=RnnModelInitializeArgs(
        model_arch=RnnArch.ARCH1
    )
))
rnn.initialize_model()

In [None]:
rnn._get_best_model_path()

In [None]:
rnn.train_model(args=RnnTrainArgs(
    window_generator=w1,
    epochs=5))

In [None]:
rnn.model.summary()

In [None]:
w1.plot(rnn.model, plot_col="RIGHT_KNEE_x")