In [None]:
import logging
# import lime.explanation
import matplotlib
import numpy as np
import os
import pandas as pd
import sys
import torch
from datetime import datetime
from matplotlib import pyplot as plt
from pathlib import PosixPath
from sklearn.metrics import (
    classification_report,
    accuracy_score,
    f1_score,
    precision_score,
    recall_score
)
from sklearn.model_selection import KFold, train_test_split
from time import time
from torch import nn
from torch.utils.data import (
    DataLoader,
    Dataset,
    ConcatDataset
)
from tqdm import tqdm
from typing import List, Literal, Optional, Tuple, Union

# import lime
# from lime import lime_tabular

logging.basicConfig(
    format="%(asctime)s | [%(levelname)-8s] | %(message)s",
    datefmt='%Y-%m-%d %H:%M:%S',
    handlers= [
        logging.FileHandler(f'session-{datetime.now().strftime("%d-%m-%Y_%H:%M:%S")}.log'),
        logging.StreamHandler(sys.stdout)
    ]
)

SEED = 42

LOGGER = logging.getLogger("diabetes_trainer")
LOGGER.setLevel(logging.DEBUG)

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
LOGGER.info(f"Running with device: {DEVICE}")

FEATURE_COUNT = 5
OUTPUT_OUTCOME_COUNT = 4

'''-------------------------------------------------------------------------------------'''

class SimpleDiabetesModel(nn.Module):
    '''
    Sets up a very simple and naive MLP model with one hidden ReLU layer of 20 nodes.
    '''
    def __init__(self):
        super(SimpleDiabetesModel, self).__init__()
        self.fc1 = nn.Linear(FEATURE_COUNT, 20)
        self.fc2 = nn.Linear(20, OUTPUT_OUTCOME_COUNT)

    def forward(self, x: torch.Tensor):
        # Performs forward pass.
        # Shape of x as input: I think it should be [rows x FEATURE_COUNT]?
        # Shape of x as output: should be [rows x OUTPUT_OUTCOME_COUNT]?
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class DiabetesDataset(Dataset):
    """Diabetes dataset"""
    def __init__(
            self,
            input_file: Optional[Union[str, os.PathLike, PosixPath]] = None,
            premade_dataset: Optional[pd.DataFrame] = None
        ):
        """
        Setups the diabetes dataset. Two options are available - either use a CSV file, or use a premade DataFrame.
        For any instance of this, an input must be provided - ValueError will be raised if no input is provided.

        To use the CSV file:

        ```python
        data = DiabetesDataset('file.csv')
        ```

        To use a previously made dataset on the memory

        ```python
        # Assume that the premade dataset is named merged_dataset in the code
        data = DiabetesDataset(premade_dataset=merged_dataset)
        ```

        NOTE: The dataset is assumed to have six columns - 'age', 'gender', 'bmi', 'blood_glucose_level', 'hypertension' and 'diabetes'

        Args:
            input_file (Optional[Union[str, os.PathLike, PosixPath]], optional): An input file containing the dataset information. Defaults to None.
            premade_dataset (Optional[pd.DataFrame], optional): A premade dataset on the memory. Defaults to None.

        Raises:
            ValueError: Raised when there is no input file or premade dataset provided.
        """
        if input_file is None and premade_dataset is None:
            raise ValueError("An input source must be provided as a dataset.")
        if input_file is not None:
            self.data: pd.DataFrame = pd.read_csv(input_file)
        elif premade_dataset is not None:
            self.data = premade_dataset

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        features = torch.tensor(self.data.loc[index, ['age', 'gender', 'bmi', 'blood_glucose_level', 'hypertension']], dtype=torch.float32)
        output = torch.tensor(self.data.loc[index, 'diabetes'], dtype=torch.long)
        return features, output


def setup_adam_optimizer(model: nn.Module, learning_rate: float = 0.001):
    """
    Setups Adam optimizer for training.
    Adam optimizer adjusts the learning rate for each param individually
        based on the gradients and their moving averages.
    This optimizer is also more efficient than gradient descent.

    Args:
        model (nn.Module): The model for training.
        learning_rate (float, optional): The starting learning rate for training. Defaults to 0.001.
    """
    # we don't have to worry about other params, yet.
    # keeps things simple first... please
    return torch.optim.Adam(model.parameters(), lr=learning_rate)

