## Imports

In [None]:
import os
import json
import re
from collections import defaultdict
import pandas as pd

In [None]:
%run ../utils/__init__.py
%run ../utils/files.py
%run ../metrics/__init__.py
%run ../models/checkpoint/__init__.py

In [None]:
pd.options.display.max_columns = None
pd.options.display.float_format = '{:.3f}'.format
pd.set_option('display.max_colwidth', None)

## Choose task

In [None]:
# TASK = 'seg'
# TASK = 'rg'
# TASK = 'cls-seg'
TASK = 'cls' # ('cls', 'cls-seg')

In [None]:
KEY_COLS = ['run_name', 'timestamp', 'dataset_type']
if TASK == 'rg':
    KEY_COLS.extend(['free', 'best', 'beam'])
KEY_COLS

## Functions

In [None]:
import glob

In [None]:
def _get_run_folders(tasks, target_folder):
    if isinstance(tasks, str):
        tasks = (tasks,)
    results = []
    for task in tasks:
        target_glob = os.path.join(WORKSPACE_DIR, _get_task_folder(task), target_folder, '*')
        results.extend(glob.glob(target_glob))
    return results

In [None]:
def get_free_suffix_and_beam(filename):
    match = re.search(
        '.*metrics-(?P<free>[notfree]+)(-(?P<suffix>[\w\-]+))?(\.bs(?P<beam>\d+))?\.json',
        filename,
    )
    if match is None:
        return 'free', None, 0
    return match.group('free'), match.group('suffix'), match.group('beam')

In [None]:
get_free_suffix_and_beam('chexpert-metrics-notfree.bs100.json')

In [None]:
METRIC_TYPES = [
    'chexpert',
    'grad-cam',
    'mirqi',
    'bertscore',
    'bleurt',
]

### Load fns

In [None]:
def _extract_timestamp(run_name):
    if re.search('^\d{4}_\d{6}', run_name):
        return run_name[:11]
    return ''

In [None]:
def load_results():
    results_by_metric_type = {}

    for run_folder in _get_run_folders(TASK, 'results'):
        run_name = os.path.basename(run_folder)
        
        if run_name == 'debug':
            continue

        for filename in os.listdir(run_folder):
            filepath = os.path.join(run_folder, filename)
            if not os.path.isfile(filepath) or not filename.endswith('json'):
                continue

            if any(
                s in filename
                for s in ('thresholds-', 'training-stats')
                ):
                continue
                
            metric_type = next(
                (met for met in METRIC_TYPES if met in filename),
                'base', # Default if no specific metric_type is found
            )

            with open(filepath, 'r') as f:
                results_dict = json.load(f)
   
            results_df = pd.DataFrame.from_dict(results_dict, orient='index')
            results_df.reset_index(inplace=True)
            results_df.rename(columns={'index': 'dataset_type'}, inplace=True)
            results_df['run_name'] = run_name
            results_df['timestamp'] = _extract_timestamp(run_name)
            if TASK == 'rg':
                free, best_metric, beam_size = get_free_suffix_and_beam(filename)
                results_df['free'] = free
                results_df['best'] = best_metric
                results_df['beam'] = int(beam_size or 0)
            
            if metric_type not in results_by_metric_type:
                results_by_metric_type[metric_type] = results_df
            else:
                # Append to the previous DF
                prev = results_by_metric_type[metric_type]
                results_by_metric_type[metric_type] = prev.append(results_df, ignore_index=True)

    df = None
    cols_in_order = list(KEY_COLS)
    for results in results_by_metric_type.values():
        cols_in_order += [col for col in results.columns if col not in cols_in_order]
        
        if df is None:
            df = results
        else:
            df = df.merge(results, on=KEY_COLS, how='outer')
                
    return df[cols_in_order], results_by_metric_type

In [None]:
def load_training_stats():
    re_filename = re.compile(r'training-stats.*.json')

    all_training_stats = []

    for run_folder in _get_run_folders(TASK, 'models'):
        run_name = os.path.basename(run_folder)

        if run_name == 'debug':
            continue

        for filename in os.listdir(run_folder):
            if not re_filename.match(filename):
                continue
            
            filepath = os.path.join(run_folder, filename)
            if not os.path.isfile(filepath):
                continue

            with open(filepath, 'r') as f:
                training_stats_original_dict = json.load(f)

            # Unwrap dicts and fix old key-values
            training_stats = dict()
            for key, value in training_stats_original_dict.items():
                if isinstance(value, dict):
                    for k, v in value.items():
                        training_stats[k] = v
                elif key == 'n_epochs' or key == 'epochs':
                    training_stats['final_epoch'] = value
                    training_stats['current_epoch'] = value
                else:
                    training_stats[key] = value
                    
            # Add total-time column
            secs_per_epoch = training_stats['secs_per_epoch']
            n_epochs = training_stats['current_epoch'] - training_stats['initial_epoch']
            total_time = secs_per_epoch * n_epochs
            training_stats['total_time'] = duration_to_str(total_time)
    
            # Add pretty-time columns
            training_stats['time_per_epoch'] = duration_to_str(secs_per_epoch)
    
            # Add run_name column
            training_stats['run_name'] = run_name

            all_training_stats.append(training_stats)
        
    df = pd.DataFrame(all_training_stats)
    cols = ['run_name'] + [c for c in df.columns if c != 'run_name']
    df = df[cols]

    return df

