Skip to content
Permalink
Browse files

hypergraph update

  • Loading branch information...
rusty1s committed May 8, 2019
1 parent 1bcda82 commit ed167bf1441b2dd6a5503a03fec92262bec2ead8
Showing with 122 additions and 135 deletions.
  1. +1 −0 README.md
  2. +24 −4 test/nn/conv/test_hypergraph_conv.py
  3. +97 −131 torch_geometric/nn/conv/hypergraph_conv.py
@@ -152,6 +152,7 @@ In detail, the following methods are currently implemented:
* **[XConv](https://rusty1s.github.io/pytorch_geometric/build/html/modules/nn.html#torch_geometric.nn.conv.XConv)** from Li *et al.*: [PointCNN: Convolution On X-Transformed Points](https://arxiv.org/abs/1801.07791) [(official implementation)](https://github.com/yangyanli/PointCNN) (NeurIPS 2018)
* **[PPFConv](https://rusty1s.github.io/pytorch_geometric/build/html/modules/nn.html#torch_geometric.nn.conv.PPFConv)** from Deng *et al.*: [PPFNet: Global Context Aware Local Features for Robust 3D Point Matching](https://arxiv.org/abs/1802.02669) (CVPR 2018)
* **[GMMConv](https://rusty1s.github.io/pytorch_geometric/build/html/modules/nn.html#torch_geometric.nn.conv.GMMConv)** from Monti *et al.*: [Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs](https://arxiv.org/abs/1611.08402) (CVPR 2017)
* **[HypergraphConv](https://rusty1s.github.io/pytorch_geometric/build/html/modules/nn.html#torch_geometric.nn.conv.HypergraphConv)** from Bai *et al.*: [Hypergraph Convolution and Hypergraph Attention](https://arxiv.org/abs/1901.08150) (CoRR 2019)
* A **[MetaLayer](https://rusty1s.github.io/pytorch_geometric/build/html/modules/nn.html#torch_geometric.nn.meta.MetaLayer)** for building any kind of graph network similar to the [TensorFlow Graph Nets library](https://github.com/deepmind/graph_nets) from Battaglia *et al.*: [Relational Inductive Biases, Deep Learning, and Graph Networks](https://arxiv.org/abs/1806.01261) (CoRR 2018)
* **[GlobalAttention](https://rusty1s.github.io/pytorch_geometric/build/html/modules/nn.html#torch_geometric.nn.glob.GlobalAttention)** from Li *et al.*: [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493) (ICLR 2016)
* **[Set2Set](https://rusty1s.github.io/pytorch_geometric/build/html/modules/nn.html#torch_geometric.nn.glob.Set2Set)** from Vinyals *et al.*: [Order Matters: Sequence to Sequence for Sets](https://arxiv.org/abs/1511.06391) (ICLR 2016)
@@ -4,12 +4,32 @@

def test_hypergraph_conv():
in_channels, out_channels = (16, 32)
hyper_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
hyper_weight = torch.tensor([1, 0.5, 0.3, 0.7])
num_nodes = hyper_index.max().item() + 1
hyperedge_index = torch.tensor([[0, 0, 1, 1, 2, 3], [0, 1, 0, 1, 0, 1]])
hyperedge_weight = torch.tensor([1, 0.5])
num_nodes = hyperedge_index[0].max().item() + 1
x = torch.randn((num_nodes, in_channels))

conv = HypergraphConv(in_channels, out_channels)
assert conv.__repr__() == 'HypergraphConv(16, 32)'
out = conv(x, hyper_index, hyper_weight)
out = conv(x, hyperedge_index)
assert out.size() == (num_nodes, out_channels)
out = conv(x, hyperedge_index, hyperedge_weight)
assert out.size() == (num_nodes, out_channels)

conv = HypergraphConv(in_channels,
out_channels,
use_attention=True,
heads=2)
out = conv(x, hyperedge_index)
assert out.size() == (num_nodes, 2 * out_channels)
out = conv(x, hyperedge_index, hyperedge_weight)
assert out.size() == (num_nodes, 2 * out_channels)

conv = HypergraphConv(in_channels,
out_channels,
use_attention=True,
heads=2,
concat=True,
dropout=0.5)
out = conv(x, hyperedge_index, hyperedge_weight)
assert out.size() == (num_nodes, 2 * out_channels)
@@ -1,20 +1,25 @@
import torch
from torch import nn
from torch.nn import Parameter
from torch.nn import functional as F
from torch_scatter import scatter_add
from torch_geometric.utils import softmax
from torch_geometric.utils import softmax, degree
from torch_geometric.nn.conv import MessagePassing

from ..inits import glorot, zeros


class HypergraphConv(MessagePassing):
r"""
The Hypergraph convolutional operator fro the `"Hypergraph Convolution
and Hypergraph Attention"<https://arxiv.org/pdf/1901.08150.pdf>`_ paper
r"""The hypergraph convolutional operator from the `"Hypergraph Convolution
and Hypergraph Attention" <https://arxiv.org/abs/1901.08150>`_ paper
.. math::
\mathbf{X}^{\prime} =\sigma\left(\mathbf{D}^{-1} \mathbf{H}
\mathbf{W} \mathbf{B}^{-1} \mathbf{H}^{\top} \mathbf{X} \mathbf{P}
\right)
\mathbf{X}^{\prime} = \mathbf{D}^{-1} \mathbf{H} \mathbf{W}
\mathbf{B}^{-1} \mathbf{H}^{\top} \mathbf{X} \mathbf{\Theta}
where :math:`\mathbf{H} \in {\{ 0, 1 \}}^{N \times M}` is the incidence
matrix, :math:`\mathbf{W}` is the diagonal hyperedge weight matrix, and
:math:`\mathbf{D}` and :math:`\mathbf{B}` are the corresponding degree
matrices.
Args:
in_channels (int): Size of each input sample.
@@ -44,145 +49,106 @@ def __init__(self,
negative_slope=0.2,
dropout=0,
bias=True):
super().__init__("add", flow="target_to_source")
super().__init__('add')

self.in_channels = in_channels
self.out_channels = out_channels
self.use_attention = use_attention

self.linear = nn.Linear(in_channels, heads * out_channels, bias=False)

if self.use_attention:
self.head = heads
self.alpha_initialized = False
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.aggregate_method = "cat" if concat else "avg"
self.attrs = nn.Parameter(torch.Tensor(1, heads, 2 * out_channels))

def norm(self,
hyper_edge_index,
dim_size,
hyper_edge_weight=None,
dim=0,
dtype=None):
"""
Args:
hyper_index (Tensor): hype edge connect with node ,[2,E]
dim_size (int): size of degree
hyper_edge_weight (Tensor): weight of all hyper edges
dim (int): normalization in which dimension
dtype (torch.Dtype): data type of newly created tensor
:return: attention weight for each edge between nodes and hyper edges
"""
if hyper_edge_weight is None:
hyper_edge_weight = torch.ones((hyper_edge_index.size(1), ),
dtype=dtype,
device=hyper_edge_index.device)
self.weight = Parameter(
torch.Tensor(in_channels, heads * out_channels))
self.att = Parameter(torch.Tensor(1, heads, 2 * out_channels))
else:
hyper_edge_weight = torch.index_select(hyper_edge_weight,
dim=0,
index=hyper_edge_index[1])
if self.use_attention:
hyper_edge_weight = hyper_edge_weight.view(
-1, 1, 1) * self.attention_weight

# hyper_edge_weight = hyper_edge_weight.view(-1)
deg = scatter_add(hyper_edge_weight,
hyper_edge_index[dim],
dim=0,
dim_size=dim_size)
deg_inv = deg.pow(-1)
deg_inv[deg_inv == float("inf")] = 0
return torch.index_select(deg_inv, 0,
hyper_edge_index[dim]) * hyper_edge_weight

def attention(self, x, hyper_edge_index):
"""
Args:
x (Tensor): features for all nodes in a graph
hyper_edge_index (Tensor): hype edge connect with node ,[2,E]
"""
row, col = hyper_edge_index
x_i = torch.index_select(x, dim=0, index=row)
x_j = torch.index_select(x, dim=0, index=col)
alpha = (torch.cat([x_i, x_j], dim=-1) * self.attrs).sum(dim=-1)
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, row, x.size(0))
if self.training and self.dropout > 0:
alpha = F.dropout(alpha, p=self.dropout, training=True)
return alpha.view(-1, self.head, 1)

def forward(self, x, hyper_edge_index, hyper_edge_weight):
"""
Args:
x (Tensor): node feature matrix, [N,C].
hyper_edge_index (Tensor): hype edge connect with node ,[2,E]
hyper_edge_weight (Tensor): weight for each unique hyper edge, [M,]
"""
x = self.linear(x)
self.heads = 1
self.concat = True
self.weight = Parameter(torch.Tensor(in_channels, out_channels))

if bias:
self.bias = Parameter(torch.Tensor(self.heads * out_channels))
else:
self.register_parameter('bias', None)

self.reset_parameters()

def reset_parameters(self):
glorot(self.weight)
if self.use_attention:
# sim compute

assert hyper_edge_weight.size(0) == x.size(
0), "Attention can only applied when hyper edge is node"
x = x.view(-1, self.head, self.out_channels)
self.attention_weight = self.attention(x, hyper_edge_index)
norm_HW = self.norm(
hyper_edge_index,
x.size(0),
hyper_edge_weight=hyper_edge_weight,
dim=0,
dtype=x.dtype,
)
norm_H = self.norm(hyper_edge_index,
hyper_edge_weight.size(0),
dim=1,
dtype=x.dtype)
#
tmp = self.propagate(
edge_index=hyper_edge_index,
x=x,
norm=norm_H,
size=(hyper_edge_weight.size(0), x.size(0)),
) # M,C
return self.propagate(
edge_index=hyper_edge_index,
x=tmp,
norm=norm_HW,
size=(x.size(0), hyper_edge_weight.size(0)),
) # N,C

def message(self, x_j, norm):
"""
glorot(self.att)
zeros(self.bias)

def __forward__(self,
x,
hyperedge_index,
hyperedge_weight=None,
alpha=None):

if hyperedge_weight is None:
D = degree(hyperedge_index[0], x.size(0), x.dtype)
else:
D = scatter_add(hyperedge_weight[hyperedge_index[1]],
hyperedge_index[0],
dim=0,
dim_size=x.size(0))
D = 1.0 / D
D[D == float("inf")] = 0

num_edges = hyperedge_index[1].max().item() + 1
B = 1.0 / degree(hyperedge_index[1], num_edges, x.dtype)
B[B == float("inf")] = 0
if hyperedge_weight is not None:
B = B * hyperedge_weight

self.flow = 'source_to_target'
out = self.propagate(hyperedge_index, x=x, norm=B, alpha=alpha)
self.flow = 'target_to_source'
out = self.propagate(hyperedge_index, x=out, norm=D, alpha=alpha)
return out

def message(self, x_j, edge_index_i, norm, alpha):
out = norm[edge_index_i].view(-1, 1, 1) * x_j.view(
-1, self.heads, self.out_channels)
if alpha is not None:
out = alpha.view(-1, self.heads, 1) * out
return out

def forward(self, x, hyperedge_index, hyperedge_weight=None):
r"""
Args:
x_j (Tensor): features of connected nodes
norm (Tensor): transition weights for connected nodes
x (Tensor): Node feature matrix :math:`\mathbf{X}`
hyper_edge_index (LongTensor): Hyperedge indices from
:math:`\mathbf{H}`.
hyperedge_weight (Tensor, optional): Sparse hyperedge weights from
:math:`\mathbf{W}`. (default: :obj:`None`)
"""
x = torch.matmul(x, self.weight)
alpha = None

if self.use_attention:
return norm * x_j.view(-1, self.head, self.out_channels)
return norm.view(-1, 1) * x_j # E,1
x = x.view(-1, self.heads, self.out_channels)
x_i, x_j = x[hyperedge_index[0]], x[hyperedge_index[1]]
alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, hyperedge_index[0], x.size(0))

def att_aggregate(self, x):
"""
x (Tensor): [N,head_num,C]
"""
if self.aggregate_method == "cat":
return x.view(-1, self.head * self.out_channels)
elif self.aggregate_method == "avg":
return x.mean(dim=1)
if self.training and self.dropout > 0:
alpha = F.dropout(alpha, p=self.dropout, training=True)

out = self.__forward__(x, hyperedge_index, hyperedge_weight, alpha)

if self.concat is True:
out = out.view(-1, self.heads * self.out_channels)
else:
raise NotImplementedError(
"Aggregation method %s is not implemented.")
out = out.mean(dim=1)

def update(self, aggr_out):
"""
Args:
aggr_out (Tensor): output of hyper graph layer
"""
if self.use_attention:
aggr_out = self.att_aggregate(aggr_out)
return aggr_out
if self.bias is not None:
out = out + self.bias

return out

def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.in_channels,

0 comments on commit ed167bf

Please sign in to comment.
You can’t perform that action at this time.