Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BasicGNN: Final Linear transform #4042

Merged
merged 7 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions test/graphgym/test_config_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_config_store():
assert cfg.dataset.transform.AddSelfLoops.fill_value is None

# Check `cfg.model`:
assert len(cfg.model) == 9
assert len(cfg.model) == 11
assert cfg.model._target_.split('.')[-1] == 'GCN'
assert cfg.model.in_channels == 34
assert cfg.model.out_channels == 4
Expand All @@ -48,7 +48,9 @@ def test_config_store():
assert cfg.model.dropout == 0.0
assert cfg.model.act == 'relu'
assert cfg.model.norm is None
assert cfg.model.jk == 'last'
assert cfg.model.jk is None
assert not cfg.model.act_first
assert cfg.model.act_kwargs is None

# Check `cfg.optim`:
assert len(cfg.optim) == 6
Expand Down
2 changes: 1 addition & 1 deletion test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
dropouts = [0.0, 0.5]
acts = [None, 'leaky_relu', torch.relu_, F.elu, ReLU()]
norms = [None, BatchNorm1d(16), LayerNorm(16)]
jks = ['last', 'cat', 'max', 'lstm']
jks = [None, 'last', 'cat', 'max', 'lstm']


@pytest.mark.parametrize('out_dim,dropout,act,norm,jk',
Expand Down
108 changes: 74 additions & 34 deletions torch_geometric/nn/models/basic_gnn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torch.nn.functional as F
Expand All @@ -12,12 +12,6 @@
from torch_geometric.nn.models.jumping_knowledge import JumpingKnowledge
from torch_geometric.typing import Adj

ACTS = {
'relu': torch.nn.ReLU(inplace=True),
'elu': torch.nn.ELU(inplace=True),
'leaky_relu': torch.nn.LeakyReLU(inplace=True),
}


class BasicGNN(torch.nn.Module):
r"""An abstract class for implementing basic GNN models.
Expand All @@ -34,33 +28,58 @@ class BasicGNN(torch.nn.Module):
use. (default: :obj:`"relu"`)
norm (torch.nn.Module, optional): The normalization operator to use.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode
(:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`).
(default: :obj:`"last"`)
jk (str, optional): The Jumping Knowledge mode. If specified, the model
will additionally apply a final linear transformation to transform
node embeddings to the expected output feature dimensionality.
(:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`,
:obj:`"lstm"`). (default: :obj:`None`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of the underlying
:class:`torch_geometric.nn.conv.MessagePassing` layers.
"""
def __init__(self, in_channels: int, hidden_channels: int, num_layers: int,
out_channels: Optional[int] = None, dropout: float = 0.0,
act: Union[str, Callable, None] = "relu",
norm: Optional[torch.nn.Module] = None, jk: str = "last",
**kwargs):
def __init__(
self,
in_channels: int,
hidden_channels: int,
num_layers: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
act: Union[str, Callable, None] = "relu",
norm: Optional[torch.nn.Module] = None,
jk: Optional[str] = None,
act_first: bool = False,
act_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__()

from class_resolver.contrib.torch import activation_resolver

self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.num_layers = num_layers

self.dropout = dropout
self.act = ACTS[act] if isinstance(act, str) else act
self.act = activation_resolver.make(act, act_kwargs)
self.jk_mode = jk
self.has_out_channels = out_channels is not None
self.act_first = act_first

if out_channels is not None:
self.out_channels = out_channels
else:
self.out_channels = hidden_channels

self.convs = ModuleList()
self.convs.append(
self.init_conv(in_channels, hidden_channels, **kwargs))
for _ in range(num_layers - 2):
self.convs.append(
self.init_conv(hidden_channels, hidden_channels, **kwargs))
if self.has_out_channels and self.jk_mode == 'last':
if out_channels is not None and jk is None:
self.convs.append(
self.init_conv(hidden_channels, out_channels, **kwargs))
else:
Expand All @@ -72,23 +91,18 @@ def __init__(self, in_channels: int, hidden_channels: int, num_layers: int,
self.norms = ModuleList()
for _ in range(num_layers - 1):
self.norms.append(copy.deepcopy(norm))
if not (self.has_out_channels and self.jk_mode == 'last'):
if jk is not None:
self.norms.append(copy.deepcopy(norm))

if self.jk_mode != 'last':
if jk is not None and jk != 'last':
self.jk = JumpingKnowledge(jk, hidden_channels, num_layers)

if self.has_out_channels:
self.out_channels = out_channels
if self.jk_mode == 'cat':
self.lin = Linear(num_layers * hidden_channels, out_channels)
elif self.jk_mode in {'max', 'lstm'}:
self.lin = Linear(hidden_channels, out_channels)
else:
self.out_channels = hidden_channels
if self.jk_mode == 'cat':
self.lin = Linear(num_layers * hidden_channels,
hidden_channels)
if jk is not None:
if jk == 'cat':
in_channels = num_layers * hidden_channels
else:
in_channels = hidden_channels
self.lin = Linear(in_channels, self.out_channels)

def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
Expand All @@ -109,12 +123,13 @@ def forward(self, x: Tensor, edge_index: Adj, *args, **kwargs) -> Tensor:
xs: List[Tensor] = []
for i in range(self.num_layers):
x = self.convs[i](x, edge_index, *args, **kwargs)
if (i == self.num_layers - 1 and self.has_out_channels
and self.jk_mode == 'last'):
if i == self.num_layers - 1 and self.jk_mode is None:
break
if self.act_first:
x = self.act(x)
if self.norms is not None:
x = self.norms[i](x)
if self.act is not None:
if not self.act_first:
x = self.act(x)
x = F.dropout(x, p=self.dropout, training=self.training)
if hasattr(self, 'jk'):
Expand Down Expand Up @@ -150,6 +165,11 @@ class GCN(BasicGNN):
jk (str, optional): The Jumping Knowledge mode
(:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`).
(default: :obj:`"last"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.GCNConv`.
"""
Expand Down Expand Up @@ -178,6 +198,11 @@ class GraphSAGE(BasicGNN):
jk (str, optional): The Jumping Knowledge mode
(:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`).
(default: :obj:`"last"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.SAGEConv`.
"""
Expand Down Expand Up @@ -206,6 +231,11 @@ class GIN(BasicGNN):
jk (str, optional): The Jumping Knowledge mode
(:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`).
(default: :obj:`"last"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.GINConv`.
"""
Expand Down Expand Up @@ -235,6 +265,11 @@ class GAT(BasicGNN):
jk (str, optional): The Jumping Knowledge mode
(:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`).
(default: :obj:`"last"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.GATConv`.
"""
Expand Down Expand Up @@ -271,6 +306,11 @@ class PNA(BasicGNN):
jk (str, optional): The Jumping Knowledge mode
(:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`).
(default: :obj:`"last"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.PNAConv`.
"""
Expand Down
4 changes: 3 additions & 1 deletion torch_geometric/nn/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
import torch.nn.functional as F
from class_resolver.contrib.torch import activation_resolver
from torch import Tensor
from torch.nn import BatchNorm1d, Identity

Expand Down Expand Up @@ -80,6 +79,9 @@ def __init__(
relu_first: bool = False,
):
super().__init__()

from class_resolver.contrib.torch import activation_resolver

act_first = act_first or relu_first # Backward compatibility.
batch_norm_kwargs = batch_norm_kwargs or {}

Expand Down