Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated docstring for shape info #3697

Merged
merged 10 commits into from Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 18 additions & 1 deletion docs/source/_static/css/custom.css
Expand Up @@ -125,7 +125,7 @@ h2,
text-transform: uppercase;
background: inherit;
border: none;
padding: 6px 0;
margin-bottom: 6px !important;
}

.rst-content dl:not(.docutils) .property:first-child .pre {
Expand Down Expand Up @@ -194,3 +194,20 @@ table.colwidths-given tr td:first-child p,
table.colwidths-given tr th:first-child p {
text-align: left;
}

dl.class dl.simple > dt {
text-transform: uppercase;
background: inherit !important;
color: inherit !important;
border: none !important;
}

html.writer-html5 .rst-content dl.field-list > dt,
html.writer-html5 .rst-content dl.footnote > dt {
padding-left: 0;
}

html.writer-html5 .rst-content dl.field-list,
html.writer-html5 .rst-content dl.footnote {
display: inherit;
}
28 changes: 9 additions & 19 deletions test/nn/conv/test_cg_conv.py
Expand Up @@ -14,15 +14,13 @@ def test_cg_conv():
assert conv.__repr__() == 'CGConv(8, dim=0)'
out = conv(x1, edge_index)
assert out.size() == (4, 8)
assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
assert conv(x1, adj.t()).tolist() == out.tolist()

t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
t = '(Tensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, edge_index).tolist() == out.tolist()
assert jit(x1, edge_index, size=(4, 4)).tolist() == out.tolist()

t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor'
t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, adj.t()).tolist() == out.tolist()

Expand All @@ -31,15 +29,13 @@ def test_cg_conv():
assert conv.__repr__() == 'CGConv((8, 16), dim=0)'
out = conv((x1, x2), edge_index)
assert out.size() == (2, 16)
assert conv((x1, x2), edge_index, size=(4, 2)).tolist() == out.tolist()
assert conv((x1, x2), adj.t()).tolist() == out.tolist()

t = '(PairTensor, Tensor, OptTensor, Size) -> Tensor'
t = '(PairTensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, x2), edge_index).tolist() == out.tolist()
assert jit((x1, x2), edge_index, size=(4, 2)).tolist() == out.tolist()

t = '(PairTensor, SparseTensor, OptTensor, Size) -> Tensor'
t = '(PairTensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, x2), adj.t()).tolist() == out.tolist()

Expand All @@ -49,13 +45,11 @@ def test_cg_conv():
assert conv.__repr__() == 'CGConv(8, dim=0)'
out = conv(x1, edge_index)
assert out.size() == (4, 8)
assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
assert conv(x1, adj.t()).tolist() == out.tolist()

t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
t = '(Tensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, edge_index).tolist() == out.tolist()
assert jit(x1, edge_index, size=(4, 4)).tolist() == out.tolist()


