diff --git a/src/model/RecurrentFWNetwork.lua b/src/model/RecurrentFWNetwork.lua index 3cea038..d15bfe1 100644 --- a/src/model/RecurrentFWNetwork.lua +++ b/src/model/RecurrentFWNetwork.lua @@ -15,6 +15,8 @@ return function(opt, params) local eps = eps or 1e-5 local outputType = opt.outputType or 'last' -- 'last' or 'all' local relu = nn.ReLU() + local mm = nn.MM(false, false) -- A * B + local mmT = nn.MM(false, true) -- A * B' -- container: params = params or {} @@ -34,7 +36,6 @@ return function(opt, params) x = torch.view(x, 1, torch.size(x, 1), torch.size(x, 2)) end local batch = torch.size(x, 1) - assert(batch == 1, 'Batch mode not yet supprted.') local steps = torch.size(x, 2) -- hiddens: @@ -44,7 +45,8 @@ return function(opt, params) local hp = prevState.h or torch.zero(x.new(batch, hiddenFeatures)) -- fast weights - local A = prevState.A or torch.zero(x.new(hiddenFeatures, hiddenFeatures)) + local A = prevState.A or + torch.zero(x.new(batch, hiddenFeatures, hiddenFeatures)) local hs = {} -- go over time: @@ -55,8 +57,11 @@ return function(opt, params) -- prev h hp = hs[t-1] or hp + -- vector to matrix + local hpMat = torch.view(hp, batch, -1, 1) + -- fast weights update - A = l * A + e * (torch.t(hp) * hp) + A = l * A + e * mmT{hpMat, hpMat} -- pack all dot products: local dot = torch.cat(xt, hp, 2) * p.W @@ -64,8 +69,13 @@ return function(opt, params) hs[t] = torch.zero(x.new(batch, hiddenFeatures)) for s = 0, S do + + -- vector to matrix + local hstMat = torch.view(hs[t], batch, -1, 1) + -- next h: - hs[t] = dot + hs[t] * A + hs[t] = dot + torch.view(mm{A, hstMat}, batch, -1) + if LayerNorm then local h = hs[t] if torch.nDimension(hs[t]) == 1 then @@ -78,7 +88,10 @@ return function(opt, params) torch.size(h)) hs[t] = torch.view(torch.cdiv(h, std), torch.size(hs[t])) end + + -- apply non-linearity hs[t] = relu(hs[t]) + end end