### 0. Recap GNN

##### Encoder Decoder Framework

<img src="pics/encode-decode.png" alt="Encoder Decoder Structure" title="Encoder-Decoder" width="600"/>

$Encoder$: maps each node to a low-dimensional vector 
- Shallow Encoders : Simplest encoding approach: Encoder is just an embedding-lookup

> - O(|V|) parameters are needed; 
> - Tranductive: cannot generate embeddings for nodes that are not seen during training
> - Do not incorporate node features

- Deep Graph Encoders
<br />

$Dncoder$: predict score based on embedding to match node similarity. If supervised:

> - Node classification: Predict a type of a given node
> - Link prediction: Predict whether two nodes are linked
> - Community detection: Identify densely linked clusters of nodes
> - Network similarity: How similar are two (sub)net

##### MESSAGE PASSING NETWORKS

<img src="pics/gnn1.png" alt="gnn1" title="gnn1" width="600"/>

\begin{equation}
\mathbf{h}_i^{(k)} = \gamma^{(k)} \left( \mathbf{h}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{h}_i^{(k-1)}, \mathbf{h}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right),
\end{equation}

where $\square$ denotes a differentiable, permutation invariant function, e.g., sum, mean or max, and $\gamma$ and $\phi$ denote differentiable functions such as MLPs (Multi Layer Perceptrons). The user only has to define these functions.
 

### 1. RGCN

#### Encoder

The relational graph convolutional operator from the "ModelingRelational Data with Graph Convolutional Networks"
    <https://arxiv.org/abs/1703.06103>`_ paper

\begin{equation}
h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}}
       \sum_{j\in\mathcal{N}^r(i)}\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)})
\end{equation}

 where : <br/>
 > - $\mathcal{N}^r(i)$ is the neighbor set of node :$i$ w.r.t. relation $r$. <br/>
 > -  $c_{i,r}$ is the normalizer equal to $|\mathcal{N}^r(i)|$. <br/>
  > - $\sigma$ is an activation function. <br/>
  > - $W_0$ is the self-loop weight.<br/>


where: $\mathcal{R}$ denotes the set of relations, *i.e.* edge types. 


**basis Decomposition(Optional）**

The problem of applying the above equation directly is the rapid growth of the number of parameters, especially with highly multi-relational data. In order to reduce model parameter size and prevent overfitting, the original paper proposes to use basis decomposition.

The basis regularization decomposes $W_r$ by
\begin{equation}
       W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)}
 \end{equation}
 
where : <br/>
$B$ is the number of bases
$V_b^{(l)}$ are linearly combined with coefficients $a_{rb}^{(l)}$.

Therefore, the weight $W_r^{(l)}$ is a linear combination of basis transformation V(l)b with coefficients a(l)rb. The number of bases B is much smaller than the number of relations in the knowledge base.


#### Decoder :  Link prediction

use DistMult factorization (Yang et al. 2014) as the scoring function, which is known to perform well on standard link prediction benchmarks when used on its own:

\begin{equation}
       f(s, r, o) = e^T_sR_re_o . 
 \end{equation}

In [7]:
import math
import numpy as np
from dputils import load_data, calc_mrr, evaluate
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing

#### Load WN18 Data

In [8]:
entity2id, relation2id, train_triplets, valid_triplets, test_triplets = load_data('./data/wn18')

load data from ./data/wn18
num_entity: 40943
num_relation: 18
num_train_triples: 141442
num_valid_triples: 5000
num_test_triples: 5000


In [9]:
type(train_triplets)

numpy.ndarray

In [4]:
train_triplets[:5]

array([[ 8881,    12, 32231],
       [35365,    17, 15368],
       [13944,    16, 24784],
       [35723,     8,  3497],
       [36279,     8, 21979]])

In [5]:
num_entity = len(entity2id)
num_relation = len(relation2id)

#### Prep Training

In [None]:
def negative_sampling(pos_samples, num_entity, negative_rate):
    
    size_of_batch = len(pos_samples)
    num_to_generate = size_of_batch * negative_rate
    neg_samples = np.tile(pos_samples, (negative_rate, 1))
    labels = np.zeros(size_of_batch * (negative_rate + 1), dtype=np.float32)
    labels[: size_of_batch] = 1
    values = np.random.choice(num_entity, size=num_to_generate)
    choices = np.random.uniform(size=num_to_generate)
    subj = choices > 0.5
    obj = choices <= 0.5
    neg_samples[subj, 0] = values[subj]
    neg_samples[obj, 2] = values[obj]

    return np.concatenate((pos_samples, neg_samples)), labels

