forked from soumith/cuda-convnet2.torch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SpatialResponseNormalization.lua
54 lines (41 loc) · 2.19 KB
/
SpatialResponseNormalization.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
local C = ccn2.C
local SpatialResponseNormalization, parent = torch.class('ccn2.SpatialResponseNormalization', 'nn.Module')
function SpatialResponseNormalization:__init(size, addScale, powScale, minDiv)
parent.__init(self)
self.size = size
self.addScale = addScale or 0.001
-- dic['scale'] /= dic['size'] if self.norm_type == self.CROSSMAP_RESPONSE_NORM else dic['size']**2
self.addScale = self.addScale / (self.size * self.size)
self.powScale = powScale or 0.75
self.minDiv = minDiv or 1.0
-- TODO: check layer.py:1333
self.output = torch.Tensor()
self.gradInput = torch.Tensor()
self.denoms = torch.Tensor()
end
function SpatialResponseNormalization:updateOutput(input)
ccn2.typecheck(input)
ccn2.inputcheck(input)
local nBatch = input:size(4)
local inputC = input:view(input:size(1) * input:size(2) * input:size(3), input:size(4))
self.output:resize(inputC:size())
C['convResponseNorm'](cutorch.getState(), inputC:cdata(), self.denoms:cdata(), self.output:cdata(),
input:size(1), self.size,
self.addScale, self.powScale, self.minDiv)
self.output = self.output:view(input:size(1), input:size(2), input:size(3), input:size(4))
return self.output
end
function SpatialResponseNormalization:updateGradInput(input, gradOutput)
ccn2.typecheck(input); ccn2.typecheck(gradOutput);
ccn2.inputcheck(input); ccn2.inputcheck(gradOutput);
local nBatch = input:size(4)
local inputC = input:view(input:size(1) * input:size(2) * input:size(3), input:size(4))
local gradOutputC = gradOutput:view(gradOutput:size(1) * gradOutput:size(2) * gradOutput:size(3), gradOutput:size(4))
local outputC = self.output:view(gradOutput:size(1) * gradOutput:size(2) * gradOutput:size(3), gradOutput:size(4))
self.gradInput:resize(inputC:size())
C['convResponseNormUndo'](cutorch.getState(), gradOutputC:cdata(), self.denoms:cdata(), inputC:cdata(), outputC:cdata(),
self.gradInput:cdata(), input:size(1), self.size,
self.addScale, self.powScale, 0, 1)
self.gradInput = self.gradInput:view(input:size(1), input:size(2), input:size(3), input:size(4))
return self.gradInput
end