### Filter fns

In [None]:
def _filter_df_run_name_contains(df, contains):
    if contains:
        filter_contains = lambda d, s: d.loc[d['run_name'].str.contains(s)]
        if isinstance(contains, (list, tuple)):
            for c in contains:
                df = filter_contains(df, c)
        elif isinstance(contains, str):
            df = filter_contains(df, contains)
    return df

def __rename_run_name(run_name, replace_strs):
    s = run_name
    for target, replace_with in replace_strs:
        s = re.sub(target, replace_with, s)
    return s

def _df_rename_runs(df, rename_runs):
    if rename_runs and 'run_name' in df:
        df['run_name'] = [
            __rename_run_name(r, rename_runs)
            for r in df['run_name']
        ]
    return df

In [None]:
def get_renamer(replace_strs):
    def _rename_run(run_name):
        s = run_name
        for target, replace_with in replace_strs:
            s = re.sub(target, replace_with, s)
        return s
    return _rename_run

In [None]:
def filter_results(dataset_type=None, metrics=None,
                   metrics_contain=None,
                   contains=None, doesnt_contain=None,
                   drop=None, drop_na_rows=False, drop_key_cols=False,
                   timestamp_col=False, best_col=False, beam_col=False,
                   rename_runs=None, remove_timestamp=False,
                   free=None,
                   beam_size=None,
                   best=None,
                  ):
    df = RESULTS_DF
    
    if dataset_type:
        if isinstance(dataset_type, str):
            df = df[df['dataset_type'] == dataset_type]
        elif isinstance(dataset_type, (list, tuple)):
            dataset_type = set(dataset_type)
            df = df[df['dataset_type'].isin(dataset_type)]
    
    if free is not None:
        free_str = 'free' if free else 'notfree'
        df = df.loc[df['free'] == free_str]
    
    if best is not None:
        # Keep null to keep paper ones
        df = df.loc[((df['best'] == best) | (df['best'].isnull()))]
    if beam_size is not None:
        if 'beam' not in df.columns:
            print('ERROR: cannot filter by beam_size, beam column not found')
        else:
            df = df.loc[(
                (df['beam'] == beam_size) |
                ((df['best'].isnull()) & (df['beam'] == 0)) ## Other cases
            )]
    
    
    df = _filter_df_run_name_contains(df, contains)
    
    if doesnt_contain:
        filter_doesnt_contain = lambda d, s: d.loc[~d['run_name'].str.contains(s)]
        if isinstance(doesnt_contain, (list, tuple)):
            for c in doesnt_contain:
                df = filter_doesnt_contain(df, c)
        elif isinstance(doesnt_contain, str):
            df = filter_doesnt_contain(df, doesnt_contain)
    
    if drop:
        df = df.loc[~df['run_name'].str.contains(drop)]
        
    if metrics_contain:
        columns = KEY_COLS + [c for c in df.columns if metrics_contain in c]
        df = df[columns]
    elif metrics:
        columns = KEY_COLS + metrics
        df = df[columns]
    
    if drop_na_rows:
        df.dropna(axis=0, how='any', inplace=True)

    # Drop cols with all na
    df.dropna(axis=1, how='all', inplace=True)

    if drop_key_cols:
        columns = [
            c for c in df.columns
            if c == 'run_name' or \
                (timestamp_col and c == 'timestamp') or \
                (best_col and c == 'best') or \
                (beam_col and c == 'beam') or \
                c not in KEY_COLS
        ]
        df = df[columns]

    _df_rename_runs(df, rename_runs)

    if remove_timestamp:
        df = df.replace(r'^\d{4}_\d{6}_', '', regex=True)

    return df

In [None]:
def filter_training_stats(contains=None, columns=None,
                          rename_runs=None, remove_timestamp=False,
                         ):
    df = TRAINING_STATS_DF
    
    df = _filter_df_run_name_contains(df, contains)

    _df_rename_runs(df, rename_runs)
    
    if remove_timestamp:
        df = df.replace(r'^\d{4}_\d{6}_', '', regex=True)
        
    if columns is not None:
        df = df[columns]
    return df

## Load results

In [None]:
import warnings

warnings.filterwarnings("ignore", message="The frame.append method is deprecated and will be removed from pandas in a future version")

In [None]:
RESULTS_DF, debug = load_results()
print(len(RESULTS_DF))

In [None]:
TRAINING_STATS_DF = load_training_stats()
print(len(TRAINING_STATS_DF))
# TRAINING_STATS_DF.tail(2)

In [None]:
# set(
#     col.replace('-', '_').split('_')[0]
#     for col in RESULTS_DF.columns
# )

## Segmentation

In [None]:
def add_macro_avg_column(target_col):
    matching_cols = [c for c in RESULTS_DF.columns if c.startswith(target_col)]
    assert len(matching_cols) == 3, f'Matching cols not 3: {matching_cols}'
    averages = RESULTS_DF[matching_cols].mean(axis=1)
    RESULTS_DF[target_col] = averages
    print(f'Calculated col {target_col}')

