diff --git a/torchqrnn/qrnn.py b/torchqrnn/qrnn.py index 71f3cda..ca5a96b 100644 --- a/torchqrnn/qrnn.py +++ b/torchqrnn/qrnn.py @@ -35,7 +35,7 @@ def __init__(self, input_size, hidden_size=None, save_prev_x=False, zoneout=0, w assert window in [1, 2], "This QRNN implementation currently only handles convolutional window of size 1 or size 2" self.window = window self.input_size = input_size - self.hidden_size = hidden_size if hidden_size else hidden_size + self.hidden_size = hidden_size if hidden_size else input_size self.zoneout = zoneout self.save_prev_x = save_prev_x self.prevX = None