Skip to content

Commit

Permalink
working cheb conv
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 7, 2018
1 parent 52e1a67 commit e2db3b3
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 28 deletions.
5 changes: 5 additions & 0 deletions examples/cora_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from torch_geometric.utils import DataLoader2 # noqa
from torch_geometric.nn.modules import GraphConv # noqa


def preprocess_input(x):
return x.sum(1, keepdim=True).pow_(-1) * x


path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '..', 'data', 'Cora')
data = Cora(path)[0].cuda().to_variable()
Expand Down
22 changes: 17 additions & 5 deletions torch_geometric/nn/functional/cheb_conv.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
import torch

from ...sparse import SparseTensor
from .graph_conv import SparseMM


def cheb_conv(x, index, weight, bias=None):
K = weight.size(0)
row, col = index
def cheb_conv(x, edge_index, weight, edge_attr=None, bias=None):
row, col = edge_index
n, e, K = x.size(0), row.size(0), weight.size(0)
# raise NotImplementedError

# Create normalized laplacian.
lap = None
if edge_attr is None:
edge_attr = x.data.new(e).fill_(1)

# Compute degree.
degree = x.data.new(n).fill_(0).scatter_add_(0, row, edge_attr)
degree = degree.pow_(-0.5)

# Compute normalized and rescaled Laplacian.
edge_attr *= degree[row]
edge_attr *= degree[col]
lap = SparseTensor(edge_index, -edge_attr, torch.Size([n, n]))

# Convolution.
Tx_0 = x
output = torch.mm(Tx_0, weight[0])

Expand Down
26 changes: 15 additions & 11 deletions torch_geometric/nn/functional/graph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,26 @@ def backward(self, grad_output):
return grad_input


def graph_conv(x, index, weight, bias=None):
row, col = index
def graph_conv(x, edge_index, weight, edge_attr=None, bias=None):
row, col = edge_index
n, e = x.size(0), row.size(0)

# Preprocess.
zero, one = x.data.new(n).fill_(0), x.data.new(e).fill_(1)
degree = zero.scatter_add_(0, row, one) + 1
if edge_attr is None:
edge_attr = x.data.new(e).fill_(1)

# Compute degree.
degree = x.data.new(n).fill_(0).scatter_add_(0, row, edge_attr) + 1
degree = degree.pow_(-0.5)

value = degree[row] * degree[col]
value = torch.cat([value, degree * degree], dim=0)
loop = torch.arange(0, n, out=index.new()).view(1, -1).repeat(2, 1)
index = torch.cat([index, loop], dim=1)
adj = SparseTensor(index, value, torch.Size([n, n]))
# Normalize adjacency matrix.
edge_attr *= degree[row]
edge_attr *= degree[col]
edge_attr = torch.cat([edge_attr, degree * degree], dim=0)
loop = torch.arange(0, n, out=row.new()).view(1, -1).repeat(2, 1)
edge_index = torch.cat([edge_index, loop], dim=1)
adj = SparseTensor(edge_index, edge_attr, torch.Size([n, n]))

# Start computation.
# Convolution.
output = SparseMM(adj)(torch.mm(x, weight))

if bias is not None:
Expand Down
19 changes: 9 additions & 10 deletions torch_geometric/nn/modules/cheb_conv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torch.nn import Module, Parameter
from torch.nn import Module, Parameter as Param

from .utils.inits import uniform
from .utils.repr import repr
Expand All @@ -14,34 +14,33 @@ class ChebConv(Module):
Args:
in_features (int): Size of each input sample.
out_features (int): Size of each output sample.
kernel_size (int): Chebyshev filter size, i.e. number of hops.
K (int): Chebyshev filter size, i.e. number of hops.
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
"""

def __init__(self, in_features, out_features, kernel_size, bias=True):
def __init__(self, in_features, out_features, K, bias=True):
super(ChebConv, self).__init__()

self.in_features = in_features
self.out_features = out_features
self.kernel_size = kernel_size
weight = torch.Tensor(kernel_size, in_features, out_features)
self.weight = Parameter(weight)
self.K = K + 1
self.weight = Param(torch.Tensor(self.K, in_features, out_features))

if bias:
self.bias = Parameter(torch.Tensor(out_features))
self.bias = Param(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)

self.reset_parameters()

def reset_parameters(self):
size = self.kernel_size * self.in_features
size = self.K * self.in_features
uniform(size, self.weight)
uniform(size, self.bias)

def forward(self, x, index):
return cheb_conv(x, index, self.weight, self.bias)
def forward(self, x, edge_index, edge_attr=None):
return cheb_conv(x, edge_index, self.weight, edge_attr, self.bias)

def __repr__(self):
return repr(self)
4 changes: 2 additions & 2 deletions torch_geometric/nn/modules/graph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def reset_parameters(self):
uniform(size, self.weight)
uniform(size, self.bias)

def forward(self, x, index):
return graph_conv(x, index, self.weight, self.bias)
def forward(self, x, edge_index, edge_attr=None):
return graph_conv(x, edge_index, self.weight, edge_attr, self.bias)

def __repr__(self):
return repr(self)

0 comments on commit e2db3b3

Please sign in to comment.