diff --git a/tltorch/factorized_conv/_tucker_conv.py b/tltorch/factorized_conv/_tucker_conv.py index 9891d12..386f547 100644 --- a/tltorch/factorized_conv/_tucker_conv.py +++ b/tltorch/factorized_conv/_tucker_conv.py @@ -76,7 +76,6 @@ class TuckerConv(BaseFactorizedConv): def __init__(self, in_channels, out_channels, kernel_size, rank, modes_fixed_rank=None, order=None, implementation='reconstructed', stride=1, padding=0, dilation=1, bias=False): super().__init__(in_channels, out_channels, kernel_size, rank, order=order, padding=padding, stride=stride, bias=bias) - self.rank = validate_tucker_rank(self.kernel_shape, rank=self.rank, fixed_modes=modes_fixed_rank) if modes_fixed_rank is None: self.modes_fixed_rank = None @@ -85,6 +84,8 @@ def __init__(self, in_channels, out_channels, kernel_size, rank, modes_fixed_ran else: self.modes_fixed_rank = modes_fixed_rank + self.rank = validate_tucker_rank(self.kernel_shape, rank=self.rank, fixed_modes=self.modes_fixed_rank) + self.core = nn.Parameter(torch.Tensor(*self.rank)) self.factors = ParameterList(nn.Parameter(torch.Tensor(s, r))\ for (s, r) in zip(self.kernel_shape, self.rank))