In [None]:
import pandas as pd
%load_ext autoreload
%autoreload 2
%cd ..

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [None]:
import copy
from typing import Dict, Iterable
import os
from pathlib import Path
from typing import List
from collections import Counter

import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import torch
import torch.utils.data
from tqdm.notebook import tqdm

from ai_based.data_handling.ai_datasets import AiDataset
from ai_based.data_handling.training_batch import TrainingBatch
from ai_based.networks import MLP, Cnn1D
from ai_based.utilities.evaluators import ConfusionMatrixEvaluator
from util.paths import RESULTS_PATH_AI, TRAIN_TEST_SPLIT_YAML, DATA_PATH
from util.train_test_split import read_train_test_split_yaml
from ai_based.utilities.evaluators import BaseEvaluator
from util.datasets import GroundTruthClass, RespiratoryEvent, RespiratoryEventType, RESPIRATORY_EVENT_TYPE__GROUND_TRUTH_CLASS
from util.mathutil import cluster_1d, IntRange
from ai_based.utilities.inference import retrieve_respiratory_events

# Some preparations to pretty-print tensors & ndarrays
np.set_printoptions(edgeitems=10)
np.core.arrayprint._line_width = 400
torch.set_printoptions(linewidth=400)
torch.set_printoptions(threshold=10_000)

n_cpu_cores = len(os.sched_getaffinity(0))

## Load model, training config, training logs, etc.
- Training runs are organized as so-called **experiments**.
- An experiment may be run multiple times in different model & hyper-parameter configurations. Each run is called a **combination**.

In [None]:
EXPERIMENT_DIR = RESULTS_PATH_AI / "cnn-3-gt_point-bs128"
COMBINATION_DIR = EXPERIMENT_DIR / "combination_0"
REPETITION_DIR = COMBINATION_DIR / "repetition_0"

assert RESULTS_PATH_AI.is_dir() and RESULTS_PATH_AI.exists()
assert EXPERIMENT_DIR.is_dir() and EXPERIMENT_DIR.exists()
assert COMBINATION_DIR.is_dir() and COMBINATION_DIR.exists()
assert REPETITION_DIR.is_dir() and REPETITION_DIR.exists()

# log = torch.load(REPETITION_DIR / "log.pt", map_location=torch.device("cpu"))
# final_validation_eval_results = torch.load(REPETITION_DIR / "eval.pt", map_location=torch.device("cpu"))
config = torch.load(EXPERIMENT_DIR / "config.pt", map_location=torch.device("cpu"))
hyperparams = torch.load(COMBINATION_DIR / "params.pt", map_location=torch.device("cpu"))
if (REPETITION_DIR / "weights.pt").exists():
    weights = torch.load(REPETITION_DIR / "weights.pt", map_location=torch.device("cpu"))
elif (REPETITION_DIR / "checkpoint_best_weights.pt").exists():
    weights = torch.load(REPETITION_DIR / "checkpoint_best_weights.pt", map_location=torch.device("cpu"))
else:
    raise RuntimeError("No weights file found")

model = hyperparams["model"](hyperparams["model_config"])
model.load_state_dict(weights)
model.eval()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Device: {device}")
model.to(device)

In [None]:
print(len(log["training"]["loss"]))
log["test"]

## Plot parameters that were logged during training
These are, amongst others:
- Training loss
- Accuracy

In [None]:
fig, axes = plt.subplots(3, 2, figsize=(24, 36))

axes[0, 0].set_title("Loss")
axes[0, 0].plot(log["training"]["loss"])

axes[0, 1].set_title("Weighted mean IoU")
# axes[0, 1].plot(training_miou, label="training")
# axes[0, 1].plot(validation_miou, label="validation")
axes[0, 1].legend()

axes[1, 0].set_title("Detection Recall")
# axes[1, 0].plot(log["training"]["detection_recall"], label="training")
# axes[1, 0].plot(log["validation"]["detection_recall"], label="validation")
axes[1, 0].legend()

axes[1, 1].set_title("Placement Accuracy")
# axes[1, 1].plot(log["training"]["placement_accuracy"], label="training")
# axes[1, 1].plot(log["validation"]["placement_accuracy"], label="validation")
axes[1, 1].legend()

axes[2, 0].set_title("Object Height Accuracy")
# axes[2, 0].plot(log["training"]["height_classification_accuracy"], label="training")
# axes[2, 0].plot(log["validation"]["height_classification_accuracy"], label="validation")
axes[2, 0].legend()

axes[2, 1].set_title("Pixel Height Accuracy")
# axes[2, 1].plot(log["training"]["accuracy"], label="training")
# axes[2, 1].plot(log["validation"]["accuracy"], label="validation")
axes[2, 1].legend()

plt.show()

## Load a sample dataset for validation purposes

In [None]:
# data_folder = DATA_PATH / "Physionet_preprocessed"
# train_test_folders = read_train_test_split_yaml(input_yaml=TRAIN_TEST_SPLIT_YAML, prefix_base_folder=data_folder)
# train_folders, test_folders = train_test_folders.train, train_test_folders.test
# del train_test_folders

# Instantiate an AiDataset with exactly __one__ contained SlidingWindowDataset
ai_dataset_config = copy.deepcopy(hyperparams["test_dataset_config"])
ai_dataset_config.dataset_folders = [DATA_PATH/"tr12-0261"]
ai_dataset_config.noise_mean_std = None
ai_dataset = AiDataset(config=ai_dataset_config)

len(ai_dataset)

## Quantitative Evaluation: Aggregate *model performance metrics* over whole ValidationDataset

In [None]:
batch_size = 512

data_loader = torch.utils.data.DataLoader(ai_dataset, batch_size, shuffle=False, collate_fn=TrainingBatch.from_iterable, num_workers=n_cpu_cores-1)

# Iterate over the dataset and gather performance information
overall_evaluator = ConfusionMatrixEvaluator.empty()
for batch in tqdm(data_loader):
    batch.to_device(model.device)
    net_input = torch.autograd.Variable(batch.input_data)
    net_output = model(net_input)
    batch_evaluator = ConfusionMatrixEvaluator(model_output_batch=net_output, ground_truth_batch=batch.ground_truth)
    overall_evaluator += batch_evaluator

overall_evaluator.print_exhausting_metrics_results(include_short_summary=False)


## Apnea events vector construction

In [None]:
respiratory_events = retrieve_respiratory_events(model=model, ai_dataset=ai_dataset, progress_fn=tqdm)[0]

respiratory_event_type_counter = Counter([e.event_type for e in respiratory_events])
print("Respiratory event types as per AI detection:")
print(" - " + "\n - ".join(f"{klass.name}: {cnt}" for klass, cnt in respiratory_event_type_counter.items()))
print()
print(f"{len(respiratory_events)} detected respiratory events:")
print(" - " + "\n - ".join([f"#{i}: {evt}" for i, evt in enumerate(respiratory_events)]))
