In [215]:
import math

import torch
import torch.nn.functional as F
from torch.nn import Parameter as Param
from torch_geometric.datasets import Entities
from torch_geometric.nn.conv import MessagePassing

In [216]:
name = 'MUTAG'
path = './data/MUTAG'
dataset = Entities(path, name)
data = dataset[0]
print(data)

Data(edge_index=[2, 148454], edge_norm=[148454], edge_type=[148454], test_idx=[68], test_y=[68], train_idx=[272], train_y=[272])


In [217]:
def uniform(size, tensor):
    stdv = 1.0 / math.sqrt(size)
    if tensor is not None:
        tensor.data.uniform_(-stdv, stdv)

In [218]:
class RGCNConv(MessagePassing):
    r"""The relational graph convolutional operator from the `"Modeling
    Relational Data with Graph Convolutional Networks"
    <https://arxiv.org/abs/1703.06103>`_ paper

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{\Theta}_0 \cdot \mathbf{x}_i +
        \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)}
        \frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j,

    where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types.

    Args:
        in_channels (int): Size of each input sample.
        out_channels (int): Size of each output sample.
        num_relations (int): Number of relations.
        num_bases (int): Number of bases used for basis-decomposition.
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 num_relations,
                 num_bases,
                 bias=True):
        super(RGCNConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations
        self.num_bases = num_bases

        self.basis = Param(torch.Tensor(num_bases, in_channels, out_channels))
        self.att = Param(torch.Tensor(num_relations, num_bases))
        self.root = Param(torch.Tensor(in_channels, out_channels))

        if bias:
            self.bias = Param(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        size = self.num_bases * self.in_channels
        uniform(size, self.basis)
        uniform(size, self.att)
        uniform(size, self.root)
        uniform(size, self.bias)


    def forward(self, x, edge_index, edge_type, edge_norm=None):
        print('in_channels: ', self.in_channels)
        print('out_channels: ', self.out_channels)
        print('num_relations: ', self.num_relations)
        print('num_bases: ', self.num_bases)
        
        print('basis: ', self.basis.shape)
        print('att: ', self.att.shape)
        print('root: ', self.root.shape)

        """"""
        if x is None:
            x = torch.arange(
                edge_index.max().item() + 1,
                dtype=torch.long,
                device=edge_index.device)

        print('x: ', x.shape)
        print('edge_index: ', edge_index.shape)
        
        return self.propagate(
            'add', edge_index, x=x, edge_type=edge_type, edge_norm=edge_norm)


    def message(self, x_j, edge_type, edge_norm):
        w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))
        print('w1 ', w.shape)
        print('x_j: ', x_j.shape, x_j.min(), x_j.max())
        print('edge_type: ', edge_type.shape, edge_type.min(), edge_type.max())
        print('edge_norm: ', edge_norm)

        # ネットワークの最初の段階で，one-hot vectorを入力した場合
        if x_j.dtype == torch.long:
            print('torch is long')
            w = w.view(-1, self.out_channels)
            print('w2: ', w.shape)
            index = edge_type * self.in_channels + x_j
            print('index: ', index.shape)
            out = w[index]
            print('out: ', out.shape)
            return out if edge_norm is None else out * edge_norm.view(-1, 1)
        
        # ネットワークの中間層の段階で，中間特徴量を入力した場合
        else:
            print('torch is not long')
            w = w.view(self.num_relations, self.in_channels, self.out_channels)
            w = w[edge_type]
            out = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2)
            return out if edge_norm is None else out * edge_norm.view(-1, 1)

    def update(self, aggr_out, x):
        # propagateで指定されたaggregateが行われた結果が，aggr_outとしてくる．
        print('aggr_out: ', aggr_out.shape, aggr_out.min(), aggr_out.max())
        print('x: ', x.shape, x.min(), x.max())
        print('root: ', self.root.shape, self.root.min(), self.root.max())
        # ネットワークの最初の段階で，one-hot vectorを入力した場合    
        if x.dtype == torch.long:
            print('self.root[x]: ', self.root[x].shape)
            # root[x]を足して，self-loopを別の重みであることを実現している．
            # つまり，rootがself-loopのweightを示している．
            out = aggr_out + self.root[x]
            
        # ネットワークの中間層の段階で，中間特徴量を入力した場合
        else:
            # rootとxをかけて，Wx(self-loop)を行なっている
            out = aggr_out + torch.matmul(x, self.root)

        if self.bias is not None:
            out = out + self.bias
        return out

    def __repr__(self):
        return '{}({}, {}, num_relations={})'.format(
            self.__class__.__name__, self.in_channels, self.out_channels,
            self.num_relations)


In [219]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = RGCNConv(
            data.num_nodes, 16, dataset.num_relations, num_bases=30)
        self.conv2 = RGCNConv(
            16, dataset.num_classes, dataset.num_relations, num_bases=30)

    def forward(self, edge_index, edge_type, edge_norm):
        x = F.relu(self.conv1(None, edge_index, edge_type))
        print(' ')
        x = self.conv2(x, edge_index, edge_type)
        return F.log_softmax(x, dim=1)

In [220]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005)

In [221]:
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.edge_index, data.edge_type, data.edge_norm)
    F.nll_loss(out[data.train_idx], data.train_y).backward()
    optimizer.step()


def test():
    model.eval()
    out = model(data.edge_index, data.edge_type, data.edge_norm)
    pred = out[data.test_idx].max(1)[1]
    acc = pred.eq(data.test_y).sum().item() / data.test_y.size(0)
    return acc

In [222]:
for epoch in range(1):
    train()
#     test_acc = test()
#     print('Epoch: {:02d}, Accuracy: {:.4f}'.format(epoch, test_acc))

in_channels:  23644
out_channels:  16
num_relations:  46
num_bases:  30
basis:  torch.Size([30, 23644, 16])
att:  torch.Size([46, 30])
root:  torch.Size([23644, 16])
x:  torch.Size([23644])
edge_index:  torch.Size([2, 148454])
w1  torch.Size([46, 378304])
x_j:  torch.Size([148454]) tensor(0) tensor(23643)
edge_type:  torch.Size([148454]) tensor(0) tensor(45)
edge_norm:  None
torch is long
w2:  torch.Size([1087624, 16])
index:  torch.Size([148454])
out:  torch.Size([148454, 16])
aggr_out:  torch.Size([23644, 16]) tensor(-0.0003, grad_fn=<MinBackward1>) tensor(0.0003, grad_fn=<MaxBackward1>)
x:  torch.Size([23644]) tensor(0) tensor(23643)
root:  torch.Size([23644, 16]) tensor(-0.0012, grad_fn=<MinBackward1>) tensor(0.0012, grad_fn=<MaxBackward1>)
self.root[x]:  torch.Size([23644, 16])
 
in_channels:  16
out_channels:  2
num_relations:  46
num_bases:  30
basis:  torch.Size([30, 16, 2])
att:  torch.Size([46, 30])
root:  torch.Size([16, 2])
x:  torch.Size([23644, 16])
edge_index:  torch.Siz

# Experiment

In [223]:
in_c = 100
out_c = 10
num_relations = 5

ord_basis = [Param(torch.Tensor(1, in_c, out_c)) for _ in range(num_relations)]

In [224]:
len(ord_basis)

5

In [225]:
tmp, w = 0, 0
for relation in range(num_relations):
    tmp = ord_basis[relation]
    if relation == 0:
        w = tmp
    else:
        w = torch.cat((w, tmp), 0)

In [226]:
w.requires_grad

True

In [227]:
ord_basis[0].requires_grad

True

In [228]:
data.edge_index.shape

torch.Size([2, 148454])

In [229]:
data.edge_type.shape

torch.Size([148454])

In [230]:
dataset.num_relations

46

In [238]:
num_node = data.edge_index.max() + 1

In [262]:
edge = torch.where(data.edge_index == 23000, 
                   torch.ones(data.edge_index.shape, 
                              dtype=torch.int64)*dataset.num_relations, 
                   torch.zeros(data.edge_index.shape, dtype=torch.int64))


print(data.edge_index.dtype, edge.dtype)
print(torch.sum(edge[0]))

relation = data.edge_type + edge[0]

import numpy as np
print(set(list(np.array(relation))))

print(dataset.num_relations)

relation = torch.where(relation >= dataset.num_relations, 
#                        torch.ones(relation.shape, dtype=relation.dtype),
                       relation - dataset.num_relations, 
                       torch.zeros(relation.shape, dtype=relation.dtype))

print(set(list(np.array(relation))))

# edge_normは最終的に，edges x 1のshapeになる必要がある．
# つまり，target_edgeのもつedgeの本数がわかればよく，
# target_edges x 1のshapeで要素にnum_edgesが入る．
# そこからそれの逆行列を取り，それをoutに通常の乗算を行う．

node_norm = torch.zeros(num_node, dtype=data.edge_index.dtype)
print(node_norm.size(0))

import time

start = time.time()
for idx in range(node_norm.size(0)):
    edge = torch.where(
                        data.edge_index == idx,
                        torch.ones(data.edge_index.shape,
                                   dtype=data.edge_index.dtype),
                        torch.zeros(data.edge_index.shape, 
                                    dtype=data.edge_index.dtype)
                        )
    
    node_norm[idx] = torch.sum(edge)
    
print('Elapsed: ', time.time() - start)

torch.int64 torch.int64
tensor(230)
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 49, 53, 54}
46
{0, 8, 3, 7}
23644
Elapsed:  99.88504409790039


In [None]:
print(node_norm.mean())