# Complete Meta-Learning Distribution

This notebook generates the complete meta-learning distribution for MetaChest.

In [None]:
from os import makedirs
from os.path import join

import pandas as pd
import seaborn as sns

from common import MCLASSES, filter_msets, read_toml, save_toml

## Generate

In [None]:
name = 'complete'
mclasses = MCLASSES
metachest_dir = read_toml('config.toml')['metachest_dir']
mtl_dir = join(metachest_dir, 'mtl')

df = pd.read_csv(join(metachest_dir, 'metachest.csv'))

filter_df = pd.DataFrame(
    [[1, 1, 1]] * df.shape[0],
    columns=mclasses.keys()
)
filter_msets(df, filter_df, mclasses)

makedirs(join(mtl_dir), exist_ok=True)
save_toml(join(mtl_dir, f'{name}.toml'), mclasses)
filter_df.to_csv(join(mtl_dir, f'{name}.csv'), index=False)

## Explore

In [None]:
print(f'Total {df.shape[0]}')
print(df['dataset'].value_counts())

Compute total dataframe:

In [None]:
# group by dataset
paths = list(df.columns[5:])
ds_sum_df = df[['dataset'] + paths].groupby('dataset').sum().astype(int)
# sorted pathologies and datasets sseries
ds_sum_sr = ds_sum_df.sum(axis=1).sort_values(ascending=False)
pt_sum_sr = ds_sum_df.sum(axis=0).sort_values(ascending=False)
# sort grouped df
ds_sum_df = ds_sum_df.reindex(list(ds_sum_sr.index))
ds_sum_df = ds_sum_df[list(pt_sum_sr.index)]

# total df
total_df = ds_sum_df.copy()
total_df.loc[:, 'total'] = total_df.sum(axis=1)
total_df.loc['total', :] = total_df.sum(axis=0)
total_df = total_df.astype(int)

In [None]:
ds_sum_df.T.plot.barh(stacked=True, figsize=(10, 5),
                      color=sns.color_palette('deep'))

In [None]:
total_df.T.iloc[::-1]

### Partition

In [None]:
mtrn, mval, mtst = MCLASSES.values()
mset = total_df[mtrn + mval + mtst].T
mset_vals = ['mtrn'] * len(mtrn) + ['mval'] * len(mval) + ['mtst'] * len(mtst)
mset.insert(0, 'mset', mset_vals)
mset

In [None]:
mset.groupby('mset', sort=False).sum()

In [None]:
ds_sum_df

In [None]:
ds_sum_df.T.plot.barh(stacked=True, figsize=(10, 5),
                      color=sns.color_palette('deep'))

In [None]:
ds_sum_df.sum(axis=0)

In [None]:
import matplotlib.pyplot as plt

def plot_metasets(ds, mclasses, figsize=(8, 5)):
    titles = ['Meta-Train (Seen) ', 'Meta-Val (Unseen) ', 'Meta-Test (Unseen) ']
    fig, axs = plt.subplots(
        nrows=len(mclasses), ncols=1, tight_layout=True,
        gridspec_kw={'height_ratios': [len(mset) for mset in mclasses]},
        figsize=figsize,
    )
    for mset, title, ax in zip(mclasses, titles, axs):
        ds_mset = ds[mset]
        cols = {col: col.replace('_', ' ').capitalize() for col in ds_mset.columns}
        idxs = {'chexpert': 'CheXpert', 'mimic': 'MIMIC', 'chestxray14': 'ChestX-ray14', 'padchest': 'PadChest'}
        ds_mset = ds_mset.rename(columns=cols, index=idxs)
        ax = ds_mset.T.plot.barh(stacked=True, width=0.8, ax=ax,
                                 color=sns.color_palette('deep'))
        ax.set_xticks([],[])
        ax.set_title(title, fontsize=10, loc='right', y=1.0, pad=-14)
        if 'Test' in title:
            ax.legend(loc='lower right', fontsize='small', labelspacing=0.25)
        else:
            ax.get_legend().remove()
        ax.tick_params(axis='y', which='major', labelsize='small')
    xticks = [x for x in range(10000, 130001, 10000)]
    ax.set_xticks(xticks, [f'{x//10000}k' for x in xticks], fontsize=7)

plot_metasets(ds_sum_df, MCLASSES.values())