Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
46 lines (36 sloc) 1.48 KB
local dt = require "decisiontree._env"
-- Ref: slide 17 in https://homes.cs.washington.edu/~tqchen/pdf/BoostedTree.pdf
-- equivalent to nn.Sigmoid() + nn.BCECriterion()
local LogitBoostCriterion, parent = torch.class("nn.LogitBoostCriterion", "nn.Criterion")
function LogitBoostCriterion:__init(sizeAverage)
parent.__init(self)
self.sizeAverage = sizeAverage
self.hessInput = self.gradInput.new()
self._output = torch.Tensor()
end
function LogitBoostCriterion:updateOutput(input, target)
input.nn.LogitBoostCriterion_updateOutput(input, target, self._output, self.sizeAverage)
self.output = self._output[1]
return self.output
end
function LogitBoostCriterion:updateGradInput(input, target)
input.nn.LogitBoostCriterion_updateGradInput(input, target, self.gradInput)
return self.gradInput
end
function LogitBoostCriterion:updateHessInput(input, target)
input.nn.LogitBoostCriterion_updateHessInput(input, target, self.hessInput)
return self.hessInput
end
-- returns gradInput and hessInput
function LogitBoostCriterion:backward2(input, target)
return self:updateGradInput(input, target), self:updateHessInput(input, target)
end
local gradWrapper = function(input, target, grad)
input.nn.LogitBoostCriterion_updateGradInput(input, target, grad)
end
local hessianWrapper = function(input, target, hessian)
input.nn.LogitBoostCriterion_updateHessInput(input, target, hessian)
end
function LogitBoostCriterion:getWrappers()
return gradWrapper, hessianWrapper
end
You can’t perform that action at this time.