In [None]:
import pandas as pd
%load_ext autoreload
%autoreload 2
%cd ..

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [None]:
import copy
from typing import Dict, Iterable, List, Tuple
import os
from pathlib import Path
from typing import List
from collections import Counter
import functools

import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import torch
import torch.utils.data
from tqdm.notebook import tqdm

from ai_based.data_handling.ai_datasets import AiDataset
from ai_based.data_handling.training_batch import TrainingBatch
from ai_based.networks import MLP, Cnn1D
from ai_based.utilities.evaluators import ConfusionMatrixEvaluator
from util.paths import RESULTS_PATH_AI, TRAIN_TEST_SPLIT_YAML, DATA_PATH
from util.train_test_split import read_train_test_split_yaml
from ai_based.utilities.evaluators import BaseEvaluator
from util.datasets import GroundTruthClass, RespiratoryEvent, RespiratoryEventType, \
    RESPIRATORY_EVENT_TYPE__GROUND_TRUTH_CLASS, SlidingWindowDataset
from util.mathutil import cluster_1d, IntRange
from ai_based.utilities.inference import retrieve_respiratory_events
from util.event_based_metrics import get_overlaps, get_n_detected_annotations, OverlapsBasedConfusionMatrix

# Some preparations to pretty-print tensors & ndarrays
np.set_printoptions(edgeitems=10)
np.core.arrayprint._line_width = 400
torch.set_printoptions(linewidth=400)
torch.set_printoptions(threshold=10_000)

n_cpu_cores = len(os.sched_getaffinity(0))
batch_size = 512

## Load model, training config, training logs, etc.
- Training runs are organized as so-called **experiments**.
- An experiment may be run multiple times in different model & hyper-parameter configurations. Each run is called a **combination**.

In [None]:
EXPERIMENT_DIR = RESULTS_PATH_AI / "cnn-5-gt_point-bs128-peakified_signals-low_wd-train_noise"
COMBINATION_DIR = EXPERIMENT_DIR / "combination_0"
REPETITION_DIR = COMBINATION_DIR / "repetition_0"

assert RESULTS_PATH_AI.is_dir() and RESULTS_PATH_AI.exists()
assert EXPERIMENT_DIR.is_dir() and EXPERIMENT_DIR.exists()
assert COMBINATION_DIR.is_dir() and COMBINATION_DIR.exists()
assert REPETITION_DIR.is_dir() and REPETITION_DIR.exists()

if (REPETITION_DIR / "log.pt").exists():
    log = torch.load(REPETITION_DIR / "log.pt", map_location=torch.device("cpu"))
else:
    log = None
    print("No training logs available")

if (REPETITION_DIR / "eval.pt").exists():
    final_validation_eval_results = torch.load(REPETITION_DIR / "eval.pt", map_location=torch.device("cpu"))
else:
    final_validation_eval_results = None
    print("No final eval results available")

config = torch.load(EXPERIMENT_DIR / "config.pt", map_location=torch.device("cpu"))
hyperparams = torch.load(COMBINATION_DIR / "params.pt", map_location=torch.device("cpu"))
if (REPETITION_DIR / "weights.pt").exists():
    weights = torch.load(REPETITION_DIR / "weights.pt", map_location=torch.device("cpu"))
elif (REPETITION_DIR / "checkpoint_best_weights.pt").exists():
    weights = torch.load(REPETITION_DIR / "checkpoint_best_weights.pt", map_location=torch.device("cpu"))
else:
    raise RuntimeError("No weights file found")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print()
print(f"Device: {device}")

model = hyperparams["model"](hyperparams["model_config"])
model.load_state_dict(weights)
model.eval()
model.to(device)

model

## Plot parameters that were logged during training
These are, amongst others:
- Training loss
- Training metrics (precision, recall, f1-score)

In [None]:
fig, axes = plt.subplots(3, 2, figsize=(24, 36))

axes[0, 0].set_title("Loss")
axes[0, 0].plot(log["training"]["loss"])

axes[0, 1].set_title("Weighted mean IoU")
# axes[0, 1].plot(training_miou, label="training")
# axes[0, 1].plot(validation_miou, label="validation")
axes[0, 1].legend()

axes[1, 0].set_title("Detection Recall")
# axes[1, 0].plot(log["training"]["detection_recall"], label="training")
# axes[1, 0].plot(log["validation"]["detection_recall"], label="validation")
axes[1, 0].legend()

axes[1, 1].set_title("Placement Accuracy")
# axes[1, 1].plot(log["training"]["placement_accuracy"], label="training")
# axes[1, 1].plot(log["validation"]["placement_accuracy"], label="validation")
axes[1, 1].legend()

