Skip to content

Commit

Permalink
corrected bug in rmspropconfig
Browse files Browse the repository at this point in the history
  • Loading branch information
viorik committed Apr 26, 2016
1 parent bf7fdae commit 40b8241
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 180 deletions.
5 changes: 4 additions & 1 deletion BilinearSamplerBHWD.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
local BilinearSamplerBHWD, parent = torch.class('nn.BilinearSamplerBHWD', 'nn.Module')
assert(nn.BilinearSamplerBHWD, "stnbhwd package not preloaded")

-- we overwrite the module of the same name found in the stnbhwd package
local BilinearSamplerBHWD, parent = nn.BilinearSamplerBHWD, nn.Module

--[[
BilinearSamplerBHWD() :
Expand Down
161 changes: 9 additions & 152 deletions ConvLSTM.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,16 @@ require 'dpnn'
require 'rnn'
require 'extracunn'

local ConvLSTM, parent = torch.class('nn.ConvLSTM', 'nn.AbstractRecurrent')
local ConvLSTM, parent = torch.class('nn.ConvLSTM', 'nn.LSTM')

function ConvLSTM:__init(inputSize, outputSize, rho, kc, km, stride)
parent.__init(self, rho or 10)
self.inputSize = inputSize
self.outputSize = outputSize
function ConvLSTM:__init(inputSize, outputSize, rho, kc, km, stride, batchSize)
self.kc = kc
self.km = km
self.padc = torch.floor(kc/2)
self.padm = torch.floor(km/2)
self.stride = stride or 1

-- build the model
self.recurrentModule = self:buildModel()
-- make it work with nn.Container
self.modules[1] = self.recurrentModule
self.sharedClones[1] = self.recurrentModule

-- for output(0), cell(0) and gradCell(T)
self.zeroTensor = torch.Tensor()

self.cells = {}
self.gradCells = {}
self.batchSize = batchSize or nil
parent.__init(self, inputSize, outputSize, rho or 10)
end

-------------------------- factory methods -----------------------------
Expand Down Expand Up @@ -139,14 +126,17 @@ function ConvLSTM:buildModel()
return model
end

------------------------- forward backward -----------------------------
function ConvLSTM:updateOutput(input)
local prevOutput, prevCell

if self.step == 1 then
prevOutput = self.userPrevOutput or self.zeroTensor
prevCell = self.userPrevCell or self.zeroTensor
self.zeroTensor:resize(self.outputSize,input:size(2),input:size(3)):zero()
if self.batchSize then
self.zeroTensor:resize(self.batchSize,self.outputSize,input:size(3),input:size(4)):zero()
else
self.zeroTensor:resize(self.outputSize,input:size(2),input:size(3)):zero()
end
else
-- previous output and memory of this module
prevOutput = self.output
Expand All @@ -164,13 +154,6 @@ function ConvLSTM:updateOutput(input)
output, cell = unpack(self.recurrentModule:updateOutput{input, prevOutput, prevCell})
end

if self.train ~= false then
local input_ = self.inputs[self.step]
self.inputs[self.step] = self.copyInputs
and nn.rnn.recursiveCopy(input_, input)
or nn.rnn.recursiveSet(input_, input)
end

self.outputs[self.step] = output
self.cells[self.step] = cell

Expand All @@ -186,137 +169,11 @@ function ConvLSTM:updateOutput(input)
return self.output
end

function ConvLSTM:backwardThroughTime(timeStep, rho)
assert(self.step > 1, "expecting at least one updateOutput")
self.gradInputs = {} -- used by Sequencer, Repeater
timeStep = timeStep or self.step
local rho = math.min(rho or self.rho, timeStep-1)
local stop = timeStep - rho

if self.fastBackward then
for step=timeStep-1,math.max(stop,1),-1 do
-- set the output/gradOutput states of current Module
local recurrentModule = self:getStepModule(step)

-- backward propagate through this step
local gradOutput = self.gradOutputs[step]
if self.gradPrevOutput then
self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], self.gradPrevOutput)
nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)
gradOutput = self._gradOutputs[step]
end

local scale = self.scales[step]
local output = (step == 1) and (self.userPrevOutput or self.zeroTensor) or self.outputs[step-1]
local cell = (step == 1) and (self.userPrevCell or self.zeroTensor) or self.cells[step-1]
local inputTable = {self.inputs[step], output, cell}
local gradCell = (step == self.step-1) and (self.userNextGradCell or self.zeroTensor) or self.gradCells[step]
local gradInputTable = recurrentModule:backward(inputTable, {gradOutput, gradCell}, scale)
gradInput, self.gradPrevOutput, gradCell = unpack(gradInputTable)
self.gradCells[step-1] = gradCell
table.insert(self.gradInputs, 1, gradInput)
if self.userPrevOutput then self.userGradPrevOutput = self.gradPrevOutput end
end
self.gradParametersAccumulated = true
return gradInput
else
local gradInput = self:updateGradInputThroughTime()
self:accGradParametersThroughTime()
return gradInput
end
end

