Skip to content

Commit

Permalink
Add comment in convolutional shift (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
tristandeleu committed Jun 27, 2016
1 parent e7afa62 commit 52408e5
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ntm/heads.py
Expand Up @@ -377,6 +377,8 @@ def get_weights(self, h_t, w_tm1, M_t, **kwargs):
w_g = g_t * w_c + (1. - g_t) * w_tm1

# Convolutional Shift (3.3.2)
# NOTE: This library is using a flat (zero-padded) convolution instead of the circular
# convolution from the original paper. In practice, this change has a minimal impact.
w_g_padded = w_g.reshape((h_t.shape[0] * num_heads, self.memory_shape[0])).dimshuffle(0, 'x', 'x', 1)
conv_filter = s_t.reshape((h_t.shape[0] * num_heads, self.num_shifts)).dimshuffle(0, 'x', 'x', 1)
pad = (self.num_shifts // 2, (self.num_shifts - 1) // 2)
Expand Down

0 comments on commit 52408e5

Please sign in to comment.