In [None]:
add_macro_avg_column('n-shapes-gen')
add_macro_avg_column('n-holes-gen')

In [None]:
SEG_METRICS = []
organs = ('heart', 'left lung', 'right lung')
def _add_metric(metric_name, macro=True):
    if macro: SEG_METRICS.append(metric_name)
    SEG_METRICS.extend(f'{metric_name}-{organ}' for organ in organs)
_add_metric('iou')
# _add_metric('dice')
_add_metric('n-shapes-gen')
_add_metric('n-holes-gen')
SEG_METRICS

In [None]:
replace_strs = [
    # (r'^\d{4}_\d{6}_', ''),
    (r'jsrt_scan_', ''),
#     ('most-similar-image', '1nn'),
#     ('_lr[\d\.]+', ''),
#     ('_size256', ''),
#     (r'_\d{4}_\d{6}_.*', ''),
#     ('dummy-', ''),
#     ('common', 'top'),
#     ('-v2', ''),
#     (r'top-(\w)\w+-(\d+)', r'top-\1-\2'),
#     ('_densenet-121', ''),
]

In [None]:
filter_results(
    metrics=SEG_METRICS,
    dataset_type='test',
    drop='1105_180035',
    rename_runs=replace_strs,
).sort_values(
    ['n-shapes-gen', 'n-holes-gen'],
    ascending=True,
).set_index('run_name')

## Report generation

In [None]:
# CHEXPERT_METRICS = ['recall', 'prec', 'f1'] # 'acc', 'roc_auc', 
CHEXPERT_DISEASE_METRICS = [
    c
    for c in RESULTS_DF.columns
    if any(
        c.startswith(f'{ch}-')
        for ch in ('f1', 'recall', 'prec')
    ) and not c.endswith('-woNF')
]
# CHEXPERT_RUNTIME_METRICS = [col for col in RESULTS_DF.columns if col.startswith('chex')]
# VAR_METRICS = [c for c in RESULTS_DF.columns if 'distinct' in c]
# MIRQI_METRICS = [c for c in RESULTS_DF.columns if 'MIRQI' in c]
MIRQI_METRICS_v1 = ['MIRQI-f', 'MIRQI-p', 'MIRQI-r']
# MIRQI_METRICS_v2 = [f'MIRQI-v2-{s}' for s in ('f', 'p', 'r', 'np', 'sp', 'attr-p', 'attr-r')]
# MIRQI_METRICS_v2 = [c for c in RESULTS_DF.columns if 'MIRQI-v2' in c]
# MIRQI_METRICS_v3 = ['MIRQI-v3-clean-f', 'MIRQI-v3-clean-p', 'MIRQI-v3-clean-r']
# MIRQI_METRICS_v4 = ['MIRQI-v4-pos-f', 'MIRQI-v4-pos-p', 'MIRQI-v4-pos-r']
# MIRQI_METRICS_v5 = ['MIRQI-v5-game-f', 'MIRQI-v5-game-p', 'MIRQI-v5-game-r']
# MIRQI_METRICS_v6 = ['MIRQI-v6-game-f', 'MIRQI-v6-game-p', 'MIRQI-v6-game-r']
# MIRQI_METRICS_v7 = [f'MIRQI-v7-attr-only-{s}' for s in ('f', 'p', 'r')]

In [None]:
NLP_METRICS = [
    'bleu', 'bleu1', 'bleu2', 'bleu3', 'bleu4',
    'rougeL', 'ciderD',
]
ESSENTIAL_METRICS = [
    ## (holistic) CHEXPERT:
    # 'acc',
    'f1', 'prec', 'recall',

    # 'f1-woNF', 'prec-woNF', 'recall-woNF', # *CHEXPERT_DISEASE_METRICS,
    
    ## NLP
    # 'bleu1', 'bleu2', 'bleu3', 'bleu4',
    'bleu1', 'bleu4',
    'rougeL', 'ciderD',
    'bleurt', 'bertscore-f1',
    # 'meteor',

    # 'chex_f1', 'chex_acc', # 'chex_recall', 'chex_prec', # Runtime-chexpert

    # *MIRQI_METRICS_v1,
#     *MIRQI_METRICS_v2,
    # *MIRQI_METRICS_v5,
    # *MIRQI_METRICS_v6,
    # *MIRQI_METRICS_v3,
    # *MIRQI_METRICS_v7,
]

### Main table

In [None]:
rename_runs = [
    # (r'_precnn-\d{4}-\d{6}', ''),
    (r'(mimic-cxr|iu-x-ray)_', ''),
    # ('most-similar-image', '1nn'),
    # (r'_lr(-\w+)?[\d\.e\-]+', ''),
    # (r'_lr[\d\.]+', ''),
    ('_size256', ''),
    # ('-v2', ''),
    ('_front', ''),
    (r'__[\w\-]*', ''),
    (r'_(pre)?cnn\-\d{4}\-\d{6}', ''),
    ('_densenet-121-v2', ''),
]