In [None]:
from torch_scatter import scatter_add
def edge_normalization(edge_type, edge_index, num_entity, num_relation):
    '''
        Edge normalization trick
        - one_hot: (num_edge, num_relation)
        - deg: (num_node, num_relation)
        - index: (num_edge)
        - deg[edge_index[0]]: (num_edge, num_relation)
        - edge_norm: (num_edge)
    '''
    one_hot = F.one_hot(edge_type.long(), num_classes = 2 * num_relation).to(torch.float)
    deg = scatter_add(one_hot, edge_index[0], dim = 0, dim_size = num_entity)
    index = edge_type + torch.arange(len(edge_index[0])) * (2 * num_relation)
    edge_norm = 1 / deg[edge_index[0]].view(-1)[index]

    return edge_norm

In [None]:
def generate_sampled_graph_and_labels(triplets, sample_size, split_size, num_entity, num_relation, negative_rate):
    """
        Get training graph and signals
        First perform edge neighborhood sampling on graph, then perform negative
        sampling to generate negative samples
    """
    def sample_edge_uniform(n_triples, sample_size):
        """Sample edges uniformly from all the edges."""
        all_edges = np.arange(n_triples)
        return np.random.choice(all_edges, sample_size, replace=False)

    edges = sample_edge_uniform(len(triplets), sample_size)

    # Select sampled edges
    edges = triplets[edges]
    src, rel, dst = edges.transpose()
    uniq_entity, edges = np.unique((src, dst), return_inverse=True)
    src, dst = np.reshape(edges, (2, -1))
    relabeled_edges = np.stack((src, rel, dst)).transpose()

    # Negative sampling
    samples, labels = negative_sampling(relabeled_edges, len(uniq_entity), negative_rate)

    # further split graph, only half of the edges will be used as graph in message passing
    # structure, while the rest half is used as unseen positive samples
    split_size = int(sample_size * split_size)
    graph_split_ids = np.random.choice(np.arange(sample_size),
                                       size=split_size, replace=False)

    src = torch.tensor(src[graph_split_ids], dtype = torch.long).contiguous()
    dst = torch.tensor(dst[graph_split_ids], dtype = torch.long).contiguous()
    rel = torch.tensor(rel[graph_split_ids], dtype = torch.long).contiguous()

    # Create bi-directional graph
    src, dst = torch.cat((src, dst)), torch.cat((dst, src))
    rel = torch.cat((rel, rel + num_relation))

    edge_index = torch.stack((src, dst))
    edge_type = rel


    entity = torch.from_numpy(uniq_entity)
    edge_norm = edge_normalization(edge_type, edge_index, len(uniq_entity), num_relation)
    samples = torch.from_numpy(samples)
    labels = torch.from_numpy(labels)
    
    '''
    # use in message propogation
    entity : sampled unique entity ids, [0 < x < 2N]
    edge_index: [2, N], pairs of (src, dst) node ids, positive only, bi-directional
    edge_type: [N], bi-directional 0<x<36
    edge_norm: [N], normalizating factor
    
    # use in loss calculation (single_direction)
    samples: [3, 2N] : first half pos, sec half neg
    labels: [2N]: [1,1,1,...0,0,0 ]
    '''
    
    return entity, edge_index, edge_type, edge_norm, samples, labels

In [None]:
def build_eval_graph(num_entity, num_relation, triplets):
    src, rel, dst = triplets.transpose()

    src = torch.from_numpy(src)
    rel = torch.from_numpy(rel)
    dst = torch.from_numpy(dst)

    src, dst = torch.cat((src, dst)), torch.cat((dst, src))
    rel = torch.cat((rel, rel + num_relation))

    edge_index = torch.stack((src, dst)).long()
    edge_type = rel

    entity = torch.from_numpy(np.arange(num_entity))
    edge_norm = edge_normalization(edge_type, edge_index, num_entity, num_relation)

    return entity, edge_index, edge_type, edge_norm

In [None]:
#full graph without sampling/neg sampling and spliting
eval_graph = build_eval_graph(num_entity, num_relation, train_triplets) 

all_triplets = torch.LongTensor(np.concatenate((train_triplets, valid_triplets, test_triplets)))
valid_triplets = torch.LongTensor(valid_triplets)
test_triplets = torch.LongTensor(test_triplets)

