Skip to content
This repository has been archived by the owner on Feb 12, 2022. It is now read-only.

Commit

Permalink
Bugfix: QRNN(window=2) fails if batch size sequence length is 1 due t…
Browse files Browse the repository at this point in the history
…o empty tensor slicing
  • Loading branch information
Smerity committed Nov 25, 2017
1 parent 2a30bb6 commit 2ffbd32
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion torchqrnn/qrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ def forward(self, X, hidden=None):
source = X
elif self.window == 2:
# Construct the x_{t-1} tensor with optional x_{-1}, otherwise a zeroed out value for x_{-1}
Xm1 = [self.prevX if self.prevX is not None else X[:1, :, :] * 0, X[:-1, :, :]]
Xm1 = []
Xm1.append(self.prevX if self.prevX is not None else X[:1, :, :] * 0)
# Note: in case of len(X) == 1, X[:-1, :, :] results in slicing of empty tensor == bad
if len(X) > 1:
Xm1.append(X[:-1, :, :])
Xm1 = torch.cat(Xm1, 0)
# Convert two (seq_len, batch_size, hidden) tensors to (seq_len, batch_size, 2 * hidden)
source = torch.cat([X, Xm1], 2)
Expand Down

0 comments on commit 2ffbd32

Please sign in to comment.