In [None]:
IU = True
# MICCAI experiments:
# iu lstm 0612_035549, best-bleu: 0621_134437
# mimic lstm best-bleu: 0621_231122
# iu h-coatt: 0623_120544|0623_110053
# mimic h-coatt: 0623_192208
CONTAINS = \
    ('iu-x-ray', # paper
     # 0623_142422|0623_142452|
     # lstm-att-v2 = 0612_035549
     # ST | SAT: 1123_001440|1119_183609
     # 1-nns (euclidean|cosine) 0612_160902|1210_212248
     r'((tpl|most-similar-image).*cnn-1118-203841.*v4-1)|1123_001440|1119_183609|0623_202003|1103_133310|1103_133405|0612_160842|0612_160823|paper(?!_show|_coatt_re-impl)') \
    if IU else \
    ('mimic-cxr',
     # Old tpl-chex-v1|tpl-m-chex-grouped: 0702_140740|0702_143050 | 1102_100501
     # 1-nn: euclidean|cosine: 1103_111912|1210_212245
     # v4-2 experiments: constant-mimic|words|sentences|1nn|rand|lstm-att|h-coatt|tpl-chex-v1|grouped
     # lstm-att-v2 = 0702_183533
     # SAT | ST = 1113_185718|1119_183153
     # SAT | ST = 1201_150847|1202_161321
     r'paper(?!_show|_coatt_re-impl)|1102_115221|1201_150847|1202_161321|0702_145200|1112_125550|1112_131626|1103_111912|1210_212245|0702_150811|0703_144847|1102_190559|1129_212630',
    )
# OLD MIMIC:
# r'dummy-m|0617_144209|0623_103308|0625_184437|0612_233628|tpl-(chex-v1|m-chex-grouped-v6)-ordbest-v2.*cnn-0612-082139'
# r'dummy|((tpl-(chex-v1-ord|m-chex-grouped-v6)|h-coatt|lstm-att-v2).*cnn-0612-082139)',

res = filter_results(
    # H-coatt models
    # contains=('iu-x-ray', 'h-coatt.*v4-1|paper_coatt'), # 0623_202003 # .*mti
    # 0623_120544 vs paper_coatt_re-impl-hrgr
    # contains=('mimic-cxr', '_h-coatt.*v4-2|paper_coatt'),
    
    ### LSTM models (show and tell, show attend, etc)
    # contains=('iu-x-ray', '_lstm-v2.*v4-1.*front|show-tell|boag-et-al-cnn'),
    # contains=('iu-x-ray', '_lstm-att-v2.*v4-1.*front|s-att|show-attend-tell'),
    # contains=('mimic-cxr', '_lstm-v2.*v4-2|show-tell|boag-et-al-cnn'),
    # contains=('mimic-cxr', '_lstm-att-v2.*1101-115743.*v4-2|s-att|show-attend-tell'), # re-impl-liu-2021-et-al-CA
    # contains=('iu-x-ray', 's-tell|show-tell|rtex'),
    # contains=('iu-x-ray', 's-att-tell|show-attend-tell'),
    # contains=('mimic-cxr', 's-tell|show-tell|liu-et-al|ratchet'),
    # contains=('mimic-cxr', 's-att-tell|show-attend-tell|liu-et-al'),
    
    ### Beam experiments
    # contains=('iu-x-ray', 's-att|s-tell'), # , 'ema'
    # contains=('mimic-cxr_s-', 'lr0\.0001'), # , '(?!ema)'
    
    ### Template sets (stress tests, ablations, etc) # new cnn: 1118_203841
    # contains=('iu-x-ray', 'tpl.*cnn-1118-203841', 'ordbest-v2'),
    # OLD TPL models: 0623_142422|0623_142452|gaming
    # IU with new CNN: 1118_210509|1118_210821

    # contains=('mimic-cxr', '1102_190559|1129_212630|gaming', 'ordbest-v2'),
    # MIMIC old: 0702_140740|0702_143050
    
    ### Checkpointing by metric
    # contains=('iu-x-ray', '0612_035549|0621_134437|0621_131422|0621_132927'),
    
    ### Stress test 3
    # contains=('iu-x-ray.*v4-1', 'constant'),
    #contains=('mimic-cxr.*v4-2', 'constant|chex-v6'),
    contains='1102_190559|1102_115221|0702_160242|1104_134722|1102_205924',
    # contains=('iu-x-ray.*v4-1', 'tpl'),
    
    
    ## New papers 2022
    # contains=('mimic-cxr', 'know|prog|kgae-supv'),
    # contains=('iu-x-ray', 'kgae|know|prog'),
    # contains=('iu-x-ray', 'paper'),
    # contains=('mimic-cxr', 'paper'),
    
    # contains=('iu-x-ray', 'dummy-common-sentences', 'v4-1'),
    # contains=('mimic-cxr', 'dummy-common-words', 'v4-2'),
    
    # contains=CONTAINS,
    doesnt_contain=(
        'paper',
        # 'dummy-common', 
        'dummy-baseline', '_satt', '_ssent', '_COPY', 'tiny',
        'boag-et-al-1nn', 'liu-et-al-ccr', 'tienet', 'rtmic',
        'most-similar-image_0519-181554', 'cls-seg', 'noisy',
        'vgg-19',
        # 'constant-mimic',
        're-impl',
        # 'nguyen-et-al', # for now
        'show-tell_re-impl-boag-et-al', # repeated from boag-et-al-cnn-rnn...
        'miura-et-al-fcen', # This ablation is ignored
        r'paper_boag-et-al-cnn-rnn(?!-beam)',
    ),
    dataset_type='test',
    free=True,
    metrics=ESSENTIAL_METRICS,
    drop_key_cols=True,
    timestamp_col=True,
    # drop_na_rows=True,
    rename_runs=rename_runs,
    remove_timestamp=True,
    best='lighter-chex_f1',
    # best='ciderD',
    beam_size=0,
    # best_col=True, # beam_col=True,
    # best='bleu4',
).set_index('run_name')
# res = res.sort_values(['run_name', 'best'], ascending=True) # 'beam'
# res = res.sort_values('bleu4')
res = res.sort_index()
res