#### RGCN Model

In [None]:
class RGCNConv(MessagePassing):
    """
    Args:
        in_channels (int): Size of each input sample.
        out_channels (int): Size of each output sample.
        num_relations (int): Number of relations.
        num_bases (int): Number of bases used for basis-decomposition.
        root_weight (bool, optional): If set to :obj:`False`, the layer will
            not add transformed root node features to the output.
            (default: :obj:`True`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """

    def __init__(self, in_channels, out_channels, num_relations, num_bases,
                 root_weight=True, bias=True, **kwargs):
        super(RGCNConv, self).__init__(aggr='mean', **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations
        self.num_bases = num_bases

        self.basis = nn.Parameter(torch.Tensor(num_bases, in_channels, out_channels))
        self.att = nn.Parameter(torch.Tensor(num_relations, num_bases))

        if root_weight:
            self.root = nn.Parameter(torch.Tensor(in_channels, out_channels))
        else:
            self.register_parameter('root', None)

        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        size = self.num_bases * self.in_channels
        uniform(size, self.basis)
        uniform(size, self.att)
        uniform(size, self.root)
        uniform(size, self.bias)


    def forward(self, x, edge_index, edge_type, edge_norm=None, size=None):
        """"""
        return self.propagate(edge_index, size=size, x=x, edge_type=edge_type,
                              edge_norm=edge_norm)


    def message(self, x_j, edge_index_j, edge_type, edge_norm):
        w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))

        # If no node features are given, we implement a simple embedding
        # loopkup based on the target node index and its edge type.
        if x_j is None:
            w = w.view(-1, self.out_channels)
            index = edge_type * self.in_channels + edge_index_j
            out = torch.index_select(w, 0, index)
        else:
            w = w.view(self.num_relations, self.in_channels, self.out_channels)
            w = torch.index_select(w, 0, edge_type)
            out = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2)

        return out if edge_norm is None else out * edge_norm.view(-1, 1)

    def update(self, aggr_out, x):
        if self.root is not None:
            if x is None:
                out = aggr_out + self.root
            else:
                out = aggr_out + torch.matmul(x, self.root)

        if self.bias is not None:
            out = out + self.bias
        return out

    def __repr__(self):
        return '{}({}, {}, num_relations={})'.format(
            self.__class__.__name__, self.in_channels, self.out_channels,
            self.num_relations)

def uniform(size, tensor):
    bound = 1.0 / math.sqrt(size)
    if tensor is not None:
        tensor.data.uniform_(-bound, bound)

In [None]:
class RGCN(torch.nn.Module):
    def __init__(self, num_entities, num_relations, num_bases, dropout):
        super(RGCN, self).__init__()

        self.entity_embedding = nn.Embedding(num_entities, 100)
        self.relation_embedding = nn.Parameter(torch.Tensor(num_relations, 100))

        nn.init.xavier_uniform_(self.relation_embedding, gain=nn.init.calculate_gain('relu'))

        self.conv1 = RGCNConv(
            100, 100, num_relations * 2, num_bases=num_bases)
        self.conv2 = RGCNConv(
            100, 100, num_relations * 2, num_bases=num_bases)

        self.dropout_ratio = dropout

    def forward(self, entity, edge_index, edge_type, edge_norm):
        x = self.entity_embedding(entity)
        x = F.relu(self.conv1(x, edge_index, edge_type, edge_norm))
        x = F.dropout(x, p = self.dropout_ratio, training = self.training)
        x = self.conv2(x, edge_index, edge_type, edge_norm)
        
        return x

    def distmult(self, embedding, triplets):
        s = embedding[triplets[:,0]]
        r = self.relation_embedding[triplets[:,1]]
        o = embedding[triplets[:,2]]
        score = torch.sum(s * r * o, dim=1)
        
        return score

    def score_loss(self, embedding, triplets, target):
        score = self.distmult(embedding, triplets)

        return F.binary_cross_entropy_with_logits(score, target)

    def reg_loss(self, embedding):
        return torch.mean(embedding.pow(2)) + torch.mean(self.relation_embedding.pow(2))

In [None]:
sample_size = 3000
split_size = 0.5
negative_rate = 1

In [None]:
num_epochs = 10000
batch_size = 5000
evaluate_interval = 1000
num_bases = 4
dropout = 0.2
lr = 0.01 
grad_norm = 1.0
regularization = 1e-2
cuda = True

