Skip to content

Commit

Permalink
update size(0) to size(node_dim)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 17, 2020
1 parent 443115e commit addcdef
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 35 deletions.
8 changes: 4 additions & 4 deletions torch_geometric/nn/conv/agnn_conv.py
Expand Up @@ -28,7 +28,6 @@ class AGNNConv(MessagePassing):
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""

def __init__(self, requires_grad=True, **kwargs):
super(AGNNConv, self).__init__(aggr='add', **kwargs)

Expand All @@ -48,12 +47,13 @@ def reset_parameters(self):
def forward(self, x, edge_index):
""""""
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
edge_index, _ = add_self_loops(edge_index,
num_nodes=x.size(self.node_dim))

x_norm = F.normalize(x, p=2, dim=-1)

return self.propagate(
edge_index, x=x, x_norm=x_norm, num_nodes=x.size(0))
return self.propagate(edge_index, x=x, x_norm=x_norm,
num_nodes=x.size(self.node_dim))

def message(self, edge_index_i, x_j, x_norm_i, x_norm_j, num_nodes):
# Compute attention coefficients.
Expand Down
5 changes: 2 additions & 3 deletions torch_geometric/nn/conv/appnp.py
Expand Up @@ -27,16 +27,15 @@ class APPNP(MessagePassing):
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""

def __init__(self, K, alpha, bias=True, **kwargs):
super(APPNP, self).__init__(aggr='add', **kwargs)
self.K = K
self.alpha = alpha

def forward(self, x, edge_index, edge_weight=None):
""""""
edge_index, norm = GCNConv.norm(
edge_index, x.size(0), edge_weight, dtype=x.dtype)
edge_index, norm = GCNConv.norm(edge_index, x.size(self.node_dim),
edge_weight, dtype=x.dtype)

hidden = x
for k in range(self.K):
Expand Down
12 changes: 2 additions & 10 deletions torch_geometric/nn/conv/arma_conv.py
Expand Up @@ -41,16 +41,8 @@ class ARMAConv(torch.nn.Module):
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
"""

def __init__(self,
in_channels,
out_channels,
num_stacks=1,
num_layers=1,
shared_weights=False,
act=F.relu,
dropout=0,
bias=True):
def __init__(self, in_channels, out_channels, num_stacks=1, num_layers=1,
shared_weights=False, act=F.relu, dropout=0, bias=True):
super(ARMAConv, self).__init__()

self.in_channels = in_channels
Expand Down
5 changes: 2 additions & 3 deletions torch_geometric/nn/conv/dna_conv.py
Expand Up @@ -213,7 +213,6 @@ class DNAConv(MessagePassing):
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""

def __init__(self, channels, heads=1, groups=1, dropout=0, cached=False,
bias=True, **kwargs):
super(DNAConv, self).__init__(aggr='add', **kwargs)
Expand Down Expand Up @@ -250,8 +249,8 @@ def forward(self, x, edge_index, edge_weight=None):

if not self.cached or self.cached_result is None:
self.cached_num_edges = edge_index.size(1)
edge_index, norm = GCNConv.norm(edge_index, x.size(0), edge_weight,
dtype=x.dtype)
edge_index, norm = GCNConv.norm(edge_index, x.size(self.node_dim),
edge_weight, dtype=x.dtype)
self.cached_result = edge_index, norm

edge_index, norm = self.cached_result
Expand Down
5 changes: 2 additions & 3 deletions torch_geometric/nn/conv/feast_conv.py
Expand Up @@ -32,7 +32,6 @@ class FeaStConv(MessagePassing):
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""

def __init__(self, in_channels, out_channels, heads=1, bias=True,
**kwargs):
super(FeaStConv, self).__init__(aggr='mean', **kwargs)
Expand All @@ -41,8 +40,8 @@ def __init__(self, in_channels, out_channels, heads=1, bias=True,
self.out_channels = out_channels
self.heads = heads

self.weight = Parameter(
torch.Tensor(in_channels, heads * out_channels))
self.weight = Parameter(torch.Tensor(in_channels,
heads * out_channels))
self.u = Parameter(torch.Tensor(in_channels, heads))
self.c = Parameter(torch.Tensor(heads))

Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/nn/conv/gat_conv.py
Expand Up @@ -46,7 +46,6 @@ class GATConv(MessagePassing):
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""

def __init__(self, in_channels, out_channels, heads=1, concat=True,
negative_slope=0.2, dropout=0, bias=True, **kwargs):
super(GATConv, self).__init__(aggr='add', **kwargs)
Expand All @@ -58,8 +57,8 @@ def __init__(self, in_channels, out_channels, heads=1, concat=True,
self.negative_slope = negative_slope
self.dropout = dropout

self.weight = Parameter(
torch.Tensor(in_channels, heads * out_channels))
self.weight = Parameter(torch.Tensor(in_channels,
heads * out_channels))
self.att = Parameter(torch.Tensor(1, heads, 2 * out_channels))

if bias and concat:
Expand All @@ -80,7 +79,8 @@ def forward(self, x, edge_index, size=None):
""""""
if size is None and torch.is_tensor(x):
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
edge_index, _ = add_self_loops(edge_index,
num_nodes=x.size(self.node_dim))

if torch.is_tensor(x):
x = torch.matmul(x, self.weight)
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/nn/conv/sage_conv.py
Expand Up @@ -65,7 +65,7 @@ def forward(self, x, edge_index, edge_weight=None, size=None,
"""
if not self.concat and torch.is_tensor(x):
edge_index, edge_weight = add_remaining_self_loops(
edge_index, edge_weight, 1, x.size(0))
edge_index, edge_weight, 1, x.size(self.node_dim))

return self.propagate(edge_index, size=size, x=x,
edge_weight=edge_weight, res_n_id=res_n_id)
Expand Down
5 changes: 2 additions & 3 deletions torch_geometric/nn/conv/sg_conv.py
Expand Up @@ -29,7 +29,6 @@ class SGConv(MessagePassing):
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""

def __init__(self, in_channels, out_channels, K=1, cached=False, bias=True,
**kwargs):
super(SGConv, self).__init__(aggr='add', **kwargs)
Expand Down Expand Up @@ -63,8 +62,8 @@ def forward(self, x, edge_index, edge_weight=None):

if not self.cached or self.cached_result is None:
self.cached_num_edges = edge_index.size(1)
edge_index, norm = GCNConv.norm(edge_index, x.size(0), edge_weight,
dtype=x.dtype)
edge_index, norm = GCNConv.norm(edge_index, x.size(self.node_dim),
edge_weight, dtype=x.dtype)

for k in range(self.K):
x = self.propagate(edge_index, x=x, norm=norm)
Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/nn/conv/tag_conv.py
@@ -1,7 +1,7 @@
import torch
from torch.nn import Linear
from torch_scatter import scatter_add
from torch_geometric.nn.conv import MessagePassing
import torch
from torch_geometric.nn.conv import MessagePassing, GCNConv


class TAGConv(MessagePassing):
Expand Down Expand Up @@ -54,8 +54,8 @@ def norm(edge_index, num_nodes, edge_weight=None, dtype=None):

def forward(self, x, edge_index, edge_weight=None):
""""""
edge_index, norm = self.norm(edge_index, x.size(0), edge_weight,
dtype=x.dtype)
edge_index, norm = GCNConv.norm(edge_index, x.size(self.node_dim),
edge_weight, dtype=x.dtype)

xs = [x]
for k in range(self.K):
Expand Down

0 comments on commit addcdef

Please sign in to comment.