In [None]:
pd.options.display.float_format = '{:.3f}'.format

### Main-table to latex

In [None]:
def bold(s):
    return '\textbf{' + s + '}'

shorten_cols = get_renamer([
    ('-woNF', '-d'),
    ('ciderD', 'C-D'),
    (r'bleu(\d)', r'B-\1'),
    ('bleu', 'B'),
    ('rougeL', 'R-L'),
    ('acc', 'Acc'),
    ('prec', 'P'),
    ('recall', 'R'),
    ('f1', 'F-1'),
    ('MIRQI-f', 'M-F-1'),
    ('MIRQI-r', 'M-R'),
    ('MIRQI-p', 'M-P'),
])
def latexify_cols(col):
    return bold(shorten_cols(col))

get_official_run_name = get_renamer([
    # All trained models
    ('_reports-v4-1', ''),
    ('_reports-v4-2', ''),
    (r'_(cnn-)?\d{4}-\d{6}', ''),
    ('_densenet-121', ''),
    # Dummy models
    (r'most-similar-image.*dist-cos', '1-NN (cosine)'),
    (r'most-similar-image', '1-NN (euclidean)'),
    ('dummy-', ''),
    ('common-', 'Top-'),
    ('constant', 'Constant'),
    ('random', 'Random retrieval'),
    # DL models
    ('lstm-att.*', 'CNN-LSTM-att'),
    ('s-tell_.*', 'ST\reimplemented{}, \shortciteauthor{vinyals2015showtell}\categoryLSTM{}'),
    ('s-att-tell_.*', 'SAT\reimplemented{}, \shortciteauthor{xu2015showattendtell}\categoryLSTM{}'),
    # Template models
    # ('ordbest-v2', ''),
    ('tpl', 'Templ.'),
    (r'-chex-v1-ordbest-v2.*', ' single'),
    # (r'-chex-v1-noisy.*', ' top-char.'),
    (r'-chex-v2-grouped-ordbest-v2', ' grouped'),
    # (r'-m-chex-grouped-v6-ordbest-v2', ' grouped'), # DEPRECATED
    (r'-m-chex-grouped-v8-ordbest-v2', ' grouped', ),
    (r'-chex-v1-gaming-rm-neg-ordbest-v2', ' abn-only'),
    (r'-chex-v1-gaming-dup-ordbest-v2', ' repeated'),
    ('-ord\w+', ''),
    (r'h-coatt.*(__)?', r'CoAtt\1\reimplemented{}, \shortciteauthor{jing2017automatic}\categoryLSTM{}'),
    # Papers
    ('paper_', ''),
    ('arl', 'ARL, \shortciteauthor{hou2021arl}\textsuperscript{L,RL}'),
    ('rtex', 'RTEX, \shortciteauthor{kougia2021rtex}\categoryRetrieval{}'),
    ('zhang-et-al-mirqi', '\shortciteauthor{zhang2020graph}\textsuperscript{L,f+i}'),
    ('lovelace-et-al', '\shortciteauthor{lovelace2020learning}\categoryTransformer{}'),
    ('liu-et-al-full', '\shortciteauthor{liu2019clinically}\textsuperscript{L,RL}'),
    ('liu-2021-et-al-CA', '\shortciteauthor{liu2021contrastive}\textsuperscript{L,CA}'),
    # ('boag-et-al-1nn', '1-nn \shortciteauthor{boag2020baselines}'),
    ('boag-et-al-cnn-rnn-beam', '\shortciteauthor{boag2020baselines}\categoryLSTM{}'),
    ('chen-et-al', '\shortciteauthor{chen2020memory}\categoryTransformer{}'),
    ('clara', 'CLARA, \shortciteauthor{biswal2020clara}\categoryRetrieval{}'),
    ('coatt', 'CoAtt, \shortciteauthor{jing2017automatic}\textsuperscript{L,f+i}'),
    ('ni-et-al', 'CVSE, \shortciteauthor{ni2020cvse}\textsuperscript{R,Ab}'),
    ('hrgr', 'HRGR, \shortciteauthor{li2018hybrid}\categoryRetrieval{}'),
    ('kerp', 'KERP, \shortciteauthor{li2019knowledge}\categoryRetrieval{}'),
    ('syeda-et-al', '\shortciteauthor{syeda2020chest}\textsuperscript{R,f+i}'),
    ('nguyen-et-al', '\shortciteauthor{nguyen2021automated}\categoryTransformer{}'),
    ('nishino-et-al', '\shortciteauthor{nishino2020reinforcement}\textsuperscript{L,RL}'),
    ('ratchet', 'RATCHET, \shortciteauthor{hou2021ratchet}\categoryTransformer{}'),
    ('vti', '\shortciteauthor{najdenkoska2021variational}\textsuperscript{T,f+i}'),
    ('kgae-supv', '\cite{liu2021kgae}\categoryTransformer{}'),
    ('knowledge', '\cite{yang2021knowledge}\categoryTransformer{}'),
    ('progressive', '\cite{nooralahzadeh2021progressive}\categoryTransformer{}'),
    ('miura-et-al-fcen', 'DELETEME'),
    ('miura-et-al-fce', '\shortciteauthor{miura-etal-2021-improving}\textsuperscript{T,RL}'),
    # ('-mirqi', ''),
    # (r'(\w+)-et-al', r'\1 et al.'),
])