axes[2, 0].set_title("Object Height Accuracy")
# axes[2, 0].plot(log["training"]["height_classification_accuracy"], label="training")
# axes[2, 0].plot(log["validation"]["height_classification_accuracy"], label="validation")
axes[2, 0].legend()

axes[2, 1].set_title("Pixel Height Accuracy")
# axes[2, 1].plot(log["training"]["accuracy"], label="training")
# axes[2, 1].plot(log["validation"]["accuracy"], label="validation")
axes[2, 1].legend()

plt.show()

### Training-specific metrics
The metrics of this sub-section are train batch based and only used during training to
compare different states of a model (e.g. to compare different epochs).

In order to compare detectors of any kind (i.e. rule-based/AI-based), it is strongly recommended
to rely on EventBasedMetrics instead. These EventBasedMetrics will be used throughout
the next examples down below.

In [None]:
dataset_folder = DATA_PATH / "training" / "tr03-0005"

ai_dataset_config = copy.deepcopy(hyperparams["test_dataset_config"])
ai_dataset_config.dataset_folders = [dataset_folder]
ai_dataset_config.noise_mean_std = None
ai_dataset = AiDataset(config=ai_dataset_config)
data_loader = torch.utils.data.DataLoader(ai_dataset, batch_size=batch_size, shuffle=False, collate_fn=TrainingBatch.from_iterable, num_workers=n_cpu_cores-1)

# Iterate over the dataset and gather performance information
overall_evaluator = ConfusionMatrixEvaluator.empty()
for batch in tqdm(data_loader):
    batch.to_device(model.device)
    net_input = torch.autograd.Variable(batch.input_data)
    net_output = model(net_input)
    batch_evaluator = ConfusionMatrixEvaluator(model_output_batch=net_output, ground_truth_batch=batch.ground_truth)
    overall_evaluator += batch_evaluator

overall_evaluator.print_exhausting_metrics_results(include_short_summary=False)

### AI detection run over a single dataset

Let's load a single dataset and output a few statistics on it

In [None]:
dataset_folder = DATA_PATH / "training" / "tr03-0005"

ai_dataset_config = copy.deepcopy(hyperparams["test_dataset_config"])
ai_dataset_config.dataset_folders = [dataset_folder]
ai_dataset_config.noise_mean_std = None
ai_dataset = AiDataset(config=ai_dataset_config)
print(f"#Sliding window positions: {len(ai_dataset)}")

assert len(ai_dataset._sliding_window_datasets) == 1, \
    "The given AiDataset contains more than one SlidingWindowDatasets. This example's intention is to show predictions on a single dataset sample."

# For plotting later on, we need to reload our SlidingWindowDataset. That's because during preprocessing, its signal data was modified by AiDataset
sliding_window_dataset = ai_dataset._sliding_window_datasets[0]
sliding_window_dataset = SlidingWindowDataset(config=sliding_window_dataset.config, dataset_folder=sliding_window_dataset.dataset_folder, allow_caching=True)

print(f"#Physionet dataset samples: {len(sliding_window_dataset.signals)}")
print(f"Timeframe of sliding window positions: {sliding_window_dataset.valid_center_points[-1] - sliding_window_dataset.valid_center_points[0]}")
print(f"Respiratory events list present: {sliding_window_dataset.respiratory_events is not None}")

if sliding_window_dataset.respiratory_events is None:
    print()
    print()
    print("The given dataset does not provide annotated respiratory events. For the following example, we need those annotations!")


Outputting some infos on the annotated respiratory events

In [None]:
annotated_respiratory_events = sliding_window_dataset.respiratory_events

respiratory_event_type_counter = Counter([e.event_type for e in annotated_respiratory_events])
print("Respiratory event types as per annotations:")
print(" - " + "\n - ".join(f"{klass.name}: {cnt}" for klass, cnt in respiratory_event_type_counter.items()))
print()
print(f"{len(annotated_respiratory_events)} annotated respiratory events:")
print(" - " + "\n - ".join([f"#{i}: {evt}" for i, evt in enumerate(annotated_respiratory_events)]))

# Enrich whole sliding window dataset by "is awake" row
awake_series = sliding_window_dataset.awake_series
sliding_window_dataset.signals[awake_series.name] = awake_series
del awake_series

# Enrich whole sliding window dataset by an events outline
annotated_events_outline_mat = np.zeros(shape=(len(sliding_window_dataset.signals),))
for event in annotated_respiratory_events:
    start_idx = sliding_window_dataset.signals.index.get_loc(event.start, method="nearest")
    end_idx = sliding_window_dataset.signals.index.get_loc(event.end, method="nearest")
    annotated_events_outline_mat[start_idx:end_idx] = 1
annotated_events_outline_series = pd.Series(data=annotated_events_outline_mat, index=sliding_window_dataset.signals.index)
sliding_window_dataset.signals["Annotated respiratory events"] = annotated_events_outline_series

