diff --git a/model.py b/model.py index ef66da0..34bf296 100644 --- a/model.py +++ b/model.py @@ -175,14 +175,14 @@ def __init__(self, input_dim, blocks_dim, n_classes, kernel_size=17, dropout_rat self.bn1 = nn.BatchNorm1d(n_filters_out) # Residual block layers - self.res_blocks = [] for i, (n_filters, n_samples) in enumerate(blocks_dim): n_filters_in, n_filters_out = n_filters_out, n_filters n_samples_in, n_samples_out = n_samples_out, n_samples downsample = _downsample(n_samples_in, n_samples_out) - resblk1d = ResBlock1d(n_filters_in, n_filters_out, downsample, kernel_size, dropout_rate) - self.add_module('resblock1d_{0}'.format(i), resblk1d) - self.res_blocks += [resblk1d] + setattr(self, f'resblock1d_{i}', ResBlock1d( + n_filters_in, n_filters_out, downsample, + kernel_size, dropout_rate + )) # Linear layer n_filters_last, n_samples_last = blocks_dim[-1] @@ -198,8 +198,8 @@ def forward(self, x): # Residual blocks y = x - for blk in self.res_blocks: - x, y = blk(x, y) + for i in range(self.n_blk): + x, y = getattr(self, f'resblock1d_{i}')(x, y) # Flatten array x = x.view(x.size(0), -1)