Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions BatchNormalization.lua
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,14 @@ function BN:accGradParameters(input, gradOutput, scale)
self.gradBias:add(scale, self.buffer)
end
end

function BN:clearState()
nn.utils.clear(self, {
'buffer',
'buffer2',
'centered',
'std',
'normalized',
})
return parent.clearState(self)
end
5 changes: 5 additions & 0 deletions Bilinear.lua
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,8 @@ function Bilinear:__tostring__()
(self.bias == nil and ' without bias' or '')
)
end

function Bilinear:clearState()
if self.buff then self.buff:set() end
return parent.clearState(self)
end
21 changes: 14 additions & 7 deletions CMul.lua
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,20 @@ end

function CMul:type(type, tensorCache)
if type then
self._input = nil
self._output = nil
self._weight = nil
self._gradWeight = nil
self._expand = nil
self._repeat = nil
self._sum = nil
self:clearState()
end
return parent.type(self, type, tensorCache)
end

function CMul:clearState()
nn.utils.clear(self, {
'_input',
'_output',
'_weight',
'_gradWeight',
'_expand',
'_repeat',
'_sum',
})
return parent.clearState(self)
end
5 changes: 5 additions & 0 deletions CMulTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,8 @@ function CMulTable:updateGradInput(input, gradOutput)

return self.gradInput
end

function CMulTable:clearState()
if self.tout then self.tout:set() end
return parent.clearState(self)
end
23 changes: 23 additions & 0 deletions Container.lua
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,26 @@ function Container:parameters()
end
return w,gw
end

function Container:clearState()
-- don't call set because it might reset referenced tensors
local function clear(f)
if self[f] then
if torch.isTensor(self[f]) then
self[f] = self[f].new()
elseif type(self[f]) == 'table' then
self[f] = {}
else
self[f] = nil
end
end
end
clear('output')
clear('gradInput')
if self.modules then
for i,module in pairs(self.modules) do
module:clearState()
end
end
return self
end
12 changes: 12 additions & 0 deletions Cosine.lua
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,15 @@ function Cosine:type(type, tensorCache)
end
return parent.type(self, type, tensorCache)
end

function Cosine:clearState()
nn.utils.clear(self, {
'_input',
'_weight',
'_gradOutput',
'_sum',
'_inputNorm',
'_weightNorm',
})
return parent.clearState(self)
end
17 changes: 17 additions & 0 deletions CosineDistance.lua
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ function CosineDistance:updateGradInput(input, gradOutput)
not_batch = true
end

if #self.gradInput ~= 2 then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clearState is missing here, it should clean the following fields.

Not important note: I'd do in here for consistence something like

self.gradInput[1] = self.gradInput[1] or v1.new()
self.gradInput[2] = self.gradInput[2] or v1.new()

self.gradInput[1] = self.gradInput[1] or v1.new()
self.gradInput[2] = self.gradInput[2] or v1.new()
end

local gw1 = self.gradInput[1]
local gw2 = self.gradInput[2]
gw1:resizeAs(v1):copy(v2)
Expand All @@ -97,3 +102,15 @@ function CosineDistance:updateGradInput(input, gradOutput)

return self.gradInput
end

function CosineDistance:clearState()
nn.utils.clear(self, {
'buffer',
'w1',
'w22',
'w',
'w32',
'ones',
})
return parent.clearState(self)
end
10 changes: 10 additions & 0 deletions DotProduct.lua
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ function DotProduct:updateGradInput(input, gradOutput)
local v2 = input[2]
local not_batch = false

if #self.gradInput ~= 2 then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as the one from CosineDistance. Also, the buffer should also be cleaned

self.gradInput[1] = self.gradInput[1] or input[1].new()
self.gradInput[2] = self.gradInput[2] or input[2].new()
end

if v1:dim() == 1 then
v1 = v1:view(1,-1)
v2 = v2:view(1,-1)
Expand All @@ -49,3 +54,8 @@ function DotProduct:updateGradInput(input, gradOutput)

return self.gradInput
end

function DotProduct:clearState()
if self.buffer then self.buffer:set() end
return parent.clearState(self)
end
8 changes: 8 additions & 0 deletions Dropout.lua
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,11 @@ end
function Dropout:__tostring__()
return string.format('%s(%f)', torch.type(self), self.p)
end


function Dropout:clearState()
if self.noise then
self.noise:set()
end
return Parent.clearState(self)
end
29 changes: 18 additions & 11 deletions Euclidean.lua
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,24 @@ end
function Euclidean:type(type, tensorCache)
if type then
-- prevent premature memory allocations
self._input = nil
self._output = nil
self._gradOutput = nil
self._weight = nil
self._div = nil
self._sum = nil
self._expand = nil
self._expand2 = nil
self._expand3 = nil
self._repeat = nil
self._repeat2 = nil
self:clearState()
end
return parent.type(self, type, tensorCache)
end

function Euclidean:clearState()
nn.utils.clear(self, {
'_input',
'_output',
'_gradOutput',
'_weight',
'_div',
'_sum',
'_expand',
'_expand2',
'_expand3',
'_repeat',
'_repeat2',
})
return parent.clearState(self)
end
6 changes: 5 additions & 1 deletion FlattenTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ end
function FlattenTable:type(type, tensorCache)
-- This function just stores references so we don't need to do any type
-- conversions. Just force the tables to be empty.
self.output = {}
self:clearState()
end

