Skip to content

Commit

Permalink
nn.Module preserve type sharing semantics (#187); add nn.Module.apply
Browse files Browse the repository at this point in the history
  • Loading branch information
adamlerer committed Sep 4, 2015
1 parent bab75d4 commit dd9d7fd
Show file tree
Hide file tree
Showing 15 changed files with 258 additions and 45 deletions.
4 changes: 2 additions & 2 deletions CMul.lua
Expand Up @@ -114,7 +114,7 @@ function CMul:accGradParameters(input, gradOutput, scale)
end
end

function CMul:type(type)
function CMul:type(type, tensorCache)
if type then
self._input = nil
self._output = nil
Expand All @@ -124,5 +124,5 @@ function CMul:type(type)
self._repeat = nil
self._sum = nil
end
return parent.type(self, type)
return parent.type(self, type, tensorCache)
end
4 changes: 2 additions & 2 deletions Copy.lua
Expand Up @@ -34,9 +34,9 @@ function Copy:updateGradInput(input, gradOutput)
return self.gradInput
end

function Copy:type(type)
function Copy:type(type, tensorCache)
if type and self.dontCast then
return self
end
return parent.type(self, type)
return parent.type(self, type, tensorCache)
end
4 changes: 2 additions & 2 deletions Criterion.lua
Expand Up @@ -28,11 +28,11 @@ function Criterion:clone()
return clone
end

function Criterion:type(type)
function Criterion:type(type, tensorCache)
assert(type, 'Criterion: must provide a type to convert to')
-- find all tensors and convert them
for key,param in pairs(self) do
self[key] = nn.utils.recursiveType(param, type)
self[key] = nn.utils.recursiveType(param, type, tensorCache)
end
return self
end
Expand Down
4 changes: 2 additions & 2 deletions Euclidean.lua
Expand Up @@ -171,7 +171,7 @@ function Euclidean:accGradParameters(input, gradOutput, scale)
end
end
function Euclidean:type(type)
function Euclidean:type(type, tensorCache)
if type then
-- prevent premature memory allocations
self._input = nil
Expand All @@ -186,5 +186,5 @@ function Euclidean:type(type)
self._repeat = nil
self._repeat2 = nil
end
return parent.type(self, type)
return parent.type(self, type, tensorCache)
end
2 changes: 1 addition & 1 deletion FlattenTable.lua
Expand Up @@ -94,7 +94,7 @@ function FlattenTable:updateGradInput(input, gradOutput)
return self.gradInput
end

function FlattenTable:type(type)
function FlattenTable:type(type, tensorCache)
-- This function just stores references so we don't need to do any type
-- conversions. Just force the tables to be empty.
self.output = {}
Expand Down
4 changes: 2 additions & 2 deletions JoinTable.lua
Expand Up @@ -64,7 +64,7 @@ function JoinTable:updateGradInput(input, gradOutput)
return self.gradInput
end

function JoinTable:type(type)
function JoinTable:type(type, tensorCache)
self.gradInput = {}
return parent.type(self, type)
return parent.type(self, type, tensorCache)
end
4 changes: 2 additions & 2 deletions LookupTable.lua
Expand Up @@ -71,8 +71,8 @@ function LookupTable:accGradParameters(input, gradOutput, scale)
self.gradWeight.nn.LookupTable_accGradParameters(self, input, gradOutput, scale)
end

function LookupTable:type(type)
parent.type(self, type)
function LookupTable:type(type, tensorCache)
parent.type(self, type, tensorCache)

if type == 'torch.CudaTensor' then
-- CUDA uses _sorted and _indices temporary tensors
Expand Down
4 changes: 2 additions & 2 deletions MixtureTable.lua
Expand Up @@ -149,13 +149,13 @@ function MixtureTable:updateGradInput(input, gradOutput)
return self.gradInput
end

function MixtureTable:type(type)
function MixtureTable:type(type, tensorCache)
self._gaterView = nil
self._expert = nil
self._expertView = nil
self._sum = nil
self._gradInput = nil
self._expert2 = nil
self._expertView2 = nil
return parent.type(self, type)
return parent.type(self, type, tensorCache)
end
18 changes: 16 additions & 2 deletions Module.lua
Expand Up @@ -116,10 +116,11 @@ end
function Module:type(type, tensorCache)
assert(type, 'Module: must provide a type to convert to')

tensorCache = tensorCache or {}

-- find all tensors and convert them
for key,param in pairs(self) do
self[key] = nn.utils.recursiveType(param, type)

self[key] = nn.utils.recursiveType(param, type, tensorCache)
end

return self
Expand Down Expand Up @@ -281,6 +282,19 @@ function Module:__call__(input, gradOutput)
end
end

-- Run a callback (called with the module as an argument) in preorder over this
-- module and its children.
--
function Module:apply(callback)
callback(self)

if self.modules then
for _, module in ipairs(self.modules) do
module:apply(callback)
end
end
end

function Module:findModules(typename, container)
container = container or self
local nodes = {}
Expand Down
4 changes: 2 additions & 2 deletions ParallelCriterion.lua
Expand Up @@ -34,7 +34,7 @@ function ParallelCriterion:updateGradInput(input, target)
return self.gradInput
end

function ParallelCriterion:type(type)
function ParallelCriterion:type(type, tensorCache)
self.gradInput = {}
return parent.type(self, type)
return parent.type(self, type, tensorCache)
end
4 changes: 2 additions & 2 deletions SelectTable.lua
Expand Up @@ -51,8 +51,8 @@ function SelectTable:updateGradInput(input, gradOutput)
return self.gradInput
end

function SelectTable:type(type)
function SelectTable:type(type, tensorCache)
self.gradInput = {}
self.output = {}
return parent.type(self, type)
return parent.type(self, type, tensorCache)
end
4 changes: 2 additions & 2 deletions WeightedEuclidean.lua
Expand Up @@ -209,7 +209,7 @@ function WeightedEuclidean:accGradParameters(input, gradOutput, scale)
end
end
function WeightedEuclidean:type(type)
function WeightedEuclidean:type(type, tensorCache)
if type then
-- prevent premature memory allocations
self._input = nil
Expand All @@ -226,7 +226,7 @@ function WeightedEuclidean:type(type)
self._repeat2 = nil
self._repeat3 = nil
end
return parent.type(self, type)
return parent.type(self, type, tensorCache)
end
function WeightedEuclidean:parameters()
Expand Down
29 changes: 24 additions & 5 deletions doc/module.md
Expand Up @@ -150,8 +150,6 @@ Note that this function if called on a [Container](containers.md#nn.Containers)
module will share the same parameters for all the contained modules as
well.

**NOTE: If you ever type-cast your network to another precision, i.e. net:cuda() for example, the sharing gets untied, and you have to reshare your modules again.**

Example:
```lua

Expand Down Expand Up @@ -186,8 +184,6 @@ If arguments are provided to the `clone(...)` function it also calls
module after creating it, hence making a deep copy of this module with
some shared parameters.

**NOTE: If you ever type-cast your network to another precision, i.e. net:cuda() for example, the sharing gets untied, and you have to reshare your modules again.**

Example:
```lua
-- make an mlp
Expand All @@ -206,12 +202,35 @@ print(mlp2:get(1).bias[1])
```

<a name="nn.Module.type"></a>
### type(type) ###
### type(type[, tensorCache]) ###

This function converts all the parameters of a module to the given
`type`. The `type` can be one of the types defined for
[torch.Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md).

If tensors (or their storages) are shared between multiple modules in a
network, this sharing will be preserved after type is called.

To preserve sharing between multiple modules and/or tensors, use
`nn.utils.recursiveType`:

```lua
-- make an mlp
mlp1=nn.Sequential();
mlp1:add(nn.Linear(100,10));

-- make a second mlp
mlp2=nn.Sequential();
mlp2:add(nn.Linear(100,10));

-- the second mlp shares the bias of the first
mlp2:share(mlp1,'bias');

-- mlp1 and mlp2 will be converted to float, and will share bias
-- note: tensors can be provided as inputs as well as modules
nn.utils.recursiveType({mlp1, mlp2}, 'torch.FloatTensor')
```

<a name="nn.Module.float"></a>
### float() ###

Expand Down

0 comments on commit dd9d7fd

Please sign in to comment.