Skip to content

Commit

Permalink
Doc improvements to set2set layers (#3889)
Browse files Browse the repository at this point in the history
* added type hints

* doc updates

* set2set parameter to accept None arg

* added shapes of input and output [docs]

* minor fix

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>

* small update

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
arunppsg and rusty1s committed Jan 20, 2022
1 parent 0805d05 commit 3e4891b
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions torch_geometric/nn/glob/set2set.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Optional

import torch
from torch import Tensor
from torch_scatter import scatter_add

from torch_geometric.utils import softmax


Expand Down Expand Up @@ -27,8 +31,17 @@ class Set2Set(torch.nn.Module):
:obj:`num_layers=2` would mean stacking two LSTMs together to form
a stacked LSTM, with the second LSTM taking in outputs of the first
LSTM and computing the final results. (default: :obj:`1`)
Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F)`,
batch vector :math:`(|\mathcal{V}|)` *(optional)*
- **output:**
set features :math:`(|\mathcal{G}|, 2 * F)` where
:math:`|\mathcal{G}|` denotes the number of graphs in the batch
"""
def __init__(self, in_channels, processing_steps, num_layers=1):
def __init__(self, in_channels: int, processing_steps: int,
num_layers: int = 1):
super().__init__()

self.in_channels = in_channels
Expand All @@ -44,8 +57,16 @@ def __init__(self, in_channels, processing_steps, num_layers=1):
def reset_parameters(self):
self.lstm.reset_parameters()

def forward(self, x, batch):
""""""
def forward(self, x: Tensor, batch: Optional[Tensor] = None) -> Tensor:
r"""
Args:
x (Tensor): The input node features.
batch (LongTensor, optional): A vector that maps each node to its
respective graph identifier. (default: :obj:`None`)
"""
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.int64)

batch_size = batch.max().item() + 1

h = (x.new_zeros((self.num_layers, batch_size, self.in_channels)),
Expand Down

0 comments on commit 3e4891b

Please sign in to comment.