Permalink
Browse files

Merge branch 'getParameters'

  • Loading branch information...
clementfarabet committed Sep 7, 2012
2 parents 4282763 + 4697b5a commit b2168b91d15f293a357dcc05bba72603806ab875
Showing with 130 additions and 58 deletions.
  1. +33 −56 extra/nn/Module.lua
  2. +97 −2 extra/nn/test/test.lua
View
@@ -139,70 +139,47 @@ function Module:getParameters()
-- get parameters
local parameters,gradParameters = self:parameters()
+ local function storageInSet(set, storage) --this is waste of time (need correct hash)
+ for key, val in pairs(set) do
+ if key == storage then
+ return val
+ end
+ end
+ end
+
-- this function flattens arbitrary lists of parameters,
-- even complex shared ones
local function flatten(parameters)
- -- already flat ?
- local flat = true
- for k = 2,#parameters do
- if parameters[k]:storage() ~= parameters[k-1]:storage() then
- flat = false
- break
+ local storages = {}
+ local nParameters = 0
+ for k = 1,#parameters do
+ if not storageInSet(storages, parameters[k]:storage()) then
+ storages[parameters[k]:storage()] = nParameters
+ nParameters = nParameters + parameters[k]:storage():size()
end
end
- if flat then
- local nParameters = 0
- for k,param in ipairs(parameters) do
- nParameters = nParameters + param:nElement()
- end
- local flatParameters = parameters[1].new(parameters[1]:storage())
- if nParameters ~= flatParameters:nElement() then
- error('flattenParameters(): weird parameters')
- end
- return flatParameters
+
+ local flatParameters = torch.Tensor(nParameters):fill(1)
+ local flatStorage = flatParameters:storage()
+
+ for k = 1,#parameters do
+ local storageOffset = storageInSet(storages, parameters[k]:storage())
+ parameters[k]:set(flatStorage,
+ storageOffset + parameters[k]:storageOffset(),
+ parameters[k]:size(),
+ parameters[k]:stride())
+ parameters[k]:zero()
end
- -- compute offsets of each parameter
- local offsets = {}
- local sizes = {}
- local strides = {}
- local elements = {}
- local storageOffsets = {}
- local params = {}
- local nParameters = 0
- for k,param in ipairs(parameters) do
- table.insert(offsets, nParameters+1)
- table.insert(sizes, param:size())
- table.insert(strides, param:stride())
- table.insert(elements, param:nElement())
- table.insert(storageOffsets, param:storageOffset())
- local isView = false
- for i = 1,k-1 do
- if param:storage() == parameters[i]:storage() then
- offsets[k] = offsets[i]
- if storageOffsets[k] ~= storageOffsets[i] or elements[k] ~= elements[i] then
- error('flattenParameters(): cannot flatten shared weights with different structures')
- end
- isView = true
- break
- end
- end
- if not isView then
- nParameters = nParameters + param:nElement()
- end
+ if (flatParameters:sum() ~= 0) then
+ print("<getParameters()> WARNING: found "
+ .. flatParameters:sum() .. " holes in the parameters vector (i.e. "
+ .. flatParameters:sum() .. " storage elements that are unused, this "
+ .. "might be an issue for your optimization procedure)")
end
- -- create flat vector
- local flatParameters = parameters[1].new(nParameters)
- local storage = flatParameters:storage()
- -- reallocate all parameters in flat vector
- for i = 1,#parameters do
- local data = parameters[i]:clone()
- parameters[i]:set(storage, offsets[i], elements[i]):resize(sizes[i],strides[i]):copy(data)
- data = nil
- collectgarbage()
+
+ for k, v in pairs(storages) do
+ flatParameters[{{v+1,v+k:size()}}]:copy(torch.Tensor():set(k))
end
- -- cleanup
- collectgarbage()
- -- return flat param
return flatParameters
end
View
@@ -1078,10 +1078,105 @@ function nntest.VolumetricConvolution()
mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ')
end
+function nntest.Module_getParameters_1()
+ local n = nn.Sequential()
+ n:add( nn.Linear(10,10) )
+ local p = n:getParameters()
+
+ mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'getParameters(): weights wrong')
+ mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'getParameters(): bias wrong')
+end
+
+function nntest.Module_getParameters_2()
+ local n = nn.Sequential()
+ n:add( nn.Linear(10,10) )
+ local p = n:getParameters()
+
+ n:add( nn.Linear(10,10) )
+ p = n:getParameters()
+
+ mytester:asserteq((p[{ {111,210} }] - n.modules[2].weight):norm(), 0, 'error when appending new module')
+ mytester:asserteq((p[{ {211,220} }] - n.modules[2].bias):norm(), 0, 'error when appending new module')
+end
+
+function nntest.Module_getParameters_3()
+ local n = nn.Sequential()
+ n:add( nn.Linear(10,10) )
+ n:add( n.modules[1]:clone() )
+ local p = n:getParameters()
+
+ mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning')
+ mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning')
+
+ mytester:asserteq((p[{ {111,210} }] - n.modules[2].weight):norm(), 0, 'error when using cloning')
+ mytester:asserteq((p[{ {211,220} }] - n.modules[2].bias):norm(), 0, 'error when using cloning')
+
+ mytester:asserteq((p[{ {111,210} }] - n.modules[1].weight):norm(), 0, 'error when using cloning')
+ mytester:asserteq((p[{ {211,220} }] - n.modules[1].bias):norm(), 0, 'error when using cloning')
+
+ n:reset()
+
+ mytester:assertgt((p[{ {111,210} }] - n.modules[1].weight):norm(), 0, 'error when using cloning')
+ mytester:assertgt((p[{ {211,220} }] - n.modules[1].bias):norm(), 0, 'error when using cloning')
+end
+
+function nntest.Module_getParameters_4()
+ local n = nn.Sequential()
+ n:add( nn.Linear(10,10) )
+ n:add( n.modules[1]:clone() )
+ local p = n:getParameters()
+
+ n:add(nn.Linear(10,10))
+ p = n:getParameters()
+
+ mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning')
+ mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning')
+
+ mytester:asserteq((p[{ {111,210} }] - n.modules[2].weight):norm(), 0, 'error when using cloning')
+ mytester:asserteq((p[{ {211,220} }] - n.modules[2].bias):norm(), 0, 'error when using cloning')
+
+ mytester:asserteq((p[{ {221,320} }] - n.modules[3].weight):norm(), 0, 'error when using cloning')
+ mytester:asserteq((p[{ {321,330} }] - n.modules[3].bias):norm(), 0, 'error when using cloning')
+end
+
+function nntest.Module_getParameters_5()
+ local n = nn.Sequential()
+ n:add( nn.Linear(10,10) )
+ n:add( n.modules[1]:clone('weight','bias') )
+ local p = n:getParameters()
+
+ mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning+sharing')
+ mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning+sharing')
+
+ mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing')
+ mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing')
+
+ n:reset()
+
+ mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing')
+ mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing')
+end
+
+function nntest.Module_getParameters_6()
+ local n = nn.Sequential()
+ n:add( nn.Linear(10,10) )
+ n:add( n.modules[1]:clone('weight','bias') )
+ local p = n:getParameters()
+
+ n:add(nn.Linear(10,10))
+ p = n:getParameters()
+
+ mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning+sharing')
+ mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning+sharing')
+
+ mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing')
+ mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing')
+
+ mytester:asserteq((p[{ {111,210} }] - n.modules[3].weight):norm(), 0, 'error when using cloning+sharing')
+ mytester:asserteq((p[{ {211,220} }] - n.modules[3].bias):norm(), 0, 'error when using cloning+sharing')
+end
mytester:add(nntest)
---mytester:add(test_SpatialConvolution)
---mytester:add(test_AbsCriterion)
if not nn then
require 'nn'

0 comments on commit b2168b9

Please sign in to comment.