# Imports

In [None]:
import torch
import matplotlib.pyplot as plt
import pandas as pd
import os

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
pd.options.display.max_columns = None

# Load utils

In [None]:
USE_DATASET = 'mimic_cxr' # 'iu'

## Load vocab and stuff

In [None]:
%run ../datasets/common/constants.py
%run ../datasets/vocab/__init__.py
%run ../utils/nlp.py

In [None]:
VOCAB = load_vocab(USE_DATASET)
REPORT_READER = ReportReader(VOCAB)
len(VOCAB)

## Load holistic chexpert

In [None]:
from medai.datasets.iu_xray import DATASET_DIR as IU_DIR
from medai.datasets.mimic_cxr import DATASET_DIR as MIMIC_DIR

In [None]:
fpath = os.path.join(
    MIMIC_DIR if 'mimic' in USE_DATASET else IU_DIR,
    'reports', 'reports_with_chexpert_labels.csv',
)
df = pd.read_csv(fpath, index_col=0)
df.replace(-1, 1, inplace=True)
df.replace(-2, 0, inplace=True)
df.head()

In [None]:
REPORTS_LIST = [
    REPORT_READER.text_to_idx(report)
    for report in df['Reports']
]
len(REPORTS_LIST)

In [None]:
def add_suffix(col):
    if col in CHEXPERT_LABELS:
        return f'{col}-gt'
    return col
df.rename(
    columns=add_suffix,
    inplace=True,
)
df.head()

In [None]:
%run ../metrics/report_generation/chexpert.py
%run -n ../eval_report_generation_chexpert_labeler.py

In [None]:
def _compute_metrics_vs_holistic(labels):
    columns = labels_with_suffix('gen')
    
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()
    full_df = pd.concat([df, pd.DataFrame(labels, columns=columns)], axis=1)
    
    return full_df, _calculate_metrics(full_df)

# Compare runtime light-chexpert vs holistic

## Calculate light-labeler chexpert

In [None]:
%run ../metrics/report_generation/labeler_correctness/light_labeler.py

In [None]:
labeler = ChexpertLightLabeler(VOCAB)

In [None]:
%%time

labels = labeler(REPORTS_LIST)
labels.shape

In [None]:
labels[labels == -2] = 0
labels[labels == -1] = 1
labels

In [None]:
acc, precision, recall, f1, roc_auc, pr_auc = _compute_metrics_vs_holistic(labels)
acc, precision, recall, f1, roc_auc, pr_auc

## Calculate with full-labeler

In [None]:
%run ../metrics/report_generation/labeler_correctness/full_labeler.py
%run ../utils/nlp.py

In [None]:
labeler = ChexpertFullLabeler(VOCAB)
labeler

In [None]:
%%time

labels = labeler(REPORTS_LIST)
labels.shape

In [None]:
labels[labels == -2] = 0
labels[labels == -1] = 1
labels

In [None]:
acc, precision, recall, f1, roc_auc, pr_auc = _compute_metrics_vs_holistic(labels)
acc, precision, recall, f1, roc_auc, pr_auc

# Compare lighter vs holistic

In [None]:
%run ../metrics/report_generation/labeler_correctness/lighter_labeler/__init__.py

In [None]:
labeler = ChexpertLighterLabeler(VOCAB, device='cpu')

In [None]:
%%time

labels = labeler(REPORTS_LIST)
labels.size()

In [None]:
if labels.size(1) == 13:
    nf_column = torch.zeros(labels.size(0), device=labels.device).unsqueeze(-1)
    labels = torch.cat((nf_column, labels), dim=1)
labels.size()

In [None]:
full_df, (acc, precision, recall, f1, roc_auc, pr_auc) = _compute_metrics_vs_holistic(labels)
len(acc)

In [None]:
f1[1:], f1[1:].mean()