## Imports

In [None]:
import os
import pandas as pd
import json
import re

In [None]:
%run ../utils/__init__.py
%run ../utils/files.py
%run ../metrics/__init__.py
%run ../models/checkpoint/__init__.py

In [None]:
pd.options.display.max_columns = None
pd.options.display.float_format = '{:.3f}'.format
pd.set_option('display.max_colwidth', None)

## Choose task

In [None]:
# TASK = 'seg'
TASK = 'rg'
# TASK = 'cls-seg'
# TASK = ('cls', 'cls-seg')

In [None]:
KEY_COLS = ['run_name', 'timestamp', 'dataset_type']
if TASK == 'rg':
    KEY_COLS.append('free')
KEY_COLS

## Functions

In [None]:
import glob

In [None]:
def _get_run_folders(tasks, target_folder):
    if isinstance(tasks, str):
        tasks = (tasks,)
    results = []
    for task in tasks:
        target_glob = os.path.join(WORKSPACE_DIR, _get_task_folder(task), target_folder, '*')
        results.extend(glob.glob(target_glob))
    return results

In [None]:
def get_suffix(filename):
    match = re.search('.*metrics-(?P<suffix>\w*)\.json', filename)
    if match is None:
        suffix = ''
    else:
        suffix = match.group('suffix')
    return suffix

In [None]:
METRIC_TYPES = [
    'chexpert',
    'grad-cam',
    'mirqi',
]

### Load fns

In [None]:
def _extract_timestamp(run_name):
    if re.search('^\d{4}_\d{6}', run_name):
        return run_name[:11]
    return ''

In [None]:
def load_results():
    results_by_metric_type = {}

    for run_folder in _get_run_folders(TASK, 'results'):
        run_name = os.path.basename(run_folder)
        
        if run_name == 'debug':
            continue

        for filename in os.listdir(run_folder):
            filepath = os.path.join(run_folder, filename)
            if not os.path.isfile(filepath) or not filename.endswith('json'):
                continue

            if any(
                s in filename
                for s in ('thresholds-', 'training-stats')
                ):
                continue
                
            metric_type = next(
                (met for met in METRIC_TYPES if met in filename),
                'base', # Default if no specific metric_type is found
            )

            with open(filepath, 'r') as f:
                results_dict = json.load(f)
   
            results_df = pd.DataFrame.from_dict(results_dict, orient='index')
            results_df.reset_index(inplace=True)
            results_df.rename(columns={'index': 'dataset_type'}, inplace=True)
            results_df['run_name'] = run_name
            results_df['timestamp'] = _extract_timestamp(run_name)
            if TASK == 'rg':
                results_df['free'] = get_suffix(filename)           
            
            if metric_type not in results_by_metric_type:
                results_by_metric_type[metric_type] = results_df
            else:
                prev = results_by_metric_type[metric_type]
                results_by_metric_type[metric_type] = prev.append(results_df, ignore_index=True)

    df = None
    cols_in_order = list(KEY_COLS)
    for results in results_by_metric_type.values():
        cols_in_order += [col for col in results.columns if col not in cols_in_order]
        
        if df is None:
            df = results
        else:
            df = df.merge(results, on=KEY_COLS, how='outer')
                
    return df[cols_in_order], results_by_metric_type

In [None]:
def load_training_stats():
    re_filename = re.compile(r'training-stats.*.json')

    all_training_stats = []

    for run_folder in _get_run_folders(TASK, 'models'):
        run_name = os.path.basename(run_folder)

        if run_name == 'debug':
            continue

        for filename in os.listdir(run_folder):
            if not re_filename.match(filename):
                continue
            
            filepath = os.path.join(run_folder, filename)
            if not os.path.isfile(filepath):
                continue

            with open(filepath, 'r') as f:
                training_stats_original_dict = json.load(f)

            # Unwrap dicts and fix old key-values
            training_stats = dict()
            for key, value in training_stats_original_dict.items():
                if isinstance(value, dict):
                    for k, v in value.items():
                        training_stats[k] = v
                elif key == 'n_epochs' or key == 'epochs':
                    training_stats['final_epoch'] = value
                    training_stats['current_epoch'] = value
                else:
                    training_stats[key] = value
                    
            # Add total-time column
            secs_per_epoch = training_stats['secs_per_epoch']
            n_epochs = training_stats['current_epoch'] - training_stats['initial_epoch']
            total_time = secs_per_epoch * n_epochs
            training_stats['total_time'] = duration_to_str(total_time)
    
            # Add pretty-time columns
            training_stats['time_per_epoch'] = duration_to_str(secs_per_epoch)
    
            # Add run_name column
            training_stats['run_name'] = run_name

            all_training_stats.append(training_stats)
        
    df = pd.DataFrame(all_training_stats)
    cols = ['run_name'] + [c for c in df.columns if c != 'run_name']
    df = df[cols]

    return df

