Skip to content

Commit

Permalink
Merge pull request #363 from adamlerer/getParameters
Browse files Browse the repository at this point in the history
getParameters: improve memory use, fix bug with non-compact tensors
  • Loading branch information
soumith committed Sep 4, 2015
2 parents 98ecaf9 + 7105419 commit bab75d4
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 58 deletions.
159 changes: 101 additions & 58 deletions Module.lua
Expand Up @@ -113,12 +113,15 @@ function Module:clone(...)
return clone
end

function Module:type(type)
function Module:type(type, tensorCache)
assert(type, 'Module: must provide a type to convert to')

-- find all tensors and convert them
for key,param in pairs(self) do
self[key] = nn.utils.recursiveType(param, type)

end

return self
end

Expand All @@ -137,95 +140,135 @@ end
function Module:reset()
end

-- this function flattens arbitrary lists of parameters,
-- even complex shared ones
-- This function is not easy to understand. It works as follows:
--
-- - gather all parameter tensors for this module (and children);
-- count all parameter values (floats)
-- - create one ginormous memory area (Storage object) with room for all
-- parameters
-- - remap each parameter tensor to point to an area within the ginormous
-- Storage, and copy it there
--
-- It has the effect of making all parameters point to the same memory area,
-- which is then returned.
--
-- The purpose is to allow operations over all parameters (such as momentum
-- updates and serialization), but it assumes that all parameters are of
-- the same type (and, in the case of CUDA, on the same device), which
-- is not always true. Use for_each() to iterate over this module and
-- children instead.
--
-- Module._flattenTensorBuffer can be used by other packages (e.g. cunn)
-- to specify the type of temporary buffers. For example, the temporary
-- buffers for CudaTensor could be FloatTensor, to avoid GPU memory usage.
--
-- TODO: This logically belongs to torch.Tensor, not nn.
Module._flattenTensorBuffer = {}
function Module.flatten(parameters)
local function storageInSet(set, storage)
local storageAndOffset = set[torch.pointer(storage)]
if storageAndOffset == nil then
return nil
end
local _, offset = table.unpack(storageAndOffset)
return offset

-- returns true if tensor occupies a contiguous region of memory (no holes)
local function isCompact(tensor)
local sortedStride, perm = torch.sort(
torch.LongTensor(tensor:nDimension()):set(tensor:stride()), 1, true)
local sortedSize = torch.LongTensor(tensor:nDimension()):set(
tensor:size()):index(1, perm)
local nRealDim = torch.clamp(sortedStride, 0, 1):sum()
sortedStride = sortedStride:narrow(1, 1, nRealDim):clone()
sortedSize = sortedSize:narrow(1, 1, nRealDim):clone()
local t = tensor.new():set(tensor:storage(), 1,
sortedSize:storage(),
sortedStride:storage())
return t:isContiguous()
end

if not parameters or #parameters == 0 then
return torch.Tensor()
end
local Tensor = parameters[1].new
local dtype = parameters[1]:type()
local TmpTensor = Module._flattenTensorBuffer[torch.type(parameters[1])] or Tensor

-- 1. construct the set of all unique storages referenced by parameter tensors
local storages = {}
local nParameters = 0
local parameterMeta = {}
for k = 1,#parameters do
if parameters[k]:type() ~= dtype then
error("Inconsistent parameter types. " .. parameters[k]:type() ..
" ~= " .. dtype)
end
local param = parameters[k]
local storage = parameters[k]:storage()
if not storageInSet(storages, storage) then
storages[torch.pointer(storage)] = {storage, nParameters}
local storageKey = torch.pointer(storage)

if not storages[storageKey] then
storages[storageKey] = {storage, nParameters}
nParameters = nParameters + storage:size()
end

parameterMeta[k] = {storageOffset = param:storageOffset() +
storages[storageKey][2],
size = param:size(),
stride = param:stride()}
end

local flatParameters = Tensor(nParameters):fill(1)
local flatStorage = flatParameters:storage()
-- 2. construct a single tensor that will hold all the parameters
local flatParameters = TmpTensor(nParameters):zero()

-- 3. determine if there are elements in the storage that none of the
-- parameter tensors reference ('holes')
local tensorsCompact = true
for k = 1,#parameters do
local storageOffset = storageInSet(storages, parameters[k]:storage())
parameters[k]:set(flatStorage,
storageOffset + parameters[k]:storageOffset(),
parameters[k]:size(),
parameters[k]:stride())
parameters[k]:zero()
local meta = parameterMeta[k]
local tmp = TmpTensor():set(
flatParameters:storage(), meta.storageOffset, meta.size, meta.stride)
tmp:fill(1)
tensorsCompact = tensorsCompact and isCompact(tmp)
end

