Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/tensorly/torch into main
Browse files Browse the repository at this point in the history
  • Loading branch information
JeanKossaifi committed Mar 4, 2021
2 parents f26c79e + 0b9e893 commit 38d2614
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tltorch/factorized_conv/_tucker_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class TuckerConv(BaseFactorizedConv):
def __init__(self, in_channels, out_channels, kernel_size, rank, modes_fixed_rank=None, order=None,
implementation='reconstructed', stride=1, padding=0, dilation=1, bias=False):
super().__init__(in_channels, out_channels, kernel_size, rank, order=order, padding=padding, stride=stride, bias=bias)
self.rank = validate_tucker_rank(self.kernel_shape, rank=self.rank, fixed_modes=modes_fixed_rank)

if modes_fixed_rank is None:
self.modes_fixed_rank = None
Expand All @@ -85,6 +84,8 @@ def __init__(self, in_channels, out_channels, kernel_size, rank, modes_fixed_ran
else:
self.modes_fixed_rank = modes_fixed_rank

self.rank = validate_tucker_rank(self.kernel_shape, rank=self.rank, fixed_modes=self.modes_fixed_rank)

self.core = nn.Parameter(torch.Tensor(*self.rank))
self.factors = ParameterList(nn.Parameter(torch.Tensor(s, r))\
for (s, r) in zip(self.kernel_shape, self.rank))
Expand Down

0 comments on commit 38d2614

Please sign in to comment.