### Filter fns

In [None]:
def _filter_df_run_name_contains(df, contains):
    if contains:
        filter_contains = lambda d, s: d.loc[d['run_name'].str.contains(s)]
        if isinstance(contains, (list, tuple)):
            for c in contains:
                df = filter_contains(df, c)
        elif isinstance(contains, str):
            df = filter_contains(df, contains)
    return df

def __rename_run_name(run_name, replace_strs):
    s = run_name
    for target, replace_with in replace_strs:
        s = re.sub(target, replace_with, s)
    return s

def _df_rename_runs(df, rename_runs):
    if rename_runs and 'run_name' in df:
        df['run_name'] = [
            __rename_run_name(r, rename_runs)
            for r in df['run_name']
        ]
    return df

In [None]:
def get_renamer(replace_strs):
    def _rename_run(run_name):
        s = run_name
        for target, replace_with in replace_strs:
            s = re.sub(target, replace_with, s)
        return s
    return _rename_run

In [None]:
def filter_results(dataset_type=None, metrics=None,
                   metrics_contain=None, free=None,
                   contains=None, doesnt_contain=None,
                   drop=None, drop_na_rows=False, drop_key_cols=False,
                   timestamp_col=False,
                   rename_runs=None, remove_timestamp=False,
                  ):
    df = RESULTS_DF
    
    if dataset_type:
        if isinstance(dataset_type, str):
            df = df[df['dataset_type'] == dataset_type]
        elif isinstance(dataset_type, (list, tuple)):
            dataset_type = set(dataset_type)
            df = df[df['dataset_type'].isin(dataset_type)]
    
    if free is not None:
        free_str = 'free' if free else 'notfree'
        df = df.loc[df['free'] == free_str]
    
    df = _filter_df_run_name_contains(df, contains)
    
    if doesnt_contain:
        filter_doesnt_contain = lambda d, s: d.loc[~d['run_name'].str.contains(s)]
        if isinstance(doesnt_contain, (list, tuple)):
            for c in doesnt_contain:
                df = filter_doesnt_contain(df, c)
        elif isinstance(doesnt_contain, str):
            df = filter_doesnt_contain(df, doesnt_contain)
    
    if drop:
        df = df.loc[~df['run_name'].str.contains(drop)]
        
    if metrics_contain:
        columns = KEY_COLS + [c for c in df.columns if metrics_contain in c]
        df = df[columns]
    elif metrics:
        columns = KEY_COLS + metrics
        df = df[columns]
    
    if drop_na_rows:
        df.dropna(axis=0, how='any', inplace=True)

    # Drop cols with all na
    df.dropna(axis=1, how='all', inplace=True)

    if drop_key_cols:
        columns = [
            c for c in df.columns
            if c == 'run_name' or (timestamp_col and c == 'timestamp') or c not in KEY_COLS
        ]
        df = df[columns]

    _df_rename_runs(df, rename_runs)

    if remove_timestamp:
        df = df.replace(r'^\d{4}_\d{6}_', '', regex=True)

    return df

In [None]:
def filter_training_stats(contains=None, columns=None,
                          rename_runs=None, remove_timestamp=False,
                         ):
    df = TRAINING_STATS_DF
    
    df = _filter_df_run_name_contains(df, contains)
            
    _df_rename_runs(df, rename_runs)
    
    if remove_timestamp:
        df = df.replace(r'^\d{4}_\d{6}_', '', regex=True)
        
    if columns is not None:
        df = df[columns]
    return df

## Load results

In [None]:
RESULTS_DF, debug = load_results()
print(len(RESULTS_DF))
# RESULTS_DF.head(2)

In [None]:
TRAINING_STATS_DF = load_training_stats()
print(len(TRAINING_STATS_DF))
# TRAINING_STATS_DF.tail(2)

In [None]:
# set(
#     col.replace('-', '_').split('_')[0]
#     for col in RESULTS_DF.columns
# )

## Segmentation

