In [12]:
import altair as alt
import numpy as np
import pandas as pd
import torch
from torch_geometric.loader import DataLoader

from loss import create_loss_fn
from main import create_and_train_model
from model import FlexNet
from utils.data import flex_graph, gen_rectangular_channel_matrix

torch.manual_seed(0)
alt.renderers.enable('html')

RendererRegistry.enable('html')

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
def train_multiple_models(batch_size, lr):
    k = 32
    data = gen_rectangular_channel_matrix(k, k, 10000, seed=11)
    data = flex_graph(data)
    for i in batch_size:
        for j in lr:
            path_ = f"./experiments/flexible_experiment_batch_size_{i}_lr_{j}.pth"
            data_loader = DataLoader(data, batch_size=i, shuffle=True)
            create_and_train_model(n=10000, k=32, data=data_loader, batch_size=i, noise_var=1., path=path_, lr=j)


def evaluate_models(batch_size, lr):
    k = 32
    data = gen_rectangular_channel_matrix(k, k, 10000, seed=899)
    data = flex_graph(data)
    perf = np.empty((len(batch_size), len(lr)))
    for i, batch in enumerate(batch_size):
        for j, rate in enumerate(lr):
            path_ = f"./experiments/flexible_experiment_batch_size_{batch}_lr_{rate}.pth"
            perf[i, j] = eval_model(path=path_, data=data).item()
    return perf

def eval_model(path, data, k=32, aggr='add'):
    n = 5000
    model = FlexNet(aggr)
    model.load_state_dict(torch.load(path, map_location=torch.device('cpu')), strict=False)
    model.eval()

    new_data = iter(DataLoader(data, batch_size=n, shuffle=False)).next()
    outs = model(new_data)
    p, t = outs
    t = torch.where(t >= 0.5, 1., 0.)
    p = torch.where(p >= 0.5, 1., 0.)
    rate = create_loss_fn(k, 1.)
    sum_r = rate((p, t), new_data.y)
    return -sum_r

In [5]:
batch_size = [32, 64, 128, 256, 512, 1024, 2048]
lr = [0.001, 0.002, 0.004, 0.006, 0.008, 0.01]
# train_multiple_models(batch_size, lr)

In [6]:
perf = evaluate_models(batch_size, lr)
perf

array([[4.08904123, 4.09442091, 4.04036045, 4.03999376, 4.04334831,
        3.96661282],
       [4.0378933 , 4.08158112, 4.05126429, 4.03951931, 4.03462744,
        4.03098249],
       [4.01309204, 4.03474283, 4.05080509, 4.07147503, 4.00365257,
        4.04746866],
       [3.96966839, 4.06361866, 4.06947088, 4.04727745, 3.87133098,
        4.07352304],
       [4.00527   , 4.02116108, 4.03391075, 3.98576093, 4.02218008,
        3.815238  ],
       [3.96181989, 3.9712503 , 3.99372697, 4.03279305, 4.03354359,
        3.908288  ],
       [3.90001249, 3.94372702, 3.99829698, 3.74330854, 4.02735758,
        4.02474499]])

In [7]:
x, y = np.meshgrid(lr, batch_size)

In [13]:
source = pd.DataFrame({'x': x.ravel(),
                     'y': y.ravel(),
                     'z': perf.ravel()})

base = alt.Chart(source).encode(
    x=alt.X('x:O', title='Learning Rate', axis=alt.Axis(labelAngle=0)),
    y=alt.Y('y:O', title='Batch Size')
).transform_joinaggregate(
        max='max(z)',
).transform_calculate(
        percent="datum.z / datum.max"
)

text = base.mark_text().encode(
    text=alt.Text('percent:Q', format='.1%'),
    color=alt.condition(
        alt.datum.percent > 0.97,
        alt.value('white'),
        alt.value('black')
    )
)

color_plot = base.mark_rect().encode(
    color=alt.Color('percent:Q',
                    title='Performance',
                    legend=alt.Legend(format='.0%')
                   )
).properties(
    width=300,
    height=200
)

color_plot + text