In [1]:
import altair as alt
import pandas as pd
import torch
from torch_geometric.loader import DataLoader

from loss import create_loss_fn
from model import FlexNet
from utils.data import gen_rectangular_channel_matrix, flex_graph

torch.manual_seed(0)

<torch._C.Generator at 0x7fca40bb4330>

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def generate_data(n, k, batch_size):
    h_batch = gen_rectangular_channel_matrix(k, k, n, seed=13)
    datalist = flex_graph(h_batch)
    return DataLoader(datalist, batch_size=batch_size, shuffle=True)

def train(model, optimizer, loss_fn, dataset, k, path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.load_state_dict(torch.load(path, map_location=device), strict=False)
    model.train()
    for epoch in range(10):
        running_loss = 0.0
        for i, data in enumerate(dataset, 0):
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data)
            loss = loss_fn(out, data.y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if i % 5 == 4:  # print every 2000 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 10:.3f}')
                running_loss = 0.0
    torch.save(model.state_dict(), path)

def eval_model(path, data, k=32, aggr='add'):
    n = 5000
    model = FlexNet(aggr)
    model.load_state_dict(torch.load(path, map_location=torch.device('cpu')), strict=False)
    model.eval()

    new_data = iter(DataLoader(data, batch_size=n, shuffle=False)).next()
    outs = model(new_data)
    p, t = outs
    t = torch.where(t >= 0.5, 1., 0.)
    p = torch.where(p >= 0.5, 1., 0.)
    rate = create_loss_fn(k, 1.)
    sum_r = rate((p, t), new_data.y)
    print(sum_r.item())
    return -sum_r

def evaluate_models(nodes):
    perf = []
    path_ = "./experiments/flexible_experiment_generalization.pth"
    for i in nodes:
        data = gen_rectangular_channel_matrix(i, i, 10000, seed=899)
        data = flex_graph(data)
        perf.append(eval_model(path=path_, data=data, k=i))
    return perf

def evaluate_perf_models(nodes):
    perf = []
    for i in nodes:
        path_ =f"./experiments/flexible_experiment_{i}_nodes.pth"
        data = gen_rectangular_channel_matrix(i, i, 10000, seed=899)
        data = flex_graph(data)
        perf.append(eval_model(path=path_, data=data, aggr='add', k=i))
    return perf

def create_and_train_model(n, batch_size, noise_var, path, lr=0.002, aggr='add'):
    model = FlexNet(aggr)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    k = [4, 8, 12, 16, 20, 24, 28, 32]
    dataset = []
    for i in k:
            dataset.append(generate_data(n, i, batch_size))
    for j in range(30):
        for i, val in enumerate(k):
            loss_fn = create_loss_fn(val, 1.)
            train(model, optimizer, loss_fn, dataset[i], val, path)

In [4]:
path_ = "./experiments/flexible_experiment_generalization.pth"
# create_and_train_model(n=10000, batch_size=64, noise_var=1., path=path_)

In [5]:
nodes = [4, 8, 12, 16, 20, 24, 28, 32]
perf = evaluate_models(nodes)

-1.724555253982544
-2.4403460025787354
-2.8788602352142334
-3.2358181476593018
-3.5193471908569336
-3.739914894104004
-3.943530797958374
-4.115488529205322


In [6]:
nodes = [4, 8, 12, 16, 20, 24, 28, 32]
perf_single = evaluate_perf_models(nodes)

-1.7907705307006836
-2.4896557331085205
-2.9190170764923096
-3.2618682384490967
-3.522634744644165
-3.7365972995758057
-3.9261507987976074
-4.089602947235107


In [7]:
perf = [x.item() for x in perf]
perf_single = [x.item() for x in perf_single]

In [8]:
source = pd.DataFrame({'Nodes': nodes, 'Multiple models': perf_single,
                       'Single model': perf })
source = source.melt('Nodes', var_name='category', value_name='Sum Rate')

chart = alt.Chart(source).mark_line(point=True, interpolate='monotone').configure_point(size=60).encode(
    alt.X('Nodes:Q', scale=alt.Scale(zero=False)),
    alt.Y('Sum Rate:Q', title='Sum Rate (bit/s)', scale=alt.Scale(zero=False)),
    alt.Color('category:N', scale=alt.Scale(range=["#e7298a", "#66a61e"]), legend=alt.Legend(orient='bottom-right', title=None)),
    shape=alt.Shape('category:N', scale=alt.Scale(range=['square', 'circle']))
).resolve_scale(
    color='independent',
    shape='independent'
)

chart

### Plot in SVG format for viewing on Github

![plot](https://user-images.githubusercontent.com/17931435/212999054-8edc2b3a-29b3-4039-9262-97dae02d9faa.svg)