In [None]:
def add_macro_avg_column(target_col):
    matching_cols = [c for c in RESULTS_DF.columns if c.startswith(target_col)]
    assert len(matching_cols) == 3, f'Matching cols not 3: {matching_cols}'
    averages = RESULTS_DF[matching_cols].mean(axis=1)
    RESULTS_DF[target_col] = averages
    print(f'Calculated col {target_col}')

In [None]:
add_macro_avg_column('n-shapes-gen')
add_macro_avg_column('n-holes-gen')

In [None]:
SEG_METRICS = []
organs = ('heart', 'left lung', 'right lung')
def _add_metric(metric_name, macro=True):
    if macro: SEG_METRICS.append(metric_name)
    SEG_METRICS.extend(f'{metric_name}-{organ}' for organ in organs)
_add_metric('iou')
# _add_metric('dice')
_add_metric('n-shapes-gen')
_add_metric('n-holes-gen')
SEG_METRICS

In [None]:
replace_strs = [
    # (r'^\d{4}_\d{6}_', ''),
    (r'jsrt_scan_', ''),
#     ('most-similar-image', '1nn'),
#     ('_lr[\d\.]+', ''),
#     ('_size256', ''),
#     (r'_\d{4}_\d{6}_.*', ''),
#     ('dummy-', ''),
#     ('common', 'top'),
#     ('-v2', ''),
#     (r'top-(\w)\w+-(\d+)', r'top-\1-\2'),
#     ('_densenet-121', ''),
]

In [None]:
filter_results(
    metrics=SEG_METRICS,
    dataset_type='test',
    drop='1105_180035',
    rename_runs=replace_strs,
).sort_values(
    ['n-shapes-gen', 'n-holes-gen'],
    ascending=True,
).set_index('run_name')

## Report generation

In [None]:
# NLP_METRICS = ['bleu1', 'bleu2', 'bleu3', 'bleu4', 'bleu', 'rougeL', 'ciderD']
# CHEXPERT_METRICS = ['recall', 'prec', 'f1'] # 'acc', 'roc_auc', 
CHEXPERT_DISEASE_METRICS = [
    c
    for c in RESULTS_DF.columns
    if any(
        c.startswith(f'{ch}-')
        for ch in ('f1', 'recall', 'prec')
    ) and not c.endswith('-woNF')
]
# CHEXPERT_RUNTIME_METRICS = [col for col in RESULTS_DF.columns if col.startswith('chex')]
# VAR_METRICS = [c for c in RESULTS_DF.columns if 'distinct' in c]
# MIRQI_METRICS = [c for c in RESULTS_DF.columns if 'MIRQI' in c]
MIRQI_METRICS_v1 = ['MIRQI-f', 'MIRQI-p', 'MIRQI-r']

In [None]:
NLP_METRICS = [
    'bleu', 'bleu1', 'bleu2', 'bleu3', 'bleu4',
    'rougeL', 'ciderD',
]
ESSENTIAL_METRICS = [
    'bleu', 'rougeL', 'ciderD',
    # 'chex_f1', 'chex_acc', # 'chex_recall', 'chex_prec', # Runtime-chexpert
    # 'MIRQI-v2-f',
    # 
    # Holistic-chexpert:
    # 'acc',
    'f1', 'prec', 'recall',

    # woNF:
    # 'f1-woNF', 'prec-woNF', 'recall-woNF',
    # 'pr_auc', 'pr_auc-woNF',
    # 'acc',
    # *CHEXPERT_DISEASE_METRICS,
    *MIRQI_METRICS_v1,
]

### Main table

In [None]:
rename_runs = [
    # (r'_precnn-\d{4}-\d{6}', ''),
    ('mimic-cxr_', ''),
    ('iu-x-ray_', ''),
    # ('most-similar-image', '1nn'),
    (r'_lr(-\w+)?[\d\.e\-]+', ''),
    # (r'_lr[\d\.]+', ''),
    ('_size256', ''),
    # ('-v2', ''),
    ('_front', ''),
    (r'__[\w\-]*', ''),
    (r'_(pre)?cnn\-\d{4}\-\d{6}', ''),
    ('_densenet-121-v2', ''),
]

