From b4ebdf2f95ee9f1429825a0d7b0948721e407d82 Mon Sep 17 00:00:00 2001 From: Sergey Zagoruyko Date: Thu, 4 Feb 2016 15:53:04 +0100 Subject: [PATCH 1/2] nn.clearState --- BatchNormalization.lua | 11 +++++++++++ Bilinear.lua | 5 +++++ CMul.lua | 21 ++++++++++++++------- CMulTable.lua | 5 +++++ Container.lua | 23 +++++++++++++++++++++++ Cosine.lua | 12 ++++++++++++ CosineDistance.lua | 17 +++++++++++++++++ DotProduct.lua | 10 ++++++++++ Dropout.lua | 8 ++++++++ Euclidean.lua | 29 ++++++++++++++++++----------- FlattenTable.lua | 6 +++++- GradientReversal.lua | 2 +- Identity.lua | 18 ++++++++++++++++++ Jacobian.lua | 1 + L1Cost.lua | 5 +++++ L1Penalty.lua | 4 ++++ Linear.lua | 4 ++++ LogSigmoid.lua | 6 ++++++ LookupTable.lua | 4 ++++ Max.lua | 5 +++++ Min.lua | 5 +++++ MixtureTable.lua | 12 ++++++++++++ Module.lua | 4 ++++ Normalize.lua | 13 +++++++++++++ PReLU.lua | 5 +++++ PairwiseDistance.lua | 6 ++++++ RReLU.lua | 5 +++++ Reshape.lua | 5 +++++ SoftMin.lua | 7 ++++++- SoftSign.lua | 7 ++++++- SparseLinear.lua | 5 +++++ SpatialAdaptiveMaxPooling.lua | 15 +++++++++------ SpatialBatchNormalization.lua | 11 +++++++++++ SpatialConvolution.lua | 6 ++++++ SpatialConvolutionLocal.lua | 5 +++++ SpatialConvolutionMM.lua | 8 +++++++- SpatialCrossMapLRN.lua | 6 ++++++ SpatialDivisiveNormalization.lua | 8 ++++++++ SpatialDropout.lua | 7 +++++++ SpatialFractionalMaxPooling.lua | 14 +++++++------- SpatialFullConvolution.lua | 6 ++++++ SpatialMaxPooling.lua | 15 +++++++++------ SpatialMaxUnpooling.lua | 5 +---- SpatialSubtractiveNormalization.lua | 7 +++++++ TemporalMaxPooling.lua | 12 ++++++------ VolumetricAveragePooling.lua | 5 +---- VolumetricConvolution.lua | 7 ++++++- VolumetricMaxPooling.lua | 12 ++++++------ VolumetricMaxUnpooling.lua | 5 +---- utils.lua | 23 +++++++++++++++++++++++ 50 files changed, 380 insertions(+), 67 deletions(-) diff --git a/BatchNormalization.lua b/BatchNormalization.lua index a1c18d426..32a437c54 100644 --- a/BatchNormalization.lua +++ b/BatchNormalization.lua @@ -154,3 +154,14 @@ function BN:accGradParameters(input, gradOutput, scale) self.gradBias:add(scale, self.buffer) end end + +function BN:clearState() + nn.utils.clear(self, { + 'buffer', + 'buffer2', + 'centered', + 'std', + 'normalized', + }) + return parent.clearState(self) +end diff --git a/Bilinear.lua b/Bilinear.lua index fe007c475..3dc687304 100644 --- a/Bilinear.lua +++ b/Bilinear.lua @@ -140,3 +140,8 @@ function Bilinear:__tostring__() (self.bias == nil and ' without bias' or '') ) end + +function Bilinear:clearState() + if self.buff then self.buff:set() end + return parent.clearState(self) +end diff --git a/CMul.lua b/CMul.lua index e8aad0171..e84f7ba05 100644 --- a/CMul.lua +++ b/CMul.lua @@ -116,13 +116,20 @@ end function CMul:type(type, tensorCache) if type then - self._input = nil - self._output = nil - self._weight = nil - self._gradWeight = nil - self._expand = nil - self._repeat = nil - self._sum = nil + self:clearState() end return parent.type(self, type, tensorCache) end + +function CMul:clearState() + nn.utils.clear(self, { + '_input', + '_output', + '_weight', + '_gradWeight', + '_expand', + '_repeat', + '_sum', + }) + return parent.clearState(self) +end diff --git a/CMulTable.lua b/CMulTable.lua index 0689f3358..b47378e83 100644 --- a/CMulTable.lua +++ b/CMulTable.lua @@ -48,3 +48,8 @@ function CMulTable:updateGradInput(input, gradOutput) return self.gradInput end + +function CMulTable:clearState() + if self.tout then self.tout:set() end + return parent.clearState(self) +end diff --git a/Container.lua b/Container.lua index 340153690..bca6e41a4 100644 --- a/Container.lua +++ b/Container.lua @@ -75,3 +75,26 @@ function Container:parameters() end return w,gw end + +function Container:clearState() + -- don't call set because it might reset referenced tensors + local function clear(f) + if self[f] then + if torch.isTensor(self[f]) then + self[f] = self[f].new() + elseif type(self[f]) == 'table' then + self[f] = {} + else + self[f] = nil + end + end + end + clear('output') + clear('gradInput') + if self.modules then + for i,module in pairs(self.modules) do + module:clearState() + end + end + return self +end diff --git a/Cosine.lua b/Cosine.lua index 893aaac4c..e655b9e0f 100644 --- a/Cosine.lua +++ b/Cosine.lua @@ -161,3 +161,15 @@ function Cosine:type(type, tensorCache) end return parent.type(self, type, tensorCache) end + +function Cosine:clearState() + nn.utils.clear(self, { + '_input', + '_weight', + '_gradOutput', + '_sum', + '_inputNorm', + '_weightNorm', + }) + return parent.clearState(self) +end diff --git a/CosineDistance.lua b/CosineDistance.lua index d135e05e4..2988c657c 100644 --- a/CosineDistance.lua +++ b/CosineDistance.lua @@ -73,6 +73,11 @@ function CosineDistance:updateGradInput(input, gradOutput) not_batch = true end + if #self.gradInput ~= 2 then + self.gradInput[1] = self.gradInput[1] or v1.new() + self.gradInput[2] = self.gradInput[2] or v1.new() + end + local gw1 = self.gradInput[1] local gw2 = self.gradInput[2] gw1:resizeAs(v1):copy(v2) @@ -97,3 +102,15 @@ function CosineDistance:updateGradInput(input, gradOutput) return self.gradInput end + +function CosineDistance:clearState() + nn.utils.clear(self, { + 'buffer', + 'w1', + 'w22', + 'w', + 'w32', + 'ones', + }) + return parent.clearState(self) +end diff --git a/DotProduct.lua b/DotProduct.lua index 2d3d1dd6e..021602e13 100644 --- a/DotProduct.lua +++ b/DotProduct.lua @@ -26,6 +26,11 @@ function DotProduct:updateGradInput(input, gradOutput) local v2 = input[2] local not_batch = false + if #self.gradInput ~= 2 then + self.gradInput[1] = self.gradInput[1] or input[1].new() + self.gradInput[2] = self.gradInput[2] or input[2].new() + end + if v1:dim() == 1 then v1 = v1:view(1,-1) v2 = v2:view(1,-1) @@ -49,3 +54,8 @@ function DotProduct:updateGradInput(input, gradOutput) return self.gradInput end + +function DotProduct:clearState() + if self.buffer then self.buffer:set() end + return parent.clearState(self) +end diff --git a/Dropout.lua b/Dropout.lua index 788ef9d1b..d676d968f 100644 --- a/Dropout.lua +++ b/Dropout.lua @@ -64,3 +64,11 @@ end function Dropout:__tostring__() return string.format('%s(%f)', torch.type(self), self.p) end + + +function Dropout:clearState() + if self.noise then + self.noise:set() + end + return Parent.clearState(self) +end diff --git a/Euclidean.lua b/Euclidean.lua index 25aa2e484..8269d13a4 100644 --- a/Euclidean.lua +++ b/Euclidean.lua @@ -174,17 +174,24 @@ end function Euclidean:type(type, tensorCache) if type then -- prevent premature memory allocations - self._input = nil - self._output = nil - self._gradOutput = nil - self._weight = nil - self._div = nil - self._sum = nil - self._expand = nil - self._expand2 = nil - self._expand3 = nil - self._repeat = nil - self._repeat2 = nil + self:clearState() end return parent.type(self, type, tensorCache) end + +function Euclidean:clearState() + nn.utils.clear(self, { + '_input', + '_output', + '_gradOutput', + '_weight', + '_div', + '_sum', + '_expand', + '_expand2', + '_expand3', + '_repeat', + '_repeat2', + }) + return parent.clearState(self) +end diff --git a/FlattenTable.lua b/FlattenTable.lua index 69abd0762..1c182557c 100644 --- a/FlattenTable.lua +++ b/FlattenTable.lua @@ -97,6 +97,10 @@ end function FlattenTable:type(type, tensorCache) -- This function just stores references so we don't need to do any type -- conversions. Just force the tables to be empty. - self.output = {} + self:clearState() +end + +function FlattenTable:clearState() self.input_map = {} + return parent.clearState(self) end diff --git a/GradientReversal.lua b/GradientReversal.lua index 75dc61d03..fdf98ed12 100644 --- a/GradientReversal.lua +++ b/GradientReversal.lua @@ -1,7 +1,7 @@ local GradientReversal = torch.class('nn.GradientReversal', 'nn.Module') function GradientReversal:updateOutput(input) - self.output = input + self.output:set(input) return self.output end diff --git a/Identity.lua b/Identity.lua index 088cc343b..5e6ccb624 100644 --- a/Identity.lua +++ b/Identity.lua @@ -10,3 +10,21 @@ function Identity:updateGradInput(input, gradOutput) self.gradInput = gradOutput return self.gradInput end + +function Identity:clearState() + -- don't call set because it might reset referenced tensors + local function clear(f) + if self[f] then + if torch.isTensor(self[f]) then + self[f] = self[f].new() + elseif type(self[f]) == 'table' then + self[f] = {} + else + self[f] = nil + end + end + end + clear('output') + clear('gradInput') + return self +end diff --git a/Jacobian.lua b/Jacobian.lua index 51ca13920..d54f21140 100644 --- a/Jacobian.lua +++ b/Jacobian.lua @@ -305,6 +305,7 @@ function nn.Jacobian.testIO(module,input, minval, maxval) -- write module local filename = os.tmpname() local f = torch.DiskFile(filename, 'w'):binary() + module:clearState() f:writeObject(module) f:close() -- read module diff --git a/L1Cost.lua b/L1Cost.lua index 5c5a89c64..6b58e0ec9 100644 --- a/L1Cost.lua +++ b/L1Cost.lua @@ -23,3 +23,8 @@ function L1Cost:updateGradInput(input) ) return self.gradInput end + +function L1Cost:clearState() + if self.output_tensor then self.output_tensor:set() end + return parent.clearState(self) +end diff --git a/L1Penalty.lua b/L1Penalty.lua index 457a343f2..998c331a3 100644 --- a/L1Penalty.lua +++ b/L1Penalty.lua @@ -41,3 +41,7 @@ function L1Penalty:updateGradInput(input, gradOutput) return self.gradInput end +function L1Penalty:clearState() + if self.loss then self.loss:set() end + return parent.clearState(self) +end diff --git a/Linear.lua b/Linear.lua index 589f6eb09..e302a643a 100644 --- a/Linear.lua +++ b/Linear.lua @@ -95,6 +95,10 @@ end -- we do not need to accumulate parameters when sharing Linear.sharedAccUpdateGradParameters = Linear.accUpdateGradParameters +function Linear:clearState() + if self.addBuffer then self.addBuffer:set() end + return parent.clearState(self) +end function Linear:__tostring__() return torch.type(self) .. diff --git a/LogSigmoid.lua b/LogSigmoid.lua index b2b773a6b..1e21d3f9a 100644 --- a/LogSigmoid.lua +++ b/LogSigmoid.lua @@ -23,3 +23,9 @@ function LogSigmoid:updateGradInput(input, gradOutput) ) return self.gradInput end + +function LogSigmoid:clearState() + if self.buffer then self.buffer:set() end + return parent.clearState(self) +end + diff --git a/LookupTable.lua b/LookupTable.lua index 4b658a9dd..95bf0fc0b 100644 --- a/LookupTable.lua +++ b/LookupTable.lua @@ -100,5 +100,9 @@ function LookupTable:type(type, tensorCache) return self end +function LookupTable:clearState() + return self +end + -- we do not need to accumulate parameters when sharing LookupTable.sharedAccUpdateGradParameters = LookupTable.accUpdateGradParameters diff --git a/Max.lua b/Max.lua index d31688412..c2495f80e 100644 --- a/Max.lua +++ b/Max.lua @@ -63,3 +63,8 @@ function Max:type(type, tensorCache) end return self end + +function Max:clearState() + nn.utils.clear(self, '_indices', '_output') + return parent.clearState(self) +end diff --git a/Min.lua b/Min.lua index 5331d7919..2e58708e2 100644 --- a/Min.lua +++ b/Min.lua @@ -63,3 +63,8 @@ function Min:type(type, tensorCache) end return self end + +function Min:clearState() + nn.utils.clear(self, '_indices', '_output') + return parent.clearState(self) +end diff --git a/MixtureTable.lua b/MixtureTable.lua index 613d1646e..17c307e75 100644 --- a/MixtureTable.lua +++ b/MixtureTable.lua @@ -156,3 +156,15 @@ function MixtureTable:type(type, tensorCache) self._expertView2 = nil return parent.type(self, type, tensorCache) end + +function MixtureTable:clearState() + nn.utils.clear(self, { + '_gaterView', + '_expert', + '_expertView', + '_sum', + '_expert2', + '_expertView2', + }) + return parent.clearState(self) +end diff --git a/Module.lua b/Module.lua index a7310d894..2c2b6e14c 100644 --- a/Module.lua +++ b/Module.lua @@ -364,3 +364,7 @@ function Module:listModules() end return modules end + +function Module:clearState() + return nn.utils.clear(self, 'output', 'gradInput') +end diff --git a/Normalize.lua b/Normalize.lua index 8e9a111ea..8bf936a7a 100644 --- a/Normalize.lua +++ b/Normalize.lua @@ -140,3 +140,16 @@ function Normalize:type(type, tensorCache) end return self end + +function Normalize:clearState() + nn.utils.clear(self, { + '_output', + '_indices', + '_gradInput', + 'buffer', + 'norm', + 'normp', + 'cross', + }) + return parent.clearState(self) +end diff --git a/PReLU.lua b/PReLU.lua index 322f3cc2c..4405c6632 100644 --- a/PReLU.lua +++ b/PReLU.lua @@ -45,3 +45,8 @@ function PReLU:accGradParameters(input, gradOutput, scale) ) return self.gradWeight end + +function PReLU:clearState() + nn.utils.clear(self, 'gradWeightBuf', 'gradWeightBuf2') + return parent.clearState(self) +end diff --git a/PairwiseDistance.lua b/PairwiseDistance.lua index c7dc09661..d5022a74d 100644 --- a/PairwiseDistance.lua +++ b/PairwiseDistance.lua @@ -10,6 +10,7 @@ function PairwiseDistance:__init(p) end function PairwiseDistance:updateOutput(input) + self.output:resize(1) if input[1]:dim() == 1 then self.output:resize(1) self.output[1]=input[1]:dist(input[2],self.norm) @@ -83,3 +84,8 @@ function PairwiseDistance:updateGradInput(input, gradOutput) self.gradInput[2]:zero():add(-1, self.gradInput[1]) return self.gradInput end + +function PairwiseDistance:clearState() + nn.utils.clear(self, 'diff', 'outExpand', 'grad', 'ones') + return parent.clearState(self) +end diff --git a/RReLU.lua b/RReLU.lua index 0c2825142..843415f7e 100644 --- a/RReLU.lua +++ b/RReLU.lua @@ -43,3 +43,8 @@ end function RReLU:__tostring__() return string.format('%s (l:%f, u:%f)', torch.type(self), self.lower, self.upper) end + +function RReLU:clearState() + if self.noise then self.noise:set() end + return parent.clearState(self) +end diff --git a/Reshape.lua b/Reshape.lua index dcc787b75..80f8b014c 100644 --- a/Reshape.lua +++ b/Reshape.lua @@ -67,3 +67,8 @@ function Reshape:__tostring__() return torch.type(self) .. '(' .. table.concat(self.size:totable(), 'x') .. ')' end + +function Reshape:clearState() + nn.utils.clear(self, '_input', '_gradOutput') + return parent.clearState(self) +end diff --git a/SoftMin.lua b/SoftMin.lua index aac32cd63..7da2a6589 100644 --- a/SoftMin.lua +++ b/SoftMin.lua @@ -1,4 +1,4 @@ -local SoftMin, _ = torch.class('nn.SoftMin', 'nn.Module') +local SoftMin, parent = torch.class('nn.SoftMin', 'nn.Module') function SoftMin:updateOutput(input) self.mininput = self.mininput or input.new() @@ -24,3 +24,8 @@ function SoftMin:updateGradInput(input, gradOutput) self.gradInput:mul(-1) return self.gradInput end + +function SoftMin:clearState() + if self.mininput then self.mininput:set() end + return parent.clearState(self) +end diff --git a/SoftSign.lua b/SoftSign.lua index 480894c71..ee72011f1 100644 --- a/SoftSign.lua +++ b/SoftSign.lua @@ -1,4 +1,4 @@ -local SoftSign = torch.class('nn.SoftSign', 'nn.Module') +local SoftSign, parent = torch.class('nn.SoftSign', 'nn.Module') function SoftSign:updateOutput(input) self.temp = self.temp or input.new() @@ -13,3 +13,8 @@ function SoftSign:updateGradInput(input, gradOutput) self.gradInput:resizeAs(input):copy(gradOutput):cdiv(self.tempgrad) return self.gradInput end + +function SoftSign:clearState() + nn.utils.clear(self, 'temp', 'tempgrad') + return parent.clearState(self) +end diff --git a/SparseLinear.lua b/SparseLinear.lua index 3c1a59c3c..b574319b9 100644 --- a/SparseLinear.lua +++ b/SparseLinear.lua @@ -76,3 +76,8 @@ function SparseLinear:zeroGradParameters() parent.zeroGradParameters(self) end end + +function SparseLinear:clearState() + if self.lastInput then self.lastInput:set() end + return parent.clearState(self) +end diff --git a/SpatialAdaptiveMaxPooling.lua b/SpatialAdaptiveMaxPooling.lua index 6dfa58b58..efe4515c2 100644 --- a/SpatialAdaptiveMaxPooling.lua +++ b/SpatialAdaptiveMaxPooling.lua @@ -29,11 +29,14 @@ function SpatialAdaptiveMaxPooling:updateGradInput(input, gradOutput) return self.gradInput end +-- for backward compat function SpatialAdaptiveMaxPooling:empty() - self.gradInput:resize() - self.gradInput:storage():resize(0) - self.output:resize() - self.output:storage():resize(0) - self.indices:resize() - self.indices:storage():resize(0) + self:clearState() +end + +function SpatialAdaptiveMaxPooling:clearState() + if self.indices then + self.indices:set() + end + return parent.clearState(self) end diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua index b03377883..f3dbb587f 100644 --- a/SpatialBatchNormalization.lua +++ b/SpatialBatchNormalization.lua @@ -140,3 +140,14 @@ function BN:read(file, version) end end end + +function BN:clearState() + nn.utils.clear(self, { + 'buffer', + 'buffer2', + 'centered', + 'std', + 'normalized', + }) + return parent.clearState(self) +end diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index f3ed2a148..452ad82b4 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -171,3 +171,9 @@ function SpatialConvolution:__tostring__() end return s .. ')' end + +function SpatialConvolution:clearState() + nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput') + return parent.clearState(self) +end + diff --git a/SpatialConvolutionLocal.lua b/SpatialConvolutionLocal.lua index 3fef70484..983dbc5a0 100644 --- a/SpatialConvolutionLocal.lua +++ b/SpatialConvolutionLocal.lua @@ -164,3 +164,8 @@ function SpatialConvolutionLocal:__tostring__() end return s .. ')' end + +function SpatialConvolutionLocal:clearState() + nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput') + return parent.clearState(self) +end diff --git a/SpatialConvolutionMM.lua b/SpatialConvolutionMM.lua index 46b813fbe..f523d8b03 100644 --- a/SpatialConvolutionMM.lua +++ b/SpatialConvolutionMM.lua @@ -23,7 +23,7 @@ function SpatialConvolutionMM:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, self.finput = torch.Tensor() self.fgradInput = torch.Tensor() - + self:reset() end @@ -137,3 +137,9 @@ function SpatialConvolutionMM:__tostring__() end return s .. ')' end + +function SpatialConvolutionMM:clearState() + nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput') + return parent.clearState(self) +end + diff --git a/SpatialCrossMapLRN.lua b/SpatialCrossMapLRN.lua index 17f79dfdb..674ca200f 100644 --- a/SpatialCrossMapLRN.lua +++ b/SpatialCrossMapLRN.lua @@ -127,3 +127,9 @@ function SpatialCrossMapLRN:updateGradInput(input, gradOutput) return self.gradInput end + + +function SpatialCrossMapLRN:clearState() + nn.utils.clear(self, 'scale', 'paddedRatio', 'accumRatio') + return parent.clearState(self) +end diff --git a/SpatialDivisiveNormalization.lua b/SpatialDivisiveNormalization.lua index 8395e231b..bdc7baca8 100644 --- a/SpatialDivisiveNormalization.lua +++ b/SpatialDivisiveNormalization.lua @@ -126,3 +126,11 @@ function SpatialDivisiveNormalization:updateGradInput(input, gradOutput) -- done return self.gradInput end + +function SpatialDivisiveNormalization:clearState() + if self.ones then self.ones:set() end + if self._coef then self._coef:set() end + self.meanestimator:clearState() + self.stdestimator:clearState() + return parent.clearState(self) +end diff --git a/SpatialDropout.lua b/SpatialDropout.lua index 37680d4de..35daa1829 100644 --- a/SpatialDropout.lua +++ b/SpatialDropout.lua @@ -45,3 +45,10 @@ end function SpatialDropout:__tostring__() return string.format('%s(%f)', torch.type(self), self.p) end + +function SpatialDropout:clearState() + if self.noise then + self.noise:set() + end + return Parent.clearState(self) +end diff --git a/SpatialFractionalMaxPooling.lua b/SpatialFractionalMaxPooling.lua index f0bfc73b0..b2267bad0 100644 --- a/SpatialFractionalMaxPooling.lua +++ b/SpatialFractionalMaxPooling.lua @@ -138,14 +138,14 @@ function SpatialFractionalMaxPooling:updateGradInput(input, gradOutput) return self.gradInput end +-- backward compat function SpatialFractionalMaxPooling:empty() - self.gradInput:resize() - self.gradInput:storage():resize(0) - self.output:resize() - self.output:storage():resize(0) - self.indices:resize() - self.indices:storage():resize(0) - self.randomSamples = nil + self:clearState() +end + +function SpatialFractionalMaxPooling:clearState() + nn.utils.clear(self, 'indices', 'randomSamples') + return parent.clearState(self) end function SpatialFractionalMaxPooling:__tostring__() diff --git a/SpatialFullConvolution.lua b/SpatialFullConvolution.lua index 1ede037bc..10142b8d9 100644 --- a/SpatialFullConvolution.lua +++ b/SpatialFullConvolution.lua @@ -111,3 +111,9 @@ function SpatialFullConvolution:__tostring__() end return s .. ')' end + +function SpatialFullConvolution:clearState() + nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput') + return parent.clearState(self) +end + diff --git a/SpatialMaxPooling.lua b/SpatialMaxPooling.lua index aa7251c8d..76b0d96c6 100644 --- a/SpatialMaxPooling.lua +++ b/SpatialMaxPooling.lua @@ -59,13 +59,9 @@ function SpatialMaxPooling:updateGradInput(input, gradOutput) return self.gradInput end +-- for backward compat function SpatialMaxPooling:empty() - self.gradInput:resize() - self.gradInput:storage():resize(0) - self.output:resize() - self.output:storage():resize(0) - self.indices:resize() - self.indices:storage():resize(0) + self:clearState() end function SpatialMaxPooling:__tostring__() @@ -78,3 +74,10 @@ function SpatialMaxPooling:__tostring__() return s end + +function SpatialMaxPooling:clearState() + if self.indices then + self.indices:set() + end + return parent.clearState(self) +end diff --git a/SpatialMaxUnpooling.lua b/SpatialMaxUnpooling.lua index 401112e5a..219c5560a 100644 --- a/SpatialMaxUnpooling.lua +++ b/SpatialMaxUnpooling.lua @@ -33,10 +33,7 @@ function SpatialMaxUnpooling:updateGradInput(input, gradOutput) end function SpatialMaxUnpooling:empty() - self.gradInput:resize() - self.gradInput:storage():resize(0) - self.output:resize() - self.output:storage():resize(0) + self:clearState() end function SpatialMaxUnpooling:__tostring__() diff --git a/SpatialSubtractiveNormalization.lua b/SpatialSubtractiveNormalization.lua index 84d943ae9..e2da2c6a2 100644 --- a/SpatialSubtractiveNormalization.lua +++ b/SpatialSubtractiveNormalization.lua @@ -106,3 +106,10 @@ function SpatialSubtractiveNormalization:updateGradInput(input, gradOutput) -- done return self.gradInput end + +function SpatialSubtractiveNormalization:clearState() + if self.ones then self.ones:set() end + if self._coef then self._coef:set() end + self.meanestimator:clearState() + return parent.clearState(self) +end diff --git a/TemporalMaxPooling.lua b/TemporalMaxPooling.lua index b8fdd3e08..881eba21e 100644 --- a/TemporalMaxPooling.lua +++ b/TemporalMaxPooling.lua @@ -22,10 +22,10 @@ function TemporalMaxPooling:updateGradInput(input, gradOutput) end function TemporalMaxPooling:empty() - self.gradInput:resize() - self.gradInput:storage():resize(0) - self.output:resize() - self.output:storage():resize(0) - self.indices:resize() - self.indices:storage():resize(0) + self:clearState() +end + +function TemporalMaxPooling:clearState() + if self.indices then self.indices:set() end + return parent.clearState(self) end diff --git a/VolumetricAveragePooling.lua b/VolumetricAveragePooling.lua index 8b5c6d5e0..f5adcd079 100644 --- a/VolumetricAveragePooling.lua +++ b/VolumetricAveragePooling.lua @@ -38,8 +38,5 @@ function VolumetricAveragePooling:updateGradInput(input, gradOutput) end function VolumetricAveragePooling:empty() - self.gradInput:resize() - self.gradInput:storage():resize(0) - self.output:resize() - self.output:storage():resize(0) + return parent.clearState(self) end diff --git a/VolumetricConvolution.lua b/VolumetricConvolution.lua index cdf37ebe8..60dcbc2dd 100644 --- a/VolumetricConvolution.lua +++ b/VolumetricConvolution.lua @@ -25,7 +25,7 @@ function VolumetricConvolution:__init(nInputPlane, nOutputPlane, kT, kW, kH, dT, self.gradBias = torch.Tensor(nOutputPlane) -- temporary buffers for unfolding (CUDA) self.finput = torch.Tensor() - self.fgradInput = torch.Tensor() + self.fgradInput = torch.Tensor() self:reset() end @@ -175,3 +175,8 @@ function VolumetricConvolution:type(type, tensorCache) self.fgradInput:set() return parent.type(self, type, tensorCache) end + +function VolumetricConvolution:clearState() + nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput') + return parent.clearState(self) +end diff --git a/VolumetricMaxPooling.lua b/VolumetricMaxPooling.lua index 0fca62c36..8e2deaf47 100644 --- a/VolumetricMaxPooling.lua +++ b/VolumetricMaxPooling.lua @@ -61,12 +61,12 @@ function VolumetricMaxPooling:updateGradInput(input, gradOutput) end function VolumetricMaxPooling:empty() - self.gradInput:resize() - self.gradInput:storage():resize(0) - self.output:resize() - self.output:storage():resize(0) - self.indices:resize() - self.indices:storage():resize(0) + self:clearState() +end + +function VolumetricMaxPooling:clearState() + if self.indices then self.indices:set() end + return parent.clearState(self) end function VolumetricMaxPooling:read(file, version) diff --git a/VolumetricMaxUnpooling.lua b/VolumetricMaxUnpooling.lua index 57d0ee0d5..1bb04ed18 100644 --- a/VolumetricMaxUnpooling.lua +++ b/VolumetricMaxUnpooling.lua @@ -56,10 +56,7 @@ function VolumetricMaxUnpooling:updateGradInput(input, gradOutput) end function VolumetricMaxUnpooling:empty() - self.gradInput:resize() - self.gradInput:storage():resize(0) - self.output:resize() - self.output:storage():resize(0) + self:clearState() end function VolumetricMaxUnpooling:__tostring__() diff --git a/utils.lua b/utils.lua index 3fcc7caec..f2ca07cd3 100644 --- a/utils.lua +++ b/utils.lua @@ -162,4 +162,27 @@ function nn.utils.contiguousView(output, input, ...) return output end +-- go over specified fields and clear them. accepts +-- nn.utils.clearState(self, {'_buffer', '_buffer2'}) and +-- nn.utils.clearState(self, '_buffer', '_buffer2') +function nn.utils.clear(self, ...) + local arg = {...} + if #arg > 0 and type(arg[1]) == 'table' then + arg = arg[1] + end + local function clear(f) + if self[f] then + if torch.isTensor(self[f]) then + self[f]:set() + elseif type(self[f]) == 'table' then + self[f] = {} + else + self[f] = nil + end + end + end + for i,v in ipairs(arg) do clear(v) end + return self +end + table.unpack = table.unpack or unpack From c92fd21305273a27bc4ff240aa26619fe06430ef Mon Sep 17 00:00:00 2001 From: Sergey Zagoruyko Date: Wed, 3 Feb 2016 01:33:47 +0100 Subject: [PATCH 2/2] clearState doc --- doc/module.md | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/doc/module.md b/doc/module.md index f65eb51de..1339fc61d 100644 --- a/doc/module.md +++ b/doc/module.md @@ -117,8 +117,8 @@ situations. Keep in mind that, this function uses a simple trick to achieve its goal and it might not be valid for a custom module. -Also note that compared to accGradParameters(), the gradients are not retained -for future use. +Also note that compared to accGradParameters(), the gradients are not retained +for future use. ```lua function Module:accUpdateGradParameters(input, gradOutput, lr) @@ -154,12 +154,12 @@ Example: ```lua -- make an mlp -mlp1=nn.Sequential(); +mlp1=nn.Sequential(); mlp1:add(nn.Linear(100,10)); -- make a second mlp -mlp2=nn.Sequential(); -mlp2:add(nn.Linear(100,10)); +mlp2=nn.Sequential(); +mlp2:add(nn.Linear(100,10)); -- the second mlp shares the bias of the first mlp2:share(mlp1,'bias'); @@ -187,7 +187,7 @@ some shared parameters. Example: ```lua -- make an mlp -mlp1=nn.Sequential(); +mlp1=nn.Sequential(); mlp1:add(nn.Linear(100,10)); -- make a copy that shares the weights and biases @@ -208,7 +208,7 @@ This function converts all the parameters of a module to the given `type`. The `type` can be one of the types defined for [torch.Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md). -If tensors (or their storages) are shared between multiple modules in a +If tensors (or their storages) are shared between multiple modules in a network, this sharing will be preserved after type is called. To preserve sharing between multiple modules and/or tensors, use @@ -216,12 +216,12 @@ To preserve sharing between multiple modules and/or tensors, use ```lua -- make an mlp -mlp1=nn.Sequential(); +mlp1=nn.Sequential(); mlp1:add(nn.Linear(100,10)); -- make a second mlp -mlp2=nn.Sequential(); -mlp2:add(nn.Linear(100,10)); +mlp2=nn.Sequential(); +mlp2:add(nn.Linear(100,10)); -- the second mlp shares the bias of the first mlp2:share(mlp1,'bias'); @@ -254,7 +254,7 @@ a `Module`. The object pointer is _never_ supposed to change. However, its contents (including its size if it is a Tensor) are supposed to change. In general state variables are -[Tensors](https://github.com/torch/torch7/blob/master/doc/tensor.md). +[Tensors](https://github.com/torch/torch7/blob/master/doc/tensor.md). However, some special sub-classes like [table layers](table.md#nn.TableLayers) contain something else. Please, refer to each module specification for further information. @@ -269,7 +269,7 @@ This contains the output of the module, computed with the last call of #### gradInput #### This contains the gradients with respect to the inputs of the module, computed with the last call of -[updateGradInput(input, gradOutput)](#nn.Module.updateGradInput). +[updateGradInput(input, gradOutput)](#nn.Module.updateGradInput). ### Parameters and gradients w.r.t parameters ### @@ -353,9 +353,9 @@ end ### listModules() ### -List all Modules instances in a network. Returns a flattened list of modules, -including container modules (which will be listed first), self, and any other -component modules. +List all Modules instances in a network. Returns a flattened list of modules, +including container modules (which will be listed first), self, and any other +component modules. For example : ```lua @@ -392,3 +392,9 @@ nn.Linear(10 -> 20) nn.Tanh nn.ReLU ``` + +### clearState() ### + +Clears intermediate module states as `output`, `gradInput` and others. +Useful when serializing networks and running low on memory. Internally calls `set()` +on tensors so it does not break buffer sharing.