# Metrics calculation

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import cc3d
import os
from pathlib import Path
from visualize import load_data, DATASETS
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import pandas as pd


res_folder = Path("/media/liushifeng/KINGSTON/nnUNet_results/")
val_folder = Path("/media/liushifeng/KINGSTON/nnUNet_results/Dataset002_3dlesion_ctfm_seg/nnUNetTrainer__nnUNetPlans__3d_fullres/fold_0/validation/")
images_folder = Path("/media/liushifeng/KINGSTON/nnUNet_raw/Dataset002_3dlesion_ctfm_seg/imagesTr")
labels_folder = Path("/media/liushifeng/KINGSTON/nnUNet_raw/Dataset002_3dlesion_ctfm_seg/labelsTr")

val_names = os.listdir(val_folder)
val_names = [n for n in val_names if n.endswith(".nii.gz")]
uls_val_names = [n for n in val_names if n.lower().startswith("uls")]
len(val_names), len(uls_val_names)

In [None]:
# 1: "ct",
# 2: "ct+seg",
# 3: "ct+seg+box",
# 5: "ct+seg+2box",
# 6: "ct+seg+mask",
# 7: "ct+seg+2mask",

### Voxel-level calculations

In [None]:
rows = []
for name in tqdm(uls_val_names):

    # name = uls_val_names[0]
    _, label, preds = load_data(name, images_folder, labels_folder, res_folder, load_input=False)
    for pred, (dataset_id, dataset_name) in zip(preds, DATASETS.items()):
        tp = (np.logical_and(pred, label)).sum()
        fp = np.equal(pred - label, 1).sum()
        fn = np.equal(label - pred, 1).sum()
        if tp + fp == 0:
            precision = 0
        else:
            precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        rows.append({
            "lesion_name": name,
            "dataset_id": dataset_id,
            "dataset_name": dataset_name,
            "p": round(precision, 3),
            "r": round(recall, 3),
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "volume": label.sum(),
            "pred_volume": pred.sum(),
        })

df_vox = pd.DataFrame(rows)
df_vox

In [None]:
df_vox.to_csv("metrics/fold_0_metrics_voxel.csv", index=False)

### Lesion-level metrics

In [None]:
rows = []
for name in tqdm(uls_val_names):

    _, label, preds = load_data(name, images_folder, labels_folder, res_folder, load_input=False)
    lesion_labels, n = cc3d.connected_components(label, return_N=True)

    for pred, (dataset_id, dataset_name) in zip(preds, DATASETS.items()):
        _, pred_n = cc3d.connected_components(pred, return_N=True)
        tp = fn = 0

        # for every label lesion, check if detected (tp) or not (fn)
        for i in range(n):
            l = i + 1
            lesion_mask = np.equal(lesion_labels, l).astype(int)
            # tp if overlap by more than 1 pixel
            if (np.logical_and(pred, lesion_mask)).sum() > 1:
                tp += 1
            else:
                fn += 1

        fp = pred_n - tp
        if tp + fp == 0:
            precision = 0
        else:
            precision = tp / (tp + fp)
        recall = tp / (tp + fn)

        rows.append({
            "lesion_name": name,
            "dataset_id": dataset_id,
            "dataset_name": dataset_name,
            "p": round(precision, 3),
            "r": round(recall, 3),
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "volume": label.sum(),
            "pred_volume": pred.sum(),
        })

df_les = pd.DataFrame(rows)
df_les

In [None]:
df_les.to_csv("metrics/fold_0_metrics_lesion.csv", index=False)

## Metrics aggregation &  visualization

In [None]:
import pandas as pd
import seaborn as sns

In [None]:
dice = {
"ct": 0.2609,
"ct+seg": 0.3806,
"ct+seg+box": 0.5843,
"ct+seg+2box": 0.7687,
"ct+seg+mask": 0.6108,
"ct+seg+2mask": 0.8123
}
df_dice = pd.DataFrame(dice, index=[0]).T
df_dice.columns = ["dice"]
df_dice.reset_index(names="dataset_name", inplace=True)

In [None]:
df_dice['metric'] = "dice"
df_dice['value'] = df_dice['dice']
g = sns.catplot(
    data=df_dice,
    y="dataset_name",
    x="value",
    col="metric",
    hue="dataset_name",
    errorbar=None,
    kind="bar",
    height=3,
    aspect=1.25,
    width=0.5,
    sharex=False,
    palette="deep",
)
g.set_axis_labels("", "inputs")
g.set_titles("metric=dice")
g.savefig("metrics/dice_plot.png")

In [None]:
# df = pd.read_csv("metrics/fold_0_metrics_voxel.csv")
df = pd.read_csv("metrics/fold_0_metrics_lesion.csv")

In [None]:
(df.groupby("dataset_name")['abs_perc_vol_error'].mean() * 100).round(2)

In [None]:
df['abs_perc_vol_error'] = (df['pred_volume'] - df['volume']).abs() / df['volume']
df = df.rename(columns={"p": "precision", "r": "recall"})
df_v = df.melt(id_vars=['lesion_name', 'dataset_name'],
        value_vars=['precision', 'recall', 'abs_perc_vol_error'],
        var_name='metric', value_name='value')

In [None]:
g = sns.catplot(
    data=df_v,
    y="dataset_name",
    x="value",
    col="metric",
    hue="dataset_name",
    errorbar=None,
    kind="bar",
    height=3,
    aspect=1.25,
    width=0.5,
    sharex=False,
    palette="deep",
)
# g.set_xticklabels(rotation=20)
# g.axes[0][0].set(ylim=(-0.1, 1.1));
g.set_axis_labels("", "inputs");
# g.set(ylim=(-0.1, 1.1))
g.savefig("metrics/lesion_precision_recall_vol_plot.png")