def setup_cross_entropy_loss():
    """
    Setups Cross Entropy Loss for training.
    PyTorch's CrossEntropyLoss already has log-softmax in the calculation, which explains why this
    implementation of the model doesn't have softmax.
    """
    return nn.CrossEntropyLoss().to(DEVICE)

def epoch_time(start_time: float, end_time: float) -> Tuple[int, int]:
    """
    Calculates the time it takes for each epoch in minutes and seconds.

    Args:
        start_time (float): The point of start time.
        end_time (float): The point of end time
    """
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

def split_dataset(
    input_file: Union[str, os.PathLike, PosixPath, pd.DataFrame],
    splits: List[float]
) -> Tuple[DiabetesDataset]:
    """
    Splits the dataset into train, test and validation sets,
    while ensuring roughly equal original distribution of the dataset.
    Args:
        input_file (Union[str, os.PathLike, PosixPath, pd.DataFrame]): Input data for the training
        splits (List[float]): The ratio of each subset for splitting.
            In order: train, test and val splitting.
            The list must have three elements.

    Return: DiabetesDataset instances of the splitted dataset.
    """
    assert len(splits) == 3, "The splits must have all three subset splits."

    if isinstance(input_file, pd.DataFrame):
        dataset = input_file
    else:
        dataset = pd.read_csv(input_file)

    # features = dataset[['age', 'gender', 'bmi', 'blood_glucose_level', 'hypertension']]
    output = dataset['diabetes']

    train_split, test_split, val_split = splits
    # round 1
    train_subset, test_subset = train_test_split(
        dataset,
        test_size=test_split,
        random_state=SEED,
        stratify=output
    )
    train_output = train_subset['diabetes']
    # round 2
    train_subset, val_subset = train_test_split(
        train_subset,
        test_size=val_split/train_split,
        random_state=SEED,
        stratify=train_output
    )

    train_subset.reset_index(drop=True, inplace=True)
    test_subset.reset_index(drop=True, inplace=True)
    val_subset.reset_index(drop=True, inplace=True)

    train_ds = DiabetesDataset(premade_dataset=train_subset)
    test_ds = DiabetesDataset(premade_dataset=test_subset)
    val_ds = DiabetesDataset(premade_dataset=val_subset)

    LOGGER.info(f"Post-split statistics:")
    LOGGER.info(f"length of train: {len(train_ds)}")
    LOGGER.info(f"ratio of train: {len(train_ds)/len(dataset)}")
    LOGGER.info(f"distribution of train:\n{train_subset['diabetes'].value_counts()/len(train_subset)}")
    LOGGER.info(f"length of test: {len(test_ds)}")
    LOGGER.info(f"ratio of test: {len(test_ds)/len(dataset)}")
    LOGGER.info(f"distribution of test:\n{test_subset['diabetes'].value_counts()/len(test_subset)}")
    LOGGER.info(f"length of val: {len(val_ds)}")
    LOGGER.info(f"ratio of val: {len(val_ds)/len(dataset)}")
    LOGGER.info(f"distribution of val:\n{val_subset['diabetes'].value_counts()/len(val_subset)}")


    return train_ds, test_ds, val_ds


