# Imports

In [None]:
import os
from collections import Counter, defaultdict
import importlib
import json
import numpy as np

In [None]:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'
matplotlib.rcParams['figure.figsize'] = (15, 5)

In [None]:
import pandas as pd
pd.options.display.max_columns = None

In [None]:
%run ../../utils/__init__.py
config_logging(logging.INFO)

In [None]:
%run ../../datasets/common/constants.py

In [None]:
%run -n ./nlp_in_chexpert_groups.py

# Debug running experiments

In [None]:
%run -n ./nlp_in_chexpert_groups.py

In [None]:
dataset_info = init_dataset_info('iu')
dataset_info.name

In [None]:
exp = init_experiment('Cardiomegaly', dataset_info)
exp_LO

In [None]:
%%time

kwargs = {
    # 'metric': 'bleu',
    # 'metric': 'rouge',
    'metric': 'cider-IDF',
    'k_times': 500,
    # 'k_gts': 1,
    'max_n': 500,
}
exp.append(calc_score_matrices(exp.grouped_2, dataset_info, groups=(0, 1), **kwargs))
# exp.append(calc_score_matrices(exp.grouped, dataset_info, **kwargs))

In [None]:
exp[-1].cube

## Analyze/plot experiments

In [None]:
exp = load_experiment_pickle('iu-cardiomegaly')
exp

In [None]:
exp.results[-1].cube

In [None]:
RESULT_I = -1
METRIC_I = 0

In [None]:
plt.figure(figsize=(15, 6))

plt.subplot(1, 2, 1)
plot_heatmap(exp, result_i=RESULT_I, metric_i=METRIC_I)

if len(exp.results) > 1:
    plt.subplot(1, 2, 2)
    plot_heatmap(exp, result_i=-2, metric_i=METRIC_I)

In [None]:
dataset_info.log_ref_len

In [None]:
plt.figure(figsize=(8, 6))

plt.subplot(2, 1, 1)
plot_hists(
    exp, [
        (0, 0), (0, 1),
        # (0, 0), (1, 0),
        # (1, 1), (0, 1),
        # (1, 0), (0, 0),
    ],
    result_i=RESULT_I, metric_i=METRIC_I,
    xlabel=False, bins=50, range=(0, 1),
    add_n_to_label=True,
)

plt.subplot(2, 1, 2)
plot_hists(
    exp, [
        (1, 1), (1, 0),
        # (1, 1), (0, 1),
        # (0, 0), (1, 0),
        # (-2, -1), (-1, -1),
    ],
    result_i=RESULT_I, metric_i=METRIC_I, add_n_to_label=True,
    title=False, bins=50, range=(0, 0.2))

# Attempting to optimize threshold

## Attempt 1

(failed)

In [None]:
result = exp[0]
result

In [None]:
# target1, target2 = (0, 0), (0, 1) # TN, FP (specificity)
target1, target2 = (1, 1), (1, 0) # TP, FN (precision)

arr1 = result.dists[target1]
arr2 = result.dists[target2]
arr1.shape, arr2.shape

In [None]:
assert target1[0] == target1[1]
CORRECT = target1[0]
INCORRECT = 1 - CORRECT

merged = [(value, CORRECT) for value in arr1] + [(value, INCORRECT) for value in arr2]
merged = sorted(merged, reverse=bool(not CORRECT))
merged[:2], merged[-2:]

In [None]:
all_threshs = []
denominator = 0 # TP + FN
numerator = 0 # TP
for value, label in merged:
    current_thresh = value
    if label == CORRECT: # add 1 TP
        numerator += 1

    denominator += 1

    all_threshs.append((current_thresh, numerator / denominator))
all_threshs[:5]

In [None]:
max(all_threshs, key=lambda x: x[1])

In [None]:
x, y = tuple(zip(*all_threshs))
plt.plot(x, y)

## Attempt 2

with sklearn
Failed

In [None]:
from sklearn.metrics import precision_recall_curve as pr_curve

In [None]:
pred, gt = tuple(zip(*merged))
pred = np.array(pred)
gt = np.array(gt)
pred.shape, gt.shape

In [None]:
# pred /= 10 # CIDER re-scaling

In [None]:
precision, recall, thresholds = pr_curve(gt, pred, pos_label=CORRECT)
precision.shape, recall.shape, thresholds.shape

In [None]:
f1 = divide_arrays(2*precision*recall, precision + recall)
f1.shape

In [None]:
best_idx = f1.argmax()
best_idx

In [None]:
thresholds[best_idx], f1[best_idx], precision[best_idx], recall[best_idx]

## Attempt 3: accuracy/prec/recall

