Skip to content
This repository has been archived by the owner on Nov 1, 2021. It is now read-only.

Commit

Permalink
initial commit for decisiontree
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Leonard committed May 31, 2017
0 parents commit 1b05061
Show file tree
Hide file tree
Showing 43 changed files with 6,406 additions and 0 deletions.
48 changes: 48 additions & 0 deletions CMakeLists.txt
@@ -0,0 +1,48 @@
SET(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR})

CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR)
CMAKE_POLICY(VERSION 2.6)

FIND_PACKAGE(Torch REQUIRED)

SET(BUILD_STATIC YES) # makes sure static targets are enabled in ADD_TORCH_PACKAGE

SET(CMAKE_C_FLAGS "--std=c99 -pedantic -Werror -Wall -Wextra -Wno-unused-function -D_GNU_SOURCE ${CMAKE_C_FLAGS}")
SET(src
init.c
hash_map.c
)
SET(luasrc
_env.lua
benchmark.lua
CartNode.lua
CartTrainer.lua
CartTree.lua
DataSet.lua
DecisionForest.lua
DecisionForestTrainer.lua
DecisionTree.lua
DFD.lua
GiniState.lua
GradientBoostState.lua
GradientBoostTrainer.lua
init.lua
LogitBoostCriterion.lua
math.lua
MSECriterion.lua
RandomForestTrainer.lua
Sparse2Dense.lua
SparseTensor.lua
test.lua
TreeState.lua
utils.lua
WorkPool.lua
)

ADD_TORCH_PACKAGE(decisiontree "${src}" "${luasrc}" "A decision tree library, for Torch")

TARGET_LINK_LIBRARIES(decisiontree luaT TH)

SET_TARGET_PROPERTIES(decisiontree_static PROPERTIES COMPILE_FLAGS "-fPIC -DSTATIC_TH")

INSTALL(FILES ${luasrc} DESTINATION "${Torch_INSTALL_LUA_PATH_SUBDIR}/decisiontree")
42 changes: 42 additions & 0 deletions CartNode.lua
@@ -0,0 +1,42 @@
local dt = require 'decisiontree._env'
local CartNode = torch.class("dt.CartNode", dt)

function CartNode:__init(nodeId, leftChild, rightChild, splitFeatureId, splitFeatureValue, score, splitGain)
self.nodeId = nodeId or 0
self.leftChild = leftChild
self.rightChild = rightChild
self.splitFeatureId = splitFeatureId or -1
self.splitFeatureValue = splitFeatureValue or 0
self.score = score or 0
self.splitGain = splitGain
end

function CartNode:__tostring__()
return self:recursivetostring()
end

function CartNode:recursivetostring(indent)
indent = indent or ' '

-- Is this a leaf node?
local res = ''
if not (self.leftChild or self.rightChild) then
res = res .. self.score .. '\n'
else
-- Print the criteria
res = res .. 'input[' .. self.splitFeatureId .. '] <' .. self.splitFeatureValue .. '?\n'

-- Print the branches
if self.leftChild then
res = res .. indent .. 'True->' .. self.leftChild:recursivetostring(indent .. ' ')
end
if self.rightChild then
res = res .. indent .. 'False->' .. self.rightChild:recursivetostring(indent .. ' ')
end
end
return res
end

function CartNode:clone()
return CartNode(self.nodeId, self.leftChild, self.rightChild, self.splitFeatureId, self.splitFeatureValue, self.score, self.splitGain)
end
180 changes: 180 additions & 0 deletions CartTrainer.lua
@@ -0,0 +1,180 @@
local dt = require "decisiontree._env"
local _ = require "moses"

local CartTrainer = torch.class("dt.CartTrainer", dt)

-- Generic CART trainer
function CartTrainer:__init(dataset, minLeafSize, maxLeafNodes)
assert(torch.isTypeOf(dataset, 'dt.DataSet'))
self.dataset = dataset
self.minLeafSize = assert(minLeafSize) -- min examples per leaf
self.maxLeafNodes = assert(maxLeafNodes) -- max leaf nodes in tree

-- by default, single thread
self.parallelMode = 'singlethread'
end

function CartTrainer:train(rootTreeState, activeFeatures)
assert(torch.isTypeOf(rootTreeState, 'dt.TreeState'))
assert(torch.isTensor(activeFeatures))
local root = dt.CartNode()
root.id = 0
root.score = rootTreeState:score(self.dataset)

local nleaf = 1