def train_and_eval(
    input_file: Union[str, os.PathLike, PosixPath, pd.DataFrame],
    model: nn.Module,
    training_mode: Literal['naive', 'kfold'] = 'naive',
    train_split: float = 0.8,
    test_split: float = 0.1,
    val_split: float = 0.1,
    batch_size: int = 64,
    learning_rate: float = 0.001,
    epochs: int = 2,
    n_folds: int = 10
):
    """
    Setting up training for the model.

    Args:
        input_file (Union[str, os.PathLike, PosixPath, pd.DataFrame]): Input data for the training
        training_mode (Literal['naive', 'kfold'], optional): Training mode for the model.
            - 'naive' for a classic epoch-based training.
            - 'kfold' for training with k-Fold Cross Validation.
            Defaults to 'naive'.
        train_split (float, optional): The split for the train set, if training with 'naive'. Defaults to 0.8.
        test_split (float, optional): The split for the test set, if training with 'naive'. Defaults to 0.1.
        val_split (float, optional): The split for the validation set, if training with 'naive'. Defaults to 0.1.
        batch_size (int, optional): The batch size for training. Defaults to 64.
        learning_rate (float, optional): The learning rate for training. Defaults to 0.001.
        epochs (int, optional): The number of epochs to run the training. Defaults to 2.
        n_folds (int, optional): For training with k-fold Cross Validation, how many folds there will be. Defaults to 10.
    """
    #dataset = DiabetesDataset(input_file)
    # model = SimpleDiabetesModel()
    model.to(DEVICE)

    # seeded_generator = torch.Generator().manual_seed(SEED)
    train_subset, test_subset, val_subset = split_dataset(input_file, [train_split, test_split, val_split])
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=True, num_workers=2)

    optimizer = setup_adam_optimizer(model, learning_rate)
    loss_func = setup_cross_entropy_loss()

    best_loss = float('inf')
    trained_model_names = []

    if training_mode == 'naive':
        for epoch in tqdm(range(epochs), desc='Training'):
            start_time = time()

            # actually start training
            model.train()
            epoch_loss = 0
            current_loss = 0
            for i, batch in enumerate(train_loader):
                feature, target = batch
                feature = feature.to(DEVICE)
                target = target.to(DEVICE)

                output = model(feature)
                loss = loss_func(output, target)
                optimizer.zero_grad()

                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
                current_loss += loss.item()
                if i % 500 == 499:
                    print('Loss after mini-batch %5d: %.3f' %
                        (i + 1, current_loss / 500))
                    current_loss = 0.0

            train_loss = epoch_loss / len(train_loader.dataset)

            model.eval()
            epoch_eval_loss = 0
            with torch.no_grad():
                for i, batch in enumerate(val_loader):
                    feature, target = batch
                    feature = feature.to(DEVICE)
                    target = target.to(DEVICE)

                    output = model(feature)
                    eval_loss = loss_func(output, target)

                    epoch_eval_loss += eval_loss.item()

            val_loss = epoch_eval_loss / len(val_loader.dataset)
            end_time = time()
            epoch_time_min, epoch_time_sec = epoch_time(start_time, end_time)

            if val_loss < best_loss:
                best_loss = val_loss
                torch.save(model.state_dict(), 'diabetes_model_ckpt.pth')

            LOGGER.info(f'Epoch: {epoch + 1} | Time: {epoch_time_min}m {epoch_time_sec}s')
            LOGGER.info(f'\tTrain Loss: {train_loss:.3f}')
            LOGGER.info(f'\tVal Loss: {val_loss:.3f}')

            # after training, we immediately evaluate the trained model

        trained_model_names.append('diabetes_model_ckpt.pth')

        model.load_state_dict(torch.load('diabetes_model_ckpt.pth', weights_only=True))
        LOGGER.info("Loaded trained model for evaluation.")
        if DEVICE == 'cuda':
            model = model.cuda()

        model.eval()

        true_target = []
        predicted_target = []

        with torch.no_grad():
            for i, batch in enumerate(test_loader):
                feature, target = batch
                feature = feature.to(DEVICE)
                target = target.to(DEVICE)

                output = model(feature)
                true_target.append(target.detach().cpu().data.numpy())
                predicted_target.append(output.detach().cpu().data.numpy())

        true_target = np.concatenate(true_target)
        predicted_target = np.concatenate(predicted_target, axis=0).argmax(axis=1)

        LOGGER.info(f"Classification Report for the model:")
        LOGGER.info(classification_report(true_target, predicted_target))

    else:
        kfold = KFold(n_splits=n_folds, shuffle=True, random_state=SEED)
        combined_train_subset = ConcatDataset([train_subset, val_subset, test_subset])
        accuracies = []
        f1s = []
        precisions = []
        recalls = []
        for fold, (train_idx, val_idx) in enumerate(kfold.split(combined_train_subset)):
            LOGGER.info(f"Fold {fold + 1}/{n_folds}")
            train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
            val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)

            train_loader = DataLoader(combined_train_subset, batch_size=batch_size, sampler=train_sampler)
            val_loader = DataLoader(combined_train_subset, batch_size=batch_size, sampler=val_sampler)

            for epoch in tqdm(range(epochs), desc=f'Training with fold={fold+1}'):
                start_time = time()
                model.train()
                epoch_loss = 0
                current_loss = 0
                for i, batch in enumerate(train_loader):
                    feature, target = batch
                    feature = feature.to(DEVICE)
                    target = target.to(DEVICE)

                    output = model(feature)
                    loss = loss_func(output, target)
                    optimizer.zero_grad()

                    loss.backward()
                    optimizer.step()
                    epoch_loss += loss.item()
                    current_loss += loss.item()
                    if i % 500 == 499:
                        print('Loss after mini-batch %5d: %.3f' %
                            (i + 1, current_loss / 500))
                        current_loss = 0.0

                train_loss = epoch_loss / len(train_loader.dataset)
                end_time = time()
                epoch_time_min, epoch_time_sec = epoch_time(start_time, end_time)

                LOGGER.info(f'Epoch: {epoch + 1} | Time: {epoch_time_min}m {epoch_time_sec}s')
                LOGGER.info(f'\tTrain Loss: {train_loss:.3f}')

            LOGGER.info('Starting testing')

            # Saving the model
            save_path = f'diabetes_model_fold_{fold}.pth'
            torch.save(model.state_dict(), save_path)

            trained_model_names.append('diabetes_model_fold_{fold}.pth')

            # Evaluation for this fold
            model.load_state_dict(torch.load(f'diabetes_model_fold_{fold}.pth', weights_only=True))
            LOGGER.info("Loaded trained model for evaluation.")
            if DEVICE == 'cuda':
                model = model.cuda()

            model.eval()

            true_target = []
            predicted_target = []

            with torch.no_grad():
                for i, batch in enumerate(test_loader):
                    feature, target = batch
                    feature = feature.to(DEVICE)
                    target = target.to(DEVICE)

                    output = model(feature)
                    true_target.append(target.detach().cpu().data.numpy())
                    predicted_target.append(output.detach().cpu().data.numpy())

            true_target = np.concatenate(true_target)
            predicted_target = np.concatenate(predicted_target, axis=0).argmax(axis=1)

            accuracy = accuracy_score(true_target, predicted_target)
            f1 = f1_score(true_target, predicted_target, average='marco')
            precision = precision_score(true_target, predicted_target, average='marco')
            recall = recall_score(true_target, predicted_target, average='marco')

            accuracies.append(accuracy)
            f1s.append(f1)
            precisions.append(precision)
            recalls.append(recall)

            LOGGER.info(f"Accuracy score for fold {fold+1}: {accuracy:.2f}")
            LOGGER.info(f"F1 score for fold {fold+1}: {f1:.2f}")
            LOGGER.info(f"Precision score for fold {fold+1}: {precision:.2f}")
            LOGGER.info(f"Recall score for fold {fold+1}: {recall:.2f}")

        LOGGER.info(f"Average accuracy across {n_folds} folds: {sum(accuracies)/len(accuracies):.2f}")
        LOGGER.info(f"Average f1 score across {n_folds} folds: {sum(f1s)/len(f1s):.2f}")
        LOGGER.info(f"Average precision score across {n_folds} folds: {sum(precisions)/len(precisions):.2f}")
        LOGGER.info(f"Average recall score across {n_folds} folds: {sum(recalls)/len(recalls):.2f}")

    return trained_model_names


