In [1]:
using StatsBase
using PyCall

@pyimport sklearn.datasets as ds

In [11]:
# workspace()

abstract BTree

type Node <: BTree
    predicate::Tuple{Int64, Float64}
    left::BTree
    right::BTree
end

type Leaf <: BTree
    proportion::Dict{Int64,Float64}
end

In [4]:
find_entropy(objs) = 1.0 - sum(proportions(objs) .^ 2)

function find_predicate(objs, labels)
    min_entropy = find_entropy(labels)
    predicate = (0, 0.)
    n_labels = size(labels, 1)
    for fid = 1:size(objs, 2)
        objs_col = objs[:, fid]
        for fval = unique(objs_col)
            mask = objs_col .< fval
            left, right = labels[mask], labels[!mask]
            n_left, n_right = size(left, 1), size(right, 1)
            if n_left == 0 || n_right == 0
                continue
            end
            e_left = find_entropy(left) * n_left
            e_right = find_entropy(right) * n_right
            curr_entropy = (e_left + e_right) / n_labels
            if curr_entropy < min_entropy
                min_entropy = curr_entropy
                predicate = fid, fval
            end
        end
    end
    predicate
end

function build_tree(objs, labels)
    fid, fval = find_predicate(objs, labels)
    if fid != 0
        mask = objs[:, fid] .< fval
        nmask = !mask
        left = build_tree(objs[mask, :], labels[mask])
        right = build_tree(objs[nmask, :], labels[nmask])
        return Node((fid, fval), left, right)
    else
        return Leaf(proportionmap(labels))
    end
end

find_leaf(tree::Leaf, obj) = tree.proportion

function find_leaf(tree::Node, obj)
    fid, fval = tree.predicate
    if obj[fid] < fval
        return find_leaf(tree.left, obj)
    else
        return find_leaf(tree.right, obj)
    end
end


function make_prediction(tree::BTree, objs)
    n_objs = size(objs, 1)
    result = Array(Dict{Int64,Float64}, n_objs)
    for i = 1:n_objs
        result[i] = find_leaf(tree, objs[i, :])
    end
    result
end

make_prediction (generic function with 1 method)

In [9]:
objs, labels = ds.make_classification(n_samples=500, n_features=10)

tree = build_tree(objs, labels)
make_prediction(tree, objs)
true

true