-- TODO : nodeparallel: parallelize here. The queue is a workqueue.
local queue = {}
table.insert(queue, 1, {cartNode=root, treeState=rootTreeState})

while #queue > 0 and nleaf < self.maxLeafNodes do
local treeGrowerArgs = table.remove(queue, #queue)
local currentTreeState = treeGrowerArgs.treeState

-- Note: if minLeafSize = 1 and maxLeafNode = inf, then each example will be its own leaf...
if self:hasEnoughTrainingExamplesToSplit(currentTreeState.exampleIds:size(1)) then
nleaf = self:processNode(nleaf, queue, treeGrowerArgs.cartNode, currentTreeState, activeFeatures)
end
end

-- CartTree with random branching (when feature is missing)
local branchleft = function() return math.random() < 0.5 end
return dt.CartTree(root, branchleft), nleaf
end

function CartTrainer:processNode(nleaf, queue, node, treeState, activeFeatures)
local bestSplit
if self.parallelMode == 'singlethread' then
bestSplit = self:findBestSplitForAllFeatures(treeState, activeFeatures)
elseif self.parallelMode == 'featureparallel' then
bestSplit = self:findBestSplitForAllFeaturesFP(treeState, activeFeatures)
else
error("Unrecognized parallel mode: " .. self.parallelMode)
end

if bestSplit then
local leftTreeState, rightTreeState = treeState:branch(bestSplit, self.dataset)
assert(bestSplit.leftChildSize + bestSplit.rightChildSize == leftTreeState.exampleIds:size(1) + rightTreeState.exampleIds:size(1), "The left and right subtrees don't match the split found!")
self:setValuesAndCreateChildrenForNode(node, bestSplit, leftTreeState, rightTreeState, nleaf)

table.insert(queue, 1, {cartNode=node.leftChild, treeState=leftTreeState})
table.insert(queue, 1, {cartNode=node.rightChild, treeState=rightTreeState})

return nleaf + 1
end

return nleaf
end

function CartTrainer:findBestSplitForAllFeatures(treeState, activeFeatures)
local timer = torch.Timer()
local bestSplit = treeState:findBestSplit(self.dataset, activeFeatures, self.minLeafSize, -1, -1)

if bestSplit then
assert(torch.type(bestSplit) == 'table')
end

if dt.PROFILE then
print("findBestSplitForAllFeatures time="..timer:time().real)
end
return bestSplit
end

-- Updates the parentNode with the bestSplit information by creates left/right child Nodes.
function CartTrainer:setValuesAndCreateChildrenForNode(parentNode, bestSplit, leftState, rightState, nleaf)
assert(torch.isTypeOf(parentNode, 'dt.CartNode'))
assert(torch.type(bestSplit) == 'table')
assert(torch.isTypeOf(leftState, 'dt.TreeState'))
assert(torch.isTypeOf(rightState, 'dt.TreeState'))
assert(torch.type(nleaf) == 'number')

local leftChild = dt.CartNode()
leftChild.score = leftState:score(self.dataset)
leftChild.nodeId = 2 * nleaf - 1

local rightChild = dt.CartNode()
rightChild.score = rightState:score(self.dataset)
rightChild.nodeId = 2 * nleaf

parentNode.splitFeatureId = bestSplit.splitId
parentNode.splitFeatureValue = bestSplit.splitValue
parentNode.leftChild = leftChild
parentNode.rightChild = rightChild
parentNode.splitGain = bestSplit.splitGain
end

-- We minimally need 2 * N examples in the parent to satisfy >= N examples per child
function CartTrainer:hasEnoughTrainingExamplesToSplit(count)
return count >= 2 * self.minLeafSize
end

-- call before training to enable feature-parallelization
function CartTrainer:featureParallel(workPool)
assert(self.parallelMode == 'singlethread', self.parallelMode)
self.parallelMode = 'featureparallel'
self.workPool = torch.type(workPool) == 'number' and dt.WorkPool(workPool) or workPool
assert(torch.isTypeOf(self.workPool, 'dt.WorkPool'))

-- this deletes all SparseTensor hash maps so that they aren't serialized
self.dataset:deleteIndex()

-- require the dt package
self.workPool:update('require', {libname='decisiontree',varname='dt'})
-- setup worker store (each worker will have its own copy)
local store = {
dataset=self.dataset,
minLeafSize=self.minLeafSize
}
self.workPool:update('storeKeysValues', store)
end

-- feature parallel
function CartTrainer:findBestSplitForAllFeaturesFP(treeState, activeFeatures)
local timer = torch.Timer()
local bestSplit
if treeState.findBestSplitFP then
bestSplit = treeState:findBestSplitFP(self.dataset, activeFeatures, self.minLeafSize, self.workPool.nThread)
end

if not bestSplit then
for i=1,self.workPool.nThread do
-- upvalues
local treeState = treeState
local shardId = i
local nShard = self.workPool.nThread
local featureIds = activeFeatures
-- closure
local task = function(store)
assert(store.dataset)
assert(store.minLeafSize)
if treeState.threadInitialize then
treeState:threadInitialize()
end

local bestSplit = treeState:findBestSplit(store.dataset, featureIds, store.minLeafSize, shardId, nShard)
return bestSplit
end

self.workPool:writeup('execute', task)
end

for i=1,self.workPool.nThread do
local taskname, candidateSplit = self.workPool:read()
assert(taskname == 'execute')
if candidateSplit then
if ((not bestSplit) or candidateSplit.splitGain < bestSplit.splitGain) then
bestSplit = candidateSplit
end
end
end
end

if bestSplit then
assert(torch.type(bestSplit) == 'table')
end

if dt.PROFILE then
print("findBestSplitForAllFeaturesFP time="..timer:time().real)
end
return bestSplit
end
90 changes: 90 additions & 0 deletions CartTree.lua
@@ -0,0 +1,90 @@
local _ = require "moses"
local dt = require 'decisiontree._env'

-- CART (classification-regression decision tree).
-- The example is always branched to the left when the splitting feature is missing.
local CartTree = torch.class("dt.CartTree", "dt.DecisionTree", dt)

function CartTree:__init(root, branchleft)
assert(torch.isTypeOf(root, 'dt.CartNode'))
self.root = root
self.branchleft = branchleft or function() return true end
end

-- TODO optimize this
function CartTree:score(input, stack, optimized)
if optimized == true and stack == nil and torch.isTensor(input) and input.isContiguous and input:isContiguous() and input:nDimension() == 2 then
return input.nn.CartTreeFastScore(input, self.root, input.new())
end
return self:recursivescore(self.root, input, stack)
end

-- Continuous: if input[node.splitFeatureId] < node.splitFeatureValue then leftNode else rightNode
-- Binary: if input[node.splitFeatureId] == 0 then leftNode else rightNode
-- when stack is provided, it is returned as the third argument containing the stack of nodes from root to leaf
function CartTree:recursivescore(node, input, stack)
assert(torch.isTypeOf(node, 'dt.CartNode'))

if stack then
stack = torch.type(stack) == 'table' and stack or {}
table.insert(stack, node)
end

if not (node.leftChild or node.rightChild) then
return node.score, node.nodeId, stack
elseif not node.leftChild then
return self:recursivescore(node.rightChild, input, stack)
elseif not node.rightChild then
return self:recursivescore(node.leftChild, input, stack)
end

local splitId = node.splitFeatureId
local splitVal = node.splitFeatureValue

if input[splitId] then -- if has key
local featureVal = input[splitId]
local nextNode = featureVal < splitVal and node.leftChild or node.rightChild
return self:recursivescore(nextNode, input, stack)
end

-- if feature is missing, branch left
local nextNode = self.branchleft() and node.leftChild or node.rightChild
return self:recursivescore(nextNode, input, stack)
end

function CartTree:__tostring__()
return self.root:recursivetostring()
end

-- expects a stack returned by score
function CartTree:stackToString(stack, input)
assert(torch.type(stack) == 'table')
assert(torch.isTypeOf(stack[1], 'dt.CartNode'))

local res = 'Stack nodes from root to leaf\n'
for i,node in ipairs(stack) do
if not (node.leftChild or node.rightChild) then
res = res .. "score="..node.score .. '\n'
else
local istr = ''
if input then
istr = '=' .. (input[node.splitFeatureId] or 'nil')
end
res = res .. 'input[' .. node.splitFeatureId .. ']' .. istr ..' < ' .. node.splitFeatureValue .. ' ? '
res = res .. '(' .. ((node.leftChild and node.rightChild) and 'LR' or node.leftChild and 'L' or node.rightChild and 'R' or 'WAT?') .. ') '
if node.leftChild == stack[i+1] then
res = res .. 'Left\n'
elseif node.rightChild == stack[i+1] then
res = res .. 'Right\n'
else
error"stackToString error"
end
end
end
return res .. #stack .. " nodes"
end

function CartTree:clone()
return CartTree(self.root:clone(), self.branchleft)
end

0 comments on commit 1b05061

Please sign in to comment.