In [23]:
import torch
import torch_geometric.nn as nn
import torch.nn.functional as F
from torch.nn import ModuleList
from torch_geometric.nn import GCNConv, GraphUNet, LayerNorm, Linear
from typing import Union, Callable
from torch_geometric.typing import OptTensor

In [24]:
class GCNBlock(torch.nn.Module):
    def __init__(
        self,
        in_channels : int,
        out_channels : int,
        n_layers : int,
        sum_res : bool = True,
        act: Union[str, Callable] = F.relu
    ) -> None:
        super().__init__()
        assert n_layers > 1
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.n_layers = n_layers
        self.sum_res = sum_res
        self.act = act

        self.layers = ModuleList()
        self.layers.append(
            GCNConv(self.in_channels, self.out_channels, dtype=torch.float32)
        )
        for _ in range(self.n_layers - 1):
            self.layers.append(
                GCNConv(self.out_channels, self.out_channels, dtype=torch.float32)
            )
    
    def reset_parameters(self):
        for i in range(self.n_layers):
            self.layers[i].reset_parameters()
        
    def forward(self, x, edge_index, batch : OptTensor = None):
        x_temp = self.layers[0](x, edge_index)
        x = x_temp
        for i in range(1, self.n_layers):
            x = self.layers[i](x, edge_index)
            x = self.act(x)
        if self.sum_res:
            x = x + x_temp
        return x

In [27]:
net = GCNBlock(in_channels=3, out_channels=2, n_layers=3)
x = torch.ones(size=(5,3)).float()
edge_index = torch.tensor([[0,1,1,2],[1,2,3,4]]).type(torch.LongTensor)
batch = torch.tensor([0,0,1,1,1]).type(torch.LongTensor)
net(x, edge_index)

tensor([[0.4625, 1.4834],
        [0.5583, 1.7906],
        [0.4625, 1.4834],
        [0.4625, 1.4834],
        [0.4625, 1.4834]], grad_fn=<AddBackward0>)