In [None]:
IU = False
# MICCAI experiments:
# iu lstm 0612_035549, best-bleu: 0621_134437
# mimic lstm best-bleu: 0621_231122
# iu h-coatt: 0623_120544|0623_110053
# mimic h-coatt: 0623_192208
CONTAINS = \
    ('iu-x-ray', r'(paper|(0612_035549|0623_202003|dummy|tpl.*-ordbest-v2.*0611-155356).*v4-1)') \
    if IU else \
    ('mimic-cxr',
     r'06.*dummy-m|0617_144209|0623_103308|0625_184437|0612_233628|tpl-(chex-v1|m-chex-grouped-v6)-ordbest-v2.*cnn-0612-082139|paper',
    )

res = filter_results(
    # contains=('iu-x-ray', '_lstm-att-v2.*hs\-512.*_front'),
    # contains=('mimic-cxr', r'tpl-(chex-v1-ordbest|m-chex-grouped-v6)'),
    
    # H-coatt models
    # contains=('iu-x-ray', 'h-coatt.*v4-1.*mti|paper_coatt'), # __og2
    # contains=('mimic-cxr', 'h-coatt'),
    
    # MICCAI template experiments:
    # contains=('iu-x-ray', 'chex-v1|chex-v2-grouped', '0611.155356', 'v4-1'), # 0611.162006

    contains=CONTAINS,
    doesnt_contain=(
        'dummy-baseline', 'dummy-common', '_satt', '_ssent', '_COPY', 'tiny',
        'boag-et-al-1nn', 'liu-et-al-ccr', 'tienet', 'rtmic',
        'most-similar-image_0519-181554', 'cls-seg', 'noisy',
        # 'constant-mimic',
        're-impl',
    ),
    dataset_type='test',
    free=True,
    metrics=ESSENTIAL_METRICS,
    drop_key_cols=True,
    timestamp_col=True,
    # drop_na_rows=True,
    rename_runs=rename_runs,
    remove_timestamp=True,
)
res = res.set_index('run_name').sort_index() # .sort_values('f1', ascending=False)
res

### Main-table to latex

In [None]:
def bold(s):
    return '\textbf{' + s + '}'

shorten_cols = get_renamer([
    ('-woNF', '-d'),
    ('ciderD', 'C-D'),
    ('bleu', 'B'),
    ('rougeL', 'R-L'),
    ('acc', 'Acc'),
    ('prec', 'P'),
    ('recall', 'R'),
    ('f1', 'F-1'),
    ('MIRQI-f', 'M-F-1'),
    ('MIRQI-r', 'M-R'),
    ('MIRQI-p', 'M-P'),
])
def latexify_cols(col):
    return bold(shorten_cols(col))

get_official_run_name = get_renamer([
    # All trained models
    ('_reports-v4-1', ''),
    (r'_(cnn-)?\d{4}-\d{6}', ''),
    ('_densenet-121', ''),
    # Dummy models
    (r'most-similar-image', '1-nn'),
    ('dummy-', ''),
    ('common-', 'top-'),
    ('constant-.*', 'Constant'),
    ('random', 'Random'),
    # DL models
    ('lstm-att.*', 'CNN-LSTM-att'),
    # Template models
    ('tpl', 'Templ.'),
    (r'-chex-v1-ordbest-v2.*', ' single'),
    # (r'-chex-v1-noisy.*', ' top-char.'),
    (r'-chex-v2-grouped-ordbest-v2', ' grouped'),
    (r'-m-chex-grouped-v6-ordbest-v2', ' grouped'),
    ('-ord\w+', ''),
    ('h-coatt.*', 'CoAtt\reimplemented{}\cite{jing2017automatic}'),
    # Papers
    ('paper_', ''),
    ('rtex', 'RTEX \cite{kougia2021rtex}'),
    ('zhang-et-al-mirqi', 'Zhang et al. \\\\findingsAndImpression{}\cite{zhang2020graph}'),
    ('lovelace-et-al', 'Lovelace et al. \cite{lovelace2020learning}'),
    ('liu-et-al-full', 'Liu et al. \cite{liu2019clinically}'),
    ('boag-et-al-1nn', 'Boag et al. (1-nn) \cite{boag2020baselines}'),
    ('boag-et-al-cnn-rnn-beam', 'Boag et al. \cite{boag2020baselines}'),
    ('chen-et-al', 'Chen et al. \cite{chen2020memory}'),
    ('clara', 'CLARA \cite{biswal2020clara}'),
    ('coatt', 'CoAtt \\\\findingsAndImpression{}\cite{jing2017automatic}'),
    ('ni-et-al', 'CVSE \cite{ni2020embeddings}'),
    ('hrgr', 'HRGR \cite{li2018hybrid}'),
    ('kerp', 'KERP \cite{li2019knowledge}'),
    ('syeda-et-al', 'S-M et al. \\\\findingsAndImpression{}\cite{syeda2020chest}'),
    # ('-mirqi', ''),
    # (r'(\w+)-et-al', r'\1 et al.'),
])

