# Experiment 1:

The source notebooks were run on [Kaggle][source]. The inputs were generated for both original MNIST and permuted MNIST following [codes][bnn-meta-paper-repo] of [meta plasticity BNN paper][bnn-meta-paper], then these were saved as well as their [natural corrupted][natcrpt-paper] verions (only 1 level of severity was used for testing for now) adapting the [robustness codes][natcrpt-paper-code], specifically adapting [`corruption.py`][natcrpt-paper-code] to use for monochrome images.

- for variations of `meta` parameter: [Version 2][V2] of [source]:
  - variation of `meta` parameter: `[0.0, 0.7, 1.35]`
  - these following parameters were fixed:

    ``` bash
    --hidden-layers 2048 2048 --lr 0.005 --decay 1e-7 --epochs-per-task 25
    ```

  - the inputs were generated then downloaded in `data/input/pmnist_robustness`
  - the outputs were downdloaded into `data/output/exp1-pmnist_robustness`
  - note: the logs for [V2] in the corruption part were not correct due to a typo but the data saved were right. This was fixed for future uses.


[bnn-meta-paper]: https://www.nature.com/articles/s41467-021-22768-y
[bnn-meta-paper-repo]: https://github.com/Laborieux-Axel/SynapticMetaplasticityBNN
[natcrpt-paper]: https://arxiv.org/abs/1903.12261
[natcrpt-paper-repo]: https://github.com/hendrycks/robustness
[natcrpt-paper-code]: https://github.com/hendrycks/robustness/blob/master/ImageNet-C/imagenet_c/imagenet_c/corruptions.py
[source]: https://www.kaggle.com/penguinsfly/bnn-cf-vs-robust/
[V2]: https://www.kaggle.com/penguinsfly/bnn-cf-vs-robust/data?scriptVersionId=79007488


In [1]:
%%capture
!pip install --upgrade plotly
!pip install -U kaleido

In [2]:
from google.colab import drive
drive.mount("/content/drive")
%cd "/content/drive/MyDrive/Courses/Fall 2021/dlsys/bnn-cf-vs-robust"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Courses/Fall 2021/dlsys/bnn-cf-vs-robust


In [3]:
import os, glob 
import pandas as pd 
import numpy as np 
import yaml 
from pathlib import Path

import plotly.express as px
import plotly.graph_objects as go

In [4]:
data_root = Path('data/output/exp1-pmnist-robustness')
fig_root = Path('figures/exp1-pmnist_robustness')
fig_root.mkdir(exist_ok=True)

In [5]:
axis_config = dict(
    showline=True,
    showgrid=False,
    showticklabels=True,
    linecolor='rgb(0, 0, 0)',
    linewidth=2,    
    ticks='outside',
    tickwidth=2
    )

font_config = dict(
    family="Fira Sans",
    size=15,
    color='black'
    )

title_config = dict(
    title_x = 0.5,
    title_y = 0.95,
    title_xanchor = 'center',
    title_yanchor = 'top',
    title_font_size=20
)

general_layout = go.Layout(
    xaxis=axis_config,
    yaxis=axis_config,
    font=font_config,
    margin=dict(autoexpand=True,l=100,r=50,t=100,b=120),
    showlegend=True,
    plot_bgcolor='white',
    autosize=True,
    **title_config
)


heatmap_layout = go.Layout(
    font=font_config,
    margin=dict(autoexpand=True,l=100,r=50,t=100,b=120),
    showlegend=True,
    plot_bgcolor='white',
    autosize=True,
    **title_config
)

for i in range(1,30): # for subplots
    general_layout['xaxis' + str(i)] = axis_config
    general_layout['yaxis' + str(i)] = axis_config

In [6]:
def load_df(data_dir, conf_fns = None):
    exp_config = yaml.safe_load(open(data_dir / 'exp-config.yaml', 'r'))
    prm2save = {k: fn(exp_config) for k, fn in conf_fns.items()} if conf_fns else dict()

    df_forget = pd.read_csv(data_dir / 'perf_forget.csv').assign(**prm2save)
    df_robust = pd.read_csv(data_dir / 'perf_robust.csv').assign(**prm2save) 

    return exp_config, df_forget, df_robust

def load_dfs(data_root, conf_fns = None, concat_dfs = True): 
    data_dirs = os.listdir(data_root)
    exp_config = []
    df_forget = [] 
    df_robust = [] 
    for d in data_dirs: 
        conf, forget, robust = load_df(data_root / d, conf_fns)

        exp_config.append(conf)
        df_forget.append(forget)
        df_robust.append(robust)

    if concat_dfs: 
        df_forget = pd.concat(df_forget, ignore_index=True)
        df_robust = pd.concat(df_robust, ignore_index=True)

    return exp_config, df_forget, df_robust


In [7]:
conf_fns = dict(
    meta = lambda conf: conf['meta'][0]
)
exp_config, df_forget, df_robust = load_dfs(data_root, conf_fns)

# Plot the training progress and final accuracy

