Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
anewell committed Mar 17, 2016
1 parent 6b3b867 commit d105c8e
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
@@ -0,0 +1 @@
*.t7
2 changes: 2 additions & 0 deletions main.lua
@@ -0,0 +1,2 @@
-- Load pre-trained model

36 changes: 36 additions & 0 deletions residual.lua
@@ -0,0 +1,36 @@
local conv = nnlib.SpatialConvolution
local batchnorm = nn.SpatialBatchNormalization
local relu = nnlib.ReLU

-- Main convolutional block
local function convBlock(numIn,numOut)
return nn.Sequential()
:add(conv(numIn,numOut/2,1,1))
:add(batchnorm(numOut/2))
:add(relu(true))
:add(conv(numOut/2,numOut/2,3,3,1,1,1,1))
:add(batchnorm(numOut/2))
:add(relu(true))
:add(conv(numOut/2,numOut,1,1))
:add(batchnorm(numOut))
end

-- Skip layer
local function skipLayer(numIn,numOut)
if numIn == numOut then
return nn.Identity()
else
return nn.Sequential()
:add(conv(numIn,numOut,1,1))
:add(batchnorm(numOut))
end
end

-- Residual block
function Residual(numIn,numOut)
return nn.Sequential()
:add(nn.ConcatTable()
:add(convBlock(numIn,numOut))
:add(skipLayer(numIn,numOut)))
:add(nn.CAddTable(true))
end
77 changes: 77 additions & 0 deletions stacked-hourglass-model.lua
@@ -0,0 +1,77 @@
paths.dofile('residual.lua')

local function hourglass(n, numIn, numOut, inp)
-- Upper branch
local up1 = Residual(numIn,256)(inp)
local up2 = Residual(256,256)(up1)
local up4 = Residual(256,numOut)(up2)

-- Lower branch
local pool = nnlib.SpatialMaxPooling(2,2,2,2)(inp)
local low1 = Residual(numIn,256)(pool)
local low2 = Residual(256,256)(low1)
local low5 = Residual(256,256)(low2)
local low6
if n > 1 then
low6 = hourglass(n-1,256,numOut,low5)
else
low6 = Residual(256,numOut)(low5)
end
local low7 = Residual(numOut,numOut)(low6)
local up5 = nn.SpatialUpSamplingNearest(2)(low7)

-- Bring two branches together
return nn.CAddTable()({up4,up5})
end

local function lin(numIn,numOut,inp)
-- Apply 1x1 convolution, no stride, no padding
local l_ = nnlib.SpatialConvolution(numIn,numOut,1,1,1,1,0,0)(inp)
return nnlib.ReLU(true)(nn.SpatialBatchNormalization(numOut)(l_))
end

function createModel()

local inp = nn.Identity()()

-- Initial processing of the image
local cnv1_ = nnlib.SpatialConvolution(3,64,7,7,2,2,3,3)(inp) -- 128
local cnv1 = nnlib.ReLU(true)(nn.SpatialBatchNormalization(64)(cnv1_))
local r1 = Residual(64,128)(cnv1)
local pool = nnlib.SpatialMaxPooling(2,2,2,2)(r1) -- 64
local r4 = Residual(128,128)(pool)
local r5 = Residual(128,128)(r4)
local r6 = Residual(128,256)(r5)

-- First hourglass
local hg1 = hourglass(4,256,512,r6)

-- Linear layers to produce first set of predictions
local l1 = lin(512,512,hg1)
local l2 = lin(512,256,l1)

-- First predicted heatmaps
local out1 = nnlib.SpatialConvolution(256,outputDim[1][1],1,1,1,1,0,0)(l2)
local out1_ = nnlib.SpatialConvolution(outputDim[1][1],256+128,1,1,1,1,0,0)(out1)

-- Concatenate with previous linear features
local cat1 = nn.JoinTable(2)({l2,pool})
local cat1_ = nnlib.SpatialConvolution(256+128,256+128,1,1,1,1,0,0)(cat1)
local int1 = nn.CAddTable()({cat1_,out1_})

-- Second hourglass
local hg2 = hourglass(4,256+128,512,int1)

-- Linear layers to produce predictions again
local l3 = lin(512,512,hg2)
local l4 = lin(512,512,l3)

-- Output heatmaps
local out2 = nnlib.SpatialConvolution(512,outputDim[2][1],1,1,1,1,0,0)(l4)

-- Final model
local model = nn.gModule({inp}, {out1,out2})

return model

end

0 comments on commit d105c8e

Please sign in to comment.