In [None]:
def bold_best_value_in_values(values):
    formatter = lambda x: f'{x:.3f}'

    values = np.nan_to_num(values, nan=-1)

    # Get max_value
    max_value = np.max(values)
    max_value = formatter(max_value)

    values_str = []
    for value in values:
        if value == -1:
            value_s = '-'
        else:
            value_s = formatter(value)
        if value_s == max_value:
            value_s = bold(value_s)
        values_str.append(value_s)
        
    return values_str

In [None]:
def bold_best_value_by_column(df):
    METRICS_RANGE_100 = set() # ('bleu', 'rougeL')
    
    df2 = df.copy()
    for col in df.columns:
        values = df[col].values

        df2[col] = bold_best_value_in_values(values)
    return df2

In [None]:
def _rotated_multirow_args():
    dataset = 'IU X-ray' if IU else 'MIMIC-CXR'
    return '{' + str(len(res)) + '}{' + dataset + '}'

In [None]:
table = res.drop(columns='timestamp') if 'timestamp' in res.columns else res
table = bold_best_value_by_column(table).rename(
    index=get_official_run_name,
    columns=latexify_cols,
).reset_index().rename(columns={'run_name': bold('Model')}).to_latex(
    float_format='%.3f',
    column_format='l' + 'c' * len(res.columns),
    na_rep='-',
    index=False,
    escape=False,
    # bold_rows=True,
)
table = re.sub(r' +', ' ', table, flags=re.M)
# Add this additional column for the dataset (IU or MIMIC)
table = re.sub(r'^ +', '& ', table, flags=re.M)
table = re.sub(
    r'^\& (CLARA|Templ\. simple|Boag)',
    r'\cline{2-11}\n& \1', table, flags=re.M,
)
table = re.sub(
    r'^\\midrule',
    r'\\midrule\n\\rotatedMultirow' + _rotated_multirow_args(),
    table, flags=re.M,
)
print(table)

### Chexpert by disease table

In [None]:
def bold_best_value_by_row(df):
    df2 = df.copy()
    for row in df.index:
        values = df.loc[row].values
        
        df2.loc[row] = bold_best_value_in_values(values)
    return df2

In [None]:
base = 'f1'
metrics = [c for c in CHEXPERT_DISEASE_METRICS if base in c] + [base]

In [None]:
rename_runs_2 = [
    ('iu-x-ray_', ''),
    ('mimic-cxr_', ''),
    ('_front', ''),
    ('tpl-chex-v1-grouped-ordbest_cnn-0611-155356_densenet-121-v2', 'densenet-121 + templates'),
    ('dummy-', ''),
    (r'_(precnn-)?\d{4}-\d{6}', ''),
    (r'_lr(-emb)?[\d\.]+', ''),
    (r'__\w+', ''),
    ('-v2', ''),
    ('_cnn-freeze', ''),
]

In [None]:
df = filter_results(
    contains=('mimic-cxr', r'0612_215504|0612_215709|0612_233628|paper_(boag|lovelace|ni)'),
    # contains=('iu-x-ray', r'dummy|tpl|__base|paper'),
    # contains=('iu-x-ray', r'0611_182321|0612_012900'), # 0612_012741
    # dummy-most|dummy-random|__freeze
    doesnt_contain=('dummy-baseline', '_satt', '_ssent', '_COPY', 'tiny'),
    dataset_type='test',
    free=True,
    metrics=metrics,
    rename_runs=rename_runs_2,
    drop_key_cols=True,
    # timestamp_col=True,
    # drop_na_rows=True,
    remove_timestamp=True,
).set_index('run_name').sort_index().transpose().rename(index={base: f'{base}-macro'})
df = df.rename(
    columns=get_official_run_name,
    index=get_renamer([
        (r'{}-macro'.format(base), 'Macro average'),
        (r'{}-(\w+)'.format(base), r'\1'),
    ])
)
df = bold_best_value_by_row(df)
df.columns.rename(f'{base.capitalize()} by disease', inplace=True)
table = df.to_latex(
    float_format='%.3f',
    column_format='l' + 'c' * len(res.columns),
    # na_rep='-',
    # index=False,
    escape=False,
)
table = re.sub(r' +', ' ', table, flags=re.M)
table = re.sub(
    r'^(Macro)',
    r'\midrule\n\1', table, flags=re.M,
)
print(table)

