Skip to content

Commit

Permalink
added temporary file with 1st-step-untied ConvLSTM model
Browse files Browse the repository at this point in the history
  • Loading branch information
viorik committed Apr 7, 2016
1 parent 221bfcb commit a7e0699
Show file tree
Hide file tree
Showing 15 changed files with 285 additions and 0 deletions.
Empty file modified BilinearSamplerBHWD.lua 100644 → 100755
Empty file.
Empty file modified ConvLSTM.lua 100644 → 100755
Empty file.
Empty file modified DenseTransformer2D.lua 100644 → 100755
Empty file.
Empty file modified README.md 100644 → 100755
Empty file.
Empty file modified SmoothHuberPenalty.lua 100644 → 100755
Empty file.
Empty file modified data-mnist.lua 100644 → 100755
Empty file.
Empty file modified decoder.lua 100644 → 100755
Empty file.
Empty file modified display_flow.lua 100644 → 100755
Empty file.
Empty file modified encoder.lua 100644 → 100755
Empty file.
Empty file modified flow.lua 100644 → 100755
Empty file.
Empty file modified main-mnist.lua 100644 → 100755
Empty file.
Empty file modified model.lua 100644 → 100755
Empty file.
Empty file modified opts-mnist.lua 100644 → 100755
Empty file.
285 changes: 285 additions & 0 deletions tmpConvLSTM.lua
@@ -0,0 +1,285 @@
--[[
Convolutional LSTM for short term visual cell
inputSize - number of input feature planes
outputSize - number of output feature planes
rho - recurrent sequence length
kc - convolutional filter size to convolve input
km - convolutional filter size to convolve cell; usually km > kc
First step is untied.
--]]
local _ = require 'moses'
require 'nn'
require 'dpnn'
require 'rnn'
require 'extracunn'
local ConvLSTM, parent = torch.class('nn.ConvLSTM', 'nn.LSTM')
function ConvLSTM:__init(inputSize, outputSize, rho, kc, km, stride, batchSize)
self.kc = kc
self.km = km
self.padc = torch.floor(kc/2)
self.padm = torch.floor(km/2)
self.stride = stride or 1
self.batchSize = batchSize or nil
parent.__init(self, inputSize, outputSize, rho or 10)
self.untiedModule = self:buildModelUntied()
end
-------------------------- factory methods -----------------------------
function ConvLSTM:buildGate()
-- Note : Input is : {input(t), output(t-1), cell(t-1)}
local gate = nn.Sequential()
gate:add(nn.NarrowTable(1,2)) -- we don't need cell here
local input2gate = nn.SpatialConvolution(self.inputSize, self.outputSize, self.kc, self.kc, self.stride, self.stride, self.padc, self.padc)
local output2gate = nn.SpatialConvolutionNoBias(self.outputSize, self.outputSize, self.km, self.km, self.stride, self.stride, self.padm, self.padm)
local para = nn.ParallelTable()
para:add(input2gate):add(output2gate)
gate:add(para)
gate:add(nn.CAddTable())
gate:add(nn.Sigmoid())
return gate
end
function ConvLSTM:buildInputGate()
self.inputGate = self:buildGate()
return self.inputGate
end
function ConvLSTM:buildForgetGate()
self.forgetGate = self:buildGate()
return self.forgetGate
end
function ConvLSTM:buildCellGate()
-- Input is : {input(t), output(t-1), cell(t-1)}, but we only need {input(t), output(t-1)}
local hidden = nn.Sequential()
hidden:add(nn.NarrowTable(1,2))
local input2gate = nn.SpatialConvolution(self.inputSize, self.outputSize, self.kc, self.kc, self.stride, self.stride, self.padc, self.padc)
local output2gate = nn.SpatialConvolutionNoBias(self.outputSize, self.outputSize, self.km, self.km, self.stride, self.stride, self.padm, self.padm)
local para = nn.ParallelTable()
para:add(input2gate):add(output2gate)
hidden:add(para)
hidden:add(nn.CAddTable())
hidden:add(nn.Tanh())
self.cellGate = hidden
return hidden
end
function ConvLSTM:buildCell()
-- Input is : {input(t), output(t-1), cell(t-1)}
self.inputGate = self:buildInputGate()
self.forgetGate = self:buildForgetGate()
self.cellGate = self:buildCellGate()
-- forget = forgetGate{input, output(t-1), cell(t-1)} * cell(t-1)
local forget = nn.Sequential()
local concat = nn.ConcatTable()
concat:add(self.forgetGate):add(nn.SelectTable(3))
forget:add(concat)
forget:add(nn.CMulTable())
-- input = inputGate{input(t), output(t-1), cell(t-1)} * cellGate{input(t), output(t-1), cell(t-1)}
local input = nn.Sequential()
local concat2 = nn.ConcatTable()
concat2:add(self.inputGate):add(self.cellGate)
input:add(concat2)
input:add(nn.CMulTable())
-- cell(t) = forget + input
local cell = nn.Sequential()
local concat3 = nn.ConcatTable()
concat3:add(forget):add(input)
cell:add(concat3)
cell:add(nn.CAddTable())
self.cell = cell
return cell
end
function ConvLSTM:buildOutputGate()
self.outputGate = self:buildGate()
return self.outputGate
end
-- cell(t) = cell{input, output(t-1), cell(t-1)}
-- output(t) = outputGate{input, output(t-1)}*tanh(cell(t))
-- output of Model is table : {output(t), cell(t)}
function ConvLSTM:buildModel()
-- Input is : {input(t), output(t-1), cell(t-1)}
self.cell = self:buildCell()
self.outputGate = self:buildOutputGate()
-- assemble
local concat = nn.ConcatTable()
concat:add(nn.NarrowTable(1,2)):add(self.cell)
local model = nn.Sequential()
model:add(concat)
-- output of concat is {{input(t), output(t-1)}, cell(t)},
-- so flatten to {input(t), output(t-1), cell(t)}
model:add(nn.FlattenTable())
local cellAct = nn.Sequential()
cellAct:add(nn.SelectTable(3))
cellAct:add(nn.Tanh())
local concat3 = nn.ConcatTable()
concat3:add(self.outputGate):add(cellAct)
local output = nn.Sequential()
output:add(concat3)
output:add(nn.CMulTable())
-- we want the model to output : {output(t), cell(t)}
local concat4 = nn.ConcatTable()
concat4:add(output):add(nn.SelectTable(3))
model:add(concat4)
return model
end
function ConvLSTM:buildGateUntied()
-- Note : Input is : input(t)
local gate = nn.Sequential()
gate:add(nn.SpatialConvolution(self.inputSize, self.outputSize, self.kc, self.kc, self.stride, self.stride, self.padc, self.padc))
gate:add(nn.Sigmoid())
return gate
end
function ConvLSTM:buildCellGateUntied()
local cellGate = nn.Sequential()
cellGate:add(nn.SpatialConvolution(self.inputSize, self.outputSize, self.kc, self.kc, self.stride, self.stride, self.padc, self.padc))
cellGate:add(nn.Tanh())
self.cellGateUntied = cellGate
return cellGate
end
function ConvLSTM:buildModelUntied()
-- Input is : input(t)
local model = nn.Sequential()
self.inputGateUntied = self:buildGateUntied()
self.cellGateUntied = self:buildCellGateUntied()
self.outputGateUntied = self:buildGateUntied()
local concat = nn.ConcatTable()
concat:add(self.inputGateUntied):add(self.cellGateUntied):add(self.outputGateUntied)
model:add(concat)
local cellAct = nn.Sequential()
cellAct:add(nn.NarrowTable(1,2))
cellAct:add(nn.CMulTable())
local concat2 = nn.ConcatTable()
concat2:add(cellAct):add(nn.SelectTable(3))
model:add(concat2)
local tanhcell = nn.Sequential()
tanhcell:add(nn.SelectTable(1)):add(nn.Tanh())
local concat3 = nn.ConcatTable()
concat3:add(nn.SelectTable(2)):add(tanhcell):add(nn.SelectTable(1))
model:add(concat3)
model:add(nn.FlattenTable())
local output = nn.Sequential()
output:add(nn.NarrowTable(1,2))
output:add(nn.CMulTable())
local concat4 = nn.ConcatTable()
concat4:add(output):add(nn.SelectTable(3))
model:add(concat4)
return model
end
function ConvLSTM:updateOutput(input)
local prevOutput, prevCell
-- output(t), cell(t) = lstm{input(t), output(t-1), cell(t-1)}
local output, cell
if self.step == 1 then
if self.batchSize then
self.zeroTensor:resize(self.batchSize,self.outputSize,input:size(3),input:size(4)):zero()
else
self.zeroTensor:resize(self.outputSize,input:size(2),input:size(3)):zero()
end
output, cell = unpack(self.untiedModule:updateOutput(input))
else
-- previous output and memory of this module
prevOutput = self.outputs[self.step-1]
prevCell = self.cells[self.step-1]
if self.train ~= false then
self:recycle()
local recurrentModule = self:getStepModule(self.step)
-- the actual forward propagation
output, cell = unpack(recurrentModule:updateOutput{input, prevOutput, prevCell})
else
output, cell = unpack(self.recurrentModule:updateOutput{input, prevOutput, prevCell})
end
end
self.outputs[self.step] = output
self.cells[self.step] = cell
self.output = output
self.cell = cell
self.step = self.step + 1
self.gradPrevOutput = nil
self.updateGradInputStep = nil
self.accGradParametersStep = nil
self.gradParametersAccumulated = false
-- note that we don't return the cell, just the output
return self.output
end
function ConvLSTM:_updateGradInput(input, gradOutput)
assert(self.step > 1, "expecting at least one updateOutput")
local step = self.updateGradInputStep - 1
assert(step >= 1)
-- set the output/gradOutput states of current Module
if self.gradPrevOutput then
self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], self.gradPrevOutput)
nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)
gradOutput = self._gradOutputs[step]
end
local gradInput
local gradInputTable
local gradCell = (step == self.step-1) and (self.userNextGradCell or self.zeroTensor) or self.gradCells[step]
if step == 1 then
gradInput = self.untiedModule:updateGradInput(input, {gradOutput, gradCell})
else
local recurrentModule = self:getStepModule(step)
local output = self.outputs[step-1]
local cell = self.cells[step-1]
local inputTable = {input, output, cell}
-- backward propagate through this step
gradInputTable = recurrentModule:updateGradInput(inputTable, {gradOutput, gradCell})
gradInput, self.gradPrevOutput, gradCell = unpack(gradInputTable)
end
self.gradCells[step-1] = gradCell
if self.userPrevOutput then self.userGradPrevOutput = self.gradPrevOutput end
if self.userPrevCell then self.userGradPrevCell = gradCell end
return gradInput
end
function ConvLSTM:_accGradParameters(input, gradOutput, scale)
local step = self.accGradParametersStep - 1
assert(step >= 1)
-- set the output/gradOutput states of current Module
gradOutput = (step == self.step-1) and gradOutput or self._gradOutputs[step]
gradCell = (step == self.step-1) and (self.userNextGradCell or self.zeroTensor) or self.gradCells[step]
gradOutputTable = {gradOutput, gradCell}
if step == 1 then
self.untiedModule:accGradParameters(input, gradOutputTable,scale)
else
local recurrentModule = self:getStepModule(step)
local output = self.outputs[step-1]
local cell = self.cells[step-1]
local inputTable = {input, output, cell}
recurrentModule:accGradParameters(inputTable, gradOutputTable,scale)
end
end
function ConvLSTM:initBias(forgetBias, otherBias)
local fBias = forgetBias or 1
local oBias = otherBias or 0
self.inputGate.modules[2].modules[1].bias:fill(oBias)
self.outputGate.modules[2].modules[1].bias:fill(oBias)
self.cellGate.modules[2].modules[1].bias:fill(oBias)
self.forgetGate.modules[2].modules[1].bias:fill(fBias)
self.inputGateUntied.modules[1].bias:fill(oBias)
self.outputGateUntied.modules[1].bias:fill(oBias)
self.cellGateUntied.modules[1].bias:fill(oBias)
end
Empty file modified weight-init.lua 100644 → 100755
Empty file.

0 comments on commit a7e0699

Please sign in to comment.