In [1]:
import altair as alt
import pandas as pd
import torch
from torch_geometric.loader import DataLoader

from main_layered import create_and_train_model, eval_model
from utils.data import flex_graph, gen_rectangular_channel_matrix

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def train_multiple_models(samples):
    k = 32
    for i in samples:
        path_ = f"./flexible_experiment_{i}_training_samples.pth"
        data = gen_rectangular_channel_matrix(k, k, samples, seed=11)
        data = flex_graph(data)
        data = DataLoader(data, batch_size=64, shuffle=True)
        create_and_train_model(n=10000, k=32, batch_size=64, noise_var=1., path=path_, data=data, lr=0.002, layers='Samp')

def evaluate_models(samples):
    perf = []
    k = 32
    data = gen_rectangular_channel_matrix(k, k, 10000, seed=899)
    data = flex_graph(data)
    for i in samples:
        path_ = f"./experiments/flexible_experiment_{i}_training_samples.pth"
        perf.append(eval_model(path=path_, data=data, layers='Samp'))
    return perf

In [4]:
samples = [2000, 5000, 10000, 15000, 20000]
# train_multiple_models(samples)

In [5]:
perf_samples = evaluate_models(samples)

-3.957087993621826
-3.9976425170898438
-4.060915470123291
-4.0777435302734375
-4.085641860961914


In [6]:
perf_samples = [x.item() for x in perf_samples]

In [7]:
source = pd.DataFrame({'Sample Count': samples,
                           'Performance': perf_samples
                           })
source = source.melt('Sample Count', var_name='category', value_name='Performance')

samples_chart = alt.Chart(source).transform_joinaggregate(
    max='max(Performance)',
).transform_calculate(
    percent="datum.Performance / datum.max"
).encode(
    alt.X('Sample Count:Q', scale=alt.Scale(zero=False, nice=True), axis=alt.Axis(format=',.2r', labelAngle=-45)),
    alt.Y('percent:Q', title='Performance',
          scale=alt.Scale(zero=False, nice=True, type='linear', 
#                           domain=[0.96, 1.001]
                         ),
          axis=alt.Axis(tickCount=6, format='.1%')),
    alt.Color('category:N', legend=None,
              scale=alt.Scale(scheme="dark2"))
).mark_area(
    color="lightblue",
    line=True,
    point=True,
    interpolate='catmull-rom'
).configure_point(
    size=50
)

samples_chart

  source = source.melt('Sample Count', var_name='category', value_name='Performance')


### Plot in SVG format for viewing on Github

![plot](https://user-images.githubusercontent.com/17931435/213000831-f4655edc-3749-491a-9ddf-471a69a1e072.svg)