In [None]:
# Standard modules
from typing import Dict, NamedTuple, List, Tuple
import os
import pickle
import glob

# External modules
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.model_selection import KFold
import yaml

# Models
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn import svm

# Internal modules
from aliases import *
from models import *
from snippet_accuracy import calculate_snippet_accuracy
from streaming_accuracy import calculate_streaming_accuracy, DataSample, load_data_samples
from train_models import Snippets, load_snippet_files, process_snippets

In [None]:
CLASSIFICATION_FILE: FilePath = "recordings/Actual Event Times.csv"
df: pd.DataFrame = pd.read_csv(CLASSIFICATION_FILE)
df

In [None]:
data: List[DataSample] = load_data_samples(CLASSIFICATION_FILE)

In [None]:
kfold = KFold(n_splits=10, shuffle=True, random_state=1)

In [None]:
# Load SpikerBox parameters
CONFIG_PATH: FilePath = "settings/spiker_box.srconfig"
with open(CONFIG_PATH, "r") as config_file:
    config_data: Dict = yaml.safe_load(config_file)
    # SpikerBox arguments
    buffer_time: float = float(config_data["buffer_time"])
    update_factor: float = float(config_data["update_factor"])
    wait_time: float = float(config_data["wait_time"])
    num_samples: int = int(config_data["num_samples"])
    quality_factor: float = float(config_data["quality_factor"])
    # Modified simple classifier parameters
    m_event_threshold: int = int(config_data["classifier"]["MSC"]["event_threshold"])
    positive_amplitude: float = float(config_data["classifier"]["MSC"]["positive_amplitude"])
    negative_amplitude: float = float(config_data["classifier"]["MSC"]["negative_amplitude"])
    spacing: float = float(config_data["classifier"]["MSC"]["spacing"])
    # Simple classifier parameters
    s_event_threshold: int = int(config_data["classifier"]["USC"]["event_threshold"])
    # Catch22 model paths
    knn_path: FilePath = config_data["classifier"]["KNN"]["file_path"]
    rfc_path: FilePath = config_data["classifier"]["RFC"]["file_path"]
    svc_path: FilePath = config_data["classifier"]["SVC"]["file_path"]

# Snippets folder
SNIPPET_FOLDER: FilePath = "snippets"

# Streaming Accuracies


In [None]:
def test_simplemodel_streaming(model: ModelBase, kfold: KFold) -> List[float]:
    # Initialise model accuracy list
    accuracies: List[float] = []
    # Evaluate k-fold accuracy
    for count, (train, test) in enumerate(kfold.split(data)):
        _train_data: List[DataSample] = [sample for idx, sample in enumerate(data) if idx in train]
        test_data: List[DataSample] = [sample for idx, sample in enumerate(data) if idx in test]

        # Get accuracy
        trial_accuracy: float = calculate_streaming_accuracy(
            test_data,
            model,
            True,
            buffer_time,
            update_factor,
            wait_time,
            num_samples,
            quality_factor,
        )
        # Append
        accuracies.append(trial_accuracy)
        # Iterate
        print(f"Finished trial {count+1}")
        count += 1
    return accuracies

In [None]:
def train_catch22model(untrained_model: SupportsPredict, training_files: List[FilePath]) -> Catch22Model:
    train_data: List[FilePath] = []
    for file_path in training_files:
        _, tail = os.path.split(file_path)
        tail = tail.rstrip(".npy")
        for train_path in glob.glob(f"{SNIPPET_FOLDER}/{tail}_*"):
            train_data.append(train_path)
    snippets: Snippets = load_snippet_files({}, train_data)
    snippet_data, labels = process_snippets(snippets, num_samples)
    untrained_model.fit(snippet_data, labels)
    return Catch22Model(untrained_model)

In [None]:
def test_catch22model_streaming(data: List[DataSample], kfold: KFold, model_type: ModelType) -> List[float]:
    accuracies: List[float] = []
    # Evaluate k-fold accuracy
    for count, (train, test) in enumerate(kfold.split(data)):
        # Train model
        training_files: List[FilePath] = [sample.file_name for idx, sample in enumerate(data) if idx in train]
        untrained_model: SupportsPredict;
        if model_type == ModelType.KNN:
            untrained_model = KNeighborsClassifier(n_neighbors=5)
        elif model_type == ModelType.RFC:
            untrained_model = RandomForestClassifier(n_estimators=100)
        elif model_type == ModelType.SVC:
            untrained_model = svm.SVC()
        model: Catch22Model = train_catch22model(untrained_model, training_files)
        # Test data
        test_data: List[DataSample] = [sample for idx, sample in enumerate(data) if idx in test]

        # Get accuracy
        trial_accuracy: float = calculate_streaming_accuracy(
            test_data,
            model,
            True,
            buffer_time,
            update_factor,
            wait_time,
            num_samples,
            quality_factor,
        )
        # Append
        accuracies.append(trial_accuracy)
        # Iterate
        print(f"Finished trial {count+1}")
        count += 1
    return accuracies