def test_cg_conv_with_edge_features():
Expand All @@ -70,15 +64,13 @@ def test_cg_conv_with_edge_features():
assert conv.__repr__() == 'CGConv(8, dim=3)'
out = conv(x1, edge_index, value)
assert out.size() == (4, 8)
assert conv(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist()
assert conv(x1, adj.t()).tolist() == out.tolist()

t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
t = '(Tensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, edge_index, value).tolist() == out.tolist()
assert jit(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist()

t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor'
t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, adj.t()).tolist() == out.tolist()

Expand All @@ -87,14 +79,12 @@ def test_cg_conv_with_edge_features():
assert conv.__repr__() == 'CGConv((8, 16), dim=3)'
out = conv((x1, x2), edge_index, value)
assert out.size() == (2, 16)
assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out.tolist()
assert conv((x1, x2), adj.t()).tolist() == out.tolist()

t = '(PairTensor, Tensor, OptTensor, Size) -> Tensor'
t = '(PairTensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, x2), edge_index, value).tolist() == out.tolist()
assert jit((x1, x2), edge_index, value, (4, 2)).tolist() == out.tolist()

t = '(PairTensor, SparseTensor, OptTensor, Size) -> Tensor'
t = '(PairTensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, x2), adj.t()).tolist() == out.tolist()
6 changes: 2 additions & 4 deletions test/nn/conv/test_cluster_gcn_conv.py
Expand Up @@ -13,14 +13,12 @@ def test_cluster_gcn_conv():
assert conv.__repr__() == 'ClusterGCNConv(16, 32, diag_lambda=1.0)'
out = conv(x, edge_index)
assert out.size() == (4, 32)
assert conv(x, edge_index, size=(4, 4)).tolist() == out.tolist()
assert torch.allclose(conv(x, adj.t()), out, atol=1e-5)

t = '(Tensor, Tensor, Size) -> Tensor'
t = '(Tensor, Tensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x, edge_index).tolist() == out.tolist()
assert jit(x, edge_index, size=(4, 4)).tolist() == out.tolist()

t = '(Tensor, SparseTensor, Size) -> Tensor'
t = '(Tensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x, adj.t()), out, atol=1e-5)
6 changes: 6 additions & 0 deletions torch_geometric/nn/conv/agnn_conv.py
Expand Up @@ -34,6 +34,12 @@ class AGNNConv(MessagePassing):
self-loops to the input graph. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.

Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F)`,
edge indices :math:`(2, |\mathcal{E}|)`
- **output:** node features :math:`(|\mathcal{V}|, F)`
"""
def __init__(self, requires_grad: bool = True, add_self_loops: bool = True,
**kwargs):
Expand Down
7 changes: 7 additions & 0 deletions torch_geometric/nn/conv/appnp.py
Expand Up @@ -45,6 +45,13 @@ class APPNP(MessagePassing):
symmetric normalization. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.

Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F)`,
edge indices :math:`(2, |\mathcal{E}|)`,
edge weights :math:`(|\mathcal{E}|)` *(optional)*
- **output:** node features :math:`(|\mathcal{V}|, F)`
"""
_cached_edge_index: Optional[Tuple[Tensor, Tensor]]
_cached_adj_t: Optional[SparseTensor]
Expand Down
7 changes: 7 additions & 0 deletions torch_geometric/nn/conv/arma_conv.py
Expand Up @@ -50,6 +50,13 @@ class ARMAConv(MessagePassing):
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.

Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})`,
edge indices :math:`(2, |\mathcal{E}|)`,
edge weights :math:`(|\mathcal{E}|)` *(optional)*
- **output:** node features :math:`(|\mathcal{V}|, F_{out})`
"""
def __init__(self, in_channels: int, out_channels: int,
num_stacks: int = 1, num_layers: int = 1,
Expand Down
16 changes: 13 additions & 3 deletions torch_geometric/nn/conv/cg_conv.py
@@ -1,5 +1,5 @@
from typing import Union, Tuple
from torch_geometric.typing import PairTensor, Adj, OptTensor, Size
from torch_geometric.typing import PairTensor, Adj, OptTensor

import torch
from torch import Tensor
Expand Down Expand Up @@ -39,6 +39,16 @@ class CGConv(MessagePassing):
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.

Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F)` or
:math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))`
if bipartite,
edge indices :math:`(2, |\mathcal{E}|)`,
edge features :math:`(|\mathcal{E}|, D)` *(optional)*
- **output:** node features :math:`(|\mathcal{V}|, F)` or
:math:`(|\mathcal{V_t}|, F_{t})` if bipartite
"""
def __init__(self, channels: Union[int, Tuple[int, int]], dim: int = 0,
aggr: str = 'add', batch_norm: bool = False,
Expand Down Expand Up @@ -67,13 +77,13 @@ def reset_parameters(self):
self.bn.reset_parameters()

def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
edge_attr: OptTensor = None, size: Size = None) -> Tensor:
edge_attr: OptTensor = None) -> Tensor:
""""""
if isinstance(x, Tensor):
x: PairTensor = (x, x)

# propagate_type: (x: PairTensor, edge_attr: OptTensor)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)
out = out if self.bn is None else self.bn(out)
out += x[1]
return out
Expand Down
10 changes: 10 additions & 0 deletions torch_geometric/nn/conv/cheb_conv.py
Expand Up @@ -62,6 +62,16 @@ class ChebConv(MessagePassing):
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.

Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})`,
edge indices :math:`(2, |\mathcal{E}|)`,
edge weights :math:`(|\mathcal{E}|)` *(optional)*,
batch vector :math:`(|\mathcal{V}|)` *(optional)*,
maximum :obj:`lambda` value :math:`(|\mathcal{G}|)` *(optional)*
- **output:** node features :math:`(|\mathcal{V}|, F_{out})`

"""
def __init__(self, in_channels: int, out_channels: int, K: int,
normalization: Optional[str] = 'sym', bias: bool = True,
Expand Down
12 changes: 9 additions & 3 deletions torch_geometric/nn/conv/cluster_gcn_conv.py
@@ -1,4 +1,4 @@
from torch_geometric.typing import Adj, Size, OptTensor
from torch_geometric.typing import Adj, OptTensor

from torch import Tensor
from torch_sparse import SparseTensor, matmul, set_diag, sum as sparsesum
Expand Down Expand Up @@ -32,6 +32,12 @@ class ClusterGCNConv(MessagePassing):
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.

Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})`,
edge indices :math:`(2, |\mathcal{E}|)`
- **output:** node features :math:`(|\mathcal{V}|, F_{out})`
"""
def __init__(self, in_channels: int, out_channels: int,
diag_lambda: float = 0., add_self_loops: bool = True,
Expand All @@ -55,11 +61,11 @@ def reset_parameters(self):
self.lin_out.reset_parameters()
self.lin_root.reset_parameters()

def forward(self, x: Tensor, edge_index: Adj, size: Size = None) -> Tensor:
def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
""""""
edge_weight: OptTensor = None
if isinstance(edge_index, Tensor):
num_nodes = size[1] if size is not None else x.size(self.node_dim)
num_nodes = x.size(self.node_dim)
if self.add_self_loops:
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
Expand Down
7 changes: 7 additions & 0 deletions torch_geometric/nn/conv/dna_conv.py
Expand Up @@ -220,6 +220,13 @@ class DNAConv(MessagePassing):
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.

Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, L, F)` where :math:`L` is the
number of layers,
edge indices :math:`(2, |\mathcal{E}|)`
- **output:** node features :math:`(|\mathcal{V}|, F)`
"""

_cached_edge_index: Optional[Tuple[Tensor, Tensor]]
Expand Down
9 changes: 9 additions & 0 deletions torch_geometric/nn/conv/edge_conv.py
Expand Up @@ -34,6 +34,15 @@ class EdgeConv(MessagePassing):
(default: :obj:`"max"`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.

Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})` or
:math:`((|\mathcal{V}|, F_{in}), (|\mathcal{V}|, F_{in}))`
if bipartite,
edge indices :math:`(2, |\mathcal{E}|)`
- **output:** node features :math:`(|\mathcal{V}|, F_{out})` or
:math:`(|\mathcal{V}_t|, F_{out})` if bipartite
"""
def __init__(self, nn: Callable, aggr: str = 'max', **kwargs):
super().__init__(aggr=aggr, **kwargs)
Expand Down
6 changes: 6 additions & 0 deletions torch_geometric/nn/conv/eg_conv.py
Expand Up @@ -67,6 +67,12 @@ class EGConv(MessagePassing):
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.

Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})`,
edge indices :math:`(2, |\mathcal{E}|)`
- **output:** node features :math:`(|\mathcal{V}|, F_{out})`
"""

_cached_edge_index: Optional[Tuple[Tensor, OptTensor]]
Expand Down
10 changes: 10 additions & 0 deletions torch_geometric/nn/conv/fa_conv.py
Expand Up @@ -50,6 +50,16 @@ class FAConv(MessagePassing):
the layer's :meth:`forward` method. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.

Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F)`,
initial node features :math:`(|\mathcal{V}|, F)`,
edge indices :math:`(2, |\mathcal{E}|)`,
edge weights :math:`(|\mathcal{E}|)` *(optional)*
- **output:** node features :math:`(|\mathcal{V}|, F)` or
:math:`((|\mathcal{V}|, H), ((2, |\mathcal{E}|),
(|\mathcal{E}|)))` if :obj:`return_attention_weights=True`
"""
_cached_edge_index: Optional[Tuple[Tensor, Tensor]]
_cached_adj_t: Optional[SparseTensor]
Expand Down
9 changes: 9 additions & 0 deletions torch_geometric/nn/conv/feast_conv.py
Expand Up @@ -39,6 +39,15 @@ class FeaStConv(MessagePassing):
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.

Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})` or
:math:`((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))`
if bipartite,
edge indices :math:`(2, |\mathcal{E}|)`
- **output:** node features :math:`(|\mathcal{V}|, F_{out})` or
:math:`(|\mathcal{V_t}|, F_{out})` if bipartite
"""
def __init__(self, in_channels: int, out_channels: int, heads: int = 1,
add_self_loops: bool = True, bias: bool = True, **kwargs):
Expand Down
10 changes: 10 additions & 0 deletions torch_geometric/nn/conv/film_conv.py
Expand Up @@ -52,6 +52,16 @@ class FiLMConv(MessagePassing):
(default: :obj:`"mean"`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.

Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})` or
:math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))`
if bipartite,
edge indices :math:`(2, |\mathcal{E}|)`,
edge types :math:`(|\mathcal{E}|)`
- **output:** node features :math:`(|\mathcal{V}|, F_{out})` or
:math:`(|\mathcal{V_t}|, F_{out})` if bipartite
"""
def __init__(
self,
Expand Down
15 changes: 15 additions & 0 deletions torch_geometric/nn/conv/gat_conv.py
Expand Up @@ -81,6 +81,21 @@ class GATConv(MessagePassing):
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.

Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})` or
:math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))`
if bipartite,
edge indices :math:`(2, |\mathcal{E}|)`,
edge features :math:`(|\mathcal{E}|, D)` *(optional)*
- **output:** node features :math:`(|\mathcal{V}|, H * F_{out})` or
:math:`((|\mathcal{V}_t|, H * F_{out})` if bipartite.
If :obj:`return_attention_weights=True`, then
:math:`((|\mathcal{V}|, H * F_{out}),
((2, |\mathcal{E}|), (|\mathcal{E}|, H)))`
or :math:`((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|),
(|\mathcal{E}|, H)))` if bipartite
"""
_alpha: OptTensor

Expand Down