# Comparing GNNs and baseline performances
---

**Libraries**

In [79]:
import os
import pandas as pd
import pickle
import plotly.express as px

In [80]:
def load_dict(path: str, filename: str) -> dict:
    """Load dictionary."""
    with open(f'{os.path.join(path, filename)}', 'rb') as f:
        data = pickle.load(f)
    return data

def data2df(path: str):
    """Load data and save it to DataFrame."""
    # Initialize DataFrame
    df = pd.DataFrame(columns=['dataset', 'model', 'k', 'train_acc', 'test_acc'])

    # Load data
    for i, file in enumerate(os.listdir(path)):
        data = load_dict(path, file)
        k = data.get('meta').k
        mydf = pd.DataFrame({'dataset': [data.get('meta').dataset] * k,
                    'model': [data.get('meta').model] * k,
                    'k': [k] * k,
                    'train_size': [1 - (2 * (1 / k))] * k,
                    'test_size': [2 * (1 / k)] * k, # test + val
                    'train_acc': data.get('results').get('train acc'),
                    'test_acc': data.get('results').get('test acc')})
        df = pd.concat([df, mydf.sort_values('test_size')], ignore_index=True)
    
    return df

## Load data

In [88]:
RUNPATH = os.path.join(os.path.dirname(os.getcwd()), 'runs')

df = data2df(RUNPATH)

In [92]:
df[df['model'] == 'GCN']

Unnamed: 0,dataset,model,k,train_acc,test_acc,train_size,test_size
507,cora,GCN,3,0.946475,0.891433,0.333333,0.666667
508,cora,GCN,3,0.944937,0.888971,0.333333,0.666667
509,cora,GCN,3,0.95718,0.878262,0.333333,0.666667


In [82]:
# Compute average and standard deviation according to values of k
grouped = df.groupby(['dataset', 'model', 'k', 'train_size', 'test_size'])[['train_acc', 'test_acc']].agg({'train_acc': ['mean', 'std'], 'test_acc': ['mean', 'std']}).reset_index()
grouped.columns = ['dataset', 'model', 'k', 'train_size', 'test_size', 'train_acc_avg', 'train_acc_std', 'test_acc_avg', 'test_acc_std', ]
grouped

Unnamed: 0,dataset,model,k,train_size,test_size,train_acc_avg,train_acc_std,test_acc_avg,test_acc_std
0,cora,Diffusion,3,0.333333,0.666667,1.000000,0.000000,0.807238,0.002528
1,cora,Diffusion,4,0.500000,0.500000,1.000000,0.000000,0.836965,0.010810
2,cora,Diffusion,5,0.600000,0.400000,1.000000,0.000000,0.840107,0.009280
3,cora,Diffusion,6,0.666667,0.333333,1.000000,0.000000,0.846197,0.004844
4,cora,Diffusion,7,0.714286,0.285714,1.000000,0.000000,0.853762,0.014027
...,...,...,...,...,...,...,...,...,...
59,cora,PageRank,14,0.857143,0.142857,0.944455,0.002600,0.851741,0.019324
60,cora,PageRank,15,0.866667,0.133333,0.944041,0.002581,0.852290,0.021916
61,cora,PageRank,16,0.875000,0.125000,0.943686,0.002369,0.853027,0.021498
62,cora,PageRank,17,0.882353,0.117647,0.943083,0.002418,0.853575,0.019967


In [86]:
grouped[grouped['model']=="PageRank"]

Unnamed: 0,dataset,model,k,train_size,test_size,train_acc_avg,train_acc_std,test_acc_avg,test_acc_std
48,cora,PageRank,3,0.333333,0.666667,0.981537,0.002778,0.800776,0.004103
49,cora,PageRank,4,0.5,0.5,0.971381,0.004447,0.824225,0.00597
50,cora,PageRank,5,0.6,0.4,0.962951,0.005236,0.828844,0.009299
51,cora,PageRank,6,0.666667,0.333333,0.958919,0.003112,0.83678,0.007337
52,cora,PageRank,7,0.714286,0.285714,0.954284,0.003724,0.843239,0.015752
53,cora,PageRank,8,0.75,0.25,0.950702,0.002936,0.844167,0.011755
54,cora,PageRank,9,0.777778,0.222222,0.94904,0.001975,0.846743,0.018936
55,cora,PageRank,10,0.8,0.2,0.947932,0.00282,0.845076,0.018892
56,cora,PageRank,11,0.818182,0.181818,0.946332,0.002971,0.84804,0.01817
57,cora,PageRank,12,0.833333,0.166667,0.945717,0.002906,0.850063,0.014403


## Plot results

In [87]:
for error_y_mode in {'band'}:
    fig = line(
        data_frame = grouped,
        x = 'test_size',
        y = 'train_acc_avg',
        error_y = 'train_acc_std',
        error_y_mode = error_y_mode,
        color = 'model',
        title = f'Training accuracy',
        markers = '.',
    )
    fig.show()

In [84]:
for error_y_mode in {'band'}:
    fig = line(
        data_frame = grouped,
        x = 'test_size',
        y = 'test_acc_avg',
        error_y = 'test_acc_std',
        error_y_mode = error_y_mode,
        color = 'model',
        title = f'Test accuracy',
        markers = '.',
    )
    fig.show()

## Annexes
---

In [21]:
import plotly.express as px
import plotly.graph_objs as go

def line(error_y_mode=None, **kwargs):
    """Extension of `plotly.express.line` to use error bands."""
    ERROR_MODES = {'bar','band','bars','bands',None}
    if error_y_mode not in ERROR_MODES:
        raise ValueError(f"'error_y_mode' must be one of {ERROR_MODES}, received {repr(error_y_mode)}.")
    if error_y_mode in {'bar','bars',None}:
        fig = px.line(**kwargs)
    elif error_y_mode in {'band','bands'}:
        if 'error_y' not in kwargs:
            raise ValueError(f"If you provide argument 'error_y_mode' you must also provide 'error_y'.")
        figure_with_error_bars = px.line(**kwargs)
        fig = px.line(**{arg: val for arg,val in kwargs.items() if arg != 'error_y'})
        for data in figure_with_error_bars.data:
            x = list(data['x'])
            y_upper = list(data['y'] + data['error_y']['array'])
            y_lower = list(data['y'] - data['error_y']['array'] if data['error_y']['arrayminus'] is None else data['y'] - data['error_y']['arrayminus'])
            color = f"rgba({tuple(int(data['line']['color'].lstrip('#')[i:i+2], 16) for i in (0, 2, 4))},.3)".replace('((','(').replace('),',',').replace(' ','')
            fig.add_trace(
                go.Scatter(
                    x = x+x[::-1],
                    y = y_upper+y_lower[::-1],
                    fill = 'toself',
                    fillcolor = color,
                    line = dict(
                        color = 'rgba(255,255,255,0)'
                    ),
                    hoverinfo = "skip",
                    showlegend = False,
                    legendgroup = data['legendgroup'],
                    xaxis = data['xaxis'],
                    yaxis = data['yaxis'],
                )
            )
        # Reorder data as said here: https://stackoverflow.com/a/66854398/8849755
        reordered_data = []
        for i in range(int(len(fig.data)/2)):
            reordered_data.append(fig.data[i+int(len(fig.data)/2)])
            reordered_data.append(fig.data[i])
        fig.data = tuple(reordered_data)
    return fig