In [None]:
def bold_best_value_in_values(values):
    formatter = lambda x: f'{x:.3f}'

    values = np.nan_to_num(values, nan=-1)

    # Get max_value
    max_value = np.max(values)
    max_value = formatter(max_value)

    values_str = []
    for value in values:
        if value == -1:
            value_s = '-'
        else:
            value_s = formatter(value)
        if value_s == max_value:
            value_s = bold(value_s)
        values_str.append(value_s)
        
    return values_str

In [None]:
def bold_best_value_by_column(df, apply=True):
    # METRICS_RANGE_100 = set() # ('bleu', 'rougeL')
    if not apply:
        return df
    
    df2 = df.copy()
    for col in df.columns:
        column = df[col]
        if column.dtypes == 'O':
            # Skip "object" like columns (e.g. with strings)
            continue

        df2[col] = bold_best_value_in_values(column.values)
    return df2

In [None]:
def _rotated_multirow_args():
    dataset = 'IU X-ray' if IU else 'MIMIC-CXR'
    return '{' + str(len(res)) + '}{' + dataset + '}'

In [None]:
table = res.drop(columns='timestamp') if 'timestamp' in res.columns else res
table = bold_best_value_by_column(table, False).rename(
    index=get_official_run_name,
    columns=latexify_cols,
)
n_metrics = len(table.columns)
table = table.reset_index().rename(columns={'run_name': bold('Model')}).to_latex(
    float_format='%.3f',
    column_format='l' + 'c' * n_metrics,
    na_rep='-',
    index=False,
    escape=False,
    # bold_rows=True,
)
table = re.sub(r' +', ' ', table, flags=re.M)
# Add this additional column for the dataset (IU or MIMIC)
# table = re.sub(r'^ +', '& ', table, flags=re.M)
# table = re.sub(
#     r'^\& (CLARA|Templ\. simple|Boag)',
#     r'\cline{2-11}\n& \1', table, flags=re.M,
# )
# table = re.sub(
#     r'^\\midrule',
#     r'\\midrule\n\\rotatedMultirow' + _rotated_multirow_args(),
#     table, flags=re.M,
# )
print(table)

### Chexpert by disease table

In [None]:
def bold_best_value_by_row(df):
    df2 = df.copy()
    for row in df.index:
        values = df.loc[row].values
        
        df2.loc[row] = bold_best_value_in_values(values)
    return df2

In [None]:
base = 'f1'
metrics = [c for c in CHEXPERT_DISEASE_METRICS if base in c] + [base]

In [None]:
rename_runs_2 = [
    ('iu-x-ray_', ''),
    ('mimic-cxr_', ''),
    ('_front', ''),
    ('tpl-chex-v1-grouped-ordbest_cnn-0611-155356_densenet-121-v2', 'densenet-121 + templates'),
    ('dummy-', ''),
    (r'_(precnn-)?\d{4}-\d{6}', ''),
    (r'_lr(-emb)?[\d\.]+', ''),
    (r'__\w+', ''),
    ('-v2', ''),
    ('_cnn-freeze', ''),
]

In [None]:
TRANSPOSE = True

In [None]:
_ORDER = [
    'most-similar',
    's-att-tell',
    'paper_boag',
    'paper_lovelace',
    'paper_ratchet',
    'paper_miura',
    'paper_ni',
    'tpl',
]
def _order_runs(index):
    def _get_order(run):
        for i, o in enumerate(_ORDER):
            if run.startswith(o):
                return i
        raise Exception(f'{run} not considered in order!')
    return pd.Index([_get_order(run) for run in index])

