In [21]:
from torch_geometric.data import Data
import torch
import torch.nn as nn
import torch.nn.functional as F 
import random
from torch_geometric.datasets import CoraFull, TUDataset
import numpy as np
import torch_geometric as tg
from utils import count_in_degree, count_out_degree, create_pointer_graph
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.nn import GCNConv

## add dataset and shuffle

In [22]:
graphs = TUDataset("", "PROTEINS")
# cora = CoraFull("", "Cora").data
perm = np.random.choice(list(range(len(graphs))), len(graphs), replace=False)
graphs = [graphs[int(i)] for i in perm]
ys = [graph.y.unsqueeze(0) for graph in graphs]
ys = torch.cat(ys)

## get train and test datasets

In [23]:
ys_test = ys[-100:]
graphs_test = graphs[-100:]
graphs_train = graphs[:-100]
ys_train = ys[:-100]

## create starting embedding

In [67]:
embdim = 64
for graph in graphs:
    zeros = torch.zeros(graph.x.shape[0], 16-graph.x.shape[1])
    graph.x = torch.cat((graph.x, zeros), dim=1)

In [178]:
class GCNLayer(nn.Module):
    def __init__(self, embdim):
        super().__init__()
        self.lin = nn.Linear(embdim, embdim)
        self.mlp = MLP([2*embdim, 64, embdim])
    def forward(self, x, graph, igraph):
        newx   = torch.zeros_like(x)
        self_h = self.lin(x)

        # uncomment to normalize 
        h = x
        # h = x / degs[igraph]['out']

        edges = graph.edge_index
        sp = torch.sparse_coo_tensor(edges, torch.ones(len(edges[0])), [len(x), len(x)])
        newx = torch.spmm(sp, h)
        
        newx = torch.cat((newx, self_h), dim=1)
        newx = self.mlp(newx) + x
        return newx

class MLP(nn.Module):
    def __init__(self, arch, last_activation=F.relu, middle_activation=F.relu):
        super().__init__()
        self.lins = nn.ModuleList([nn.Linear(a, b) for a, b in zip(arch[:-1], arch[1:])])
        self.last_activation = last_activation
        self.middle_activation = middle_activation
    def forward(self, h):
        for lin in self.lins[:-1]:
            h = self.middle_activation(lin(h))
        h = self.last_activation(self.lins[-1](h))
        return h

class Model(nn.Module):
    def __init__(self, n, embdim=16):
        super().__init__()
        self.layers = nn.ModuleList([GCNLayer(embdim) for i in range(n)])
        self.classifier = MLP([embdim, 32, 2], last_activation= lambda x: torch.softmax(x, dim=-1))
        self.embdim = embdim
    def forward(self, graph, igraph):
        h = graph.x.clone()
        zeros = torch.zeros(h.shape[0], self.embdim-h.shape[1])
        h = torch.cat((h, zeros), dim=1)     
        for layer in self.layers:
            h = layer(h, graph, igraph)
        m = h.mean(dim=0, keepdim=True)
        preds = self.classifier(m)
        return preds, m

model = Model(7)

# Train

In [179]:
opt = torch.optim.Adam(model.parameters(), lr=0.008)

In [180]:
! mkdir tb
writer = SummaryWriter('tb/gcn_'+str(random.random())) 
eps = 100

out_dim = 2
bpreds = torch.zeros(0, out_dim)
bsize, bpass, ibatch = 32, 0, 0
ys = []
for ep in range(0, eps):
    for igraph, (graph, one_y) in enumerate(zip(graphs_train, ys_train)):
        
        preds, x = model(graph, igraph)
        bpreds = torch.cat((bpreds, preds))
        ys.append(one_y)
        bpass += 1
        if bpass < bsize:
            continue

        y = torch.tensor(ys)
        yhot = F.one_hot(y.long(), num_classes=out_dim)
        # print(y.shape, yhot.shape, bpreds.shape)
        loss = -(yhot*torch.log(bpreds+1e-8) + (1-yhot)*torch.log(1-bpreds+1e-8)).mean()
        acc = (((bpreds>0.5)==yhot).float().sum(dim=1)==yhot.shape[-1]).float().mean()

        opt.zero_grad()
        loss.backward()
        opt.step()

        ibatch += 1
        bpass = 0
        bpreds = torch.zeros(0, out_dim)
        ys = []
        
        writer.add_scalar("loss", loss.item(), ep*eps//bsize+ibatch)
        writer.add_scalar("acc" , acc.item() , ep*eps//bsize+ibatch)
        print(ep, ibatch, loss.item(), acc.item())



 545 0.5674636363983154 0.65625
17 546 0.5844385623931885 0.75
17 547 0.6784171462059021 0.53125
17 548 0.5347244739532471 0.78125
17 549 0.6111361980438232 0.71875
17 550 0.4236001670360565 0.84375
17 551 0.5886922478675842 0.75
17 552 0.40338483452796936 0.875
17 553 0.6657277941703796 0.625
17 554 0.48944056034088135 0.78125
17 555 0.6872168779373169 0.6875
17 556 0.6219539046287537 0.75
17 557 0.5899772047996521 0.78125
17 558 0.5678197741508484 0.65625
17 559 0.497464120388031 0.75
17 560 0.5440641641616821 0.6875
17 561 0.5503969788551331 0.75
17 562 0.5817108154296875 0.625
17 563 0.5466321706771851 0.75
17 564 0.5236157178878784 0.75
17 565 0.4851984977722168 0.78125
17 566 0.529097318649292 0.6875
17 567 0.5717544555664062 0.75
17 568 0.500135064125061 0.78125
17 569 0.6102918386459351 0.6875
18 570 0.504416286945343 0.75
18 571 0.705080509185791 0.59375
18 572 0.6129415035247803 0.6875
18 573 0.4417411684989929 0.84375
18 574 0.7552437782287598 0.59375
18 575 0.44240716099739

KeyboardInterrupt: 

# Try on validatation set

In [183]:
bpreds = torch.zeros(0, out_dim)
ys = []
for igraph, (graph, one_y) in enumerate(zip(graphs_test, ys_test)):
    
    preds, x = model(graph, len(graphs_train)+igraph)
    bpreds = torch.cat((bpreds, preds))
    ys.append(one_y)

y = torch.tensor(ys)
yhot = F.one_hot(y.long(), num_classes=out_dim)
# print(yhot.shape, bpreds.shape)
loss = -(yhot*torch.log(bpreds+1e-8) + (1-yhot)*torch.log(1-bpreds+1e-8)).mean()
acc = (((bpreds>0.5)==yhot).float().sum(dim=1)==yhot.shape[-1]).float().mean()
acc

tensor(0.6900)