Skip to content

Commit

Permalink
FIX embeddings: fix for non-contiguous inputs + typo
Browse files Browse the repository at this point in the history
  • Loading branch information
JeanKossaifi committed Mar 8, 2023
1 parent 923dc54 commit 5f4d0e5
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tltorch/factorized_layers/factorized_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,20 @@ def forward(self, input, indices=0):
#to handle case where input is not 1-D
output_shape = (*input.shape, self.embedding_dim)

flatenned_input = input.view(-1)
flattened_input = input.reshape(-1)

if self.n_layers == 1:
if indices == 0:
embeddings = self.weight[flatenned_input, :]
embeddings = self.weight[flattened_input, :]
else:
embeddings = self.weight[indices, flatenned_input, :]
embeddings = self.weight[indices, flattened_input, :]

#CPTensorized returns CPTensorized when indexing
if self.factorization.lower() == 'cp':
embeddings = embeddings.to_matrix()

#TuckerTensorized returns tensor not matrix,
#and requires reshape not view for contiguous
# and requires reshape not view for contiguous
elif self.factorization.lower() == 'tucker':
embeddings = embeddings.reshape(input.shape[0], -1)

Expand Down

0 comments on commit 5f4d0e5

Please sign in to comment.