/
nn_conv.py
126 lines (102 loc) · 4.65 KB
/
nn_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from typing import Callable, Tuple, Union
import torch
from torch import Tensor
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import reset, zeros
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
class NNConv(MessagePassing):
r"""The continuous kernel-based convolutional operator from the
`"Neural Message Passing for Quantum Chemistry"
<https://arxiv.org/abs/1704.01212>`_ paper.
This convolution is also known as the edge-conditioned convolution from the
`"Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on
Graphs" <https://arxiv.org/abs/1704.02901>`_ paper (see
:class:`torch_geometric.nn.conv.ECConv` for an alias):
.. math::
\mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot
h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),
where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.*
a MLP.
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
out_channels (int): Size of each output sample.
nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
maps edge features :obj:`edge_attr` of shape :obj:`[-1,
num_edge_features]` to shape
:obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by
:class:`torch.nn.Sequential`.
aggr (str, optional): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
(default: :obj:`"add"`)
root_weight (bool, optional): If set to :obj:`False`, the layer will
not add the transformed root node features to the output.
(default: :obj:`True`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})` or
:math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))`
if bipartite,
edge indices :math:`(2, |\mathcal{E}|)`,
edge features :math:`(|\mathcal{E}|, D)` *(optional)*
- **output:** node features :math:`(|\mathcal{V}|, F_{out})` or
:math:`(|\mathcal{V}_t|, F_{out})` if bipartite
"""
def __init__(self, in_channels: Union[int, Tuple[int, int]],
out_channels: int, nn: Callable, aggr: str = 'add',
root_weight: bool = True, bias: bool = True, **kwargs):
super().__init__(aggr=aggr, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.nn = nn
self.root_weight = root_weight
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
self.in_channels_l = in_channels[0]
if root_weight:
self.lin = Linear(in_channels[1], out_channels, bias=False,
weight_initializer='uniform')
if bias:
self.bias = Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
super().reset_parameters()
reset(self.nn)
if self.root_weight:
self.lin.reset_parameters()
zeros(self.bias)
def forward(
self,
x: Union[Tensor, OptPairTensor],
edge_index: Adj,
edge_attr: OptTensor = None,
size: Size = None,
) -> Tensor:
if isinstance(x, Tensor):
x = (x, x)
# propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
x_r = x[1]
if x_r is not None and self.root_weight:
out = out + self.lin(x_r)
if self.bias is not None:
out = out + self.bias
return out
def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
weight = self.nn(edge_attr)
weight = weight.view(-1, self.in_channels_l, self.out_channels)
return torch.matmul(x_j.unsqueeze(1), weight).squeeze(1)
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, aggr={self.aggr}, nn={self.nn})')