## Imports

In [None]:
import os
import pandas as pd
import json
import re

In [None]:
%run ../utils/__init__.py
%run ../utils/files.py
%run ../metrics/__init__.py

In [None]:
pd.options.display.max_columns = None

## Choose task

In [None]:
TASK = 'seg'
# TASK = 'rg'
# TASK = 'cls'

In [None]:
KEY_COLS = ['run_name', 'dataset_type']
if TASK == 'rg':
    KEY_COLS.append('free')
KEY_COLS

## Functions

In [None]:
TASK_FOLDER = _get_task_folder(TASK)
BASE_FOLDER = os.path.join(WORKSPACE_DIR, TASK_FOLDER)
RESULTS_FOLDER = os.path.join(BASE_FOLDER, 'results')

In [None]:
def get_suffix(filename):
    match = re.search('.*metrics-(?P<suffix>\w*)\.json', filename)
    if match is None:
        suffix = ''
    else:
        suffix = match.group('suffix')
    return suffix

In [None]:
METRIC_TYPES = [
    'chexpert',
    'grad-cam',
    'mirqi',
]

In [None]:
def load_results():
    results_by_metric_type = {}

    for run_name in os.listdir(RESULTS_FOLDER):
        if run_name == 'debug':
            continue

        folder = os.path.join(RESULTS_FOLDER, run_name)
        for filename in os.listdir(folder):
            filepath = os.path.join(folder, filename)
            if not os.path.isfile(filepath) or not filename.endswith('json'):
                continue

            metric_type = next(
                (met for met in METRIC_TYPES if met in filename),
                'base', # Default if no specific metric_type is found
            )

            with open(filepath, 'r') as f:
                results_dict = json.load(f)
   
            results_df = pd.DataFrame.from_dict(results_dict, orient='index')
            results_df.reset_index(inplace=True)
            results_df.rename(columns={'index': 'dataset_type'}, inplace=True)
            results_df['run_name'] = run_name
            if TASK == 'rg':
                results_df['free'] = get_suffix(filename)           
            
            if metric_type not in results_by_metric_type:
                results_by_metric_type[metric_type] = results_df
            else:
                prev = results_by_metric_type[metric_type]
                results_by_metric_type[metric_type] = prev.append(results_df, ignore_index=True)

    df = None
    cols_in_order = list(KEY_COLS)
    for results in results_by_metric_type.values():
        cols_in_order += [col for col in results.columns if col not in cols_in_order]
        
        if df is None:
            df = results
        else:
            df = df.merge(results, on=KEY_COLS, how='outer')
                
    return df[cols_in_order], results_by_metric_type

In [None]:
def filter_results(dataset_type=None, metrics=None,
                   metrics_contain=None, free=None,
                   contains=None, doesnt_contain=None,
                   drop=None, drop_na_rows=False, drop_key_cols=False):
    df = RESULTS_DF
    
    if dataset_type:
        if isinstance(dataset_type, str):
            df = df[df['dataset_type'] == dataset_type]
        elif isinstance(dataset_type, (list, tuple)):
            dataset_type = set(dataset_type)
            df = df[df['dataset_type'].isin(dataset_type)]
    
    if free is not None:
        free_str = 'free' if free else 'notfree'
        df = df.loc[df['free'] == free_str]
    
    if contains:
        filter_contains = lambda d, s: d.loc[d['run_name'].str.contains(s)]
        if isinstance(contains, (list, tuple)):
            for c in contains:
                df = filter_contains(df, c)
        elif isinstance(contains, str):
            df = filter_contains(df, contains)
    
    if doesnt_contain:
        filter_doesnt_contain = lambda d, s: d.loc[~d['run_name'].str.contains(s)]
        if isinstance(doesnt_contain, (list, tuple)):
            for c in doesnt_contain:
                df = filter_doesnt_contain(df, c)
        elif isinstance(doesnt_contain, str):
            df = filter_doesnt_contain(df, doesnt_contain)
    
    if drop:
        df = df.loc[~df['run_name'].str.contains(drop)]
        
    if metrics_contain:
        columns = KEY_COLS + [c for c in df.columns if metrics_contain in c]
        df = df[columns]
    elif metrics:
        columns = KEY_COLS + metrics
        df = df[columns]
    
    if drop_na_rows:
        df.dropna(axis=0, how='any', inplace=True)

    # Drop cols with all na
    df.dropna(axis=1, how='all', inplace=True)

    if drop_key_cols:
        columns = [c for c in df.columns if c == 'run_name' or c not in KEY_COLS]
        df = df[columns]
    
    return df

## Load results

In [None]:
RESULTS_DF, debug = load_results()
print(len(RESULTS_DF))
RESULTS_DF.head()

In [None]:
set(
    col.replace('-', '_').split('_')[0]
    for col in RESULTS_DF.columns
)

