In [None]:
from google.colab import drive
drive.mount("/content/drive")
%cd "/content/drive/MyDrive/Courses/Fall 2021/dlsys/bnn-cf-vs-robust"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Courses/Fall 2021/dlsys/bnn-cf-vs-robust


In [None]:
!ls data/test-kaggle/2021-11-04/
!ls data/test-kaggle/2021-11-04/11-11-02_gpu0


10-00-56_gpu0  12-24-57_gpu0  14-51-26_gpu0
11-11-02_gpu0  13-39-42_gpu0  16-01-01_gpu0
11-25-32_weight_distribution.png
11-40-27_weight_distribution.png
11-55-08_weight_distribution.png
12-10-04_weight_distribution.png
12-24-54_bnn_1024-1024_pMNIST1-pMNIST2-pMNIST3-pMNIST4-pMNIST5-.csv
12-24-54_weight_distribution.png
hyperparameters.txt


In [None]:
data_root = 'data/test-kaggle/2021-11-04'

In [None]:
import os 
import glob 
import pandas as pd 
import numpy as np 
import yaml 

In [None]:
data_dirs = ['10-00-56_gpu0', '11-11-02_gpu0', '12-24-57_gpu0', '13-39-42_gpu0', '14-51-26_gpu0']
# data_dirs = ['10-00-56_gpu0', '11-11-02_gpu0', '14-51-26_gpu0','16-01-01_gpu0']

In [None]:
def load_df(data_root, data_dir):
    data_path = os.path.join(data_root, data_dir)
    csv_data_path = glob.glob(data_path + '/*csv')[0]
    hyp_data_path = os.path.join(data_path, 'hyperparameters.txt')
    hyp_params = yaml.safe_load(open(hyp_data_path))
    hyp_params = {k:v for list_item in hyp_params for (k,v) in list_item.items()}

    df = pd.read_csv(csv_data_path)\
            .filter(regex='(acc_test_.*|task_order|epoch)', axis=1)\
            .assign(meta=hyp_params['meta'][0])\
            .melt(id_vars=['task_order','epoch','meta'], var_name='test_set', value_name='test_acc')

    max_epoch = max(df.epoch)
    df.epoch = (df.task_order - 1) * (max_epoch) + df.epoch 
    # df = df.sort_values(by=['epoch', 'test_set'], axis=0)
    return df

df = pd.concat([load_df(data_root, d) for d in data_dirs], ignore_index=True)


In [None]:
import plotly.express as px
import plotly.graph_objects as go

In [None]:
axis_config = dict(
    showline=True,
    showgrid=False,
    showticklabels=True,
    linecolor='rgb(0, 0, 0)',
    linewidth=2,    
    ticks='outside',
    tickwidth=2
    )

font_config = dict(
    family="Fira Sans",
    size=18,
    color='black'
    )

title_config = dict(
    title_x = 0.5,
    title_y = 0.9,
    title_xanchor = 'center',
    title_yanchor = 'top',
    title_font_size=23
)

general_layout = go.Layout(
    xaxis=axis_config,
    yaxis=axis_config,
    xaxis2=axis_config,
    yaxis2=axis_config,
    xaxis3=axis_config,
    yaxis3=axis_config,
    xaxis4=axis_config,
    yaxis4=axis_config,
    font=font_config,
    margin=dict(
        autoexpand=True,
        l=100,
        r=50,
        t=100,
        b=120
    ),
    showlegend=True,
    plot_bgcolor='white',
    autosize=True,
    **title_config
)


In [None]:
fig = px.line(
    df, x="epoch", y="test_acc", 
    color="test_set",
    facet_col="meta",
    color_discrete_sequence=px.colors.sequential.Plasma_r,
    facet_col_wrap=2)
fig.update_layout(
    general_layout,
    height=800
)
fig.show()