In [1]:
import argparse

import dgl
import networkx as nx
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from graphregression import *
from models import GraphRegression

import pdb

In [2]:
class Args:
    num_samples = 1000
    batch_size = 64
    epochs = 10
    num_test_samples = 10

In [3]:
args = Args()

In [4]:
prior = get_prior()

model = GraphRegression()
dataset = Dataset(args.num_samples, prior)
dataloader = dgl.dataloading.GraphDataLoader(dataset, batch_size=args.batch_size, 
                                             shuffle=True, drop_last=False)

In [5]:
opt = torch.optim.Adam(model.parameters())
loss_fn = nn.MSELoss()

In [6]:
model.train(True)
for ep in range(args.epochs):
    for batched_graph, labels in dataloader:
        opt.zero_grad()
        
        feats = get_normalized_features(batched_graph)
        y_hat = model(batched_graph, feats)
        loss = loss_fn(y_hat, labels.unsqueeze(-1)) # loss = F.cross_entropy(logits, labels)
        loss.backward()
        opt.step()

In [None]:
eval_excess_risk(model, args)

In [8]:
prior = get_prior()
graphs, labels = synthetic_dataset(args.num_test_samples, prior)
labels = labels.unsqueeze(-1)
model.train(False)

GraphRegression(
  (conv): Sequential(
    (0): GraphConv(in=1, out=8, normalization=both, activation=<function relu at 0x7fc7ea720dc0>)
    (1): GraphConv(in=8, out=8, normalization=both, activation=None)
  )
  (readout): Linear(in_features=8, out_features=1, bias=True)
)

In [12]:
h_star(batched_graph)

tensor([0.6886, 0.3419, 0.2263, 0.3294, 0.3224, 0.5605, 0.1742, 0.8148, 0.9806,
        0.4388])

In [11]:
labels

tensor([[0.6310],
        [0.3240],
        [0.2200],
        [0.3210],
        [0.3130],
        [0.5610],
        [0.1580],
        [0.8050],
        [0.9810],
        [0.4400]])

In [15]:
batched_graph = dgl.batch(graphs)
feats = get_normalized_features(batched_graph)
y_hat = model(batched_graph, feats)
L_hat = F.mse_loss(y_hat, labels)
Lstar = F.mse_loss(h_star(batched_graph).unsqueeze(-1), labels)

In [16]:
L_hat

tensor(0.0712, grad_fn=<MseLossBackward0>)

In [17]:
Lstar

tensor(0.0004)

In [18]:
model._get_norm()

tensor(5.0737)

In [14]:
estimate_rademacher(model, args)

tensor(0.4276)

In [None]:
def eval_excess_risk(model, args):
    prior = get_prior()
    graphs, labels = synthetic_dataset(args.num_test_samples, prior)
    labels = labels.unsqueeze(-1)
    model.train(False)
    
    batched_graph = dgl.batch(graphs)
    feats = get_normalized_features(batched_graph)
    y_hat = model(batched_graph, feats)
    L_hat = F.mse_loss(y_hat, labels)
    Lstar = F.mse_loss(h_star(batched_graph), labels)

    return Lstar - L_hat