# def explain_model(
#         model: nn.Module,
#         model_name: str,
#         explainer: lime_tabular.LimeTabularExplainer,
#         test_sample: np.ndarray
#     ):
#     """
#     Attempts to explain the model with LIME

#     Args:
#         model (nn.Module): A PyTorch-based model for explaination
#         model_name (str): The name of the model to load into.
#         test_bundle (DataLoader): The testing set that needs to be examined.
#     """
#     model.load_state_dict(torch.load(f'{model_name}', weights_only=True))

    # def quick_predict(test: np.ndarray):
    #     #print(test.dtype)
    #     model.eval()
    #     with torch.no_grad():
    #         test = torch.from_numpy(test)
    #         test = test.float()
    #         #print(test.dtype)
    #         test = test.to(DEVICE)
    #         output = model(test)
    #         probas = nn.functional.softmax(output, dim=1)
    #         print(probas.shape)
    #     #print(output.detach().cpu().data.numpy())
    #     #print(probas)
    #     print(probas.numpy())
    #     return probas.numpy()

    # exp: lime.explanation.Explanation = explainer.explain_instance(test_sample, quick_predict, num_features=5)
    # plot = exp.as_pyplot_figure()
    # plot.savefig('test.png')
if __name__ == '__main__':
    model = SimpleDiabetesModel()
    train_and_eval(
        'MergedDataset.csv',
        model,
        'naive',
        epochs=10,
    )