del annotated_events_outline_series, annotated_events_outline_mat

Perform the detection. Also generate an outline for the detected events,
which is nice for plotting purposes.

In [None]:
events_dict = retrieve_respiratory_events(model=model, ai_dataset=ai_dataset, batch_size=batch_size, progress_fn=tqdm, min_cluster_length_s=7)
detected_respiratory_events = events_dict[sliding_window_dataset.dataset_name]

detected_hypopnea_events_ = [d_ for d_ in detected_respiratory_events if d_.event_type == RespiratoryEventType.Hypopnea]
detected_apnea_events_ = [d_ for d_ in detected_respiratory_events if d_.event_type != RespiratoryEventType.Hypopnea]

print()
print(f"Detected {len(detected_respiratory_events)} respiratory events")
print(f" ..of which are {len(detected_hypopnea_events_)} hypopneas")

# Enrich the sliding window dataset by an events outline
detected_events_outline_mat = np.zeros(shape=(len(sliding_window_dataset.signals),))
for event in detected_respiratory_events:
    start_idx = sliding_window_dataset.signals.index.get_loc(event.start, method="nearest")
    end_idx = sliding_window_dataset.signals.index.get_loc(event.end, method="nearest")
    detected_events_outline_mat[start_idx:end_idx] = 1
detected_events_outline_series = pd.Series(data=detected_events_outline_mat, index=sliding_window_dataset.signals.index)
sliding_window_dataset.signals["Detected respiratory events"] = detected_events_outline_series

del detected_events_outline_series, detected_events_outline_mat

Generate and output some statistics on the detection performance. These are:
- Overlaps of detected & annotated respiratory events
- Confusion matrix based metrics
- Confusion matrix plot

In [None]:
# Get overlapping annotated/detected events & derive some statistics
overlapping_events = get_overlaps(annotated_events=annotated_respiratory_events, detected_events=detected_respiratory_events)
detected_but_not_annotated = [d_ for d_ in detected_respiratory_events if not any(a_.overlaps(d_) for a_ in annotated_respiratory_events)]
annotated_but_not_detected = [a_ for a_ in annotated_respiratory_events if not any(d_.overlaps(a_) for d_ in detected_respiratory_events)]

print(f"Number of annotated events: {len(annotated_respiratory_events)}")
print(f"Number of detected events: {len(detected_respiratory_events)}")
print()
print(f"Number of OVERLAPPING events: {len(overlapping_events)}")
print(f"- Coverage of annotated respiratory events {len(overlapping_events)/len(annotated_respiratory_events)*100:.1f}%")
print(f"- Detected events that also appear in annotations: {len(overlapping_events)/len(detected_respiratory_events)*100:.1f}%")
print()

# Obtain confusion-matrix based metrics
confusion_matrix = OverlapsBasedConfusionMatrix(annotated_events=annotated_respiratory_events, detected_events=detected_respiratory_events)
macro_scores = confusion_matrix.get_macro_scores()
print("Confusion-matrix based macro scores:")
print(f" -> {macro_scores}")

confusion_matrix.plot(title="Confusion matrix for classification confidence over a single dataset")


The following lines allow plotting annotated & detected respiratory events

In [None]:
event_num = 11
# event = annotated_respiratory_events[event_num]
event = detected_respiratory_events[event_num]
# event = detected_but_not_annotated[event_num]
# event = annotated_but_not_detected[event_num]
# event = detected_hypopnea_events[event_num]

# Let's determine a few values, then plot
window_center_point = event.start + (event.end-event.start)/2
window_start = window_center_point - sliding_window_dataset.config.time_window_size / 2
window_end = window_center_point + sliding_window_dataset.config.time_window_size / 2

annotated_in_window = [e for e in annotated_respiratory_events if e.end > window_start and e.start < window_end]
detected_in_window = [e for e in detected_respiratory_events if e.end > window_start and e.start < window_end]
print()
print("Annotated respiratory events in window:")
print(" - " + "\n - ".join([f"{e.event_type.name}: {(e.end-e.start).total_seconds():.1f}s" for e in annotated_in_window]))
print()
print("Detected respiratory events in window:")
print(" - " + "\n - ".join([f"{e.event_type.name}: {(e.end-e.start).total_seconds():.1f}s" for e in detected_in_window]))

window_data = sliding_window_dataset.get(center_point=window_center_point)
_ = window_data.signals.plot(figsize=(25, 12), subplots=True)

### Detection run over a multiple datasets

Run AI-based detector on a number of datasets.

In [None]:
# Use training dataset folders as per train-test-split
data_folder = DATA_PATH / "training"
train_test_folders = read_train_test_split_yaml(input_yaml=TRAIN_TEST_SPLIT_YAML, prefix_base_folder=data_folder)
dataset_folders = train_test_folders.train
del train_test_folders

