In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

import pandas as pd

In [3]:
outpath = Path('results/origin_distance_mse.tsv')

In [4]:
results = pd.read_csv(outpath, sep='\t')
results

Unnamed: 0,model of geometry,distance function,hdt_reg_mse,edt_reg_mse,max_depth,n_dim,seed,noise_std,n_samples,n_classes
0,poincare,euclidean,8.711850e-04,2.149039e+05,5,2,42,0.1,1000,2
1,poincare,poincare,8.443482e-03,2.139521e+05,5,2,42,0.1,1000,2
2,hyperboloid,euclidean,5.954323e+00,2.031758e+05,5,2,42,0.1,1000,2
3,hyperboloid,hyperboloid,4.197303e-08,1.922590e+05,5,2,42,0.1,1000,2
4,poincare,euclidean,3.745208e-03,2.050664e+05,5,2,42,0.1,1000,5
...,...,...,...,...,...,...,...,...,...,...
1075,hyperboloid,hyperboloid,,1.735006e+05,20,32,44,0.5,1000,5
1076,poincare,euclidean,,1.681987e+05,20,32,44,0.5,1000,10
1077,poincare,poincare,,1.578095e+05,20,32,44,0.5,1000,10
1078,hyperboloid,euclidean,,2.953678e+31,20,32,44,0.5,1000,10


In [5]:
# groupby and compute mean and std
results = results.dropna()
groups = results.groupby(['model of geometry', 'distance function', 'max_depth', 'n_dim', 'n_samples', 'noise_std', 'n_classes'])
stats = groups.agg(['mean', 'std']).reset_index().drop('seed', axis=1, level=0)
stats.columns = [' '.join(col).strip() for col in stats.columns.values]
stats

Unnamed: 0,model of geometry,distance function,max_depth,n_dim,n_samples,noise_std,n_classes,hdt_reg_mse mean,hdt_reg_mse std,edt_reg_mse mean,edt_reg_mse std
0,hyperboloid,euclidean,5,2,1000,0.1,2,7.369487e+02,1.045553e+03,1.901376e+05,2.169340e+04
1,hyperboloid,euclidean,5,2,1000,0.1,5,1.991190e+02,1.870364e+02,1.818566e+05,4.331971e+03
2,hyperboloid,euclidean,5,2,1000,0.1,10,2.073553e+03,2.057575e+03,1.933143e+05,1.183979e+04
3,hyperboloid,euclidean,5,2,1000,0.5,2,2.301315e+07,3.039256e+07,2.543656e+07,3.432465e+07
4,hyperboloid,euclidean,5,2,1000,0.5,5,2.850955e+06,4.251437e+06,4.384914e+05,3.162543e+05
...,...,...,...,...,...,...,...,...,...,...,...
247,poincare,poincare,20,8,1000,0.5,5,6.052810e+00,2.575006e-01,2.079987e+05,1.529293e+04
248,poincare,poincare,20,8,1000,0.5,10,6.139616e+00,5.902031e-01,2.098772e+05,2.069556e+04
249,poincare,poincare,20,16,1000,0.1,2,2.738591e+00,2.539980e-01,2.063239e+05,1.750993e+04
250,poincare,poincare,20,16,1000,0.1,5,2.677655e+00,4.063953e-01,1.942163e+05,2.438256e+04


In [6]:
from collections.abc import Iterator

def filter_df(df, col_to_vals):
    for col, vals in col_to_vals.items():
        if isinstance(vals, float) or isinstance(vals, int) or isinstance(vals, str):
            vals = [vals]  
        df = df[df[col].isin(vals)]
    return df

In [7]:
col_to_val = {
    'model of geometry': ['hyperboloid'],
    'distance function': ['hyperboloid', 'euclidean'],
}

df = filter_df(stats, col_to_val)
df

Unnamed: 0,model of geometry,distance function,max_depth,n_dim,n_samples,noise_std,n_classes,hdt_reg_mse mean,hdt_reg_mse std,edt_reg_mse mean,edt_reg_mse std
0,hyperboloid,euclidean,5,2,1000,0.1,2,7.369487e+02,1.045553e+03,1.901376e+05,2.169340e+04
1,hyperboloid,euclidean,5,2,1000,0.1,5,1.991190e+02,1.870364e+02,1.818566e+05,4.331971e+03
2,hyperboloid,euclidean,5,2,1000,0.1,10,2.073553e+03,2.057575e+03,1.933143e+05,1.183979e+04
3,hyperboloid,euclidean,5,2,1000,0.5,2,2.301315e+07,3.039256e+07,2.543656e+07,3.432465e+07
4,hyperboloid,euclidean,5,2,1000,0.5,5,2.850955e+06,4.251437e+06,4.384914e+05,3.162543e+05
...,...,...,...,...,...,...,...,...,...,...,...
121,hyperboloid,hyperboloid,20,8,1000,0.5,5,2.712129e-03,2.381732e-03,1.991542e+05,1.753192e+04
122,hyperboloid,hyperboloid,20,8,1000,0.5,10,3.145761e-03,4.996747e-03,1.864830e+05,1.686768e+04
123,hyperboloid,hyperboloid,20,16,1000,0.1,2,7.268678e-03,1.247736e-02,1.830728e+05,1.985969e+04
124,hyperboloid,hyperboloid,20,16,1000,0.1,5,4.871258e-04,8.178081e-04,1.984191e+05,2.357072e+04


In [26]:
import plotly.express as px
import plotly.graph_objects as go

In [66]:
fig = px.scatter(
    df,
    x='n_dim',
    y='hdt_reg_mse mean',
    log_y=True,
    color='distance function',
    facet_col='n_classes',
    facet_row='max_depth',
    symbol='noise_std',
    title='Hyperboloid Distance Function MSE',
    hover_data=['hdt_reg_mse mean', 'edt_reg_mse mean'],
    )

# set all facet axis titles to blank
# fig.for_each_yaxis(lambda y: y.update(title = ''))
for axis in fig.layout:
    if type(fig.layout[axis]) == go.layout.YAxis or type(fig.layout[axis]) == go.layout.XAxis:
        fig.layout[axis].title.text = ''

# fig.update_layout(
#     xaxis_title='Dimension',
#     yaxis3_title='MSE',
# )

fig.add_annotation(
    x=-0.1, y=0.5,
    text="MSE",
    textangle=-90,
    xref="paper", yref="paper",
    showarrow=False,
    font=dict(size=16)
    )
fig.add_annotation(
    x=0.5, y=-0.15,
    text="Tree Max Depth",
    xref="paper", yref="paper",
    showarrow=False,
    font=dict(size=16),
    )


fig.show()