Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Added convnet template

  • Loading branch information...
commit 7b98aee1127830f2d73679b69e91cb6d6f59c01d 1 parent dc8b803
@clementfarabet clementfarabet authored
Showing with 51 additions and 0 deletions.
  1. +20 −0 g.lua
  2. +31 −0 test.lua
View
20 g.lua
@@ -451,6 +451,26 @@ function g.MultiLayerPerceptron(sizes, input)
return groupNodes(layers, last)
end
+-- A standard ConvNet
+function g.ConvNet(nfeatures, fanins, filters, poolings, input)
+ local layers = {}
+ local last = input
+ for i=2,#nfeatures do
+ local c = nn.SpatialConvolution(nfeatures[i-1], nfeatures[i], filters[i-1], filters[i-1]){last}
+ table.insert(layers, c)
+ if i ~= #nfeatures then
+ local s = nn.Tanh(){c.output}
+ local p = nn.SpatialMaxPooling(poolings[i-1], poolings[i-1]){s.output}
+ table.insert(layers, s)
+ table.insert(layers, p)
+ last = p.output
+ else
+ last = c.output
+ end
+ end
+ return groupNodes(layers, last)
+end
+
-- An Elman network has three fully connected layers (in, hidden, out),
-- with the activations of the hidden layer feeding back into the
-- input, with a time-delay.
View
31 test.lua
@@ -186,6 +186,37 @@ function tests.backward()
print("gradients", firstlinear.twin.gradParameters.read()[1])
end
+-- Testing a backward pass
+function tests.convnet()
+ -- define convnet
+ input = g.DataNode()
+ features = {3, 8, 16, 32}
+ fanins = {1, 4, 16}
+ filters = {7, 7, 7}
+ poolings = {2, 2}
+ convnet = g.ConvNet(features, fanins, filters, poolings, input)
+
+ -- and a linear classifier for a 4-class problem
+ reshaper = nn.Reshape(32){convnet.output}
+ classifier = nn.Linear(32, 4){reshaper.output}
+
+ -- loss
+ target = g.DataNode()
+ logsoftmax = nn.LogSoftMax(){classifier.output}
+ loss = nn.ClassNLLCriterion(){logsoftmax.output, target}
+
+ -- random input: a 3-channel 46x46 image
+ input.write(lab.randn(3, 46, 46))
+
+ -- let's do a forward to check that the network works
+ print("forward", logsoftmax.output.read())
+
+ -- evaluate the loss
+ target.write(3)
+ print("target:", target.read())
+ print("loss:", loss.output.read())
+end
+
-- Testing criterion
function tests.criterion()
input = g.DataNode()
Please sign in to comment.
Something went wrong with that request. Please try again.