In [None]:
from IPython.core.display import display, HTML
display(HTML('<style>.container { width:100% !important; }</style><link href="https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;700&display=swap" rel="stylesheet">'))

In [None]:
import altair as alt
import pandas as pd
import numpy as np
from altair_saver import save
import ipywidgets as widgets
alt.data_transformers.disable_max_rows()
alt.renderers.enable('default')

In [None]:
from functools import partial

def pd_reduce(dataframe, output_column, fn):
    result = dataframe.copy()
    result[output_column] = np.nan
    for i, index in enumerate(result.index):
        prev = None if i == 0 else result.iloc[i - 1]
        curr = result.loc[index]
        result.loc[index, output_column] = fn(prev, curr, output_column)
    return result

def auc_segment(prev, curr, output_column, epsilon=0):
    last_samp = 0 if prev is None else prev.samples
    last_tot = 0 if prev is None else prev[output_column]
    return (curr.samples - last_samp) * max(curr.test_loss - epsilon, 0)

def sum_reduction(prev, curr, output_column, input_column='auc_segment'):
    last_tot = 0 if prev is None else prev[output_column]
    return last_tot + curr[input_column]

def auc(dataframe, carry_column='auc_segment', output_column='auc_agg', epsilon=0):
    result = pd_reduce(dataframe, carry_column, partial(auc_segment, epsilon=epsilon))
    result = pd_reduce(result, output_column, partial(sum_reduction, input_column=carry_column))
    return result

def auc_per_data(df, epsilons):
    for epsilon in epsilons:
        colname = f'auc_agg@{epsilon}'.replace('.', '_')
        results = []

        for label in df.label.unique():
            subset = df[(df.label == label)]
            results.append(auc(subset, output_column=f'auc_agg@{epsilon}'.replace('.', '_'), epsilon=epsilon))

        df = pd.concat(results)
        df[f'str_{colname}'] = df[colname].round(2).astype(str)
        if epsilon > 0:
            df.loc[df['test_loss'] > epsilon, f'str_{colname}'] = "> " + df.loc[df['test_loss'] > epsilon, f'str_{colname}']
    return df

def sc_segment(prev, curr, output_column, epsilon=0):
    if prev is not None:
        prev_sc = prev[output_column]
    else:
        prev_sc = 1e20

    if curr.test_loss <= epsilon:
        curr_sc = curr.samples
    else:
        curr_sc = 1e20
    return min(prev_sc, curr_sc)

def sc(dataframe, carry_column='sc_segment', output_column='sc', epsilon=0):
    result = pd_reduce(dataframe, output_column, partial(sc_segment, epsilon=epsilon))
    return result

def sc_per_data(df, epsilons):
    for epsilon in epsilons:
        colname = f'sc@{epsilon}'.replace('.', '_')
        results = []
        for label in df.label.unique():
            subset = df[(df.label == label)]
            results.append(sc(subset, output_column=colname, epsilon=epsilon))


        df = pd.concat(results)
        df[f'str_{colname}'] = df[colname].astype(int).astype(str)
        df.loc[df[colname] > 1e10, f'str_{colname}'] = "> " + df.loc[df[colname] > 1e10, 'samples'].astype(str)
    # note that this overwrites `df` many times! not having an outer concat is by design
    return df