## Segmentation

In [None]:
def add_macro_avg_column(target_col):
    matching_cols = [c for c in RESULTS_DF.columns if c.startswith(target_col)]
    assert len(matching_cols) == 3, f'Matching cols not 3: {matching_cols}'
    averages = RESULTS_DF[matching_cols].mean(axis=1)
    RESULTS_DF[target_col] = averages
    print(f'Calculated col {target_col}')

In [None]:
add_macro_avg_column('n-shapes-gen')
add_macro_avg_column('n-holes-gen')

In [None]:
SEG_METRICS = []
organs = ('heart', 'left lung', 'right lung')
def _add_metric(metric_name, macro=True):
    if macro: SEG_METRICS.append(metric_name)
    SEG_METRICS.extend(f'{metric_name}-{organ}' for organ in organs)
_add_metric('iou')
# _add_metric('dice')
_add_metric('n-shapes-gen')
_add_metric('n-holes-gen')
SEG_METRICS

In [None]:
replace_strs = [
    # (r'^\d{4}_\d{6}_', ''),
    (r'jsrt_scan_', ''),
#     ('most-similar-image', '1nn'),
#     ('_lr[\d\.]+', ''),
#     ('_size256', ''),
#     (r'_\d{4}_\d{6}_.*', ''),
#     ('dummy-', ''),
#     ('common', 'top'),
#     ('-v2', ''),
#     (r'top-(\w)\w+-(\d+)', r'top-\1-\2'),
#     ('_densenet-121', ''),
]

def rename_runs(run_name):
    s = run_name
    for target, replace_with in replace_strs:
        s = re.sub(target, replace_with, s)
    return s

In [None]:
filter_results(
    metrics=SEG_METRICS,
    dataset_type='test',
    drop='1105_180035',
).sort_values(
    ['n-shapes-gen', 'n-holes-gen'],
    ascending=True,
).set_index('run_name').rename(index=rename_runs)

## Report generation

In [None]:
NLP_METRICS = ['bleu1', 'bleu2', 'bleu3', 'bleu4', 'bleu', 'rougeL', 'ciderD']
CHEXPERT_METRICS = ['acc', 'roc_auc', 'recall', 'prec', 'f1', 'roc_auc']
CHEXPERT_METRICS = [c for c in RESULTS_DF.columns
                    if any(c.startswith(ch) for ch in CHEXPERT_METRICS)]
CHEXPERT_RUNTIME_METRICS = [col for col in RESULTS_DF.columns if col.startswith('chex')]
VAR_METRICS = [c for c in RESULTS_DF.columns if 'distinct' in c]
MIRQI_METRICS = [c for c in RESULTS_DF.columns if 'MIRQI' in c]

In [None]:
ESSENTIAL_METRICS = [
    'bleu', 'ciderD', 'rougeL',
    # 'chex_f1', 'chex_acc', # 'chex_recall', 'chex_prec', # Runtime-chexpert
    # 'MIRQI-v2-f',
    'roc_auc', 'acc', 'f1', 'prec', 'recall', # Holistic-chexpert
]

In [None]:
runtime_chexpert = [c for c in RESULTS_DF.columns if c.startswith('chex_f1')]
holistic_chexpert = [c for c in RESULTS_DF.columns if c.startswith('f1')]
# metrics = runtime_chexpert + holistic_chexpert
metrics = [
    'chex_f1', 'chex_f1_No Finding',
    'f1', 'f1-No Finding'
]

In [None]:
pd.options.display.float_format = '{:.3f}'.format

In [None]:
filter_results(
    # contains='^01\d{2}_\d{6}_(h-|lstm)',
    # contains='.*dummy',
    dataset_type='test',
    free=True,
    metrics=ESSENTIAL_METRICS, # CHEXPERT_RUNTIME_METRICS,
    drop_key_cols=True,
    drop_na_rows=True,
).replace(r'^\d{4}_\d{6}_(.*)', r'\1', regex=True).sort_values('run_name')

### Compare runtime chexpert vs holistic chexpert

In [None]:
def subtract_cols(df, cols_a, cols_b, drop_na_rows=True):
    array_a = df[cols_a].to_numpy()
    array_b = df[cols_b].to_numpy()
    
    df_2 = df[KEY_COLS].copy()
    df_2 = pd.concat([df_2, pd.DataFrame(array_a - array_b, columns=cols_a)], axis=1)
    
    if drop_na_rows:
        df_2.dropna(axis=0, inplace=True, how='any')
    
    return df_2

In [None]:
metric = 'f1'

runtime_chexpert = [c for c in RESULTS_DF.columns if c.startswith(f'chex_{metric}')]
holistic_chexpert = [c for c in RESULTS_DF.columns if c.startswith(metric)]