local maskParameters = flatParameters:float():clone()
local cumSumOfHoles = flatParameters:float():cumsum(1)
local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles]
local flatUsedParameters = Tensor(nUsedParameters)
local flatUsedStorage = flatUsedParameters:storage()
local maskParameters = flatParameters:byte():clone()
local compactOffsets = flatParameters:long():cumsum(1)
local nUsedParameters = compactOffsets[-1]

for k = 1,#parameters do
local offset = cumSumOfHoles[parameters[k]:storageOffset()]
parameters[k]:set(flatUsedStorage,
parameters[k]:storageOffset() - offset,
parameters[k]:size(),
parameters[k]:stride())
-- 4. copy storages into the flattened parameter tensor
for _, storageAndOffset in pairs(storages) do
local storage, offset = table.unpack(storageAndOffset)
flatParameters[{{offset+1,offset+storage:size()}}]:copy(Tensor():set(storage))
end

for _, storageAndOffset in pairs(storages) do
local k, v = table.unpack(storageAndOffset)
flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k))
-- 5. allow garbage collection
storages = nil
for k = 1,#parameters do
parameters[k]:set(Tensor())
end

if cumSumOfHoles:sum() == 0 then
flatUsedParameters:copy(flatParameters)
else
local counter = 0
for k = 1,flatParameters:nElement() do
if maskParameters[k] == 0 then
counter = counter + 1
flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]]
end
-- 6. compact the flattened parameters if there were holes
if nUsedParameters ~= nParameters then
assert(tensorsCompact,
"Cannot gather tensors that are not compact")

flatParameters = TmpTensor(nUsedParameters):copy(
flatParameters:maskedSelect(maskParameters))
for k = 1,#parameters do
parameterMeta[k].storageOffset =
compactOffsets[parameterMeta[k].storageOffset]
end
assert (counter == nUsedParameters)
end
return flatUsedParameters

if TmpTensor ~= Tensor then
flatParameters = Tensor(flatParameters:nElement()):copy(flatParameters)
end

-- 7. fix up the parameter tensors to point at the flattened parameters
for k = 1,#parameters do
parameters[k]:set(flatParameters:storage(),
parameterMeta[k].storageOffset,
parameterMeta[k].size,
parameterMeta[k].stride)
end

return flatParameters
end

function Module:getParameters()
-- get parameters
local parameters,gradParameters = self:parameters()
-- flatten parameters and gradients
local flatParameters = nn.Module.flatten(parameters)
collectgarbage()
local flatGradParameters = nn.Module.flatten(gradParameters)
collectgarbage()

-- return new flat vector that contains all discrete parameters
return flatParameters, flatGradParameters
return Module.flatten(parameters), Module.flatten(gradParameters)
end

function Module:__call__(input, gradOutput)
Expand Down
39 changes: 39 additions & 0 deletions test.lua
Expand Up @@ -2793,6 +2793,45 @@ function nntest.Module_getParameters_8()

end

function nntest.Module_getParameters_10()
-- tensors are non-contiguous but compact; they can be gathered
local L = nn.Linear(10,10)
L.weight = torch.Tensor(10,10):t():fill(1)
local tmp = torch.Tensor(10,10):fill(2)
L.bias = tmp:select(1,2)
local P = L:getParameters()
mytester:asserteq(L.weight:mean(), 1)
mytester:asserteq(L.bias:mean(), 2)
mytester:asserteq(L.weight:storage(), L.bias:storage())
mytester:asserteq(P:nElement(), 110)
mytester:asserteq(P:storage():size(), 110)
mytester:assertlt(L.bias[{ {10} }]:storageOffset() - 1, L.bias:storage():size())
end

function nntest.Module_getParameters_11()
-- tensors are non-compact; they can't be gathered
local L = nn.Linear(10,10)
local tmp = torch.Tensor(10,10):fill(2)
L.bias = tmp:select(2,2)
local ok, err = pcall(L.getParameters, L)
mytester:assert(not ok)
end

function nntest.Module_getParameters_12()
-- tensors are expanded (i.e. have dimension 0)
local L = nn.Linear(10,10)
L.weight = torch.Tensor(10, 1):fill(1)
torch.expand(L.weight, 10, 10)
L.bias = torch.Tensor(10):fill(2)
local P = L:getParameters()
mytester:asserteq(L.weight:mean(), 1)
mytester:asserteq(L.bias:mean(), 2)
mytester:asserteq(L.weight:storage(), L.bias:storage())
mytester:asserteq(P:nElement(), 20)
mytester:asserteq(P:storage():size(), 20)
mytester:assertlt(L.bias[{ {10} }]:storageOffset() - 1, L.bias:storage():size())
end

function nntest.Module_listModules()
local batchSize = 4
local inputSize, outputSize = 7, 6
Expand Down

0 comments on commit bab75d4

Please sign in to comment.