CheXpert 4-class classification task --> is a binary classification task in NLP scores
(i.e. NLP scores tell less information)

In [None]:
exp = load_experiment_pickle('mimic-cardiomegaly')
exp

In [None]:
result = exp.results[-1]
result.metric

In [None]:
result.dists

In [None]:
merged = [
    # Value, correct-or-not, original-key
    (value, 1, (0, 0)) for value in result.dists[(0, 0)]
] + [
    (value, 1, (1, 1)) for value in result.dists[(1, 1)]
] + [
    (value, 0, (0, 1)) for value in result.dists[(0, 1)]
] + [
    (value, 0, (1, 0)) for value in result.dists[(1, 0)]
]
merged = sorted(merged)
len(merged), merged[:3]

In [None]:
n_correct = sum(1 for _, correct, _ in merged if correct)
n_incorrect = sum(1 for _, correct, _ in merged if not correct)
n_correct, n_incorrect

In [None]:
def smart_division(a, b):
    if b == 0:
        return 0
    return a / b

In [None]:
all_threshs = []

# At first, the threshold is at 0
# --> No negative predictions, all positive predictions
# --> TN = FN = 0
TP = sum(1 for _, correct, _ in merged if correct)
FP = sum(1 for _, correct, _ in merged if not correct)
TN, FN = 0, 0

total = len(merged)

assert TP + FP + FN + TN == total, f'Begin: {TP + FP + FN + TN} vs {total}'

for value, correct, _ in merged:
    current_thresh = value

    if correct:
        TP -= 1
        FN += 1
    else:
        TN += 1
        FP -= 1

    assert TP + FP + FN + TN == total, f'Thresh={value}: {TP + FP + FN + TN} vs {total}'
        
    acc = (TP + TN) / total
    prec = smart_division(TP, TP + FP)
    recall = smart_division(TP, TP + FN)
    f1 = smart_division(2*prec*recall, prec+recall)
    spec = smart_division(TN, TN + FP)
    npv = smart_division(TN, TN + FN)
    f1_neg = smart_division(2*npv*spec, spec+npv)
    CM = (TP, FN, FP, TN)

    all_threshs.append({
        'thresh': current_thresh,
        'acc': acc,
        'prec': prec,
        'recall': recall,
        'f1': f1,
        'npv': npv,
        'spec': spec,
        'f1_neg': f1_neg,
        'CM': CM,
    })
all_threshs[:1]

In [None]:
max(all_threshs, key=lambda x: x['acc'])

In [None]:
sl = lambda k: tuple(zip(*[(x['thresh'], x[k]) for x in all_threshs]))

In [None]:
plt.figure(figsize=(6, 5))
keys = ('prec', 'recall', 'acc', 'f1') # 'f1', 
# keys = ('acc', )
# keys = ('npv', 'spec', 'f1_neg')
for k in keys:
    thresh, y = sl(k)
    plt.plot(thresh, y, label=k)
plt.legend()
plt.xlabel('Thresh')
plt.ylabel('Value')
plt.title('Optimize by')

In [None]:
best = max(all_threshs, key=lambda x: x['acc'])
best

In [None]:
def plot_cm(cm, title=None):
    TP, FN, FP, TN = cm
    ticks = ['Entailment', 'Contradiction']
    sns.heatmap([[TP, FN], [FP, TN]], annot=True, square=True, cmap='Blues',
                xticklabels=ticks, yticklabels=ticks, fmt=',',
               )
    plt.ylabel('Real')
    plt.xlabel('Scored by Metric')
    if title:
        plt.title(title)

In [None]:
plot_cm(best['CM'], title=f'CM for {exp.abnormality} with {get_pretty_metric(result.metric)}')

## Attempt 4: use AUC

In [None]:
from sklearn.metrics import roc_curve, roc_auc_score

In [None]:
def prepare_gt_pred_for_roc(result, metric_i=0, keys=None):
    pred = []
    gt = []
    if keys is None:
        keys = list(result.dists.keys())

    for a, b in keys:
        elements = result.dists[(a, b)]
        if elements.ndim == 2:
            elements = elements[metric_i] # BLEU case
        pred += list(elements)

        entailment = int(a == b)
        gt += [entailment] * len(elements)

    return gt, pred

In [None]:
result = exp.results[-1]

In [None]:
gt, pred = prepare_gt_pred_for_roc(result)

In [None]:
fpr, tpr, thresholds = roc_curve(gt, pred)

J_stat = tpr - fpr
best_idx = J_stat.argmax()

thresholds[best_idx], J_stat[best_idx]

In [None]:
roc = roc_auc_score(gt, pred)
roc

