In [10]:
import os
import pickle

In [11]:
def load_dict(path: str, filename: str) -> dict:
    """Load dictionary."""
    with open(f'{os.path.join(path, filename)}', 'rb') as f:
        data = pickle.load(f)
    return data

In [12]:
RUNPATH = os.path.join(os.path.dirname(os.getcwd()), 'runs')

In [69]:
import pandas as pd

df = pd.DataFrame(columns=['dataset', 'model', 'k', 'train_acc', 'test_acc'])

for i, file in enumerate(os.listdir(RUNPATH)):
    data = load_dict(RUNPATH, file)
    k = data.get('meta').k
    mydf = pd.DataFrame({'dataset': [data.get('meta').dataset] * k,
                'model': [data.get('meta').model] * k,
                'k': [k] * k,
                'train_size': [1 / k] * k,
                'test_size': [1 - (1 / k)] * k,
                'train_acc': data.get('results').get('train acc'),
                'test_acc': data.get('results').get('test acc')})
    df = pd.concat([df, mydf.sort_values('test_size')], ignore_index=True)

In [70]:
grouped = df.groupby(['dataset', 'model', 'k', 'train_size', 'test_size'])[['train_acc', 'test_acc']].agg({'train_acc': ['mean', 'std'], 'test_acc': ['mean', 'std']}).reset_index()
grouped.columns = ['dataset', 'model', 'k', 'train_size', 'test_size', 'train_acc_avg', 'train_acc_std', 'test_acc_avg', 'test_acc_std', ]
grouped

Unnamed: 0,dataset,model,k,train_size,test_size,train_acc_avg,train_acc_std,test_acc_avg,test_acc_std
0,cora,Diffusion,3,0.333333,0.666667,1.000000,0.000000,0.807238,0.002528
1,cora,Diffusion,4,0.250000,0.750000,1.000000,0.000000,0.836965,0.010810
2,cora,Diffusion,5,0.200000,0.800000,1.000000,0.000000,0.840107,0.009280
3,cora,Diffusion,6,0.166667,0.833333,1.000000,0.000000,0.846197,0.004844
4,cora,Diffusion,7,0.142857,0.857143,1.000000,0.000000,0.853762,0.014027
...,...,...,...,...,...,...,...,...,...
59,cora,PageRank,14,0.071429,0.928571,0.944455,0.002600,0.851741,0.019324
60,cora,PageRank,15,0.066667,0.933333,0.944041,0.002581,0.852290,0.021916
61,cora,PageRank,16,0.062500,0.937500,0.943686,0.002369,0.853027,0.021498
62,cora,PageRank,17,0.058824,0.941176,0.943083,0.002418,0.853575,0.019967


In [72]:
import plotly.express as px

for error_y_mode in {'band'}:
    fig = line(
        data_frame = grouped,
        x = 'train_size',
        y = 'train_acc_avg',
        error_y = 'train_acc_std',
        error_y_mode = error_y_mode,
        color = 'model',
        title = f'Training accuracy',
        markers = '.',
    )
    fig.show()

In [73]:
for error_y_mode in {'band'}:
    fig = line(
        data_frame = grouped,
        x = 'test_size',
        y = 'test_acc_avg',
        error_y = 'test_acc_std',
        error_y_mode = error_y_mode,
        color = 'model',
        title = f'Test accuracy',
        markers = '.',
    )
    fig.show()

## Annexes
---

In [21]:
import plotly.express as px
import plotly.graph_objs as go

def line(error_y_mode=None, **kwargs):
    """Extension of `plotly.express.line` to use error bands."""
    ERROR_MODES = {'bar','band','bars','bands',None}
    if error_y_mode not in ERROR_MODES:
        raise ValueError(f"'error_y_mode' must be one of {ERROR_MODES}, received {repr(error_y_mode)}.")
    if error_y_mode in {'bar','bars',None}:
        fig = px.line(**kwargs)
    elif error_y_mode in {'band','bands'}:
        if 'error_y' not in kwargs:
            raise ValueError(f"If you provide argument 'error_y_mode' you must also provide 'error_y'.")
        figure_with_error_bars = px.line(**kwargs)
        fig = px.line(**{arg: val for arg,val in kwargs.items() if arg != 'error_y'})
        for data in figure_with_error_bars.data:
            x = list(data['x'])
            y_upper = list(data['y'] + data['error_y']['array'])
            y_lower = list(data['y'] - data['error_y']['array'] if data['error_y']['arrayminus'] is None else data['y'] - data['error_y']['arrayminus'])
            color = f"rgba({tuple(int(data['line']['color'].lstrip('#')[i:i+2], 16) for i in (0, 2, 4))},.3)".replace('((','(').replace('),',',').replace(' ','')
            fig.add_trace(
                go.Scatter(
                    x = x+x[::-1],
                    y = y_upper+y_lower[::-1],
                    fill = 'toself',
                    fillcolor = color,
                    line = dict(
                        color = 'rgba(255,255,255,0)'
                    ),
                    hoverinfo = "skip",
                    showlegend = False,
                    legendgroup = data['legendgroup'],
                    xaxis = data['xaxis'],
                    yaxis = data['yaxis'],
                )
            )
        # Reorder data as said here: https://stackoverflow.com/a/66854398/8849755
        reordered_data = []
        for i in range(int(len(fig.data)/2)):
            reordered_data.append(fig.data[i+int(len(fig.data)/2)])
            reordered_data.append(fig.data[i])
        fig.data = tuple(reordered_data)
    return fig