This repository has been archived by the owner on Nov 1, 2021. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Nicholas Leonard
committed
May 31, 2017
0 parents
commit 1b05061
Showing
43 changed files
with
6,406 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
SET(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}) | ||
|
||
CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) | ||
CMAKE_POLICY(VERSION 2.6) | ||
|
||
FIND_PACKAGE(Torch REQUIRED) | ||
|
||
SET(BUILD_STATIC YES) # makes sure static targets are enabled in ADD_TORCH_PACKAGE | ||
|
||
SET(CMAKE_C_FLAGS "--std=c99 -pedantic -Werror -Wall -Wextra -Wno-unused-function -D_GNU_SOURCE ${CMAKE_C_FLAGS}") | ||
SET(src | ||
init.c | ||
hash_map.c | ||
) | ||
SET(luasrc | ||
_env.lua | ||
benchmark.lua | ||
CartNode.lua | ||
CartTrainer.lua | ||
CartTree.lua | ||
DataSet.lua | ||
DecisionForest.lua | ||
DecisionForestTrainer.lua | ||
DecisionTree.lua | ||
DFD.lua | ||
GiniState.lua | ||
GradientBoostState.lua | ||
GradientBoostTrainer.lua | ||
init.lua | ||
LogitBoostCriterion.lua | ||
math.lua | ||
MSECriterion.lua | ||
RandomForestTrainer.lua | ||
Sparse2Dense.lua | ||
SparseTensor.lua | ||
test.lua | ||
TreeState.lua | ||
utils.lua | ||
WorkPool.lua | ||
) | ||
|
||
ADD_TORCH_PACKAGE(decisiontree "${src}" "${luasrc}" "A decision tree library, for Torch") | ||
|
||
TARGET_LINK_LIBRARIES(decisiontree luaT TH) | ||
|
||
SET_TARGET_PROPERTIES(decisiontree_static PROPERTIES COMPILE_FLAGS "-fPIC -DSTATIC_TH") | ||
|
||
INSTALL(FILES ${luasrc} DESTINATION "${Torch_INSTALL_LUA_PATH_SUBDIR}/decisiontree") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
local dt = require 'decisiontree._env' | ||
local CartNode = torch.class("dt.CartNode", dt) | ||
|
||
function CartNode:__init(nodeId, leftChild, rightChild, splitFeatureId, splitFeatureValue, score, splitGain) | ||
self.nodeId = nodeId or 0 | ||
self.leftChild = leftChild | ||
self.rightChild = rightChild | ||
self.splitFeatureId = splitFeatureId or -1 | ||
self.splitFeatureValue = splitFeatureValue or 0 | ||
self.score = score or 0 | ||
self.splitGain = splitGain | ||
end | ||
|
||
function CartNode:__tostring__() | ||
return self:recursivetostring() | ||
end | ||
|
||
function CartNode:recursivetostring(indent) | ||
indent = indent or ' ' | ||
|
||
-- Is this a leaf node? | ||
local res = '' | ||
if not (self.leftChild or self.rightChild) then | ||
res = res .. self.score .. '\n' | ||
else | ||
-- Print the criteria | ||
res = res .. 'input[' .. self.splitFeatureId .. '] <' .. self.splitFeatureValue .. '?\n' | ||
|
||
-- Print the branches | ||
if self.leftChild then | ||
res = res .. indent .. 'True->' .. self.leftChild:recursivetostring(indent .. ' ') | ||
end | ||
if self.rightChild then | ||
res = res .. indent .. 'False->' .. self.rightChild:recursivetostring(indent .. ' ') | ||
end | ||
end | ||
return res | ||
end | ||
|
||
function CartNode:clone() | ||
return CartNode(self.nodeId, self.leftChild, self.rightChild, self.splitFeatureId, self.splitFeatureValue, self.score, self.splitGain) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
local dt = require "decisiontree._env" | ||
local _ = require "moses" | ||
|
||
local CartTrainer = torch.class("dt.CartTrainer", dt) | ||
|
||
-- Generic CART trainer | ||
function CartTrainer:__init(dataset, minLeafSize, maxLeafNodes) | ||
assert(torch.isTypeOf(dataset, 'dt.DataSet')) | ||
self.dataset = dataset | ||
self.minLeafSize = assert(minLeafSize) -- min examples per leaf | ||
self.maxLeafNodes = assert(maxLeafNodes) -- max leaf nodes in tree | ||
|
||
-- by default, single thread | ||
self.parallelMode = 'singlethread' | ||
end | ||
|
||
function CartTrainer:train(rootTreeState, activeFeatures) | ||
assert(torch.isTypeOf(rootTreeState, 'dt.TreeState')) | ||
assert(torch.isTensor(activeFeatures)) | ||
local root = dt.CartNode() | ||
root.id = 0 | ||
root.score = rootTreeState:score(self.dataset) | ||
|
||
local nleaf = 1 | ||
|
||
-- TODO : nodeparallel: parallelize here. The queue is a workqueue. | ||
local queue = {} | ||
table.insert(queue, 1, {cartNode=root, treeState=rootTreeState}) | ||
|
||
while #queue > 0 and nleaf < self.maxLeafNodes do | ||
local treeGrowerArgs = table.remove(queue, #queue) | ||
local currentTreeState = treeGrowerArgs.treeState | ||
|
||
-- Note: if minLeafSize = 1 and maxLeafNode = inf, then each example will be its own leaf... | ||
if self:hasEnoughTrainingExamplesToSplit(currentTreeState.exampleIds:size(1)) then | ||
nleaf = self:processNode(nleaf, queue, treeGrowerArgs.cartNode, currentTreeState, activeFeatures) | ||
end | ||
end | ||
|
||
-- CartTree with random branching (when feature is missing) | ||
local branchleft = function() return math.random() < 0.5 end | ||
return dt.CartTree(root, branchleft), nleaf | ||
end | ||
|
||
function CartTrainer:processNode(nleaf, queue, node, treeState, activeFeatures) | ||
local bestSplit | ||
if self.parallelMode == 'singlethread' then | ||
bestSplit = self:findBestSplitForAllFeatures(treeState, activeFeatures) | ||
elseif self.parallelMode == 'featureparallel' then | ||
bestSplit = self:findBestSplitForAllFeaturesFP(treeState, activeFeatures) | ||
else | ||
error("Unrecognized parallel mode: " .. self.parallelMode) | ||
end | ||
|
||
if bestSplit then | ||
local leftTreeState, rightTreeState = treeState:branch(bestSplit, self.dataset) | ||
assert(bestSplit.leftChildSize + bestSplit.rightChildSize == leftTreeState.exampleIds:size(1) + rightTreeState.exampleIds:size(1), "The left and right subtrees don't match the split found!") | ||
self:setValuesAndCreateChildrenForNode(node, bestSplit, leftTreeState, rightTreeState, nleaf) | ||
|
||
table.insert(queue, 1, {cartNode=node.leftChild, treeState=leftTreeState}) | ||
table.insert(queue, 1, {cartNode=node.rightChild, treeState=rightTreeState}) | ||
|
||
return nleaf + 1 | ||
end | ||
|
||
return nleaf | ||
end | ||
|
||
function CartTrainer:findBestSplitForAllFeatures(treeState, activeFeatures) | ||
local timer = torch.Timer() | ||
local bestSplit = treeState:findBestSplit(self.dataset, activeFeatures, self.minLeafSize, -1, -1) | ||
|
||
if bestSplit then | ||
assert(torch.type(bestSplit) == 'table') | ||
end | ||
|
||
if dt.PROFILE then | ||
print("findBestSplitForAllFeatures time="..timer:time().real) | ||
end | ||
return bestSplit | ||
end | ||
|
||
-- Updates the parentNode with the bestSplit information by creates left/right child Nodes. | ||
function CartTrainer:setValuesAndCreateChildrenForNode(parentNode, bestSplit, leftState, rightState, nleaf) | ||
assert(torch.isTypeOf(parentNode, 'dt.CartNode')) | ||
assert(torch.type(bestSplit) == 'table') | ||
assert(torch.isTypeOf(leftState, 'dt.TreeState')) | ||
assert(torch.isTypeOf(rightState, 'dt.TreeState')) | ||
assert(torch.type(nleaf) == 'number') | ||
|
||
local leftChild = dt.CartNode() | ||
leftChild.score = leftState:score(self.dataset) | ||
leftChild.nodeId = 2 * nleaf - 1 | ||
|
||
local rightChild = dt.CartNode() | ||
rightChild.score = rightState:score(self.dataset) | ||
rightChild.nodeId = 2 * nleaf | ||
|
||
parentNode.splitFeatureId = bestSplit.splitId | ||
parentNode.splitFeatureValue = bestSplit.splitValue | ||
parentNode.leftChild = leftChild | ||
parentNode.rightChild = rightChild | ||
parentNode.splitGain = bestSplit.splitGain | ||
end | ||
|
||
-- We minimally need 2 * N examples in the parent to satisfy >= N examples per child | ||
function CartTrainer:hasEnoughTrainingExamplesToSplit(count) | ||
return count >= 2 * self.minLeafSize | ||
end | ||
|
||
-- call before training to enable feature-parallelization | ||
function CartTrainer:featureParallel(workPool) | ||
assert(self.parallelMode == 'singlethread', self.parallelMode) | ||
self.parallelMode = 'featureparallel' | ||
self.workPool = torch.type(workPool) == 'number' and dt.WorkPool(workPool) or workPool | ||
assert(torch.isTypeOf(self.workPool, 'dt.WorkPool')) | ||
|
||
-- this deletes all SparseTensor hash maps so that they aren't serialized | ||
self.dataset:deleteIndex() | ||
|
||
-- require the dt package | ||
self.workPool:update('require', {libname='decisiontree',varname='dt'}) | ||
-- setup worker store (each worker will have its own copy) | ||
local store = { | ||
dataset=self.dataset, | ||
minLeafSize=self.minLeafSize | ||
} | ||
self.workPool:update('storeKeysValues', store) | ||
end | ||
|
||
-- feature parallel | ||
function CartTrainer:findBestSplitForAllFeaturesFP(treeState, activeFeatures) | ||
local timer = torch.Timer() | ||
local bestSplit | ||
if treeState.findBestSplitFP then | ||
bestSplit = treeState:findBestSplitFP(self.dataset, activeFeatures, self.minLeafSize, self.workPool.nThread) | ||
end | ||
|
||
if not bestSplit then | ||
for i=1,self.workPool.nThread do | ||
-- upvalues | ||
local treeState = treeState | ||
local shardId = i | ||
local nShard = self.workPool.nThread | ||
local featureIds = activeFeatures | ||
-- closure | ||
local task = function(store) | ||
assert(store.dataset) | ||
assert(store.minLeafSize) | ||
if treeState.threadInitialize then | ||
treeState:threadInitialize() | ||
end | ||
|
||
local bestSplit = treeState:findBestSplit(store.dataset, featureIds, store.minLeafSize, shardId, nShard) | ||
return bestSplit | ||
end | ||
|
||
self.workPool:writeup('execute', task) | ||
end | ||
|
||
for i=1,self.workPool.nThread do | ||
local taskname, candidateSplit = self.workPool:read() | ||
assert(taskname == 'execute') | ||
if candidateSplit then | ||
if ((not bestSplit) or candidateSplit.splitGain < bestSplit.splitGain) then | ||
bestSplit = candidateSplit | ||
end | ||
end | ||
end | ||
end | ||
|
||
if bestSplit then | ||
assert(torch.type(bestSplit) == 'table') | ||
end | ||
|
||
if dt.PROFILE then | ||
print("findBestSplitForAllFeaturesFP time="..timer:time().real) | ||
end | ||
return bestSplit | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
local _ = require "moses" | ||
local dt = require 'decisiontree._env' | ||
|
||
-- CART (classification-regression decision tree). | ||
-- The example is always branched to the left when the splitting feature is missing. | ||
local CartTree = torch.class("dt.CartTree", "dt.DecisionTree", dt) | ||
|
||
function CartTree:__init(root, branchleft) | ||
assert(torch.isTypeOf(root, 'dt.CartNode')) | ||
self.root = root | ||
self.branchleft = branchleft or function() return true end | ||
end | ||
|
||
-- TODO optimize this | ||
function CartTree:score(input, stack, optimized) | ||
if optimized == true and stack == nil and torch.isTensor(input) and input.isContiguous and input:isContiguous() and input:nDimension() == 2 then | ||
return input.nn.CartTreeFastScore(input, self.root, input.new()) | ||
end | ||
return self:recursivescore(self.root, input, stack) | ||
end | ||
|
||
-- Continuous: if input[node.splitFeatureId] < node.splitFeatureValue then leftNode else rightNode | ||
-- Binary: if input[node.splitFeatureId] == 0 then leftNode else rightNode | ||
-- when stack is provided, it is returned as the third argument containing the stack of nodes from root to leaf | ||
function CartTree:recursivescore(node, input, stack) | ||
assert(torch.isTypeOf(node, 'dt.CartNode')) | ||
|
||
if stack then | ||
stack = torch.type(stack) == 'table' and stack or {} | ||
table.insert(stack, node) | ||
end | ||
|
||
if not (node.leftChild or node.rightChild) then | ||
return node.score, node.nodeId, stack | ||
elseif not node.leftChild then | ||
return self:recursivescore(node.rightChild, input, stack) | ||
elseif not node.rightChild then | ||
return self:recursivescore(node.leftChild, input, stack) | ||
end | ||
|
||
local splitId = node.splitFeatureId | ||
local splitVal = node.splitFeatureValue | ||
|
||
if input[splitId] then -- if has key | ||
local featureVal = input[splitId] | ||
local nextNode = featureVal < splitVal and node.leftChild or node.rightChild | ||
return self:recursivescore(nextNode, input, stack) | ||
end | ||
|
||
-- if feature is missing, branch left | ||
local nextNode = self.branchleft() and node.leftChild or node.rightChild | ||
return self:recursivescore(nextNode, input, stack) | ||
end | ||
|
||
function CartTree:__tostring__() | ||
return self.root:recursivetostring() | ||
end | ||
|
||
-- expects a stack returned by score | ||
function CartTree:stackToString(stack, input) | ||
assert(torch.type(stack) == 'table') | ||
assert(torch.isTypeOf(stack[1], 'dt.CartNode')) | ||
|
||
local res = 'Stack nodes from root to leaf\n' | ||
for i,node in ipairs(stack) do | ||
if not (node.leftChild or node.rightChild) then | ||
res = res .. "score="..node.score .. '\n' | ||
else | ||
local istr = '' | ||
if input then | ||
istr = '=' .. (input[node.splitFeatureId] or 'nil') | ||
end | ||
res = res .. 'input[' .. node.splitFeatureId .. ']' .. istr ..' < ' .. node.splitFeatureValue .. ' ? ' | ||
res = res .. '(' .. ((node.leftChild and node.rightChild) and 'LR' or node.leftChild and 'L' or node.rightChild and 'R' or 'WAT?') .. ') ' | ||
if node.leftChild == stack[i+1] then | ||
res = res .. 'Left\n' | ||
elseif node.rightChild == stack[i+1] then | ||
res = res .. 'Right\n' | ||
else | ||
error"stackToString error" | ||
end | ||
end | ||
end | ||
return res .. #stack .. " nodes" | ||
end | ||
|
||
function CartTree:clone() | ||
return CartTree(self.root:clone(), self.branchleft) | ||
end | ||
|
Oops, something went wrong.