In [None]:
model = RGCN(num_entity, num_relation, num_bases=num_bases, dropout=dropout)
optimizer = torch.optim.Adam(model.parameters(), lr = lr )
print(model)
if cuda:
    model.cuda()

In [None]:

entity, edge_index, edge_type, edge_norm, samples, labels = generate_sampled_graph_and_labels(train_triplets,
                                                                                  sample_size = batch_size,
                                                                                  split_size = 0.5,
                                                                                  num_entity = num_entity,
                                                                                  num_relation = num_relation,
                                                                                  negative_rate = negative_rate)

if cuda:
    entity = entity.cuda()
    edge_index = edge_index.cuda()
    edge_type = edge_type.cuda()
    edge_norm = edge_norm.cuda()
    samples = samples.cuda()
    labels = labels.cuda()

entity_embedding = model(entity, edge_index, edge_type, edge_norm)
loss = model.score_loss(entity_embedding, samples, labels) + regularization * model.reg_loss(entity_embedding)


In [None]:
best_mrr  = 0.
for epoch in range(1, (num_epochs + 1)):

    model.train()
    optimizer.zero_grad()
    
    # sample training data
    entity, edge_index, edge_type, edge_norm, samples, labels = generate_sampled_graph_and_labels(train_triplets,
                                                                                      sample_size = batch_size,
                                                                                      split_size = 0.5,
                                                                                      num_entity = num_entity,
                                                                                      num_relation = num_relation,
                                                                                      negative_rate = negative_rate)

    if cuda:
        entity = entity.cuda()
        edge_index = edge_index.cuda()
        edge_type = edge_type.cuda()
        edge_norm = edge_norm.cuda()
        samples = samples.cuda()
        labels = labels.cuda()
    
    # traning and back prop
    entity_embedding = model(entity, edge_index, edge_type, edge_norm)
    loss = model.score_loss(entity_embedding, samples, labels) + regularization * model.reg_loss(entity_embedding)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm)
    optimizer.step()

    if epoch % evaluate_interval == 0:

        print("Train Loss {} at epoch {}".format(loss, epoch))

#         if cuda:
#             model.cpu()

#         model.eval()
#         valid_mrr = evaluate(valid_triplets, model, eval_graph, all_triplets)

#         if valid_mrr > best_mrr:
#             best_mrr = valid_mrr
#             torch.save({'state_dict': model.state_dict(), 'epoch': epoch},
#                         'best_mrr_model.pth')

#         if cuda:
#             model.cuda()

if cuda:
    model.cpu()

model.eval()

# checkpoint = torch.load('best_mrr_model.pth')
# model.load_state_dict(checkpoint['state_dict'])
# test_mrr = evaluate(test_triplets, model, eval_graph, all_triplets)


<!-- Train Loss 0.13028721511363983 at epoch 1000
Train Loss 0.09021419286727905 at epoch 2000
Train Loss 0.07641539722681046 at epoch 3000
Train Loss 0.06500016152858734 at epoch 4000
Train Loss 0.059444960206747055 at epoch 5000
Train Loss 0.056531891226768494 at epoch 6000
Train Loss 0.052266258746385574 at epoch 7000
Train Loss 0.04904014989733696 at epoch 8000
Train Loss 0.05327734351158142 at epoch 9000
Train Loss 0.03978230431675911 at epoch 10000

MRR (filtered): 0.191787
Hits (filtered) @ 1: 0.116500
Hits (filtered) @ 3: 0.206700
Hits (filtered) @ 10: 0.340600 -->

In [None]:
Train Loss 0.13028721511363983 at epoch 1000
Train Loss 0.09021419286727905 at epoch 2000
Train Loss 0.07641539722681046 at epoch 3000
Train Loss 0.06500016152858734 at epoch 4000
Train Loss 0.059444960206747055 at epoch 5000
Train Loss 0.056531891226768494 at epoch 6000
Train Loss 0.052266258746385574 at epoch 7000
Train Loss 0.04904014989733696 at epoch 8000
Train Loss 0.05327734351158142 at epoch 9000
Train Loss 0.03978230431675911 at epoch 10000

MRR (filtered): 0.191787
Hits (filtered) @ 1: 0.116500
Hits (filtered) @ 3: 0.206700
Hits (filtered) @ 10: 0.340600