### Training stats

In [None]:
replace_strs = [
    (r'_precnn-\d{4}-\d{6}', ''),
    (r'_lr[\d\.]+', ''),
    (r'_lr-emb[\d\.]+', ''),
    ('_size256', ''),
    ('-v2', ''),
    ('_front', ''),
    (r'__[\w\-]*', ''),
]

In [None]:
cols = [
    'run_name',
    'time_per_epoch', 'total_time',
    'current_epoch', 'final_epoch',
    'batch_size', 'device', 'visible',
]
res = filter_training_stats(
    contains='__base',
    columns=cols,
    rename_runs=replace_strs,
)
res = res.replace(r'^\d{4}_\d{6}_(.*)', r'\1', regex=True)
res = res.set_index('run_name').rename(index=rename_runs)
res.sort_index()

### Compare runtime chexpert vs holistic chexpert

In [None]:
def subtract_cols(df, cols_a, cols_b, drop_na_rows=True):
    array_a = df[cols_a].to_numpy()
    array_b = df[cols_b].to_numpy()
    
    df_2 = df[KEY_COLS].copy()
    df_2 = pd.concat([df_2, pd.DataFrame(array_a - array_b, columns=cols_a)], axis=1)
    
    if drop_na_rows:
        df_2.dropna(axis=0, inplace=True, how='any')
    
    return df_2

In [None]:
metric = 'f1'

runtime_chexpert = [c for c in RESULTS_DF.columns if c.startswith(f'chex_{metric}')]
holistic_chexpert = [c for c in RESULTS_DF.columns if c.startswith(metric)]

In [None]:
df = RESULTS_DF
df = df.loc[~df['run_name'].str.contains('dummy')]
len(df)

In [None]:
set(df['run_name'])

In [None]:
df = subtract_cols(df, runtime_chexpert, holistic_chexpert)
df.head()

In [None]:
df.describe()

In [None]:
df

In [None]:
from collections import Counter

In [None]:
run_name = '0112_154506_lstm-v2_lr0.001_densenet-121-v2_noes'
debug = False
d1 = load_rg_outputs(run_name, debug=debug, free=True)
d2 = load_rg_outputs(run_name, debug=debug, free=False)
len(d1), len(d2)

In [None]:
c1 = Counter(d1['filename'])
c2 = Counter(d2['filename'])
len(c1), len(c2)

In [None]:
for fname in c1.keys():
    v1 = c1[fname]
    v2 = c2[fname]
    if v1 != v2:
        print('Wrong: ', fname, v1, v2)

In [None]:
d2.head()

In [None]:
set(d2['dataset_type'])

### Pretty-print (latex)

In [None]:
replace_strs = [
    (r'^\d{4}_\d{6}_', ''),
    ('most-similar-image', '1nn'),
    ('_lr[\d\.]+', ''),
    ('_size256', ''),
    (r'_\d{4}_\d{6}_.*', ''),
    ('dummy-', ''),
    ('common', 'top'),
    ('-v2', ''),
    (r'top-(\w)\w+-(\d+)', r'top-\1-\2'),
    ('_densenet-121', ''),
]

In [None]:
columns = ['bleu', 'rougeL', 'ciderD'] + CHEXPERT_METRICS + MIRQI_METRICS

In [None]:
df = filter_results(dataset_type='test',
                    free=True,
                    metrics=columns,
                    contains='(?=_lstm-att-v2.*densenet|_lstm-v2.*densenet|dummy)',
                    drop='0915_173951|0915_174222|0916_104739',
                    drop_na_rows=True,
                    rename_runs=replace_strs,
                   )
df

In [None]:
shorten_cols = lambda s: s.replace('MIRQI-v2', 'v2')

In [None]:
print(df.set_index('run_name').rename(
    index=rename_runs,
    columns=shorten_cols,
).sort_index().to_latex(
    columns=[shorten_cols(c) for c in columns],
    float_format='%.3f',
    column_format='l' + 'c' * len(columns),
))

## Classification

### Check results

