In [1]:
require 'nn';
if nn.GuidedReLU ~= nil then
  return
end

local GuidedReLU, parent = torch.class('nn.GuidedReLU', 'nn.Module')

function GuidedReLU:__init()
   parent.__init(self)
   self.relu = nn.ReLU()
end

function GuidedReLU:updateOutput(input)
  self.output = self.relu:forward(input)
  return self.output
end

function GuidedReLU:updateGradInput(input, gradOutput)
  self.gradInput = self.relu:backward(input, gradOutput)
  local positive_mask = gradOutput:gt(0):double()
  self.gradInput = self.gradInput:cmul(positive_mask):cmul(gradOutput)
  return self.gradInput
end

In [2]:
input = torch.rand(3,3)-0.5
print(input)

 0.0444  0.2231  0.2875
-0.2768 -0.4780 -0.1057
-0.0444 -0.1649  0.3252
[torch.DoubleTensor of size 3x3]



In [3]:
print(nn.ReLU():forward(input))

 0.0444  0.2231  0.2875
 0.0000  0.0000  0.0000
 0.0000  0.0000  0.3252
[torch.DoubleTensor of size 3x3]



In [4]:
print(nn.GuidedReLU():forward(input))

 0.0444  0.2231  0.2875
 0.0000  0.0000  0.0000
 0.0000  0.0000  0.3252
[torch.DoubleTensor of size 3x3]



In [5]:
nextgrad = torch.rand(3,3)-0.5
print(nextgrad)

-0.0255  0.3835 -0.1734
-0.3290  0.1518 -0.3441
 0.4581 -0.2933  0.1737
[torch.DoubleTensor of size 3x3]



In [6]:
print(nn.ReLU():backward(input, nextgrad))

-0.0255  0.3835 -0.1734
 0.0000  0.0000  0.0000
 0.0000  0.0000  0.1737
[torch.DoubleTensor of size 3x3]



In [7]:
print(nn.GuidedReLU():backward(input, nextgrad))

 0.0000  0.1471  0.0000
-0.0000  0.0000 -0.0000
 0.0000 -0.0000  0.0302
[torch.DoubleTensor of size 3x3]

