# Import

In [None]:
import pandas as pd

In [None]:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

# Load run

In [None]:
%run ../utils/files.py

In [None]:
run_id = RunId('0422_163242', False, 'cls-seg')
run_id

In [None]:
results_folder = get_results_folder(run_id)
fpath = os.path.join(results_folder, 'outputs.csv')
df = pd.read_csv(fpath)
df.head()

# Calculate threshs

In [None]:
%run ../metrics/classification/optimize_threshold.py
# %run ../utils/common.py

In [None]:
diseases = _get_diseases_from_results_df(df)
len(diseases)

In [None]:
val_df = df.loc[df['dataset_type'] == 'val']
print(len(val_df))
assert len(val_df['filename'].unique()) == len(val_df)
val_df.head()

## Plot ROC curve

In [None]:
def get_gt_pred_for_disease(disease):
    gt = val_df[f'{disease}-gt'].to_numpy()
    pred = val_df[f'{disease}-pred'].to_numpy()
    return gt, pred

In [None]:
disease = 'Cardiomegaly'

In [None]:
gt, pred = get_gt_pred_for_disease(disease)
fpr, tpr, thresholds = roc_curve(gt, pred)
thresholds[0] = 1

J_stat = tpr - fpr
best_idx = J_stat.argmax()
thresholds[best_idx], J_stat[best_idx]

In [None]:
n_rows = 1
n_cols = 2
plt.figure(figsize=(14, 5))

plt.subplot(n_rows, n_cols, 1)
plt.plot(fpr, tpr)
plt.title(f'ROC-curve ({disease})', fontsize=25)
plt.xlabel('False positive rate', fontsize=18)
plt.ylabel('True positive rate', fontsize=18)

plt.scatter(fpr[best_idx], tpr[best_idx], marker='o', color='orange')

plt.subplot(n_rows, n_cols, 2)
plt.plot(thresholds, J_stat)
plt.title('Optimal threshold', fontsize=20)
plt.xlabel('Threshold', fontsize=18)
plt.ylabel('J = tpr - fpr', fontsize=18)

## Plot PR curves

In [None]:
precision, recall, thresholds = precision_recall_curve(gt, pred)
precision = precision[:-1]
recall = recall[:-1]


f1 = divide_arrays(2 * precision * recall, precision + recall)
best_idx = f1.argmax()
best_idx, thresholds[best_idx], f1[best_idx]

In [None]:
plt.figure(figsize=(14, 5))

plt.subplot(1, 2, 1)
plt.plot(recall, precision)
plt.title(f'PR-curve ({disease})', fontsize=25)
plt.xlabel('Recall', fontsize=18)
plt.ylabel('Precision', fontsize=18)

plt.scatter(recall[best_idx], precision[best_idx], marker='o', color='orange')

plt.subplot(1, 2, 2)
plt.plot(thresholds, f1)
plt.title('Optimal threshold', fontsize=20)
plt.xlabel('Threshold', fontsize=18)
plt.ylabel('F1', fontsize=18)