diff --git a/CMul.lua b/CMul.lua
index ea3485a86..6787410fd 100644
--- a/CMul.lua
+++ b/CMul.lua
@@ -114,7 +114,7 @@ function CMul:accGradParameters(input, gradOutput, scale)
end
end
-function CMul:type(type)
+function CMul:type(type, tensorCache)
if type then
self._input = nil
self._output = nil
@@ -124,5 +124,5 @@ function CMul:type(type)
self._repeat = nil
self._sum = nil
end
- return parent.type(self, type)
+ return parent.type(self, type, tensorCache)
end
diff --git a/Copy.lua b/Copy.lua
index 8ef09dcb1..d87c2b272 100644
--- a/Copy.lua
+++ b/Copy.lua
@@ -34,9 +34,9 @@ function Copy:updateGradInput(input, gradOutput)
return self.gradInput
end
-function Copy:type(type)
+function Copy:type(type, tensorCache)
if type and self.dontCast then
return self
end
- return parent.type(self, type)
+ return parent.type(self, type, tensorCache)
end
diff --git a/Criterion.lua b/Criterion.lua
index f6e0d82ab..4efb279c2 100644
--- a/Criterion.lua
+++ b/Criterion.lua
@@ -28,11 +28,11 @@ function Criterion:clone()
return clone
end
-function Criterion:type(type)
+function Criterion:type(type, tensorCache)
assert(type, 'Criterion: must provide a type to convert to')
-- find all tensors and convert them
for key,param in pairs(self) do
- self[key] = nn.utils.recursiveType(param, type)
+ self[key] = nn.utils.recursiveType(param, type, tensorCache)
end
return self
end
diff --git a/Euclidean.lua b/Euclidean.lua
index ae3fee907..25aa2e484 100644
--- a/Euclidean.lua
+++ b/Euclidean.lua
@@ -171,7 +171,7 @@ function Euclidean:accGradParameters(input, gradOutput, scale)
end
end
-function Euclidean:type(type)
+function Euclidean:type(type, tensorCache)
if type then
-- prevent premature memory allocations
self._input = nil
@@ -186,5 +186,5 @@ function Euclidean:type(type)
self._repeat = nil
self._repeat2 = nil
end
- return parent.type(self, type)
+ return parent.type(self, type, tensorCache)
end
diff --git a/FlattenTable.lua b/FlattenTable.lua
index 3a88588cd..69abd0762 100644
--- a/FlattenTable.lua
+++ b/FlattenTable.lua
@@ -94,7 +94,7 @@ function FlattenTable:updateGradInput(input, gradOutput)
return self.gradInput
end
-function FlattenTable:type(type)
+function FlattenTable:type(type, tensorCache)
-- This function just stores references so we don't need to do any type
-- conversions. Just force the tables to be empty.
self.output = {}
diff --git a/JoinTable.lua b/JoinTable.lua
index c143bd451..c787a3b3e 100644
--- a/JoinTable.lua
+++ b/JoinTable.lua
@@ -64,7 +64,7 @@ function JoinTable:updateGradInput(input, gradOutput)
return self.gradInput
end
-function JoinTable:type(type)
+function JoinTable:type(type, tensorCache)
self.gradInput = {}
- return parent.type(self, type)
+ return parent.type(self, type, tensorCache)
end
diff --git a/LookupTable.lua b/LookupTable.lua
index b7cf49040..fa6920813 100644
--- a/LookupTable.lua
+++ b/LookupTable.lua
@@ -71,8 +71,8 @@ function LookupTable:accGradParameters(input, gradOutput, scale)
self.gradWeight.nn.LookupTable_accGradParameters(self, input, gradOutput, scale)
end
-function LookupTable:type(type)
- parent.type(self, type)
+function LookupTable:type(type, tensorCache)
+ parent.type(self, type, tensorCache)
if type == 'torch.CudaTensor' then
-- CUDA uses _sorted and _indices temporary tensors
diff --git a/MixtureTable.lua b/MixtureTable.lua
index 16e583263..16dafc4c1 100644
--- a/MixtureTable.lua
+++ b/MixtureTable.lua
@@ -149,7 +149,7 @@ function MixtureTable:updateGradInput(input, gradOutput)
return self.gradInput
end
-function MixtureTable:type(type)
+function MixtureTable:type(type, tensorCache)
self._gaterView = nil
self._expert = nil
self._expertView = nil
@@ -157,5 +157,5 @@ function MixtureTable:type(type)
self._gradInput = nil
self._expert2 = nil
self._expertView2 = nil
- return parent.type(self, type)
+ return parent.type(self, type, tensorCache)
end
diff --git a/Module.lua b/Module.lua
index 4c8c21cd0..5bf771a3e 100644
--- a/Module.lua
+++ b/Module.lua
@@ -116,10 +116,11 @@ end
function Module:type(type, tensorCache)
assert(type, 'Module: must provide a type to convert to')
+ tensorCache = tensorCache or {}
+
-- find all tensors and convert them
for key,param in pairs(self) do
- self[key] = nn.utils.recursiveType(param, type)
-
+ self[key] = nn.utils.recursiveType(param, type, tensorCache)
end
return self
@@ -281,6 +282,19 @@ function Module:__call__(input, gradOutput)
end
end
+-- Run a callback (called with the module as an argument) in preorder over this
+-- module and its children.
+--
+function Module:apply(callback)
+ callback(self)
+
+ if self.modules then
+ for _, module in ipairs(self.modules) do
+ module:apply(callback)
+ end
+ end
+end
+
function Module:findModules(typename, container)
container = container or self
local nodes = {}
diff --git a/ParallelCriterion.lua b/ParallelCriterion.lua
index 84d4ee19f..30064a2a2 100644
--- a/ParallelCriterion.lua
+++ b/ParallelCriterion.lua
@@ -34,7 +34,7 @@ function ParallelCriterion:updateGradInput(input, target)
return self.gradInput
end
-function ParallelCriterion:type(type)
+function ParallelCriterion:type(type, tensorCache)
self.gradInput = {}
- return parent.type(self, type)
+ return parent.type(self, type, tensorCache)
end
diff --git a/SelectTable.lua b/SelectTable.lua
index 12dc71868..9fd634877 100644
--- a/SelectTable.lua
+++ b/SelectTable.lua
@@ -51,8 +51,8 @@ function SelectTable:updateGradInput(input, gradOutput)
return self.gradInput
end
-function SelectTable:type(type)
+function SelectTable:type(type, tensorCache)
self.gradInput = {}
self.output = {}
- return parent.type(self, type)
+ return parent.type(self, type, tensorCache)
end
diff --git a/WeightedEuclidean.lua b/WeightedEuclidean.lua
index 8acd35147..606510cd2 100644
--- a/WeightedEuclidean.lua
+++ b/WeightedEuclidean.lua
@@ -209,7 +209,7 @@ function WeightedEuclidean:accGradParameters(input, gradOutput, scale)
end
end
-function WeightedEuclidean:type(type)
+function WeightedEuclidean:type(type, tensorCache)
if type then
-- prevent premature memory allocations
self._input = nil
@@ -226,7 +226,7 @@ function WeightedEuclidean:type(type)
self._repeat2 = nil
self._repeat3 = nil
end
- return parent.type(self, type)
+ return parent.type(self, type, tensorCache)
end
function WeightedEuclidean:parameters()
diff --git a/doc/module.md b/doc/module.md
index 97e14a07c..1c1d40839 100755
--- a/doc/module.md
+++ b/doc/module.md
@@ -150,8 +150,6 @@ Note that this function if called on a [Container](containers.md#nn.Containers)
module will share the same parameters for all the contained modules as
well.
-**NOTE: If you ever type-cast your network to another precision, i.e. net:cuda() for example, the sharing gets untied, and you have to reshare your modules again.**
-
Example:
```lua
@@ -186,8 +184,6 @@ If arguments are provided to the `clone(...)` function it also calls
module after creating it, hence making a deep copy of this module with
some shared parameters.
-**NOTE: If you ever type-cast your network to another precision, i.e. net:cuda() for example, the sharing gets untied, and you have to reshare your modules again.**
-
Example:
```lua
-- make an mlp
@@ -206,12 +202,35 @@ print(mlp2:get(1).bias[1])
```
-### type(type) ###
+### type(type[, tensorCache]) ###
This function converts all the parameters of a module to the given
`type`. The `type` can be one of the types defined for
[torch.Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md).
+If tensors (or their storages) are shared between multiple modules in a
+network, this sharing will be preserved after type is called.
+
+To preserve sharing between multiple modules and/or tensors, use
+`nn.utils.recursiveType`:
+
+```lua
+-- make an mlp
+mlp1=nn.Sequential();
+mlp1:add(nn.Linear(100,10));
+
+-- make a second mlp
+mlp2=nn.Sequential();
+mlp2:add(nn.Linear(100,10));
+
+-- the second mlp shares the bias of the first
+mlp2:share(mlp1,'bias');
+
+-- mlp1 and mlp2 will be converted to float, and will share bias
+-- note: tensors can be provided as inputs as well as modules
+nn.utils.recursiveType({mlp1, mlp2}, 'torch.FloatTensor')
+```
+
### float() ###
diff --git a/test.lua b/test.lua
index 3fff1513d..25be158be 100644
--- a/test.lua
+++ b/test.lua
@@ -39,6 +39,7 @@ for test_name, component in pairs(tostringTestModules) do
end
end
+
function nntest.Add()
local inj_vals = {math.random(3,5), 1} -- Also test the inj = 1 spatial case
local ini = math.random(3,5)
@@ -310,7 +311,7 @@ function nntest.PReLU()
for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
mytester:assertlt(err, precision, string.format(
- 'error on weight [%s]', t))
+ 'error on weight [%s]', t))
end
-- 2D
@@ -328,7 +329,7 @@ function nntest.PReLU()
for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
mytester:assertlt(err, precision, string.format(
- 'error on weight [%s]', t))
+ 'error on weight [%s]', t))
end
-- 4D
@@ -347,7 +348,7 @@ function nntest.PReLU()
for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
mytester:assertlt(err, precision, string.format(
- 'error on weight [%s]', t))
+ 'error on weight [%s]', t))
end
-- IO
@@ -1067,7 +1068,7 @@ function nntest.LogSoftmax()
local module = nn.LogSoftMax()
local err = jac.testJacobian(module,input)
- mytester:assertlt(err,1e-3, 'error on state ')
+ mytester:assertlt(err, 1e-3, 'error on state ')
local ferr,berr = jac.testIO(module,input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
@@ -2999,7 +3000,7 @@ function nntest.AddConstant()
-- Test BPROP
local err = jac.testJacobian(mod, input)
mytester:assertlt(err, precision, 'bprop error ')
-
+
-- inplace comparisons
local ini = math.random(3,5)
local inj = math.random(3,5)
@@ -3024,7 +3025,7 @@ function nntest.AddConstant()
local gradInput1 = module1:backward(input1, gradOutput1)
local gradInput2 = module2:backward(input2, gradOutput2)
- mytester:asserteq(0, (gradInput1-gradInput2):abs():max(),
+ mytester:asserteq(0, (gradInput1-gradInput2):abs():max(),
torch.typename(module1) .. ' - in-place backward err ')
local input1 = torch.rand(ink, inj, ini)
@@ -3034,7 +3035,7 @@ function nntest.AddConstant()
module1:backward(module1.output,torch.rand(input1:size()))
local err = (input1-input2):abs():max()
- mytester:asserteq(err, 0, torch.typename(module1) ..
+ mytester:asserteq(err, 0, torch.typename(module1) ..
' - inplace input change err ')
end
@@ -3056,7 +3057,7 @@ function nntest.MulConstant()
-- Test BPROP
local err = jac.testJacobian(mod, input)
mytester:assertlt(err, precision, 'bprop error ')
-
+
-- inplace comparisons
local ini = math.random(3,5)
local inj = math.random(3,5)
@@ -3075,13 +3076,13 @@ function nntest.MulConstant()
local out1 = module1:forward(input1)
local out2 = module2:forward(input2)
- mytester:asserteq(0, (out1-out2):abs():max(), torch.typename(module1) ..
+ mytester:asserteq(0, (out1-out2):abs():max(), torch.typename(module1) ..
' - in-place forward err ')
local gradInput1 = module1:backward(input1, gradOutput1)
local gradInput2 = module2:backward(input2, gradOutput2)
-
- mytester:asserteq(0, (gradInput1-gradInput2):abs():max(),
+
+ mytester:asserteq(0, (gradInput1-gradInput2):abs():max(),
torch.typename(module1) .. ' - in-place backward err ')
local input1 = torch.rand(ink, inj, ini)
@@ -3091,7 +3092,7 @@ function nntest.MulConstant()
module1:backward(module1.output,torch.rand(input1:size()))
local err = (input1-input2):abs():max()
- mytester:assertalmosteq(err, 0, 1e-15, torch.typename(module1) ..
+ mytester:assertalmosteq(err, 0, 1e-15, torch.typename(module1) ..
' - inplace input change err ')
end
@@ -4154,6 +4155,126 @@ function nntest.addSingletonDimension()
"invalid dimension not detected")
end
+function nntest.Typecast()
+ local function make_network()
+ local seq = nn.Sequential()
+ seq:add(nn.Linear(15, 10))
+ seq:add(nn.Linear(15, 10))
+ seq.modules[1].bias:fill(1)
+ seq.modules[2].bias:fill(2)
+ return seq
+ end
+
+ -- make sure that the typecasts aren't nops
+ assert(torch.getdefaulttensortype() == 'torch.DoubleTensor')
+
+ -- basic net
+ local net = make_network()
+ net.modules[1].empty_tensor = torch.Tensor()
+ net:float()
+ assert(net.modules[1].bias:type() == 'torch.FloatTensor',
+ net.modules[1].bias:type())
+ assert(net.modules[1].empty_tensor:type() == 'torch.FloatTensor')
+ assert(net.modules[1].bias ~= net.modules[2].bias)
+ net.modules[1].bias:fill(3)
+ assert(net.modules[1].bias[1] == 3)
+ assert(net.modules[2].bias[1] == 2)
+
+ -- shared tensors remain shared
+ local net = make_network()
+ net.modules[2].bias = net.modules[1].bias
+ net:float()
+ assert(net.modules[1].bias:type() == 'torch.FloatTensor')
+ assert(net.modules[1].bias == net.modules[2].bias)
+ assert(net.modules[1].bias[1] == 1)
+
+ -- shared storages remain shared
+ local net = make_network()
+ net.modules[2].bias:set(net.modules[1].bias)
+ local net = net:float()
+ assert(net.modules[1].bias:type() == 'torch.FloatTensor')
+ assert(net.modules[1].bias ~= net.modules[2].bias)
+ net.modules[1].bias:fill(3)
+ assert(net.modules[1].bias[1] == 3)
+ assert(net.modules[2].bias[1] == 3)
+
+ -- tricky: overlapping views on the same storage are preserved
+ local net = make_network()
+ local overlap_storage = torch.Tensor(15):fill(1)
+ net.modules[1].bias = overlap_storage:narrow(1, 1, 10)
+ net.modules[2].bias = overlap_storage:narrow(1, 6, 10)
+ net:float()
+ assert(net.modules[1].bias:type() == 'torch.FloatTensor')
+ assert(net.modules[1].bias ~= net.modules[2].bias)
+ net.modules[1].bias:fill(3)
+ assert(net.modules[1].bias[1] == 3)
+ assert(net.modules[2].bias[1] == 3)
+ assert(net.modules[2].bias[6] == 1) -- only the first 5 elements overlapped
+
+ -- check recursiveType on a table
+ local net1 = make_network()
+ local net2 = make_network()
+ net2.modules[1].bias:set(net1.modules[1].bias)
+ net1:float()
+ net2:float()
+ net1.modules[1].bias:fill(3)
+ assert(net2.modules[1].bias[1] == 1)
+
+ local net1 = make_network()
+ local net2 = make_network()
+ net2.modules[1].bias:set(net1.modules[1].bias)
+
+ local tensorCache = {}
+ net1:type('torch.FloatTensor', tensorCache)
+ net2:type('torch.FloatTensor', tensorCache)
+ net1.modules[1].bias:fill(3)
+ assert(net2.modules[1].bias[1] == 3)
+
+ local net1 = make_network()
+ local net2 = make_network()
+ net2.modules[1].bias:set(net1.modules[1].bias)
+
+ nn.utils.recursiveType({net1, net2}, 'torch.FloatTensor')
+ net1.modules[1].bias:fill(3)
+ assert(net2.modules[1].bias[1] == 3)
+
+ -- smoke test some modules with custom type methods
+ local custom_type_modules = {
+ nn.MixtureTable(3),
+ nn.ConcatTable(),
+ nn.Copy(),
+ nn.Copy(nil, nil, nil, true),
+ nn.SpatialContrastiveNormalization(),
+ nn.DotProduct(),
+ nn.PairwiseDistance(1),
+ nn.SpatialDivisiveNormalization(),
+ nn.SpatialSubtractiveNormalization()
+ }
+ for _, module in ipairs(custom_type_modules) do
+ module:float()
+ end
+end
+
+function nntest.Module_apply()
+ local s = nn.Sequential()
+ s:add(nn.Linear(10,10))
+ local s2 = nn.Sequential()
+ s2:add(nn.Linear(10,5))
+ s:add(s2)
+ s:add(nn.Tanh())
+
+ local seen = 0
+ s:apply(function(module)
+ if torch.type(module) == 'nn.Linear' then
+ module.bias:resize(20)
+ seen = seen + 1
+ end
+ end)
+ mytester:asserteq(seen, 2)
+ mytester:asserteq(s.modules[1].bias:size(1), 20)
+ mytester:asserteq(s2.modules[1].bias:size(1), 20)
+end
+
mytester:add(nntest)
if not nn then
diff --git a/utils.lua b/utils.lua
index 40cc37737..375137637 100644
--- a/utils.lua
+++ b/utils.lua
@@ -1,15 +1,75 @@
nn.utils = {}
-function nn.utils.recursiveType(param, type_str)
+-- oops; someone forgot to add torch.Storage.type
+-- TODO replace with torch.Storage.type when implemented
+local function torch_Storage_type(self, type)
+ local current = torch.typename(self)
+ if not type then return current end
+ if type ~= current then
+ local new = torch.getmetatable(type).new()
+ if self:size() > 0 then
+ new:resize(self:size()):copy(self)
+ end
+ return new
+ else
+ return self
+ end
+end
+
+-- tensorCache maintains a list of all tensors and storages that have been
+-- converted (recursively) by calls to recursiveType() and type().
+-- It caches conversions in order to preserve sharing semantics
+-- i.e. if two tensors share a common storage, then type conversion
+-- should preserve that.
+--
+-- You can preserve sharing semantics across multiple networks by
+-- passing tensorCache between the calls to type, e.g.
+--
+-- > tensorCache = {}
+-- > net1:type('torch.CudaTensor', tensorCache)
+-- > net2:type('torch.CudaTensor', tensorCache)
+-- > nn.utils.recursiveType(anotherTensor, 'torch.CudaTensor', tensorCache)
+--
+-- Implementation note: to make Lua table lookup behave correctly,
+-- tensor keys are stored as actual tensor objects, while storage
+-- keys are stored as the pointers themselves (as numbers).
+function nn.utils.recursiveType(param, type, tensorCache)
+ tensorCache = tensorCache or {}
+
if torch.type(param) == 'table' then
for k, v in pairs(param) do
- param[k] = nn.utils.recursiveType(v, type_str)
+ param[k] = nn.utils.recursiveType(v, type, tensorCache)
end
elseif torch.isTypeOf(param, 'nn.Module') or
torch.isTypeOf(param, 'nn.Criterion') then
- param:type(type_str)
+ param:type(type, tensorCache)
elseif torch.isTensor(param) then
- param = param:type(type_str)
+ if torch.typename(param) ~= type then
+ local newparam
+ if tensorCache[param] then
+ newparam = tensorCache[param]
+ else
+ newparam = torch.Tensor():type(type)
+ local storageType = type:gsub('Tensor','Storage')
+ if param:storage() then
+ local storage_key = torch.pointer(param:storage())
+ if not tensorCache[storage_key] then
+ tensorCache[storage_key] = torch_Storage_type(
+ param:storage(), storageType)
+ end
+ assert(torch.type(tensorCache[storage_key]) == storageType)
+ newparam:set(
+ tensorCache[storage_key],
+ param:storageOffset(),
+ param:size(),
+ param:stride()
+ )
+ tensorCache[param] = newparam
+ end
+ end
+ assert(torch.type(newparam) == type)
+ param = newparam
+ end
end
return param
end
@@ -90,5 +150,4 @@ function nn.utils.addSingletonDimension(t, dim)
return view
end
-
table.unpack = table.unpack or unpack