In [None]:
# contains = 'covid-x'
# contains = 'cxr14'
# contains = 'e0'
# contains = '0717_120222_covid-x_densenet-121_lr1e-06_os_aug-covid'
# contains = '0717_101812_covid-x_densenet-121_lr1e-06_os-max2_aug-covid'
# run_name = '0717_120222_covid-x_densenet-121_lr1e-06_os_aug-covid' # WINNER

# contains = '0717_101812_covid-x_densenet-121_lr1e-06_os-max2_aug-covid'
# contains = 'covid-uc'

In [None]:
metrics = [
    'roc_auc', 'pr_auc', # 'hamming', #
]
# metrics = [
#     'acc', 'roc_auc', 'prec', 'recall', 'roc_auc_Cardiomegaly', 'roc_auc_Pneumonia',
#     'recall_Cardiomegaly', 'recall_Pneumonia',
#     'iobb-masks', 'iobb-masks-Cardiomegaly', 'iobb-masks-Pneumonia',
# ]

In [None]:
replace_strs = [
    # (r'^\d{4}_\d{6}_', ''),
    # (r'_precnn-\d{4}-\d{6}', ''),
    (r'_lr[e\-\d\.]+', ''),
    # (r'(cxr14|chexpert|iu-x-ray)_', ''),
    ('_size256', ''),
    (r'_cl-wbce_seg-w', ''),
    (r'_seg-unw', ''),
    # (r'_aug\d-(touch|double|single)', ''),
    ('_shuffle', ''),
    ('_sch-(roc|pr)[\-_]auc-p\d-f0.5(-c\d)?', ''),
    ('_best-(roc|pr)[\-_]auc', ''),
    ('_norm[SD]', ''),
    ('_labels13', ''),
    # ('_front', ''),
    # (r'__[\w\-]*', ''),
]

In [None]:
# CONTAINS = r'cxr14.*(?:small|tiny)|0402_062551'
CONTAINS = 'cxr14_densenet-121'
# CONTAINS = 'chexpert_densenet-121'
# CONTAINS = r'chexpert' # .*(?:small|tiny)
# CONTAINS = r'iu-x-ray.*(?:tiny)|0420_175514'
# CONTAINS = r'iu-x-ray.*(?:small)|0420_175514'
# CONTAINS = r'iu-x-ray_densenet-121' # 0420_175514

DATASET_TYPE = 'val' if 'chex' in CONTAINS else 'test'

d = filter_results(
    contains=CONTAINS,
    doesnt_contain=['hint', 'balance', 'Cardiomeg', 'Pneumonia'],
    dataset_type=DATASET_TYPE,
    metrics=metrics,
    drop_key_cols=True,
    # rename_runs=replace_strs,
).sort_values('pr_auc', ascending=False)
d.set_index('run_name')

In [None]:
# meta = load_metadata(RunId('0406_230221', False, 'cls'))
meta['hparams']

### Check training stats

In [None]:
replace_strs = [
    (r'_lr[e\-\d\.]+', ''),
    # (r'(cxr14|chexpert|iu-x-ray)_', ''),
    (r'_pre\d{4}-\d{6}', ''),
    ('_size256', ''),
    (r'_cl-wbce_seg-w', ''),
    (r'_seg-unw', ''),
    (r'_aug\d-(touch|double|single)', ''),
    ('_shuffle', ''),
    ('_sch-(roc|pr)[\-_]auc-p\d-f0.5(-c\d)?', ''),
    ('_best-(roc|pr)[\-_]auc', ''),
    ('_norm[SD]', ''),
    ('_labels13', ''),
]

In [None]:
cols = [
    'run_name',
    'time_per_epoch', 'total_time',
    'current_epoch', 'initial_epoch', 'final_epoch',
    'batch_size', 'visible',
]
res = filter_training_stats(
    contains=r'cxr14|chexpert',
    columns=cols,
    rename_runs=replace_strs,
)
# res = res.replace(r'^\d{4}_\d{6}_(.*)', r'\1', regex=True)
res = res.set_index('run_name').rename(index=rename_runs)
res.sort_index()

## Report-generation: results at different report lengths

In [None]:
vals_words = [20, 25, 27, 33, 44, None]
vals_sents = [3, 4, 5, 6, None]

In [None]:
max_words = vals_words[0]
suffix = f'max-words-{max_words}' if max_words else ''
all_results = load_results(suffix)
results_df_test = create_results_df(all_results, 'test')
results_df_test