In [8]:
df_forget = df_forget.filter(regex='(meta|glob_epoch|test_acc.*)', axis=1)\
    .sort_values(by=['meta','glob_epoch'], ignore_index=True)\
    .melt(id_vars=['meta','glob_epoch'], var_name='test_set', value_name='test_acc')
df_forget['test_set'] = df_forget['test_set'].apply(lambda x: x.replace('test_acc::', ''))

In [9]:
fig = px.line(
    df_forget, x="glob_epoch", y="test_acc", 
    color="test_set",
    facet_col="meta",
    color_discrete_sequence=px.colors.sequential.Plasma_r,
    facet_col_wrap=2, 
    facet_col_spacing=0.05,
    facet_row_spacing=0.1
    )

fig.update_layout(
    general_layout,
    width=1800,
    height=800,
    title_text='Variation of meta  (2048 x 2048)'
)

fig.write_image(fig_root / 'vary-meta-progress.svg')

fig.show()

In [10]:
max_globepoch = max(df_forget.glob_epoch)
final_df_forget = df_forget.query('glob_epoch == @max_globepoch')\
    .filter(['meta','test_set', 'test_acc'])\
    .reset_index(drop=True)\
    .rename(columns={'test_acc': 'final test acc'})
final_df_forget['test_set'] = final_df_forget['test_set'].apply(lambda x: int(x.replace('task-', '')))


In [11]:
fig = px.line(
    final_df_forget, x="test_set", y="final test acc", color="meta",
    color_discrete_sequence=px.colors.sequential.Burg,
    markers=True
)

fig.update_traces(line=dict(width=3), marker_size=10)

fig.update_layout(
    general_layout,
    height=500, width=700,
    title_y = 0.9,
    font_size = 16,
    title_text = 'Variations of meta (2048 x 2048)'
)

fig.write_image(fig_root / 'vary-meta-final.svg')

fig.show()

# Plot the robustness to natural corruption tests 

In [12]:
df_robust['data_key'] = df_robust['data_key'].apply(
    lambda x: 'original' if 'original' in x else x.replace('corruptions::', '') 
)

final_train = df_robust.train_phase.unique()[-1]


In [13]:
final_df_robust = df_robust.query('train_phase == @final_train').reset_index(drop=True)
final_df_robust['source_task'] = final_df_robust['source_task'].apply(lambda x: int(x.replace('task-', '')))
final_df_robust['meta'] = final_df_robust['meta'].astype(float)
final_df_robust = final_df_robust.sort_values(by=['meta', 'source_task']).reset_index(drop=True)

In [46]:
fig = px.line(
    final_df_robust, 
    x="source_task", 
    y="test_acc", 
    color="meta",
    facet_col="data_key",
    color_discrete_sequence=px.colors.sequential.Burg,
    facet_col_wrap=7, 
    facet_col_spacing=0.03,
    facet_row_spacing=0.12,
    markers=True,
    labels = dict(
        data_key='corruption', 
        test_acc='final test acc',
        source_task = 'task data source'
    )
)

fig.update_traces(line=dict(width=3), marker=dict(size=15))

fig.update_layout(
    general_layout,
    width=1700,
    title_y = 0.97,
    height=800,
    font_size = 22, title_font_size = 27,
    title_text='Robustness due to naturalistic corruption with variatons of meta for BNN (2048 x 2048), tested at the end of metaplasticity training',
    legend=dict(
        yanchor="bottom",
        y=0.01,
        xanchor="right",
        x=0.99
    )
)

for anno in fig['layout']['annotations']:
    anno['text']=anno['text'].replace('corruption=', '')
    
fig.write_image(fig_root / 'vary-meta-final-robust.svg')

fig.show()


In [29]:
num_lists = 3
crpt_types = [x for x in list(df_robust.data_key.unique()) if x != 'original']
crpt_lists = [['original'] + list(x) for x in np.array_split(crpt_types, num_lists)]

for i, sel_dkeys in enumerate(crpt_lists):
    fig_name = 'vary-meta-robust-alltrainphases-corruptset%2d.svg' %(i+1)

    sel_def = df_robust.query('data_key in @sel_dkeys')\
                    .sort_values(by=['meta', 'train_phase', 'source_task'])\
                    .reset_index(drop=True)

    fig = px.density_heatmap(
        sel_def, 
        x="source_task", y="train_phase", z='test_acc',
        facet_col = 'data_key', facet_row = 'meta', 
        labels = dict(
            source_task = 'task data source', 
            test_acc = 'test acc', 
            train_phase = 'train phase (task)',
            data_key = 'corruption'
        )
    )

    fig.update_layout(
        heatmap_layout,
        height=750, width=1800, 
        title_y = 0.95,
        font_size = 16, title_font_size = 20, 
        title_text = 'Robustness due to nat. corruptions (2048 x 2048 BNN), showing all train phases'
    )

    for anno in fig['layout']['annotations']:
        anno['text']=anno['text'].replace('corruption=', '')

    fig.write_image(fig_root / fig_name)

    fig.show()