In [None]:
df = filter_results(
    # lstm-att: 0612_233628
    # 1-nn old-cnn: 0612_215504 , euclidean: 1103_111912, cosine: 1210_212245
    contains=('mimic-cxr', r'1210_212245|1102_190559|1113_185718|paper_(boag-et-al-cnn-rnn-beam|lovelace|ni\-et|ratchet|miura)'),
    # contains=('1102_190559|1129_191853'),
    # contains=('chex-v5'),
    # old: 0612_215709 
    # contains=('iu-x-ray', r'dummy|tpl|__base|paper'),
    # contains=('iu-x-ray', r'0611_182321|0612_012900'), # 0612_012741
    # dummy-most|dummy-random|__freeze
    doesnt_contain=('dummy-baseline', '_satt', '_ssent', '_COPY', 'tiny', 'miura-et-al-fcen'),
    dataset_type='test',
    free=True,
    metrics=metrics,
    rename_runs=rename_runs_2,
    drop_key_cols=True,
    # timestamp_col=True,
    # drop_na_rows=True,
    best='lighter-chex_f1',
    beam_size=0,
    # best_col=True, beam_col=True,
    remove_timestamp=True,
).set_index('run_name').sort_index(key=_order_runs)
if TRANSPOSE:
    df = df.transpose()
df

In [None]:
%run ../datasets/common/constants.py

In [None]:
if TRANSPOSE:
    d = df.rename(index={base: f'{base}-macro'}).rename(
        columns=lambda x: '\tableDiseaseColname{%s}' % get_official_run_name(x),
        index=get_renamer([
            (r'{}-macro'.format(base), 'Macro average'),
            (r'{}-(\w+)'.format(base), r'\1'),
        ])
    )
    d = bold_best_value_by_row(d)
    n_models = len(d.columns)
    d.columns.rename(bold(f'{base.capitalize()}-scores'), inplace=True)
    table = d.to_latex(
        float_format='%.3f',
        column_format='l' + 'c' * n_models,
        # na_rep='-',
        # index=False,
        escape=False,
    )
else:
    d = df.rename(columns={base: f'Macro'}).rename(
        columns=lambda x: '\tableDiseaseColname{%s}' % ABN_SHORTCUTS.get(x.strip('f1-'), x),
        index=get_official_run_name,
    )
    d = bold_best_value_by_column(d)
    table = d.to_latex(
        float_format='%.3f',
        column_format='l' + 'c' * len(d.columns),
        # na_rep='-',
        # index=False,
        escape=False,
    )
table = re.sub(r' +', ' ', table, flags=re.M)
table = re.sub(
    r'^(Macro)',
    r'\midrule\n\1', table, flags=re.M,
)
print(table)

### Training stats

In [None]:
replace_strs = [
    (r'_precnn-\d{4}-\d{6}', ''),
    (r'_lr[\d\.]+', ''),
    (r'_lr-emb[\d\.]+', ''),
    ('_size256', ''),
    ('-v2', ''),
    ('_front', ''),
    (r'__[\w\-]*', ''),
]

In [None]:
cols = [
    'run_name',
    'time_per_epoch', 'total_time',
    'current_epoch', 'final_epoch',
    'batch_size', 'device', 'visible',
]
res = filter_training_stats(
    contains='__base',
    columns=cols,
    rename_runs=replace_strs,
)
res = res.replace(r'^\d{4}_\d{6}_(.*)', r'\1', regex=True)
res = res.set_index('run_name').rename(index=rename_runs)
res.sort_index()

### Compare runtime chexpert vs holistic chexpert

In [None]:
def subtract_cols(df, cols_a, cols_b, drop_na_rows=True):
    array_a = df[cols_a].to_numpy()
    array_b = df[cols_b].to_numpy()
    
    df_2 = df[KEY_COLS].copy()
    df_2 = pd.concat([df_2, pd.DataFrame(array_a - array_b, columns=cols_a)], axis=1)
    
    if drop_na_rows:
        df_2.dropna(axis=0, inplace=True, how='any')
    
    return df_2

In [None]:
metric = 'f1'

runtime_chexpert = [c for c in RESULTS_DF.columns if c.startswith(f'chex_{metric}')]
holistic_chexpert = [c for c in RESULTS_DF.columns if c.startswith(metric)]

In [None]:
df = RESULTS_DF
df = df.loc[~df['run_name'].str.contains('dummy')]
len(df)

In [None]:
set(df['run_name'])

In [None]:
df = subtract_cols(df, runtime_chexpert, holistic_chexpert)
df.head()

In [None]:
df.describe()

In [None]:
df

In [None]:
from collections import Counter

In [None]:
run_name = '0112_154506_lstm-v2_lr0.001_densenet-121-v2_noes'
debug = False
d1 = load_rg_outputs(run_name, debug=debug, free=True)
d2 = load_rg_outputs(run_name, debug=debug, free=False)
len(d1), len(d2)

In [None]:
c1 = Counter(d1['filename'])
c2 = Counter(d2['filename'])
len(c1), len(c2)

In [None]:
for fname in c1.keys():
    v1 = c1[fname]
    v2 = c2[fname]
    if v1 != v2:
        print('Wrong: ', fname, v1, v2)

In [None]:
d2.head()

In [None]:
set(d2['dataset_type'])

### Pretty-print (latex)

In [None]:
replace_strs = [
    (r'^\d{4}_\d{6}_', ''),
    ('most-similar-image', '1nn'),
    ('_lr[\d\.]+', ''),
    ('_size256', ''),
    (r'_\d{4}_\d{6}_.*', ''),
    ('dummy-', ''),
    ('common', 'top'),
    ('-v2', ''),
    (r'top-(\w)\w+-(\d+)', r'top-\1-\2'),
    ('_densenet-121', ''),
]

