Skip to content

Commit

Permalink
tucker_conv: fix a bug with modes_fixed_rank
Browse files Browse the repository at this point in the history
  • Loading branch information
Wonmin Byeon committed Feb 18, 2021
1 parent b6cc797 commit 0b9e893
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tltorch/factorized_conv/_tucker_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class TuckerConv(BaseFactorizedConv):
def __init__(self, in_channels, out_channels, kernel_size, rank, modes_fixed_rank=None, order=None,
implementation='reconstructed', stride=1, padding=0, dilation=1, bias=False):
super().__init__(in_channels, out_channels, kernel_size, rank, order=order, padding=padding, stride=stride, bias=bias)
self.rank = validate_tucker_rank(self.kernel_shape, rank=self.rank, fixed_modes=modes_fixed_rank)

if modes_fixed_rank is None:
self.modes_fixed_rank = None
Expand All @@ -84,6 +83,8 @@ def __init__(self, in_channels, out_channels, kernel_size, rank, modes_fixed_ran
else:
self.modes_fixed_rank = modes_fixed_rank

self.rank = validate_tucker_rank(self.kernel_shape, rank=self.rank, fixed_modes=self.modes_fixed_rank)

self.core = nn.Parameter(torch.Tensor(*self.rank))
self.factors = nn.ParameterList(nn.Parameter(torch.Tensor(s, r))\
for (s, r) in zip(self.kernel_shape, self.rank))
Expand Down

0 comments on commit 0b9e893

Please sign in to comment.