In [None]:
import networkx as nx
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.datasets.molecule_net import MoleculeNet
import random 
import numpy as np
from tqdm import tqdm
import copy
import itertools
import json
import random

## Datasets

In [None]:
def gen_cycle_pairs(sizes):
  pairs = []
  for n in sizes:
    for k in range(3, n-2):
      single = nx.cycle_graph(n)
      disjoint = nx.disjoint_union(nx.cycle_graph(k),
                                   nx.cycle_graph(n-k))
      pairs.append((single, disjoint))
  return pairs


def to_pyg(g, label):
  data = torch_geometric.utils.from_networkx(g)
  data.x = torch.zeros((g.number_of_nodes(), 50))
  data.y = torch.tensor([label])
  return data


def cycles_dataset(sizes):
  graph_pairs = gen_cycle_pairs(sizes)
  data = sum(([to_pyg(g1, 1), to_pyg(g2, 0)] for (g1,g2) in graph_pairs), [])
  return data


def cycles_dict(n1, n2):
  return {i: to_pyg(nx.cycle_graph(i), 1) for i in range(n1, n2+1)}

In [None]:
cycles_train_loader = torch_geometric.data.DataLoader(cycles_dataset([6, 7, 9, 10]), batch_size=1, shuffle=True)
cycles_test_loader = torch_geometric.data.DataLoader(cycles_dataset([8]), batch_size=1, shuffle=True)
cycles = cycles_dict(3,13)



In [None]:
def prepare_tinymolhiv():
  tiny_molhiv = MoleculeNet(root = './molhiv', name='HIV', 
                            pre_filter = lambda g: len(g.x) <= 10)  
  pos_samples = [g for g in tiny_molhiv if g.y == 1]
  neg_samples = [g for g in tiny_molhiv if g.y == 0]
  random.seed(0)
  random.shuffle(pos_samples)
  random.shuffle(neg_samples)
  neg_samples = neg_samples[:3*len(pos_samples)]
  
  for g in pos_samples:
    g.y = g.y.squeeze(dim=0)
  for g in neg_samples:
    g.y = g.y.squeeze(dim=0)
  
  pos_splits = [pos_samples[i::5] for i in range(5)]
  neg_splits = [neg_samples[i::5] for i in range(5)]

  splits = [3*pos_splits[i] + neg_splits[i] for i in range(5)]
  for i in range(5):
    random.shuffle(splits[i])
    for g in splits[i]:
      g.x = g.x.float()
  return splits

def prepare_loaders(splits, eval_split=0):
  train_data = sum((splits[i] for i in range(len(splits)) if not i == eval_split), [])
  eval_data = splits[eval_split]
  train_loader = torch_geometric.data.DataLoader(train_data, batch_size=1, shuffle=True)
  eval_loader = torch_geometric.data.DataLoader(eval_data, batch_size=1, shuffle=True)
  return train_loader, eval_loader

In [None]:
hiv_splits = prepare_tinymolhiv()

## Models

### Helper functions

In [None]:
def tensorpow_squaremat(t, n):
  res = t
  start_dim = t.shape[0]
  dim = start_dim
  for i in range(n-1):
    res = torch.tensordot(res, t, dims=0)
    res = res.permute((0, 2, 1, 3))
    dim *= start_dim    
    res = res.reshape((dim, dim))
  return res

def tensormul_vecs(terms):
  res = terms[0]
  start_dim = res.shape[0]
  dim = start_dim
  for t in terms[1:]:
    res = torch.tensordot(res, t, dims=0)
    dim *= start_dim    
    res = res.reshape(dim)
  return res

def hermitian(t):
  return t + t.conj().T

### EQGC classes