## Compute AUC for all abnormalities

In [None]:
# Load experiments
def load_experiments(dataset_name):
    exp_by_abn = {}
    errors = defaultdict(list)
    for abnormality in CHEXPERT_DISEASES[1:]:
        fname = f'{dataset_name}-{abnormality.replace(" ", "-").lower()}'
        if not exist_experiment_pickle(fname):
            errors['not-found'].append(fname)
            continue
        exp = load_experiment_pickle(fname)
        exp_by_abn[abnormality] = exp
        
    if len(errors['not-found']):
        print('Not found: ', errors['not-found'])
        
    return exp_by_abn

In [None]:
exp_by_abn_iu = load_experiments('iu')
len(exp_by_abn_iu)

In [None]:
show = True
dataset_name = 'iu'
target_sampler = None # 'random-gen_k500_n500'
target_groups = None # [-2, -1, 0, 1]
# keys = [(0, 0), (0, 1), (1, 1), (1, 0)]
keys = None

final_records = []

for abnormality in CHEXPERT_DISEASES[1:]:
    if abnormality not in exp_by_abn_iu:
        continue
    exp = exp_by_abn_iu[abnormality]

    for result in tqdm(exp.results, desc=abnormality, disable=not show):
        if target_sampler is not None and result.sampler != target_sampler:
            continue
        if target_groups is not None and sorted(result.groups) != target_groups:
            continue

        gt, pred = prepare_gt_pred_for_roc(result, keys=keys)
        roc = roc_auc_score(gt, pred)
        
        final_records.append((abnormality, result.metric, result.groups, result.sampler, roc))
        
        if result.metric == 'bleu':
            gt, pred = prepare_gt_pred_for_roc(result, metric_i=3, keys=keys)
            roc = roc_auc_score(gt, pred)
            final_records.append((
                abnormality, f'{result.metric}-4', result.groups, result.sampler, roc,
            ))
            
len(final_records)

In [None]:
final_records[:1]

In [None]:
cols = ['disease', 'metric', 'groups', 'sampler', 'roc']
df = pd.DataFrame(final_records, columns=cols)
df.head(2)

In [None]:
df = df.loc[df['groups'] == (0, 1)]
df = df.loc[df['sampler'] == 'random-gen_k500_n500']
del df['sampler'], df['groups']
df.head(2)

In [None]:
d = df.groupby('disease').apply(lambda subdf: {
    row[1]: row[2] # row is: [disease, metric, roc]
    for row in list(subdf.values)
}).apply(pd.Series)
d

In [None]:
s = d.to_latex()
print(s)

In [None]:
def get_result(exp_by_abn, abnormality, metric,
               groups=[0, 1], sampler='random-gen_k500_n500'):
    if abnormality not in exp_by_abn:
        print(f'No exp for {abnormality}')
        return None, None
    groups = list(groups)
    exp = exp_by_abn[abnormality]

    for i, result in enumerate(exp.results):
        if sampler is not None and result.sampler != sampler:
            continue
        if groups is not None and sorted(result.groups) != groups:
            continue
        if result.metric != metric:
            continue
            
        return exp, i
    
    print('No experiment found with conditions')
    return exp, None

In [None]:
exp, result_i = get_result(exp_by_abn_iu, 'Atelectasis', 'bleu')
exp

In [None]:
plot_heatmap(exp, result_i=result_i, metric_i=3)

In [None]:
plot_hists(exp, keys=[(0, 0), (0, 1)], result_i=result_i, metric_i=3, bins=50)

# Statistical tests

In [None]:
from scipy.stats import ttest_ind, mannwhitneyu, f_oneway, kruskal

In [None]:
# exp = load_experiment_pickle('mimic-cardiomegaly')
len(exp.results)

In [None]:
plot_heatmap(exp, result_i=-1)

In [None]:
EXP_I = -1
result = exp[EXP_I]
result.metric

In [None]:
key1 = (0, 0)
key2 = (0, 1)
group1 = result.dists[key1]
group2 = result.dists[key2]
if result.metric == 'bleu':
    group1 = group1[0]
    group2 = group2[0]
group1.shape, group2.shape

In [None]:
plot_hists(exp, [key1, key2], result_i=EXP_I, bins=50, range=(0, 1))

In [None]:
r = mannwhitneyu(group1, group2)
r

In [None]:
r = ttest_ind(group1, group2, equal_var=False)
r

In [None]:
groups = [result.dists[k] for k in [(0, 0), (0, 1), (1, 0), (1, 1)]]

In [None]:
anova = f_oneway(*groups)
anova

In [None]:
kru = kruskal(*groups)
kru