# Complete Meta-Learning Distribution

This notebook generates the complete meta-learning distribution for MetaChest.

In [None]:
import sys
from os.path import join

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

sys.path.append('../')
from common import ALIASES, _filter_mset, plot_venn, plot_coocc, read_toml

## Generate

In [None]:
config = read_toml('../config.toml')
mclasses = {'mtrn': config['mtrn'], 'mval': config['mval'], 'mtst': config['mtst']}
mclasses

In [None]:
df = pd.read_csv(join(config['metachest_dir'], 'metachest.csv'))
df.head()

## Exploration

In [None]:
print(f'Total {df.shape[0]}')
print(df['dataset'].value_counts())

Compute total dataframe:

In [None]:
# group by dataset
paths = list(df.columns[5:])
ds_sum_df = df[['dataset'] + paths].groupby('dataset').sum().astype(int)
# sorted pathologies and datasets sseries
ds_sum_sr = ds_sum_df.sum(axis=1).sort_values(ascending=False)
pt_sum_sr = ds_sum_df.sum(axis=0).sort_values(ascending=False)
# sort grouped df
ds_sum_df = ds_sum_df.reindex(list(ds_sum_sr.index))
ds_sum_df = ds_sum_df[list(pt_sum_sr.index)]

# total df
total_df = ds_sum_df.copy()
total_df.loc[:, 'total'] = total_df.sum(axis=1)
total_df.loc['total', :] = total_df.sum(axis=0)
total_df = total_df.astype(int)

In [None]:
ds_sum_df.T.plot.barh(stacked=True, figsize=(10, 5),
                      color=sns.color_palette('deep'))

In [None]:
total_df.T.iloc[::-1]

### MTL Setup

In [None]:
mtrn, mval, mtst = mclasses.values()
mset_vals = ['mtrn'] * len(mtrn) + ['mval'] * len(mval) + ['mtst'] * len(mtst)
mset_df = total_df[mtrn + mval + mtst].T
mset_df.insert(0, 'mset', mset_vals)
mset_df

In [None]:
mset_df.groupby('mset', sort=False).sum()

Distribution per meta-set

In [None]:
def plot_metasets(df, mclasses, figsize=(8, 5)):
    titles = ['Meta-Train (Seen) ', 'Meta-Val (Unseen) ', 'Meta-Test (Unseen) ']
    fig, axs = plt.subplots(
        nrows=len(mclasses), ncols=1, tight_layout=True,
        gridspec_kw={'height_ratios': [len(mset) for mset in mclasses]},
        figsize=figsize,
    )
    for mset, title, ax in zip(mclasses, titles, axs):
        ds_mset = df[mset]
        cols = {col: col.replace('_', ' ').capitalize() for col in ds_mset.columns}
        idxs = {'chexpert': 'CheXpert', 'mimic': 'MIMIC', 'chestxray14': 'ChestX-ray14', 'padchest': 'PadChest'}
        ds_mset = ds_mset.rename(columns=cols, index=idxs)
        ax = ds_mset.T.plot.barh(stacked=True, width=0.8, ax=ax,
                                 color=sns.color_palette('deep'))
        ax.set_xticks([],[])
        ax.set_title(title, fontsize=10, loc='right', y=1.0, pad=-14)
        ax.set_xlim(0, 130000)
        if 'Test' in title:
            ax.legend(loc='lower right', fontsize='small', labelspacing=0.25)
        else:
            ax.get_legend().remove()
        ax.tick_params(axis='y', which='major', labelsize='small')
    xticks = [x for x in range(10000, 130001, 10000)]
    ax.set_xticks(xticks, [f'{x//10000}k' for x in xticks], fontsize=7)

plot_metasets(ds_sum_df, mclasses.values())

Plot correlation matrices per meta-set.

In [None]:
for mset in ('mtrn', 'mval', 'mtst'):
    mset_df = _filter_mset(mset, mclasses, df)
    plot_venn(mset)
    plot_coocc('MetaChest', mset, mset_df.iloc[:, 5:])

## Unused data 

In [None]:
def compute_unused():
    from matplotlib_venn import venn3
    plt.figure(figsize=(2, 2))
    diagram = venn3((1, 1, 1, 1, 1, 1, 1),
                    set_labels=('mval\nclasses',
                                'mtst\nclasses',
                                'mtrn\nclasses'))
    for sid in ("100", "010", "110", "001", "101", "011", "111"):
        diagram.get_label_by_id(sid).set_text('')
        diagram.get_patch_by_id(sid).set_color('white')

    color = 'grey'
    diagram.get_patch_by_id('110').set_color(color)
    diagram.get_patch_by_id('111').set_color(color)

    for sid in ("100", "010",  "001", "111"):
        diagram.get_patch_by_id(sid).set_edgecolor('black')

    for i in range(3):
        diagram.set_labels[i].set_fontsize('small')

    plt.title('unused\ndata')
    plt.show()

    mval_classes = mclasses['mval']
    mtst_classes = mclasses['mtst']
    mval_mask = df[mval_classes].any(axis=1)
    mtst_mask = df[mtst_classes].any(axis=1)
    unused_df = df[mval_mask & mtst_mask]

    unused_count = unused_df[['dataset']].groupby(by='dataset')['dataset'].count()

    print(f'Unused {unused_count.sum()}')
    return unused_count

compute_unused()
