Skip to content

Commit

Permalink
Doc fixes to GlobalAttention (#3946)
Browse files Browse the repository at this point in the history
* added shape

* batch to accept None parameter

* added type hints

* update

* typo

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
arunppsg and rusty1s committed Jan 26, 2022
1 parent d74a132 commit defa0e7
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 14 deletions.
33 changes: 28 additions & 5 deletions torch_geometric/nn/glob/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Optional

import torch
from torch import Tensor
from torch_scatter import scatter_add

from torch_geometric.utils import softmax
Expand All @@ -17,7 +20,7 @@ class GlobalAttention(torch.nn.Module):
where :math:`h_{\mathrm{gate}} \colon \mathbb{R}^F \to
\mathbb{R}` and :math:`h_{\mathbf{\Theta}}` denote neural networks, *i.e.*
MLPS.
MLPs.
Args:
gate_nn (torch.nn.Module): A neural network :math:`h_{\mathrm{gate}}`
Expand All @@ -29,8 +32,17 @@ class GlobalAttention(torch.nn.Module):
shape :obj:`[-1, in_channels]` to shape :obj:`[-1, out_channels]`
before combining them with the attention scores, *e.g.*, defined by
:class:`torch.nn.Sequential`. (default: :obj:`None`)
Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F)`,
batch vector :math:`(|\mathcal{V}|)` *(optional)*
- **output:**
graph features :math:`(|\mathcal{G}|, 2 * F)` where
:math:`|\mathcal{G}|` denotes the number of graphs in the batch
"""
def __init__(self, gate_nn, nn=None):
def __init__(self, gate_nn: torch.nn.Module,
nn: Optional[torch.nn.Module] = None):
super().__init__()
self.gate_nn = gate_nn
self.nn = nn
Expand All @@ -41,10 +53,21 @@ def reset_parameters(self):
reset(self.gate_nn)
reset(self.nn)

def forward(self, x, batch, size=None):
""""""
def forward(self, x: Tensor, batch: Optional[Tensor] = None,
size: Optional[int] = None) -> Tensor:
r"""
Args:
x (Tensor): The input node features.
batch (LongTensor, optional): A vector that maps each node to its
respective graph identifier. (default: :obj:`None`)
size (int, optional): The number of graphs in the batch.
(default: :obj:`None`)
"""
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.int64)

x = x.unsqueeze(-1) if x.dim() == 1 else x
size = batch[-1].item() + 1 if size is None else size
size = int(batch.max()) + 1 if size is None else size

gate = self.gate_nn(x).view(-1, 1)
x = self.nn(x) if self.nn is not None else x
Expand Down
21 changes: 12 additions & 9 deletions torch_geometric/nn/glob/set2set.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Set2Set(torch.nn.Module):
node features :math:`(|\mathcal{V}|, F)`,
batch vector :math:`(|\mathcal{V}|)` *(optional)*
- **output:**
set features :math:`(|\mathcal{G}|, 2 * F)` where
graph features :math:`(|\mathcal{G}|, 2 * F)` where
:math:`|\mathcal{G}|` denotes the number of graphs in the batch
"""
def __init__(self, in_channels: int, processing_steps: int,
Expand All @@ -57,28 +57,31 @@ def __init__(self, in_channels: int, processing_steps: int,
def reset_parameters(self):
self.lstm.reset_parameters()

def forward(self, x: Tensor, batch: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, batch: Optional[Tensor] = None,
size: Optional[int] = None) -> Tensor:
r"""
Args:
x (Tensor): The input node features.
batch (LongTensor, optional): A vector that maps each node to its
respective graph identifier. (default: :obj:`None`)
size (int, optional): The number of graphs in the batch.
(default: :obj:`None`)
"""
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.int64)

batch_size = batch.max().item() + 1
size = int(batch.max()) + 1 if size is None else size

h = (x.new_zeros((self.num_layers, batch_size, self.in_channels)),
x.new_zeros((self.num_layers, batch_size, self.in_channels)))
q_star = x.new_zeros(batch_size, self.out_channels)
h = (x.new_zeros((self.num_layers, size, self.in_channels)),
x.new_zeros((self.num_layers, size, self.in_channels)))
q_star = x.new_zeros(size, self.out_channels)

for _ in range(self.processing_steps):
q, h = self.lstm(q_star.unsqueeze(0), h)
q = q.view(batch_size, self.in_channels)
q = q.view(size, self.in_channels)
e = (x * q.index_select(0, batch)).sum(dim=-1, keepdim=True)
a = softmax(e, batch, num_nodes=batch_size)
r = scatter_add(a * x, batch, dim=0, dim_size=batch_size)
a = softmax(e, batch, num_nodes=size)
r = scatter_add(a * x, batch, dim=0, dim_size=size)
q_star = torch.cat([q, r], dim=-1)

return q_star
Expand Down

0 comments on commit defa0e7

Please sign in to comment.