In [None]:
def loss_data_chart(df, title='', xdomain=[8,60000], ydomain=[0.008,2], xrules=[], yrules=[], 
                    color_title='Representation', final=False):
    if final:
        line_width = 5
        label_size = 24
        title_size = 30
    else:
        line_width = 5
        label_size = 14
        title_size = 20
        
    rules_df = pd.concat([
        pd.DataFrame({'x': xrules}),
        pd.DataFrame({'y': yrules})
    ], sort=False)

    colorscheme = 'set1'
    stroke_color = '333'
    line = alt.Chart(df[df.samples >= 10], title=title).mark_line(size=line_width, opacity=0.4).encode(
        x=alt.X('samples', scale=alt.Scale(type='log', domain=xdomain, nice=False),  title='Dataset size'),
        y=alt.Y('mean(test_loss)', scale=alt.Scale(type='log', domain=ydomain, nice=False), title='Test loss'),
        color=alt.Color('label:N', title=color_title, scale=alt.Scale(scheme=colorscheme,), legend=None),
    )

    point = alt.Chart(df[df.samples >= 10], title=title).mark_point(size=80, opacity=1).encode(
        x=alt.X('samples', scale=alt.Scale(type='log', domain=xdomain, nice=False),  title='Dataset size'),
        y=alt.Y('mean(test_loss)', scale=alt.Scale(type='log', domain=ydomain, nice=False), title='Test loss'),
        color=alt.Color('label:N', title=color_title, scale=alt.Scale(scheme=colorscheme,)),
        shape=alt.Shape('label:N', title=color_title), 
        tooltip=['samples', 'label']
    )
    
    rule_x = alt.Chart(rules_df).mark_rule(size=3, color='999', strokeDash=[4, 4]).encode(x='x')
    rule_y = alt.Chart(rules_df).mark_rule(size=3, color='999', strokeDash=[4, 4]).encode(y='y')

    chart = alt.layer(rule_x, rule_y, line, point).resolve_scale(
        color='independent',
        shape='independent'
    )
    chart = chart.properties(width=600, height=500, background='white')
    chart = chart.configure_legend(labelLimit=0)
    chart = chart.configure(
        title=alt.TitleConfig(fontSize=title_size, fontWeight='normal'),
        axis=alt.AxisConfig(titleFontSize=title_size, labelFontSize=label_size, grid=(not final), 
                            domainWidth=5, domainColor=stroke_color, 
                            tickWidth=3, tickSize=9, tickCount=4, tickColor=stroke_color, tickOffset=0),
#         axisX=alt.AxisConfig(grid=True),
        legend=alt.LegendConfig(titleFontSize=title_size, labelFontSize=label_size, labelLimit=0, titleLimit=0,
                                orient='top-right', padding=10, 
                                titlePadding=10, rowPadding=5,
                                fillColor='white', strokeColor='black', cornerRadius=0),
        view=alt.ViewConfig(strokeWidth=0, stroke=stroke_color),
        font='Roboto',
    )    
    return chart

In [None]:
def make_latex(df, ns, stack=False, group_n=True, epsilons=[0.5, 0.1]):
    df = auc_per_data(df, epsilons + [0]).reset_index(drop=True)
    df = sc_per_data(df, epsilons)
    df.reset_index(drop=True, inplace=True)
    auc_cols = {f'str_auc_agg@{eps}'.replace('.', '_'):  f'SDL, $\\varepsilon$={eps}' for eps in epsilons}
    auc_cols['str_auc_agg@0'] = 'MDL'
    sc_cols = {f'str_sc@{eps}'.replace('.', '_'):  f'$\\varepsilon$SC, $\\varepsilon$={eps}' for eps in epsilons}
    output_df = df[df.samples.isin(ns)].groupby(['label', 'data', 'samples', *auc_cols.keys(), *sc_cols.keys()]).mean().reset_index()
    output_df = output_df[['samples', 'label', 'test_loss', *auc_cols.keys(), *sc_cols.keys()]]
    output_df = output_df.sort_values('samples')

    output_df = output_df.rename(columns={'label': 'Representation', 'samples': 'n', 'test_loss': 'Val loss', **auc_cols, **sc_cols})
    auc_cols.pop('str_auc_agg@0')
    output_df = output_df.reindex(['Representation', 'n', 'Val loss', 'MDL', *auc_cols.values(), *sc_cols.values()], axis=1)
    output_df['n'] = output_df['n'].astype(int)
    if stack:
        if not group_n:
            output_df['n'] = '$n=' + output_df['n'].astype(str) + '$'
        output_df = output_df.set_index(['Representation', 'n'])
        if group_n:
            output_df = output_df.transpose()
            output_df.reindex(['Val loss', 'MDL', *auc_cols.values(), *sc_cols.values()])
            display(output_df)
            output_df = output_df.stack()
            output_df = output_df.swaplevel().sort_values('n', ascending=True)
    else:
        output_df = output_df.set_index(['n', 'label'])
        output_df = output_df.transpose()
    out = widgets.Output(layout={'border': '1px solid black'})
    latex_str = output_df.to_latex(multicolumn_format='c', float_format="{:0.2f}".format, escape=False, column_format='llrrr')
    out.append_stdout(latex_str)
    return out

In [None]:
df = pd.concat([
    *[pd.read_pickle(f'results/realprobe_paper_mnist-repr_raw_dim784_level3-seed{seed}.pkl') 
      for seed in range(8)],
    *[pd.read_pickle(f'results/realprobe_paper_mnist-repr_cifar_supervised_dim784_level3-seed{seed}.pkl') 
      for seed in range(8)],
    *[pd.read_pickle(f'results/realprobe_paper_mnist-repr_mnist_vae_dim8_level3-seed{seed}.pkl') 
      for seed in range(8)],
], sort=False).reset_index(drop=True)