In [None]:
streaming_accuracies: Dict[ModelType, List[float]] = {}

In [None]:
# Modified simple classifier
# Model parameters
model_parameters: List[float] = [
    m_event_threshold,
    positive_amplitude,
    negative_amplitude,
    spacing,
]
# Initialise model
model: ModelBase = ModifiedModel(*model_parameters)
# Initialise model accuracy list
accuracies: List[float] = test_simplemodel_streaming(model, kfold)
# Set new value
streaming_accuracies[ModelType.MSC] = accuracies
print(np.median(accuracies))

In [None]:
# Simple classifier
# Initialise model
model: ModelBase = SimpleModel(s_event_threshold)
# Initialise model accuracy list
accuracies: List[float] = test_simplemodel_streaming(model, kfold)
# Set new value
streaming_accuracies[ModelType.USC] = accuracies
print(np.median(accuracies))

In [None]:
# KNN
accuracies: List[float] = test_catch22model_streaming(data, kfold, ModelType.KNN)
# Set new value
streaming_accuracies[ModelType.KNN] = accuracies
print(np.median(accuracies))

In [None]:
# RFC
accuracies: List[float] = test_catch22model_streaming(data, kfold, ModelType.RFC)
# Set new value
streaming_accuracies[ModelType.RFC] = accuracies
print(np.median(accuracies))

In [None]:
# SVC
accuracies: List[float] = test_catch22model_streaming(data, kfold, ModelType.SVC)
# Set new value
streaming_accuracies[ModelType.SVC] = accuracies
print(np.median(accuracies))

In [None]:
DIAGNOSTICS_FOLDER: FilePath = "diagnostics"
if not os.path.isdir(DIAGNOSTICS_FOLDER):
    os.mkdir(DIAGNOSTICS_FOLDER)

In [None]:
# Boxplot comparing all the accuracies across models
plt.boxplot(streaming_accuracies.values(), labels=["MSC", "USC", "KNN", "RFC", "SVC"])
plt.ylim(0, 1.1);
plt.ylabel("Accuracy")
plt.title("10-fold cross validation streaming accuracies")
plt.savefig(f"{DIAGNOSTICS_FOLDER}/streaming_boxplot.png");

In [None]:
streaming_accuracies_cache: FilePath = f"{DIAGNOSTICS_FOLDER}/streaming_accuracies.pickle"
if not os.path.isfile(streaming_accuracies_cache):
    with open(streaming_accuracies_cache, "wb") as handle:
        pickle.dump(streaming_accuracies, handle)

In [None]:
if os.path.isfile(streaming_accuracies_cache):
    with open(streaming_accuracies_cache, "rb") as handle:
        accuracy_data = pickle.load(handle)
        plt.boxplot(accuracy_data.values(), labels=["MSC", "USC", "KNN", "RFC", "SVC"])
        plt.ylim(0, 1.1);
        plt.ylabel("Accuracy")
        plt.title("10-fold cross validation streaming accuracies")

# Snippet Accuracies


In [None]:
def test_simplemodel_snippet(model: ModelBase, kfold: KFold) -> List[float]:
    # Initialise model accuracy list
    accuracies: List[float] = []
    # Evaluate k-fold accuracy
    for count, (train, test) in enumerate(kfold.split(data)):
        _training_data: List[DataSample] = [sample for idx, sample in enumerate(data) if idx in train]
        test_files: List[DataSample] = [sample.file_name for idx, sample in enumerate(data) if idx in test]
        test_data: List[FilePath] = []
        for file_path in test_files:
            _, tail = os.path.split(file_path)
            tail = tail.rstrip(".npy")
            for train_path in glob.glob(f"{SNIPPET_FOLDER}/{tail}_*"):
                test_data.append(train_path)

        # Get accuracy
        trial_accuracy: float = calculate_snippet_accuracy(
            model,
            test_data,
            num_samples,
        )
        # Append
        accuracies.append(trial_accuracy)
        # Iterate
        print(f"Finished trial {count+1}")
        count += 1
    return accuracies