In [None]:
class EDU_QGC(torch.nn.Module):
  def __init__(self, qb_per_node=1, n_layers = 1, init_u3=False):
    super(EDU_QGC, self).__init__()
    self.qb_per_node = qb_per_node
    self.node_state_dim = 2 ** qb_per_node
    self.n_layers = n_layers
    self.init_u3 = init_u3
    self.node_halfH = torch.nn.ParameterList([
      torch.nn.Parameter(
          torch.randn((self.node_state_dim, self.node_state_dim), dtype=torch.cfloat)) 
      for i in range(n_layers)
    ])
    self.edge_D = torch.nn.ParameterList([
      torch.nn.Parameter(
          torch.randn(self.node_state_dim ** 2))
      for i in range(n_layers)
    ])

  def init_state(self, xs):
    if self.init_u3:
      node_states = [
        torch.tensor([torch.cos(feat[0]), torch.exp(1j*feat[1])*torch.sin(feat[0])])
        for feat in xs
      ]
      return tensormul_vecs(node_states)
    else:
      n_nodes = len(xs)      
      full_dim = self.node_state_dim ** n_nodes
      return torch.ones(full_dim, dtype = torch.cfloat) / np.sqrt(full_dim)

  def prep_node_layer(self, node_halfH, n_nodes):
    node_H = hermitian(node_halfH)
    node_U = torch.matrix_exp(1j * node_H)
    return tensorpow_squaremat(node_U, n_nodes)

  def prep_edge_layer(self, edge_D, n_nodes, edge_index):
    full_dim = self.node_state_dim ** n_nodes
    v = torch.ones(full_dim, dtype=torch.cfloat)      
    for n1,n2 in edge_index.T:
      d = torch.exp(1j*edge_D)
      d = d.reshape(self.node_state_dim, self.node_state_dim)
      d = d.repeat([self.node_state_dim]*(n_nodes-2)+[1,1])
      if n2 > n1:
        perm = list(range(n1)) + [n_nodes-2] + list(range(n1, n2-1)) + [n_nodes-1] + list(range(n2-1, n_nodes-2))
      else:
        perm = list(range(n2)) + [n_nodes-1] + list(range(n2, n1-1)) + [n_nodes-2] + list(range(n1-1, n_nodes-2))
      d = d.permute(perm).flatten()
      v *= d
    return v

  def forward(self, g):
    state = self.init_state(g.x)
    n = len(g.x)
    for i in range(self.n_layers):
      
      edge_d = self.prep_edge_layer(self.edge_D[i], n, g.edge_index)
      state *= edge_d
      node_u = self.prep_node_layer(self.node_halfH[i], n)
      state = node_u @ state
    probs = torch.square(torch.abs(state))
    probs = probs / probs.sum() # normalize for floating point inaccuracies
    return probs

### Aggregators

In [None]:
class OneCountAggregator(torch.nn.Module):
  def __init__(self, max_graph_size, verbose = False):
    super(OneCountAggregator, self).__init__()
    self.max_n = max_graph_size
    self.w = torch.nn.Parameter(torch.zeros(max_graph_size+1))
    self.verbose = verbose
  
  def forward(self, probs):
    count_probs = torch.zeros(self.max_n+1)
    for s in range(len(probs)):
      ones = 0
      for i in range(self.max_n):
        if (s & (1 << i)):
          ones += 1
      count_probs[ones] += probs[s]
    if self.verbose:
      print("count probs", count_probs)
    cond_probs = torch.sigmoid(self.w)
    total_prob = cond_probs @ count_probs # sum P(count = i) x P(pos | count = i)
    return total_prob


class OneRatioAggregator(torch.nn.Module):
  def __init__(self, mlp_hidden_dim = 15, verbose = False):
    super(OneRatioAggregator, self).__init__()
    self.mlp = torch.nn.Sequential(
        torch.nn.Linear(1, mlp_hidden_dim),
        torch.nn.ReLU(),
        torch.nn.Linear(mlp_hidden_dim, 1),
        torch.nn.Sigmoid()
    )
    self.verbose = verbose
  
  def forward(self, probs):
    max_ones = np.ceil(np.log2(len(probs))).astype('int')
    count_probs = torch.zeros(max_ones+1)
    for s in range(len(probs)):
      ones = 0
      for i in range(max_ones):
        if (s & (1 << i)):
          ones += 1
      count_probs[ones] += probs[s]
    if self.verbose:
      print("count probs", count_probs)
    ratios = torch.linspace(0.0, 1.0, max_ones+1)
    cond_probs = self.mlp(ratios.reshape(-1,1)).reshape(-1)
    total_prob = cond_probs @ count_probs # sum P(count = i) x P(pos | count = i)
    return total_prob


## Experiments

### Training and evaluation

