From 1c2ae32d67f7f7854542212b229cd95c85cf4026 Mon Sep 17 00:00:00 2001 From: Vinit Ravishankar Date: Thu, 31 Jan 2019 00:03:14 +0100 Subject: [PATCH] fix batch size (#462) --- torchtext/data/iterator.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchtext/data/iterator.py b/torchtext/data/iterator.py index 017287f7e6..b4923c2781 100644 --- a/torchtext/data/iterator.py +++ b/torchtext/data/iterator.py @@ -220,10 +220,15 @@ def __iter__(self): for i in range(0, len(self) * self.bptt_len, self.bptt_len): self.iterations += 1 seq_len = min(self.bptt_len, len(data) - i - 1) + batch_text = data[i:i + seq_len] + batch_target = data[i + 1:i + 1 + seq_len] + if TEXT.batch_first: + batch_text = batch_text.t().contiguous() + batch_target = batch_target.t().contiguous() yield Batch.fromvars( dataset, self.batch_size, - text=data[i:i + seq_len], - target=data[i + 1:i + 1 + seq_len]) + text=batch_text, + target=batch_target) if not self.repeat: return