if 'name' not in df: df['name'] = ''
df['name'].fillna('',  inplace=True)
df['name'] = df.name.str.replace('_?seeds?[0-9]*', '')
df['test_error'] = 1 - df.test_accuracy
df['zero'] = 0
df['label'] = df.data + ' ' + df.repr + '-' + df.repr_dim.astype(str) + ' ' + df.name
df.loc[df.repr == 'cifar_supervised', 'label'] = "CIFAR"
df.loc[df.repr == 'raw', 'label'] = "Pixels"
df.loc[df.repr == 'mnist_vae', 'label'] = "VAE"

ns = [60, 20398]
epsilons = [ 0.1, 0.02]
chart = loss_data_chart(df, title="", xrules=ns, yrules=epsilons, final=True, ydomain=[0.005, 1])
display(chart)
save(chart, 'mnist_reprs.pdf')
df = df.groupby(['label', 'samples', 'data']).mean().reset_index()

make_latex(df, ns=ns, epsilons=epsilons, stack=True)

In [None]:
df = pd.concat([
    *[pd.read_pickle(f'results/realprobe_mnist-repr_raw_dim784-ntrain_50000-seed{seed}.pkl') 
      for seed in [0, 2, 4, 6]],
    *[pd.read_pickle(f'results/realprobe_mnist_noisygt-repr_raw_dim784_level3-seed{seed}.pkl') 
      for seed in [0, 2, 4, 6]],
], sort=False).reset_index(drop=True)

if 'name' not in df: df['name'] = ''
df['name'].fillna('',  inplace=True)
df['name'] = df.name.str.replace('_?seeds?[0-9]*', '')
df['test_error'] = 1 - df.test_accuracy
df['zero'] = 0
df['label'] = df.data + ' ' + df.repr + '-' + df.repr_dim.astype(str) + ' ' + df.name
df.loc[df.data == 'mnist', 'label'] = "Raw pixels"
df.loc[df.data == 'mnist_noisygt', 'label'] = "Noisy labels"

final = True
ydomain = [0.01, 1]
title=''
xdomain=[8,60000]
xrules=[]
yrules=[]
color_title='Representation'

if final:
    line_width = 8
    label_size = 30
    title_size = 54
else:
    line_width = 5
    label_size = 14
    title_size = 20

rules_df = pd.concat([
    pd.DataFrame({'x': xrules}),
    pd.DataFrame({'y': yrules})
], sort=False)

colorscheme = 'set1'
stroke_color = '333'
line = alt.Chart(df[df.samples >= 10], title=title).mark_line(size=line_width, opacity=0.5).encode(
    x=alt.X('samples', scale=alt.Scale(type='log', domain=xdomain, nice=False),  title='Dataset size'),
    y=alt.Y('mean(test_loss)', scale=alt.Scale(type='log', domain=ydomain, nice=False), title=''),
    color=alt.Color('label:N', title=color_title, scale=alt.Scale(scheme=colorscheme,), legend=None),
)

point = alt.Chart(df[df.samples >= 10], title=title).mark_point(size=120, opacity=1).encode(
    x=alt.X('samples', scale=alt.Scale(type='log', domain=xdomain, nice=False),  title='Dataset size'),
    y=alt.Y('mean(test_loss)', scale=alt.Scale(type='log', domain=ydomain, nice=False), title=''),
    color=alt.Color('label:N', title=color_title, scale=alt.Scale(scheme=colorscheme,)),
    shape=alt.Shape('label:N', title=color_title), 
)

rule_x = alt.Chart(rules_df).mark_rule(size=3, color='999', strokeDash=[4, 4]).encode(x='x')
rule_y = alt.Chart(rules_df).mark_rule(size=3, color='999', strokeDash=[4, 4]).encode(y='y')

chart = alt.layer(rule_x, rule_y, line, point).resolve_scale(
    color='independent',
    shape='independent'
)
chart = chart.properties(width=600, height=500, background='white')
chart = chart.configure_legend(labelLimit=0)
chart = chart.configure(
    title=alt.TitleConfig(fontSize=title_size, fontWeight='normal'),
    axis=alt.AxisConfig(titleFontSize=title_size, labelFontSize=label_size, grid=(not final), 
                        domainWidth=5, domainColor=stroke_color, 
                        tickWidth=3, tickSize=9, tickCount=4, tickColor=stroke_color, tickOffset=0),
    axisX=alt.AxisConfig(titlePadding=50),
    legend=alt.LegendConfig(titleFontSize=36, labelFontSize=label_size, labelLimit=0, titleLimit=0,
                            orient='top-right', padding=10, 
                            titlePadding=10, rowPadding=5,
                            fillColor='white', strokeColor='black', cornerRadius=0),
    view=alt.ViewConfig(strokeWidth=0, stroke=stroke_color),
    font='Roboto',
)    

save(chart, "noisygt_bold.pdf")
chart