function ConvLSTM:updateGradInputThroughTime(timeStep, rho)
assert(self.step > 1, "expecting at least one updateOutput")
self.gradInputs = {}
local gradInput
timeStep = timeStep or self.step
local rho = math.min(rho or self.rho, timeStep-1)
local stop = timeStep - rho

for step=timeStep-1,math.max(stop,1),-1 do
-- set the output/gradOutput states of current Module
local recurrentModule = self:getStepModule(step)

-- backward propagate through this step
local gradOutput = self.gradOutputs[step]
if self.gradPrevOutput then
self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], self.gradPrevOutput)
nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)
gradOutput = self._gradOutputs[step]
end

local output = (step == 1) and (self.userPrevOutput or self.zeroTensor) or self.outputs[step-1]
local cell = (step == 1) and (self.userPrevCell or self.zeroTensor) or self.cells[step-1]
local inputTable = {self.inputs[step], output, cell}
local gradCell = (step == self.step-1) and (self.userNextGradCell or self.zeroTensor) or self.gradCells[step]
local gradInputTable = recurrentModule:updateGradInput(inputTable, {gradOutput, gradCell})
gradInput, self.gradPrevOutput, gradCell = unpack(gradInputTable)
self.gradCells[step-1] = gradCell
table.insert(self.gradInputs, 1, gradInput)
if self.userPrevOutput then self.userGradPrevOutput = self.gradPrevOutput end
end

return gradInput
end

function ConvLSTM:accGradParametersThroughTime(timeStep, rho)
timeStep = timeStep or self.step
local rho = math.min(rho or self.rho, timeStep-1)
local stop = timeStep - rho

for step=timeStep-1,math.max(stop,1),-1 do
-- set the output/gradOutput states of current Module
local recurrentModule = self:getStepModule(step)

-- backward propagate through this step
local scale = self.scales[step]
local output = (step == 1) and (self.userPrevOutput or self.zeroTensor) or self.outputs[step-1]
local cell = (step == 1) and (self.userPrevCell or self.zeroTensor) or self.cells[step-1]
local inputTable = {self.inputs[step], output, cell}
local gradOutput = (step == self.step-1) and self.gradOutputs[step] or self._gradOutputs[step]
local gradCell = (step == self.step-1) and (self.userNextGradCell or self.zeroTensor) or self.gradCells[step]
local gradOutputTable = {gradOutput, gradCell}
recurrentModule:accGradParameters(inputTable, gradOutputTable, scale)
end

self.gradParametersAccumulated = true
return gradInput
end

function ConvLSTM:accUpdateGradParametersThroughTime(lr, timeStep, rho)
timeStep = timeStep or self.step
local rho = math.min(rho or self.rho, timeStep-1)
local stop = timeStep - rho

for step=timeStep-1,math.max(stop,1),-1 do
-- set the output/gradOutput states of current Module
local recurrentModule = self:getStepModule(step)

-- backward propagate through this step
local scale = self.scales[step]
local output = (step == 1) and (self.userPrevOutput or self.zeroTensor) or self.outputs[step-1]
local cell = (step == 1) and (self.userPrevCell or self.zeroTensor) or self.cells[step-1]
local inputTable = {self.inputs[step], output, cell}
local gradOutput = (step == self.step-1) and self.gradOutputs[step] or self._gradOutputs[step]
local gradCell = (step == self.step-1) and (self.userNextGradCell or self.zeroTensor) or self.gradCells[step]
local gradOutputTable = {self.gradOutputs[step], gradCell}
recurrentModule:accUpdateGradParameters(inputTable, gradOutputTable, lr*scale)
end

return gradInput
end


function ConvLSTM:initBias(forgetBias, otherBias)
local fBias = forgetBias or 1
local oBias = otherBias or 0
self.inputGate.modules[2].modules[1].bias:fill(oBias)
--self.inputGate.modules[2].modules[2].bias:fill(oBias)
self.outputGate.modules[2].modules[1].bias:fill(oBias)
--self.outputGate.modules[2].modules[2].bias:fill(oBias)
self.cellGate.modules[2].modules[1].bias:fill(oBias)
--self.cellGate.modules[2].modules[2].bias:fill(oBias)
self.forgetGate.modules[2].modules[1].bias:fill(fBias)
--self.forgetGate.modules[2].modules[2].bias:fill(fBias)
end
16 changes: 9 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# ConvLSTM