In [None]:
df = RESULTS_DF
df = df.loc[~df['run_name'].str.contains('dummy')]
len(df)

In [None]:
set(df['run_name'])

In [None]:
df = subtract_cols(df, runtime_chexpert, holistic_chexpert)
df.head()

In [None]:
df.describe()

In [None]:
df

In [None]:
from collections import Counter

In [None]:
run_name = '0112_154506_lstm-v2_lr0.001_densenet-121-v2_noes'
debug = False
d1 = load_rg_outputs(run_name, debug=debug, free=True)
d2 = load_rg_outputs(run_name, debug=debug, free=False)
len(d1), len(d2)

In [None]:
c1 = Counter(d1['filename'])
c2 = Counter(d2['filename'])
len(c1), len(c2)

In [None]:
for fname in c1.keys():
    v1 = c1[fname]
    v2 = c2[fname]
    if v1 != v2:
        print('Wrong: ', fname, v1, v2)

In [None]:
d2.head()

In [None]:
set(d2['dataset_type'])

### Pretty-print (latex)

In [None]:
replace_strs = [
    (r'^\d{4}_\d{6}_', ''),
    ('most-similar-image', '1nn'),
    ('_lr[\d\.]+', ''),
    ('_size256', ''),
    (r'_\d{4}_\d{6}_.*', ''),
    ('dummy-', ''),
    ('common', 'top'),
    ('-v2', ''),
    (r'top-(\w)\w+-(\d+)', r'top-\1-\2'),
    ('_densenet-121', ''),
]

def rename_runs(run_name):
    s = run_name
    for target, replace_with in replace_strs:
        s = re.sub(target, replace_with, s)
    return s

In [None]:
columns = ['bleu', 'rougeL', 'ciderD'] + CHEXPERT_METRICS + MIRQI_METRICS

In [None]:
df = filter_results(dataset_type='test',
                    free=True,
                    metrics=columns,
                    contains='(?=_lstm-att-v2.*densenet|_lstm-v2.*densenet|dummy)',
                    drop='0915_173951|0915_174222|0916_104739',
                    drop_na_rows=True,
                   )

In [None]:
df

In [None]:
shorten_cols = lambda s: s.replace('MIRQI-v2', 'v2')

In [None]:
print(df.set_index('run_name').rename(
    index=rename_runs,
    columns=shorten_cols,
).sort_index().to_latex(
    columns=[shorten_cols(c) for c in columns],
    float_format='%.3f',
    column_format='l' + 'c' * len(columns),
))

## Classification

In [None]:
# contains = 'covid-x'
# contains = 'cxr14'
# contains = 'e0'
# contains = '0717_120222_covid-x_densenet-121_lr1e-06_os_aug-covid'
# contains = '0717_101812_covid-x_densenet-121_lr1e-06_os-max2_aug-covid'
run_name = '0717_120222_covid-x_densenet-121_lr1e-06_os_aug-covid' # WINNER

contains = '0717_101812_covid-x_densenet-121_lr1e-06_os-max2_aug-covid'
contains = 'covid-uc'

In [None]:
ESSENTIAL_METRICS = [
    'acc', 'roc_auc', 'hamming', # 'prec', 'recall',
]

In [None]:
metrics = [
    'acc', 'roc_auc', 'prec', 'recall', 'roc_auc_Cardiomegaly', 'roc_auc_Pneumonia',
    'recall_Cardiomegaly', 'recall_Pneumonia',
    'iobb-masks', 'iobb-masks-Cardiomegaly', 'iobb-masks-Pneumonia',
]

In [None]:
pd.options.display.float_format = '{:.3f}'.format

In [None]:
def simplify_names(s):
    model_name = re.sub(r'^\d{4}_\d{6}_cxr14_(.*)_lr.*', r'\1', s)
    if 'hint' in s:
        return f'{model_name}_hint'
    return model_name

In [None]:
d = filter_results(
    contains=r'^(01|12|02)\w+_cxr14_',
    doesnt_contain=[
        r'_Card',
        r'_Pneu',
        '1027_144914', # really bad results
    ],
    dataset_type=('val'), # -bbox
    # metrics_contain='iou',
    metrics=metrics, # ESSENTIAL_METRICS,
    drop_key_cols=True,
    # drop_na_rows=True,
).sort_values('roc_auc', ascending=False)
# d['run_name'] = d['run_name'].apply(simplify_names)
d

## Report-generation: results at different report lengths

In [None]:
vals_words = [20, 25, 27, 33, 44, None]
vals_sents = [3, 4, 5, 6, None]

In [None]:
max_words = vals_words[0]
suffix = f'max-words-{max_words}' if max_words else ''
all_results = load_results(suffix)
results_df_test = create_results_df(all_results, 'test')
results_df_test