diff --git a/snli/model.py b/snli/model.py index e483724e8b..e51f778633 100644 --- a/snli/model.py +++ b/snli/model.py @@ -10,7 +10,7 @@ def forward(self, input): return super(Bottle, self).forward(input) size = input.size()[:2] out = super(Bottle, self).forward(input.view(size[0]*size[1], -1)) - return out.view(*size, -1) + return out.view(size[0], size[1], -1) class Linear(Bottle, nn.Linear):