In [None]:
def train(model, optimizer, lr_scheduler=None, epochs=200, loader=cycles_train_loader, print_metrics=True):
  model.train()
  for i in range(epochs):
    total = 0
    correct = 0.0
    rso50 = 0
    rso55 = 0
    loss_sum = 0
    min_margin = 0.5
    for g in loader: 
      optimizer.zero_grad()   
      out = model(g).unsqueeze(0)
      loss = F.binary_cross_entropy(out, g.y.float())  
      loss.backward()
      optimizer.step()
      
      total += 1
      loss_sum += loss.detach().numpy()      
      p = out.detach().numpy()[0]
      if g.y == 1:
        correct += p
        if p >= 0.5:
          rso50 += 1
          min_margin = min(p-0.5, min_margin)
      else:
        correct += (1 - p)
        if p < 0.5:
          rso50 += 1
          min_margin = min(0.5-p, min_margin)
    if print_metrics:
      print("Epoch ", i)
      print("Loss: ", loss_sum/total)
      print("Acc: ", correct/total)
      print("RSo50: ", rso50/total)
      print("MinMargin: ", min_margin)
    else:
      print("Epoch ", i, ", loss ", loss_sum/total)
    if lr_scheduler is not None:
        lr_scheduler.step()
    if i == epochs-1:
      return {
          'Loss': loss_sum/total,
          'Acc:': correct/total,
          'rso50': rso50/total,
          'margin': min_margin
      }


def evaluate(model, loader=cycles_test_loader):
  total = 0
  correct = 0.0
  rso50 = 0
  min_margin = 0.5
  loss_sum = 0
  with torch.no_grad():
    for g in loader: 
      out = model(g).unsqueeze(0)
      loss = F.binary_cross_entropy(out, g.y.float())  
      loss_sum += loss.detach().numpy()
      total += 1
      p = out.detach().numpy()[0]
      if g.y == 1:
        correct += p
        if p >= 0.5:
          rso50 += 1
          min_margin = min(min_margin, p-0.5)
      else:
        correct += (1 - p)
        if p < 0.5:
          rso50 += 1
          min_margin = min(min_margin, 0.5-p)
  return {
    'Loss': loss_sum/total,
    'Acc:': correct/total,
    'rso50': rso50/total,
    'margin': min_margin
  }

### Cycles experiment

In [None]:
for i in range(1, 11):
  res = dict()
  for n_layers in range(1, 15):
    model = torch.nn.Sequential(EDU_QGC(n_layers=n_layers), OneRatioAggregator())
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99)
    train_metrics = train(model, optimizer, scheduler, epochs=100, print_metrics=False)
    eval_metrics = evaluate(model)
    res[str(n_layers) + '_train'] = train_metrics
    res[str(n_layers) + '_eval'] = eval_metrics
    print(n_layers, 'layers')
    print('Train: ', train_metrics)
    print('Eval: ', eval_metrics)
  with open('res_' + str(i) + '.json', 'w') as fp:
    json.dump(res, fp)

1 layers
Train:  {'Loss': 0.6780364091197649, 'Acc:': 0.5078519197801749, 'rso50': 0.75, 'margin': 0.001971423625946045}
Eval:  {'Loss': 0.6748795012633005, 'Acc:': 0.5093061973651251, 'rso50': 0.8333333333333334, 'margin': 0.006442070007324219}
2 layers
Train:  {'Loss': 0.6177651981512705, 'Acc:': 0.5418155615528425, 'rso50': 0.9583333333333334, 'margin': 9.495019912719727e-05}
Eval:  {'Loss': 0.6040907402833303, 'Acc:': 0.5491860409577688, 'rso50': 0.5, 'margin': 0.10236728191375732}
3 layers
Train:  {'Loss': 0.3950042249634862, 'Acc:': 0.6816817863533894, 'rso50': 1.0, 'margin': 0.039186716079711914}
Eval:  {'Loss': 0.31118466208378476, 'Acc:': 0.7440535922845205, 'rso50': 1.0, 'margin': 0.11598515510559082}
4 layers
Train:  {'Loss': 0.6938727994759878, 'Acc:': 0.49963802595933277, 'rso50': 0.375, 'margin': 7.134675979614258e-05}
Eval:  {'Loss': 0.6931493580341339, 'Acc:': 0.5000005314747492, 'rso50': 0.5, 'margin': 0.0012707710266113281}
5 layers
Train:  {'Loss': 0.3623693004871408