# New developments in deep learning - R-GCNs
Isotropic Graph Neural Networks 
- Representations are learned via differentiable message passing scheme 
-  All neighbors are treated as equally important 
-  Starting points: 
  -  Kipf & Welling: “Semi-Supervised Classification with Graph Convolutional Networks” (https://arxiv.org/abs/1609.02907) 
  -  Schlichtkrull et al.: “Modeling Relational Data with Graph Convolutional Networks” (https://arxiv.org/abs/1703.06103) 
- Task: 
  - Implement Relational Graph convolutional Neural Network for Node Classification

Blog posts:
- https://towardsdatascience.com/how-to-do-deep-learning-on-graphs-with-graph-convolutional-networks-7d2250723780
- http://tkipf.github.io/graph-convolutional-networks/

Jupyter notebook tutorial:
- https://github.com/TobiasSkovgaardJepsen/posts/blob/master/HowToDoDeepLearningOnGraphsWithGraphConvolutionalNetworks/Part2_SemiSupervisedLearningWithSpectralGraphConvolutions/notebook.ipynb

Keras implementation:
- https://github.com/tkipf/relational-gcn
PyTorch implementations
- https://github.com/tkipf/pygcn 
- https://github.com/masakicktashiro/rgcn_pytorch_implementation
- https://github.com/mjDelta/relation-gcn-pytorch

Dataset:
- https://ogb.stanford.edu
- https://ogb.stanford.edu/docs/nodeprop/#ogbn-proteins
- https://ogb.stanford.edu/docs/leader_nodeprop/#ogbn-proteins 

## Relational GCNs - Theory
Extension of GCNs: Use a set of relation-specific weight matrices $W_r^{(l)}$, where $r \in R$ denotes the relation type

Propagation model:
> $h_i^{l+1} = \sigma\left(\sum_{r\in R}\sum_{j\in N^r_i}\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}+\underbrace{W_0^{(l)}h_i^{(l)}}_{\text{self-connection}}\right)$

where 
- $N^r_i$ denotes the set of neighbor indices of node $i$ under relation $r \in R$, 
- $c_{i,r}$ is a problem-specific normalization constant that can either be learned or chosen in advance (such as $c_{i,r} = |N_i^r|$).

Neural network layer update: evaluate message passing update in parallel for every node $i \in V$.

Parameter sharing for highly- multi-relational data: basis decomposition of relation-specific weight matrices
> $W_r^{(l)} = \sum^B_{b=1}a^{(l)}_{r,b}V_b^{(l)}$

Linear combination of basis transformations $V_b^{(l)} \in \mathbb{R}^{d^{(l+1)}\times d^{(l)}}$ with learnable coefficients $a^{(l)}_{r,b}$ such that only the coefficients depend on $r$. $B$, the number of basis functions, is a hyperparameter.

For entity classification as described in the paper minimize:
> $L = -\sum_{i\in Y}\sum^K_{k=1}t_{i,k}\ln h_{i,k}^{(l)}$

whre:
- $Y$ is the set of node indices with labels
- $K$ is the number of classes (?)
- $t_{i,k}$ is the ground-truth label
- $h_{i,k}^{(l)}$ is the $k$-th entry of network ouput for $i$-th labeled node

Training and evaluation
- 2 layer model with 16 hidden units (dimension of hidden node representation)
- 50 epochs with learning rate 0.01 using Adam optimizer
- normalization constant $c_{i,r} = |N_i^r|$, i.e. average all incoming messages from a particular relation type
- $l2$ penalty on first layer weights $C_{l2} \in \{0, 5\cdot 10^{-4}\}$
- number of basis functions $B \in \{0, 10, 20, 30, 40\}$

Results reported
- Accuracy and standard error over 10 runs

## Datasets: AIFB, MUTAG
**AIFB**: 
Describes the AIFB research institute, predict affiliation to research group

4 classes, 45 relations, 8k entities, 28k edges, 176 labelled instances

**MUTAG**:
Information about complex molecules that are potentially carcinogenic

2 classes, 23 relations, 23k entities, 74k edges, 340 labelled instances

### References
Ristoski, P., De Vries, G. K. D., & Paulheim, H. (2016, October). A collection of benchmark datasets for systematic evaluations of machine learning on the semantic web. In International Semantic Web Conference (pp. 186-194). Springer, Cham

## Loading the data

In [None]:
dataset_name = "AIFB"  # choices=['AIFB', 'MUTAG', 'BGS', 'AM']

In [None]:
# https://github.com/rusty1s/pytorch_geometric/blob/master/examples/rgcn.py
import os.path as osp

import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.datasets import Entities
from torch_geometric.nn import RGCNConv

path = osp.join('.', 'data', 'Entities')
dataset = Entities(path, dataset_name)
data = dataset[0]

In [None]:
dataset.num_relations

In [None]:
dataset.num_classes

In [None]:
data

In [None]:
# https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/utils/num_nodes.py
def maybe_num_nodes(edge_index, num_nodes=None):
    if num_nodes is not None:
        return num_nodes
    elif isinstance(edge_index, Tensor):
        return int(edge_index.max()) + 1
    else:
        return max(edge_index.size(0), edge_index.size(1))

In [None]:
data.num_nodes = maybe_num_nodes(data.edge_index)

In [None]:
data.num_nodes

## R-GCN

In [None]:
# https://github.com/rusty1s/pytorch_geometric/blob/master/examples/rgcn.py
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = RGCNConv(data.num_nodes, 16, dataset.num_relations,
                              num_bases=30)
        self.conv2 = RGCNConv(16, dataset.num_classes, dataset.num_relations,
                              num_bases=30)

    def forward(self, edge_index, edge_type):
        x = F.relu(self.conv1(None, edge_index, edge_type))
        x = self.conv2(x, edge_index, edge_type)
        return F.log_softmax(x, dim=1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005)


def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.edge_index, data.edge_type)
    loss = F.nll_loss(out[data.train_idx], data.train_y)
    loss.backward()
    optimizer.step()
    return loss.item()


@torch.no_grad()
def test():
    model.eval()
    pred = model(data.edge_index, data.edge_type).argmax(dim=-1)
    train_acc = pred[data.train_idx].eq(data.train_y).to(torch.float).mean()
    test_acc = pred[data.test_idx].eq(data.test_y).to(torch.float).mean()
    return train_acc.item(), test_acc.item()


for epoch in range(1, 51):
    loss = train()
    train_acc, test_acc = test()
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f} '
          f'Test: {test_acc:.4f}')