function FlattenTable:clearState()
self.input_map = {}
return parent.clearState(self)
end
2 changes: 1 addition & 1 deletion GradientReversal.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
local GradientReversal = torch.class('nn.GradientReversal', 'nn.Module')

function GradientReversal:updateOutput(input)
self.output = input
self.output:set(input)
return self.output
end

Expand Down
18 changes: 18 additions & 0 deletions Identity.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,21 @@ function Identity:updateGradInput(input, gradOutput)
self.gradInput = gradOutput
return self.gradInput
end

function Identity:clearState()
-- don't call set because it might reset referenced tensors
local function clear(f)
if self[f] then
if torch.isTensor(self[f]) then
self[f] = self[f].new()
elseif type(self[f]) == 'table' then
self[f] = {}
else
self[f] = nil
end
end
end
clear('output')
clear('gradInput')
return self
end
1 change: 1 addition & 0 deletions Jacobian.lua
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ function nn.Jacobian.testIO(module,input, minval, maxval)
-- write module
local filename = os.tmpname()
local f = torch.DiskFile(filename, 'w'):binary()
module:clearState()
f:writeObject(module)
f:close()
-- read module
Expand Down
5 changes: 5 additions & 0 deletions L1Cost.lua
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,8 @@ function L1Cost:updateGradInput(input)
)
return self.gradInput
end

function L1Cost:clearState()
if self.output_tensor then self.output_tensor:set() end
return parent.clearState(self)
end
4 changes: 4 additions & 0 deletions L1Penalty.lua
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@ function L1Penalty:updateGradInput(input, gradOutput)
return self.gradInput
end

function L1Penalty:clearState()
if self.loss then self.loss:set() end
return parent.clearState(self)
end
4 changes: 4 additions & 0 deletions Linear.lua
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ end
-- we do not need to accumulate parameters when sharing
Linear.sharedAccUpdateGradParameters = Linear.accUpdateGradParameters

function Linear:clearState()
if self.addBuffer then self.addBuffer:set() end
return parent.clearState(self)
end

function Linear:__tostring__()
return torch.type(self) ..
Expand Down
6 changes: 6 additions & 0 deletions LogSigmoid.lua
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@ function LogSigmoid:updateGradInput(input, gradOutput)
)
return self.gradInput
end

function LogSigmoid:clearState()
if self.buffer then self.buffer:set() end
return parent.clearState(self)
end

4 changes: 4 additions & 0 deletions LookupTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,9 @@ function LookupTable:type(type, tensorCache)
return self
end

function LookupTable:clearState()
return self
end

-- we do not need to accumulate parameters when sharing
LookupTable.sharedAccUpdateGradParameters = LookupTable.accUpdateGradParameters
5 changes: 5 additions & 0 deletions Max.lua
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,8 @@ function Max:type(type, tensorCache)
end
return self
end

function Max:clearState()
nn.utils.clear(self, '_indices', '_output')
return parent.clearState(self)
end
5 changes: 5 additions & 0 deletions Min.lua
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,8 @@ function Min:type(type, tensorCache)
end
return self
end

function Min:clearState()
nn.utils.clear(self, '_indices', '_output')
return parent.clearState(self)
end
12 changes: 12 additions & 0 deletions MixtureTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,15 @@ function MixtureTable:type(type, tensorCache)
self._expertView2 = nil
return parent.type(self, type, tensorCache)
end

function MixtureTable:clearState()
nn.utils.clear(self, {
'_gaterView',
'_expert',
'_expertView',
'_sum',
'_expert2',
'_expertView2',
})
return parent.clearState(self)
end
4 changes: 4 additions & 0 deletions Module.lua
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,7 @@ function Module:listModules()
end
return modules
end

function Module:clearState()
return nn.utils.clear(self, 'output', 'gradInput')
end
13 changes: 13 additions & 0 deletions Normalize.lua
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,16 @@ function Normalize:type(type, tensorCache)
end
return self
end

function Normalize:clearState()
nn.utils.clear(self, {
'_output',
'_indices',
'_gradInput',
'buffer',
'norm',
'normp',
'cross',
})
return parent.clearState(self)
end
5 changes: 5 additions & 0 deletions PReLU.lua
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,8 @@ function PReLU:accGradParameters(input, gradOutput, scale)
)
return self.gradWeight
end

function PReLU:clearState()
nn.utils.clear(self, 'gradWeightBuf', 'gradWeightBuf2')
return parent.clearState(self)
end
6 changes: 6 additions & 0 deletions PairwiseDistance.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ function PairwiseDistance:__init(p)
end

function PairwiseDistance:updateOutput(input)
self.output:resize(1)
if input[1]:dim() == 1 then
self.output:resize(1)
self.output[1]=input[1]:dist(input[2],self.norm)
Expand Down Expand Up @@ -83,3 +84,8 @@ function PairwiseDistance:updateGradInput(input, gradOutput)
self.gradInput[2]:zero():add(-1, self.gradInput[1])
return self.gradInput
end

function PairwiseDistance:clearState()
nn.utils.clear(self, 'diff', 'outExpand', 'grad', 'ones')
return parent.clearState(self)
end
5 changes: 5 additions & 0 deletions RReLU.lua
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,8 @@ end
function RReLU:__tostring__()
return string.format('%s (l:%f, u:%f)', torch.type(self), self.lower, self.upper)
end

function RReLU:clearState()
if self.noise then self.noise:set() end
return parent.clearState(self)
end
Loading