/
hsn_layer.py
123 lines (102 loc) · 4.8 KB
/
hsn_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""High Skip Network (HSN) Layer."""
import torch
from topomodelx.base.aggregation import Aggregation
from topomodelx.base.conv import Conv
class HSNLayer(torch.nn.Module):
"""Layer of a High Skip Network (HSN).
Implementation of the HSN layer proposed in [HRGZ22]_.
Notes
-----
This is the architecture proposed for node classification on simplicial complices.
References
----------
.. [HRGZ22] Hajij, Ramamurthy, Guzmán-Sáenz, Zamzmi.
High Skip Networks: A Higher Order Generalization of Skip Connections.
Geometrical and Topological Representation Learning Workshop at ICLR 2022.
https://openreview.net/pdf?id=Sc8glB-k6e9
Parameters
----------
channels : int
Dimension of features on each simplicial cell.
initialization : string
Initialization method.
"""
def __init__(
self,
channels,
):
super().__init__()
self.channels = channels
self.conv_level1_0_to_0 = Conv(
in_channels=channels,
out_channels=channels,
update_func="sigmoid",
)
self.conv_level1_0_to_1 = Conv(
in_channels=channels,
out_channels=channels,
update_func="sigmoid",
)
self.conv_level2_0_to_0 = Conv(
in_channels=channels,
out_channels=channels,
update_func=None,
)
self.conv_level2_1_to_0 = Conv(
in_channels=channels,
out_channels=channels,
update_func=None,
)
self.aggr_on_nodes = Aggregation(aggr_func="sum", update_func="sigmoid")
def reset_parameters(self):
r"""Reset learnable parameters."""
self.conv_level1_0_to_0.reset_parameters()
self.conv_level1_0_to_1.reset_parameters()
self.conv_level2_0_to_0.reset_parameters()
self.conv_level2_1_to_0.reset_parameters()
def forward(self, x_0, incidence_1, adjacency_0):
r"""Forward pass.
The forward pass was initially proposed in [HRGZ22]_.
Its equations are given in [TNN23]_ and graphically illustrated in [PSHM23]_.
.. math::
\begin{align*}
&🟥 \quad m_{{y \rightarrow z}}^{(0 \rightarrow 0)} = \sigma ((A_{\uparrow,0})_{xy} \cdot h^{t,(0)}_y \cdot \Theta^{t,(0)1})\\
&🟥 \quad m_{z \rightarrow x}^{(0 \rightarrow 0)} = (A_{\uparrow,0})_{xy} \cdot m_{y \rightarrow z}^{(0 \rightarrow 0)} \cdot \Theta^{t,(0)2}\\
&🟥 \quad m_{{y \rightarrow z}}^{(0 \rightarrow 1)} = \sigma((B_1^T)_{zy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(0 \rightarrow 1)})\\
&🟥 \quad m_{z \rightarrow x)}^{(1 \rightarrow 0)} = (B_1)_{xz} \cdot m_{z \rightarrow x}^{(0 \rightarrow 1)} \cdot \Theta^{t, (1 \rightarrow 0)}\\
&🟧 \quad m_{x}^{(0 \rightarrow 0)} = \sum_{z \in \mathcal{L}_\uparrow(x)} m_{z \rightarrow x}^{(0 \rightarrow 0)}\\
&🟧 \quad m_{x}^{(1 \rightarrow 0)} = \sum_{z \in \mathcal{C}(x)} m_{z \rightarrow x}^{(1 \rightarrow 0)}\\
&🟩 \quad m_x^{(0)} = m_x^{(0 \rightarrow 0)} + m_x^{(1 \rightarrow 0)}\\
&🟦 \quad h_x^{t+1,(0)} = I(m_x^{(0)})
\end{align*}
References
----------
.. [HRGZ22] Hajij, Ramamurthy, Guzmán-Sáenz, Zamzmi.
High Skip Networks: A Higher Order Generalization of Skip Connections.
Geometrical and Topological Representation Learning Workshop at ICLR 2022.
https://openreview.net/pdf?id=Sc8glB-k6e9
.. [TNN23] Equations of Topological Neural Networks.
https://github.com/awesome-tnns/awesome-tnns/
.. [PSHM23] Papillon, Sanborn, Hajij, Miolane.
Architectures of Topological Deep Learning: A Survey on Topological Neural Networks.
(2023) https://arxiv.org/abs/2304.10031.
Parameters
----------
x: torch.Tensor, shape=[n_nodes, channels]
Input features on the nodes of the simplicial complex.
incidence_1 : torch.sparse, shape=[n_nodes, n_edges]
Incidence matrix :math:`B_1` mapping edges to nodes.
adjacency_0 : torch.sparse, shape=[n_nodes, n_nodes]
Adjacency matrix :math:`A_0^{\uparrow}` mapping nodes to nodes via edges.
Returns
-------
_ : torch.Tensor, shape=[n_nodes, channels]
Output features on the nodes of the simplicial complex.
"""
incidence_1_transpose = incidence_1.transpose(1, 0)
x_0_level1 = self.conv_level1_0_to_0(x_0, adjacency_0)
x_1_level1 = self.conv_level1_0_to_1(x_0, incidence_1_transpose)
x_0_level2 = self.conv_level2_0_to_0(x_0_level1, adjacency_0)
x_1_level2 = self.conv_level2_1_to_0(x_1_level1, incidence_1)
x_0 = self.aggr_on_nodes([x_0_level2, x_1_level2])
return x_0