Skip to content
This repository has been archived by the owner on Nov 1, 2021. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
Add batch support via nn.MM()
  • Loading branch information
Atcold committed Jan 12, 2017
1 parent 816c70c commit a5e586a
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions src/model/RecurrentFWNetwork.lua
Expand Up @@ -15,6 +15,8 @@ return function(opt, params)
local eps = eps or 1e-5
local outputType = opt.outputType or 'last' -- 'last' or 'all'
local relu = nn.ReLU()
local mm = nn.MM(false, false) -- A * B
local mmT = nn.MM(false, true) -- A * B'

-- container:
params = params or {}
Expand All @@ -34,7 +36,6 @@ return function(opt, params)
x = torch.view(x, 1, torch.size(x, 1), torch.size(x, 2))
end
local batch = torch.size(x, 1)
assert(batch == 1, 'Batch mode not yet supprted.')
local steps = torch.size(x, 2)

-- hiddens:
Expand All @@ -44,7 +45,8 @@ return function(opt, params)
local hp = prevState.h or torch.zero(x.new(batch, hiddenFeatures))

-- fast weights
local A = prevState.A or torch.zero(x.new(hiddenFeatures, hiddenFeatures))
local A = prevState.A or
torch.zero(x.new(batch, hiddenFeatures, hiddenFeatures))

local hs = {}
-- go over time:
Expand All @@ -55,17 +57,25 @@ return function(opt, params)
-- prev h
hp = hs[t-1] or hp

-- vector to matrix
local hpMat = torch.view(hp, batch, -1, 1)

-- fast weights update
A = l * A + e * (torch.t(hp) * hp)
A = l * A + e * mmT{hpMat, hpMat}

-- pack all dot products:
local dot = torch.cat(xt, hp, 2) * p.W
+ torch.expand(p.b, batch, hiddenFeatures)

hs[t] = torch.zero(x.new(batch, hiddenFeatures))
for s = 0, S do

-- vector to matrix
local hstMat = torch.view(hs[t], batch, -1, 1)

-- next h:
hs[t] = dot + hs[t] * A
hs[t] = dot + torch.view(mm{A, hstMat}, batch, -1)

if LayerNorm then
local h = hs[t]
if torch.nDimension(hs[t]) == 1 then
Expand All @@ -78,7 +88,10 @@ return function(opt, params)
torch.size(h))
hs[t] = torch.view(torch.cdiv(h, std), torch.size(hs[t]))
end

-- apply non-linearity
hs[t] = relu(hs[t])

end
end

Expand Down

0 comments on commit a5e586a

Please sign in to comment.