In [1]:
import altair as alt
import pandas as pd
import torch
from torch_geometric.loader import DataLoader

from main_layered import create_and_train_model, eval_model
from utils.data import gen_rectangular_channel_matrix, flex_graph

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def train_multiple_models(layers):
    k = 32
    data = gen_rectangular_channel_matrix(k, k, 10000, seed=11)
    data = flex_graph(data)
    data = DataLoader(data, batch_size=64, shuffle=True)
    for i in layers:
        path_ = f"./experiments/flexible_experiment_{i}_layers.pth"
        create_and_train_model(n=10000, k=32, batch_size=64, noise_var=1., path=path_, data=data, lr=0.002, layers=i)

def evaluate_models(layers):
    perf = []
    k = 32
    data = gen_rectangular_channel_matrix(k, k, 100, seed=899)
    data = flex_graph(data)
    for i in layers:
        path_ = f"./experiments/flexible_experiment_{i}_layers.pth"
        perf.append(eval_model(path=path_, data=data, layers=i).item())
    return perf

In [4]:
layers = [1, 2, 3, 4, 5]
# train_multiple_models(layers)

In [5]:
perf_layers = evaluate_models(layers)

-3.4525604248046875
-3.677065134048462
-3.9491357803344727
-3.9616055488586426
-4.01860237121582


In [6]:
def evaluate_models(samples):
    perf = []
    k = 32
    data = gen_rectangular_channel_matrix(k, k, 10000, seed=899)
    data = flex_graph(data)
    for i in samples:
        path_ = f"./flexible_experiment_{i}_training_samples.pth"
        perf.append(eval_model(path=path_, data=data))
    return perf

In [8]:
source = pd.DataFrame({'Layers': layers,
                           'Perf': perf_layers
                           })
source = source.melt('Layers', var_name='category', value_name='Performance')

layers_chart = alt.Chart(source).transform_joinaggregate(
    max='max(Performance)',
).transform_calculate(
    percent="datum.Performance / datum.max"
).transform_filter(
    0.85 <= alt.datum.percent
).mark_area(
    color="lightblue",
    line=True,
    point=True,
    interpolate='monotone'
).encode(
    alt.X('Layers:Q', axis=alt.Axis(tickCount=5)),
    alt.Y('percent:Q', title='Performance',
          scale=alt.Scale(zero=False, nice=True, type='linear', 
                          domain=[0.85, 1], padding=0
                         ),
          axis=alt.Axis(tickCount=6, format='.0%')),
    color=alt.value('#D95F02')
).configure_point(
    size=50
)
layers_chart

Displaying chart at http://localhost:58486/


## Plot in SVG format for viewing on Github

![svg plot](https://user-images.githubusercontent.com/17931435/212996585-4f5da4bf-a744-467e-b4b4-21cc423f4f75.svg)