Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Doc fixes to GlobalAttention #3946

Merged
merged 5 commits into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions torch_geometric/nn/glob/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Optional

import torch
from torch import Tensor
from torch_scatter import scatter_add

from torch_geometric.utils import softmax
Expand All @@ -17,7 +20,7 @@ class GlobalAttention(torch.nn.Module):

where :math:`h_{\mathrm{gate}} \colon \mathbb{R}^F \to
\mathbb{R}` and :math:`h_{\mathbf{\Theta}}` denote neural networks, *i.e.*
MLPS.
MLPs.

Args:
gate_nn (torch.nn.Module): A neural network :math:`h_{\mathrm{gate}}`
Expand All @@ -29,8 +32,17 @@ class GlobalAttention(torch.nn.Module):
shape :obj:`[-1, in_channels]` to shape :obj:`[-1, out_channels]`
before combining them with the attention scores, *e.g.*, defined by
:class:`torch.nn.Sequential`. (default: :obj:`None`)

Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F)`,
batch vector :math:`(|\mathcal{V}|)` *(optional)*
- **output:**
graph features :math:`(|\mathcal{G}|, 2 * F)` where
:math:`|\mathcal{G}|` denotes the number of graphs in the batch
"""
def __init__(self, gate_nn, nn=None):
def __init__(self, gate_nn: torch.nn.Module,
nn: Optional[torch.nn.Module] = None):
super().__init__()
self.gate_nn = gate_nn
self.nn = nn
Expand All @@ -41,10 +53,21 @@ def reset_parameters(self):
reset(self.gate_nn)
reset(self.nn)

def forward(self, x, batch, size=None):
""""""
def forward(self, x: Tensor, batch: Optional[Tensor] = None,
size: Optional[int] = None) -> Tensor:
r"""
Args:
x (Tensor): The input node features.
batch (LongTensor, optional): A vector that maps each node to its
respective graph identifier. (default: :obj:`None`)
size (int, optional): The number of graphs in the batch.
(default: :obj:`None`)
"""
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.int64)

x = x.unsqueeze(-1) if x.dim() == 1 else x
size = batch[-1].item() + 1 if size is None else size
size = int(batch.max()) + 1 if size is None else size

gate = self.gate_nn(x).view(-1, 1)
x = self.nn(x) if self.nn is not None else x
Expand Down
21 changes: 12 additions & 9 deletions torch_geometric/nn/glob/set2set.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Set2Set(torch.nn.Module):
node features :math:`(|\mathcal{V}|, F)`,
batch vector :math:`(|\mathcal{V}|)` *(optional)*
- **output:**
set features :math:`(|\mathcal{G}|, 2 * F)` where
graph features :math:`(|\mathcal{G}|, 2 * F)` where
:math:`|\mathcal{G}|` denotes the number of graphs in the batch
"""
def __init__(self, in_channels: int, processing_steps: int,
Expand All @@ -57,28 +57,31 @@ def __init__(self, in_channels: int, processing_steps: int,
def reset_parameters(self):
self.lstm.reset_parameters()

def forward(self, x: Tensor, batch: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, batch: Optional[Tensor] = None,
size: Optional[int] = None) -> Tensor:
r"""
Args:
x (Tensor): The input node features.
batch (LongTensor, optional): A vector that maps each node to its
respective graph identifier. (default: :obj:`None`)
size (int, optional): The number of graphs in the batch.
(default: :obj:`None`)
"""
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.int64)

batch_size = batch.max().item() + 1
size = int(batch.max()) + 1 if size is None else size

h = (x.new_zeros((self.num_layers, batch_size, self.in_channels)),
x.new_zeros((self.num_layers, batch_size, self.in_channels)))
q_star = x.new_zeros(batch_size, self.out_channels)
h = (x.new_zeros((self.num_layers, size, self.in_channels)),
x.new_zeros((self.num_layers, size, self.in_channels)))
q_star = x.new_zeros(size, self.out_channels)

for _ in range(self.processing_steps):
q, h = self.lstm(q_star.unsqueeze(0), h)
q = q.view(batch_size, self.in_channels)
q = q.view(size, self.in_channels)
e = (x * q.index_select(0, batch)).sum(dim=-1, keepdim=True)
a = softmax(e, batch, num_nodes=batch_size)
r = scatter_add(a * x, batch, dim=0, dim_size=batch_size)
a = softmax(e, batch, num_nodes=size)
r = scatter_add(a * x, batch, dim=0, dim_size=size)
q_star = torch.cat([q, r], dim=-1)

return q_star
Expand Down