In [None]:
columns = ['bleu', 'rougeL', 'ciderD'] + CHEXPERT_METRICS + MIRQI_METRICS

In [None]:
df = filter_results(dataset_type='test',
                    free=True,
                    metrics=columns,
                    contains='(?=_lstm-att-v2.*densenet|_lstm-v2.*densenet|dummy)',
                    drop='0915_173951|0915_174222|0916_104739',
                    drop_na_rows=True,
                    rename_runs=replace_strs,
                   )
df

In [None]:
shorten_cols = lambda s: s.replace('MIRQI-v2', 'v2')

In [None]:
print(df.set_index('run_name').rename(
    index=rename_runs,
    columns=shorten_cols,
).sort_index().to_latex(
    columns=[shorten_cols(c) for c in columns],
    float_format='%.3f',
    column_format='l' + 'c' * len(columns),
))

## Classification

### Check results

In [None]:
# contains = 'covid-x'
# contains = 'cxr14'
# contains = 'e0'
# contains = '0717_120222_covid-x_densenet-121_lr1e-06_os_aug-covid'
# contains = '0717_101812_covid-x_densenet-121_lr1e-06_os-max2_aug-covid'
# run_name = '0717_120222_covid-x_densenet-121_lr1e-06_os_aug-covid' # WINNER

# contains = '0717_101812_covid-x_densenet-121_lr1e-06_os-max2_aug-covid'
# contains = 'covid-uc'

In [None]:
metrics = [
    'roc_auc', 'pr_auc', # 'hamming', #
]
# metrics = [
#     'acc', 'roc_auc', 'prec', 'recall', 'roc_auc_Cardiomegaly', 'roc_auc_Pneumonia',
#     'recall_Cardiomegaly', 'recall_Pneumonia',
#     'iobb-masks', 'iobb-masks-Cardiomegaly', 'iobb-masks-Pneumonia',
# ]

In [None]:
replace_strs = [
    # (r'^\d{4}_\d{6}_', ''),
    # (r'_precnn-\d{4}-\d{6}', ''),
    (r'_lr[e\-\d\.]+', ''),
    # (r'(cxr14|chexpert|iu-x-ray)_', ''),
    ('_size256', ''),
    (r'_cl-wbce_seg-w', ''),
    (r'_seg-unw', ''),
    # (r'_aug\d-(touch|double|single)', ''),
    ('_shuffle', ''),
    ('_sch-(roc|pr)[\-_]auc-p\d-f0.5(-c\d)?', ''),
    ('_best-(roc|pr)[\-_]auc', ''),
    ('_norm[SD]', ''),
    ('_labels13', ''),
    # ('_front', ''),
    # (r'__[\w\-]*', ''),
]

In [None]:
# CONTAINS = r'cxr14.*(?:small|tiny)|0402_062551'
CONTAINS = 'cxr14_densenet-121'
# CONTAINS = 'chexpert_densenet-121'
# CONTAINS = r'chexpert' # .*(?:small|tiny)
# CONTAINS = r'iu-x-ray.*(?:tiny)|0420_175514'
# CONTAINS = r'iu-x-ray.*(?:small)|0420_175514'
# CONTAINS = r'iu-x-ray_densenet-121' # 0420_175514

DATASET_TYPE = 'val' if 'chex' in CONTAINS else 'test'

d = filter_results(
    contains=CONTAINS,
    doesnt_contain=['hint', 'balance', 'Cardiomeg', 'Pneumonia'],
    dataset_type=DATASET_TYPE,
    metrics=metrics,
    drop_key_cols=True,
    # rename_runs=replace_strs,
).sort_values('pr_auc', ascending=False)
d.set_index('run_name')

In [None]:
# meta = load_metadata(RunId('0406_230221', False, 'cls'))
meta['hparams']

### Check training stats

In [None]:
replace_strs = [
    (r'_lr[e\-\d\.]+', ''),
    # (r'(cxr14|chexpert|iu-x-ray)_', ''),
    (r'_pre\d{4}-\d{6}', ''),
    ('_size256', ''),
    (r'_cl-wbce_seg-w', ''),
    (r'_seg-unw', ''),
    (r'_aug\d-(touch|double|single)', ''),
    ('_shuffle', ''),
    ('_sch-(roc|pr)[\-_]auc-p\d-f0.5(-c\d)?', ''),
    ('_best-(roc|pr)[\-_]auc', ''),
    ('_norm[SD]', ''),
    ('_labels13', ''),
]

In [None]:
cols = [
    'run_name',
    'time_per_epoch', 'total_time',
    'current_epoch', 'initial_epoch', 'final_epoch',
    'batch_size', 'visible',
]
res = filter_training_stats(
    contains=r'cxr14|chexpert',
    columns=cols,
    rename_runs=replace_strs,
)
# res = res.replace(r'^\d{4}_\d{6}_(.*)', r'\1', regex=True)
res = res.set_index('run_name').rename(index=rename_runs)
res.sort_index()