From 5382ee0ff330c7fddd2ba7ff0b3bea8d515dfa50 Mon Sep 17 00:00:00 2001 From: Hellsegga Date: Tue, 27 Jun 2023 13:11:41 +0200 Subject: [PATCH] Using torch.nn.ModuleList --- topomodelx/nn/simplicial/snn_layer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/topomodelx/nn/simplicial/snn_layer.py b/topomodelx/nn/simplicial/snn_layer.py index 8044da7a..28a8278e 100644 --- a/topomodelx/nn/simplicial/snn_layer.py +++ b/topomodelx/nn/simplicial/snn_layer.py @@ -36,11 +36,13 @@ def __init__(self, in_channels, out_channels, K): self.out_channels = out_channels self.K = K - self.convs = [ + convs = [ Conv(in_channels=in_channels, out_channels=out_channels, update_func="relu") for _ in range(self.K) ] + self.convs = torch.nn.ModuleList(convs) + self.aggr = Aggregation(aggr_func="sum", update_func="relu") def reset_parameters(self):