Source code associated with [Spatio-temporal video autoencoder with differentiable memory](http://arxiv.org/abs/1511.06309), submitted to ICLR2016.
Source code associated with [Spatio-temporal video autoencoder with differentiable memory](http://arxiv.org/abs/1511.06309), to appear in ICLR2016 Workshop track.

This is a demo version to be trained on our modified version of moving MNIST dataset, available [here](http://mi.eng.cam.ac.uk/~vp344/). Some videos obtained on real test sequences are also available [here](http://mi.eng.cam.ac.uk/~vp344/). In case you have issues seeing the content of the webpage when using Ubuntu Chrome, try using Mozilla.
This is a demo version to be trained on a modified version of moving MNIST dataset, available [here](http://mi.eng.cam.ac.uk/~vp344/). Some videos obtained on real test sequences are also available [here](http://mi.eng.cam.ac.uk/~vp344/) (not up-to-date though).

This code extends the [rnn](https://github.com/Element-Research/rnn) package by providing a spatio-temporal convolutional version of LSTM cells.
The repository contains also a demo of training a simple model using the ConvLSTM module to predict the next frame in a sequence. The difference between this model and the one in the paper is that the former does not explicitly estimate the optical flow to generate the next frame.

To run this demo, you first need to install the [extracunn](https://github.com/viorik/extracunn) package, which contains cuda code for SpatialConvolutionalNoBias layer and Huber gradient computation.

You also need to install the [stn](https://github.com/qassemoquab/stnbhwd) package, and replace the existing BilinearSamplerBHWD.lua with the file provided here.
The ConvLSTM module can be used as is. Optionally, the untied version implemented in UntiedConvLSTM class, can be employed. The latter uses a separate model for the first step in the sequence, which has no memory. This can be helpful on training on shorter sequences, to reduce the impact of the first memoryless step on the training.
#### Dependencies

More details soon.
* [rnn](https://github.com/Element-Research/rnn): our code extends [rnn](https://github.com/Element-Research/rnn) by providing a spatio-temporal convolutional version of LSTM cells.
* [extracunn](https://github.com/viorik/extracunn): contains cuda code for SpatialConvolutionalNoBias layer and Huber gradient computation.
* [stn](https://github.com/qassemoquab/stnbhwd).



Expand Down
26 changes: 12 additions & 14 deletions data-mnist.lua
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
local data_verbose = false

function getdataSeq_mnist(datafile)
local data = torch.DiskFile(datafile,'r'):readObject()
--local data = torch.load(datafile) -- uncomment this line if dataset in binary format
local data = torch.DiskFile(datafile,'r'):readObject() -- uncomment this line if dataset in ascii format
local datasetSeq ={}
data = data:float()/255.0
-- local std = std or 0.2
local nsamples = data:size(1)
local nseq = data:size(2)
local nrows = data:size(4)
Expand All @@ -14,18 +12,18 @@ function getdataSeq_mnist(datafile)
return nsamples
end

local idx = 1
local shuffle = torch.randperm(nsamples)
function datasetSeq:selectSeq()
local imageok = false
if simdata_verbose then
print('selectSeq')
end
while not imageok do
local i = math.ceil(torch.uniform(1e-12,nsamples))
--image index

local im = data:select(1,i)
return im,i
if idx>nsamples then
shuffle = torch.randperm(nsamples)
idx = 1
print ('data: Shuffle the data')
end
local i = shuffle[idx]
local seq = data:select(1,i)
idx = idx + 1
return seq,i
end

dsample = torch.Tensor(nseq,1,nrows,ncols)
Expand Down
7 changes: 3 additions & 4 deletions model.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
require 'nn'
require 'rnn'
require 'ConvLSTM'
require 'UntiedConvLSTM'
require 'DenseTransformer2D'
require 'SmoothHuberPenalty'
require 'encoder'
Expand All @@ -18,13 +18,12 @@ model:add(seqe)

-- memory branch
local memory_branch = nn.Sequential()
local seq = nn.Sequencer(nn.ConvLSTM(opt.nFiltersMemory[1],opt.nFiltersMemory[2], opt.nSeq, opt.kernelSize, opt.kernelSizeMemory, opt.stride))
--local seq = nn.Sequencer(nn.ConvLSTM(opt.nFiltersMemory[1],opt.nFiltersMemory[2], opt.nSeq, opt.kernelSize, opt.kernelSizeMemory, opt.stride))
local seq = nn.Sequencer(nn.UntiedConvLSTM(opt.nFiltersMemory[1],opt.nFiltersMemory[2], opt.nSeq, opt.kernelSize, opt.kernelSizeMemory, opt.stride))
seq:remember('both')
seq:training()
memory_branch:add(seq)
memory_branch:add(nn.SelectTable(opt.nSeq))
--memory_branch:add(nn.SelectTable(opt.nSeq))
--memory_branch:add(nn.L1Penalty(opt.constrWeight[2]))
memory_branch:add(flow)

-- keep last frame to apply optical flow on
Expand Down
5 changes: 3 additions & 2 deletions opts-mnist.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ opt.memorySizeW = 32
opt.memorySizeH = 32

opt.dataFile = 'dataset_fly_64x64_lines_train.t7'
opt.dataFileTest = 'dataset_fly_64x64_lines_test.t7'
opt.statInterval = 50 -- interval for printing error
opt.v = false -- be verbose
opt.display = true -- display stuff
opt.displayInterval = opt.statInterval*10
opt.displayInterval = opt.statInterval
opt.save = true -- save models

opt.saveInterval = 10000

if not paths.dirp(opt.dir) then
os.execute('mkdir -p ' .. opt.dir)
Expand Down

0 comments on commit 40b8241

Please sign in to comment.