# Use a given, small set of dataset folders
dataset_names = ("tr03-0005", "tr03-0289", "tr03-0921", "tr04-1078", "tr07-0168")
dataset_folders = [DATA_PATH / "training" / name for name in dataset_names]

In [None]:
# Load our datasets
ai_dataset_config = copy.deepcopy(hyperparams["test_dataset_config"])
ai_dataset_config.dataset_folders = dataset_folders
ai_dataset_config.noise_mean_std = None
ai_dataset = AiDataset(config=ai_dataset_config)

assert all(ds.respiratory_events is not None for ds in ai_dataset._sliding_window_datasets),\
    "At least one of the sub-datasets has no annotations."

# Perform the detection run
progress_fn = functools.partial(tqdm, desc="Detecting")
detected_events_dict = retrieve_respiratory_events(model=model, ai_dataset=ai_dataset, batch_size=batch_size, progress_fn=progress_fn)

# Reload our SlidingWindowDatasets for later plotting purposes. That's necessary, because their signals were modified by AiDataset
print()
print("Reload SlidingWindowDatasets for plotting purposes.. ", end="")
sliding_window_datasets = []
for ds in ai_dataset._sliding_window_datasets:
    ds_reloaded = SlidingWindowDataset(config=ds.config, dataset_folder=ds.dataset_folder, allow_caching=True)
    sliding_window_datasets += [ds_reloaded]
print("Finished")


In [None]:
detected_events_dict_flattened_ = [event for event_list in detected_events_dict.values() for event in event_list]
detected_hypopnea_events_ = [d_ for d_ in detected_events_dict_flattened_ if d_.event_type == RespiratoryEventType.Hypopnea]
detected_apnea_events_ = [d_ for d_ in detected_events_dict_flattened_ if d_.event_type != RespiratoryEventType.Hypopnea]

print()
print(f"Detected {len(detected_events_dict_flattened_)} respiratory events in all {len(sliding_window_datasets)} datasets")
print(f" ..of which are {len(detected_hypopnea_events_)} hypopneas")
del detected_events_dict_flattened_, detected_hypopnea_events_, detected_apnea_events_

# Enrich the SlidingWindowDataset by event outlines
for sliding_window_dataset in sliding_window_datasets:
    detected_events_outline_mat = np.zeros(shape=(len(sliding_window_dataset.signals),))
    for event in detected_events_dict[sliding_window_dataset.dataset_name]:
        start_idx = sliding_window_dataset.signals.index.get_loc(event.start, method="nearest")
        end_idx = sliding_window_dataset.signals.index.get_loc(event.end, method="nearest")
        detected_events_outline_mat[start_idx:end_idx] = 1
    detected_events_outline_series = pd.Series(data=detected_events_outline_mat, index=sliding_window_dataset.signals.index)
    sliding_window_dataset.signals["Detected respiratory events"] = detected_events_outline_series
del detected_events_outline_series, detected_events_outline_mat


Run a few metrics on the freshly-detected respiratory events. Also, directly print out a few statistics on overlaps
of annotated & detected respiratory events, which results in the __annotation recall__ score.

In [None]:
# Run metrics on our detections
overall_confusion_matrix = OverlapsBasedConfusionMatrix.empty()
n_annotated_events: int = 0
n_detected_events: int = 0
n_detected_annotations: int = 0
for sliding_window_dataset in sliding_window_datasets:
    detected_events = detected_events_dict[sliding_window_dataset.dataset_name]
    n_detected_events += len(detected_events)
    n_annotated_events += len(sliding_window_dataset.respiratory_events)

    cm_ = OverlapsBasedConfusionMatrix(annotated_events=sliding_window_dataset.respiratory_events, detected_events=detected_events)
    overall_confusion_matrix += cm_
    o_ = get_n_detected_annotations(annotated_events=sliding_window_dataset.respiratory_events, detected_events=detected_events)
    n_detected_annotations += o_

print(f"Number of annotated respiratory events: {n_annotated_events}")
print(f"Number of detected respiratory events: {n_detected_events}")
print()
print(f"Number of detected annotations (overlaps): {n_detected_annotations} out of {n_annotated_events}")
print(f" -> Annotation recall: {n_detected_annotations/n_annotated_events:.3f}")

Print the confusion matrix and the derived scores

In [None]:
macro_scores = overall_confusion_matrix.get_macro_scores()
print("Confusion-matrix based macro scores:")
print(f" -> {macro_scores}")

plt.figure(figsize=(7, 7))
overall_confusion_matrix.plot(title=f"Confusion matrix for classification confidence over {len(dataset_folders)} datasets")