In [None]:
def test_catch22model_snippets(data: List[DataSample], kfold: KFold, model_type: ModelType) -> List[float]:
    accuracies: List[float] = []
    # Evaluate k-fold accuracy
    for count, (train, test) in enumerate(kfold.split(data)):
        # Train model
        training_files: List[FilePath] = [sample.file_name for idx, sample in enumerate(data) if idx in train]
        untrained_model: SupportsPredict;
        if model_type == ModelType.KNN:
            untrained_model = KNeighborsClassifier(n_neighbors=5)
        elif model_type == ModelType.RFC:
            untrained_model = RandomForestClassifier(n_estimators=100)
        elif model_type == ModelType.SVC:
            untrained_model = svm.SVC()
        model: Catch22Model = train_catch22model(untrained_model, training_files)
        # Test data
        test_files: List[DataSample] = [sample.file_name for idx, sample in enumerate(data) if idx in test]
        test_data: List[FilePath] = []
        for file_path in test_files:
            _, tail = os.path.split(file_path)
            tail = tail.rstrip(".npy")
            for train_path in glob.glob(f"{SNIPPET_FOLDER}/{tail}_*"):
                test_data.append(train_path)

        # Get accuracy
        trial_accuracy: float = calculate_snippet_accuracy(
            model,
            test_data,
            num_samples,
        )
        # Append
        accuracies.append(trial_accuracy)
        # Iterate
        print(f"Finished trial {count+1}")
        count += 1
    return accuracies

In [None]:
snippet_accuracies: Dict[ModelType, List[float]] = {}

In [None]:
# Modified simple classifier
# Model parameters
model_parameters: List[float] = [
    m_event_threshold,
    positive_amplitude,
    negative_amplitude,
    spacing,
]
# Initialise model
model: ModelBase = ModifiedModel(*model_parameters)
# Initialise model accuracy list
accuracies: List[float] = test_simplemodel_snippet(model, kfold)
# Set new value
snippet_accuracies[ModelType.MSC] = accuracies
print(np.median(accuracies))

In [None]:
# Simple classifier
# Initialise model
model: ModelBase = SimpleModel(s_event_threshold)
# Initialise model accuracy list
accuracies: List[float] = test_simplemodel_snippet(model, kfold)
# Set new value
snippet_accuracies[ModelType.USC] = accuracies
print(np.median(accuracies))

In [None]:
# KNN
accuracies: List[float] = test_catch22model_snippets(data, kfold, ModelType.KNN)
# Set new value
snippet_accuracies[ModelType.KNN] = accuracies
print(np.median(accuracies))

In [None]:
# RFC
accuracies: List[float] = test_catch22model_snippets(data, kfold, ModelType.RFC)
# Set new value
snippet_accuracies[ModelType.RFC] = accuracies
print(np.median(accuracies))

In [None]:
# SVC
accuracies: List[float] = test_catch22model_snippets(data, kfold, ModelType.SVC)
# Set new value
snippet_accuracies[ModelType.SVC] = accuracies
print(np.median(accuracies))

In [None]:
# Boxplot comparing all the accuracies across models
plt.boxplot(snippet_accuracies.values(), labels=["MSC", "USC", "KNN", "RFC", "SVC"])
plt.ylim(0, 1.1);
plt.ylabel("Accuracy")
plt.title("10-fold cross validation snippet accuracies")
plt.savefig(f"{DIAGNOSTICS_FOLDER}/snippet_boxplot.png");

In [None]:
snippet_accuracies_cache: FilePath = f"{DIAGNOSTICS_FOLDER}/snippet_accuracies.pickle"
if not os.path.isfile(snippet_accuracies_cache):
    with open(snippet_accuracies_cache, "wb") as handle:
        pickle.dump(snippet_accuracies, handle)

In [None]:
if os.path.isfile(snippet_accuracies_cache):
    with open(snippet_accuracies_cache, "rb") as handle:
        accuracy_data = pickle.load(handle)
        plt.boxplot(accuracy_data.values(), labels=["MSC", "USC", "KNN", "RFC", "SVC"])
        plt.ylim(0, 1.1);
        plt.ylabel("Accuracy")
        plt.title("10-fold cross validation snippet accuracies")