diff --git a/topomodelx/nn/simplicial/snn_layer.py b/topomodelx/nn/simplicial/snn_layer.py index 8044da7a..28a8278e 100644 --- a/topomodelx/nn/simplicial/snn_layer.py +++ b/topomodelx/nn/simplicial/snn_layer.py @@ -36,11 +36,13 @@ def __init__(self, in_channels, out_channels, K): self.out_channels = out_channels self.K = K - self.convs = [ + convs = [ Conv(in_channels=in_channels, out_channels=out_channels, update_func="relu") for _ in range(self.K) ] + self.convs = torch.nn.ModuleList(convs) + self.aggr = Aggregation(aggr_func="sum", update_func="relu") def reset_parameters(self):