diff --git a/CMul.lua b/CMul.lua index ea3485a86..6787410fd 100644 --- a/CMul.lua +++ b/CMul.lua @@ -114,7 +114,7 @@ function CMul:accGradParameters(input, gradOutput, scale) end end -function CMul:type(type) +function CMul:type(type, tensorCache) if type then self._input = nil self._output = nil @@ -124,5 +124,5 @@ function CMul:type(type) self._repeat = nil self._sum = nil end - return parent.type(self, type) + return parent.type(self, type, tensorCache) end diff --git a/Copy.lua b/Copy.lua index 8ef09dcb1..d87c2b272 100644 --- a/Copy.lua +++ b/Copy.lua @@ -34,9 +34,9 @@ function Copy:updateGradInput(input, gradOutput) return self.gradInput end -function Copy:type(type) +function Copy:type(type, tensorCache) if type and self.dontCast then return self end - return parent.type(self, type) + return parent.type(self, type, tensorCache) end diff --git a/Criterion.lua b/Criterion.lua index f6e0d82ab..4efb279c2 100644 --- a/Criterion.lua +++ b/Criterion.lua @@ -28,11 +28,11 @@ function Criterion:clone() return clone end -function Criterion:type(type) +function Criterion:type(type, tensorCache) assert(type, 'Criterion: must provide a type to convert to') -- find all tensors and convert them for key,param in pairs(self) do - self[key] = nn.utils.recursiveType(param, type) + self[key] = nn.utils.recursiveType(param, type, tensorCache) end return self end diff --git a/Euclidean.lua b/Euclidean.lua index ae3fee907..25aa2e484 100644 --- a/Euclidean.lua +++ b/Euclidean.lua @@ -171,7 +171,7 @@ function Euclidean:accGradParameters(input, gradOutput, scale) end end -function Euclidean:type(type) +function Euclidean:type(type, tensorCache) if type then -- prevent premature memory allocations self._input = nil @@ -186,5 +186,5 @@ function Euclidean:type(type) self._repeat = nil self._repeat2 = nil end - return parent.type(self, type) + return parent.type(self, type, tensorCache) end diff --git a/FlattenTable.lua b/FlattenTable.lua index 3a88588cd..69abd0762 100644 --- a/FlattenTable.lua +++ b/FlattenTable.lua @@ -94,7 +94,7 @@ function FlattenTable:updateGradInput(input, gradOutput) return self.gradInput end -function FlattenTable:type(type) +function FlattenTable:type(type, tensorCache) -- This function just stores references so we don't need to do any type -- conversions. Just force the tables to be empty. self.output = {} diff --git a/JoinTable.lua b/JoinTable.lua index c143bd451..c787a3b3e 100644 --- a/JoinTable.lua +++ b/JoinTable.lua @@ -64,7 +64,7 @@ function JoinTable:updateGradInput(input, gradOutput) return self.gradInput end -function JoinTable:type(type) +function JoinTable:type(type, tensorCache) self.gradInput = {} - return parent.type(self, type) + return parent.type(self, type, tensorCache) end diff --git a/LookupTable.lua b/LookupTable.lua index b7cf49040..fa6920813 100644 --- a/LookupTable.lua +++ b/LookupTable.lua @@ -71,8 +71,8 @@ function LookupTable:accGradParameters(input, gradOutput, scale) self.gradWeight.nn.LookupTable_accGradParameters(self, input, gradOutput, scale) end -function LookupTable:type(type) - parent.type(self, type) +function LookupTable:type(type, tensorCache) + parent.type(self, type, tensorCache) if type == 'torch.CudaTensor' then -- CUDA uses _sorted and _indices temporary tensors diff --git a/MixtureTable.lua b/MixtureTable.lua index 16e583263..16dafc4c1 100644 --- a/MixtureTable.lua +++ b/MixtureTable.lua @@ -149,7 +149,7 @@ function MixtureTable:updateGradInput(input, gradOutput) return self.gradInput end -function MixtureTable:type(type) +function MixtureTable:type(type, tensorCache) self._gaterView = nil self._expert = nil self._expertView = nil @@ -157,5 +157,5 @@ function MixtureTable:type(type) self._gradInput = nil self._expert2 = nil self._expertView2 = nil - return parent.type(self, type) + return parent.type(self, type, tensorCache) end diff --git a/Module.lua b/Module.lua index 4c8c21cd0..5bf771a3e 100644 --- a/Module.lua +++ b/Module.lua @@ -116,10 +116,11 @@ end function Module:type(type, tensorCache) assert(type, 'Module: must provide a type to convert to') + tensorCache = tensorCache or {} + -- find all tensors and convert them for key,param in pairs(self) do - self[key] = nn.utils.recursiveType(param, type) - + self[key] = nn.utils.recursiveType(param, type, tensorCache) end return self @@ -281,6 +282,19 @@ function Module:__call__(input, gradOutput) end end +-- Run a callback (called with the module as an argument) in preorder over this +-- module and its children. +-- +function Module:apply(callback) + callback(self) + + if self.modules then + for _, module in ipairs(self.modules) do + module:apply(callback) + end + end +end + function Module:findModules(typename, container) container = container or self local nodes = {} diff --git a/ParallelCriterion.lua b/ParallelCriterion.lua index 84d4ee19f..30064a2a2 100644 --- a/ParallelCriterion.lua +++ b/ParallelCriterion.lua @@ -34,7 +34,7 @@ function ParallelCriterion:updateGradInput(input, target) return self.gradInput end -function ParallelCriterion:type(type) +function ParallelCriterion:type(type, tensorCache) self.gradInput = {} - return parent.type(self, type) + return parent.type(self, type, tensorCache) end diff --git a/SelectTable.lua b/SelectTable.lua index 12dc71868..9fd634877 100644 --- a/SelectTable.lua +++ b/SelectTable.lua @@ -51,8 +51,8 @@ function SelectTable:updateGradInput(input, gradOutput) return self.gradInput end -function SelectTable:type(type) +function SelectTable:type(type, tensorCache) self.gradInput = {} self.output = {} - return parent.type(self, type) + return parent.type(self, type, tensorCache) end diff --git a/WeightedEuclidean.lua b/WeightedEuclidean.lua index 8acd35147..606510cd2 100644 --- a/WeightedEuclidean.lua +++ b/WeightedEuclidean.lua @@ -209,7 +209,7 @@ function WeightedEuclidean:accGradParameters(input, gradOutput, scale) end end -function WeightedEuclidean:type(type) +function WeightedEuclidean:type(type, tensorCache) if type then -- prevent premature memory allocations self._input = nil @@ -226,7 +226,7 @@ function WeightedEuclidean:type(type) self._repeat2 = nil self._repeat3 = nil end - return parent.type(self, type) + return parent.type(self, type, tensorCache) end function WeightedEuclidean:parameters() diff --git a/doc/module.md b/doc/module.md index 97e14a07c..1c1d40839 100755 --- a/doc/module.md +++ b/doc/module.md @@ -150,8 +150,6 @@ Note that this function if called on a [Container](containers.md#nn.Containers) module will share the same parameters for all the contained modules as well. -**NOTE: If you ever type-cast your network to another precision, i.e. net:cuda() for example, the sharing gets untied, and you have to reshare your modules again.** - Example: ```lua @@ -186,8 +184,6 @@ If arguments are provided to the `clone(...)` function it also calls module after creating it, hence making a deep copy of this module with some shared parameters. -**NOTE: If you ever type-cast your network to another precision, i.e. net:cuda() for example, the sharing gets untied, and you have to reshare your modules again.** - Example: ```lua -- make an mlp @@ -206,12 +202,35 @@ print(mlp2:get(1).bias[1]) ``` -### type(type) ### +### type(type[, tensorCache]) ### This function converts all the parameters of a module to the given `type`. The `type` can be one of the types defined for [torch.Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md). +If tensors (or their storages) are shared between multiple modules in a +network, this sharing will be preserved after type is called. + +To preserve sharing between multiple modules and/or tensors, use +`nn.utils.recursiveType`: + +```lua +-- make an mlp +mlp1=nn.Sequential(); +mlp1:add(nn.Linear(100,10)); + +-- make a second mlp +mlp2=nn.Sequential(); +mlp2:add(nn.Linear(100,10)); + +-- the second mlp shares the bias of the first +mlp2:share(mlp1,'bias'); + +-- mlp1 and mlp2 will be converted to float, and will share bias +-- note: tensors can be provided as inputs as well as modules +nn.utils.recursiveType({mlp1, mlp2}, 'torch.FloatTensor') +``` + ### float() ### diff --git a/test.lua b/test.lua index 3fff1513d..25be158be 100644 --- a/test.lua +++ b/test.lua @@ -39,6 +39,7 @@ for test_name, component in pairs(tostringTestModules) do end end + function nntest.Add() local inj_vals = {math.random(3,5), 1} -- Also test the inj = 1 spatial case local ini = math.random(3,5) @@ -310,7 +311,7 @@ function nntest.PReLU() for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( - 'error on weight [%s]', t)) + 'error on weight [%s]', t)) end -- 2D @@ -328,7 +329,7 @@ function nntest.PReLU() for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( - 'error on weight [%s]', t)) + 'error on weight [%s]', t)) end -- 4D @@ -347,7 +348,7 @@ function nntest.PReLU() for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( - 'error on weight [%s]', t)) + 'error on weight [%s]', t)) end -- IO @@ -1067,7 +1068,7 @@ function nntest.LogSoftmax() local module = nn.LogSoftMax() local err = jac.testJacobian(module,input) - mytester:assertlt(err,1e-3, 'error on state ') + mytester:assertlt(err, 1e-3, 'error on state ') local ferr,berr = jac.testIO(module,input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') @@ -2999,7 +3000,7 @@ function nntest.AddConstant() -- Test BPROP local err = jac.testJacobian(mod, input) mytester:assertlt(err, precision, 'bprop error ') - + -- inplace comparisons local ini = math.random(3,5) local inj = math.random(3,5) @@ -3024,7 +3025,7 @@ function nntest.AddConstant() local gradInput1 = module1:backward(input1, gradOutput1) local gradInput2 = module2:backward(input2, gradOutput2) - mytester:asserteq(0, (gradInput1-gradInput2):abs():max(), + mytester:asserteq(0, (gradInput1-gradInput2):abs():max(), torch.typename(module1) .. ' - in-place backward err ') local input1 = torch.rand(ink, inj, ini) @@ -3034,7 +3035,7 @@ function nntest.AddConstant() module1:backward(module1.output,torch.rand(input1:size())) local err = (input1-input2):abs():max() - mytester:asserteq(err, 0, torch.typename(module1) .. + mytester:asserteq(err, 0, torch.typename(module1) .. ' - inplace input change err ') end @@ -3056,7 +3057,7 @@ function nntest.MulConstant() -- Test BPROP local err = jac.testJacobian(mod, input) mytester:assertlt(err, precision, 'bprop error ') - + -- inplace comparisons local ini = math.random(3,5) local inj = math.random(3,5) @@ -3075,13 +3076,13 @@ function nntest.MulConstant() local out1 = module1:forward(input1) local out2 = module2:forward(input2) - mytester:asserteq(0, (out1-out2):abs():max(), torch.typename(module1) .. + mytester:asserteq(0, (out1-out2):abs():max(), torch.typename(module1) .. ' - in-place forward err ') local gradInput1 = module1:backward(input1, gradOutput1) local gradInput2 = module2:backward(input2, gradOutput2) - - mytester:asserteq(0, (gradInput1-gradInput2):abs():max(), + + mytester:asserteq(0, (gradInput1-gradInput2):abs():max(), torch.typename(module1) .. ' - in-place backward err ') local input1 = torch.rand(ink, inj, ini) @@ -3091,7 +3092,7 @@ function nntest.MulConstant() module1:backward(module1.output,torch.rand(input1:size())) local err = (input1-input2):abs():max() - mytester:assertalmosteq(err, 0, 1e-15, torch.typename(module1) .. + mytester:assertalmosteq(err, 0, 1e-15, torch.typename(module1) .. ' - inplace input change err ') end @@ -4154,6 +4155,126 @@ function nntest.addSingletonDimension() "invalid dimension not detected") end +function nntest.Typecast() + local function make_network() + local seq = nn.Sequential() + seq:add(nn.Linear(15, 10)) + seq:add(nn.Linear(15, 10)) + seq.modules[1].bias:fill(1) + seq.modules[2].bias:fill(2) + return seq + end + + -- make sure that the typecasts aren't nops + assert(torch.getdefaulttensortype() == 'torch.DoubleTensor') + + -- basic net + local net = make_network() + net.modules[1].empty_tensor = torch.Tensor() + net:float() + assert(net.modules[1].bias:type() == 'torch.FloatTensor', + net.modules[1].bias:type()) + assert(net.modules[1].empty_tensor:type() == 'torch.FloatTensor') + assert(net.modules[1].bias ~= net.modules[2].bias) + net.modules[1].bias:fill(3) + assert(net.modules[1].bias[1] == 3) + assert(net.modules[2].bias[1] == 2) + + -- shared tensors remain shared + local net = make_network() + net.modules[2].bias = net.modules[1].bias + net:float() + assert(net.modules[1].bias:type() == 'torch.FloatTensor') + assert(net.modules[1].bias == net.modules[2].bias) + assert(net.modules[1].bias[1] == 1) + + -- shared storages remain shared + local net = make_network() + net.modules[2].bias:set(net.modules[1].bias) + local net = net:float() + assert(net.modules[1].bias:type() == 'torch.FloatTensor') + assert(net.modules[1].bias ~= net.modules[2].bias) + net.modules[1].bias:fill(3) + assert(net.modules[1].bias[1] == 3) + assert(net.modules[2].bias[1] == 3) + + -- tricky: overlapping views on the same storage are preserved + local net = make_network() + local overlap_storage = torch.Tensor(15):fill(1) + net.modules[1].bias = overlap_storage:narrow(1, 1, 10) + net.modules[2].bias = overlap_storage:narrow(1, 6, 10) + net:float() + assert(net.modules[1].bias:type() == 'torch.FloatTensor') + assert(net.modules[1].bias ~= net.modules[2].bias) + net.modules[1].bias:fill(3) + assert(net.modules[1].bias[1] == 3) + assert(net.modules[2].bias[1] == 3) + assert(net.modules[2].bias[6] == 1) -- only the first 5 elements overlapped + + -- check recursiveType on a table + local net1 = make_network() + local net2 = make_network() + net2.modules[1].bias:set(net1.modules[1].bias) + net1:float() + net2:float() + net1.modules[1].bias:fill(3) + assert(net2.modules[1].bias[1] == 1) + + local net1 = make_network() + local net2 = make_network() + net2.modules[1].bias:set(net1.modules[1].bias) + + local tensorCache = {} + net1:type('torch.FloatTensor', tensorCache) + net2:type('torch.FloatTensor', tensorCache) + net1.modules[1].bias:fill(3) + assert(net2.modules[1].bias[1] == 3) + + local net1 = make_network() + local net2 = make_network() + net2.modules[1].bias:set(net1.modules[1].bias) + + nn.utils.recursiveType({net1, net2}, 'torch.FloatTensor') + net1.modules[1].bias:fill(3) + assert(net2.modules[1].bias[1] == 3) + + -- smoke test some modules with custom type methods + local custom_type_modules = { + nn.MixtureTable(3), + nn.ConcatTable(), + nn.Copy(), + nn.Copy(nil, nil, nil, true), + nn.SpatialContrastiveNormalization(), + nn.DotProduct(), + nn.PairwiseDistance(1), + nn.SpatialDivisiveNormalization(), + nn.SpatialSubtractiveNormalization() + } + for _, module in ipairs(custom_type_modules) do + module:float() + end +end + +function nntest.Module_apply() + local s = nn.Sequential() + s:add(nn.Linear(10,10)) + local s2 = nn.Sequential() + s2:add(nn.Linear(10,5)) + s:add(s2) + s:add(nn.Tanh()) + + local seen = 0 + s:apply(function(module) + if torch.type(module) == 'nn.Linear' then + module.bias:resize(20) + seen = seen + 1 + end + end) + mytester:asserteq(seen, 2) + mytester:asserteq(s.modules[1].bias:size(1), 20) + mytester:asserteq(s2.modules[1].bias:size(1), 20) +end + mytester:add(nntest) if not nn then diff --git a/utils.lua b/utils.lua index 40cc37737..375137637 100644 --- a/utils.lua +++ b/utils.lua @@ -1,15 +1,75 @@ nn.utils = {} -function nn.utils.recursiveType(param, type_str) +-- oops; someone forgot to add torch.Storage.type +-- TODO replace with torch.Storage.type when implemented +local function torch_Storage_type(self, type) + local current = torch.typename(self) + if not type then return current end + if type ~= current then + local new = torch.getmetatable(type).new() + if self:size() > 0 then + new:resize(self:size()):copy(self) + end + return new + else + return self + end +end + +-- tensorCache maintains a list of all tensors and storages that have been +-- converted (recursively) by calls to recursiveType() and type(). +-- It caches conversions in order to preserve sharing semantics +-- i.e. if two tensors share a common storage, then type conversion +-- should preserve that. +-- +-- You can preserve sharing semantics across multiple networks by +-- passing tensorCache between the calls to type, e.g. +-- +-- > tensorCache = {} +-- > net1:type('torch.CudaTensor', tensorCache) +-- > net2:type('torch.CudaTensor', tensorCache) +-- > nn.utils.recursiveType(anotherTensor, 'torch.CudaTensor', tensorCache) +-- +-- Implementation note: to make Lua table lookup behave correctly, +-- tensor keys are stored as actual tensor objects, while storage +-- keys are stored as the pointers themselves (as numbers). +function nn.utils.recursiveType(param, type, tensorCache) + tensorCache = tensorCache or {} + if torch.type(param) == 'table' then for k, v in pairs(param) do - param[k] = nn.utils.recursiveType(v, type_str) + param[k] = nn.utils.recursiveType(v, type, tensorCache) end elseif torch.isTypeOf(param, 'nn.Module') or torch.isTypeOf(param, 'nn.Criterion') then - param:type(type_str) + param:type(type, tensorCache) elseif torch.isTensor(param) then - param = param:type(type_str) + if torch.typename(param) ~= type then + local newparam + if tensorCache[param] then + newparam = tensorCache[param] + else + newparam = torch.Tensor():type(type) + local storageType = type:gsub('Tensor','Storage') + if param:storage() then + local storage_key = torch.pointer(param:storage()) + if not tensorCache[storage_key] then + tensorCache[storage_key] = torch_Storage_type( + param:storage(), storageType) + end + assert(torch.type(tensorCache[storage_key]) == storageType) + newparam:set( + tensorCache[storage_key], + param:storageOffset(), + param:size(), + param:stride() + ) + tensorCache[param] = newparam + end + end + assert(torch.type(newparam) == type) + param = newparam + end end return param end @@ -90,5 +150,4 @@ function nn.utils.addSingletonDimension(t, dim) return view end - table.unpack = table.unpack or unpack