In [1]:
using JuMP, Gurobi

[1m[36mINFO: [39m[22m[36mPrecompiling module JuMP.
[39m[1m[36mINFO: [39m[22m[36mPrecompiling module Gurobi.
[39m

### Create th OCT in the paper with a toy dataset

In [61]:
# generate XOR-like data
function f(x1::Float64, x2::Float64)
    if (x1 <= 0.5 && x2 <= 0.5) || (x1 > 0.5 && x2 > 0.5)
        return 0
    end
    return 1     
end

k = 2  # number of classes
p = 2  # number of features
n = 10  # number of data points
x = rand(n, p)
y = diag([f(x1, x2) for x1 = x[:, 1], x2 = x[:, 2]])
Y = reshape(vcat(-y * 2 + 1, y * 2 - 1), (n, p));

In [120]:
# find left and right ancestors
function get_ancestors(t)
    current_node = t
    Al = Int[]
    Ar = Int[]
    
    while current_node != 1
        parent = convert(Int, floor(current_node/2))
        if current_node % 2 == 0
            append!(Al, parent)
        else
            append!(Ar, parent)
        end
        current_node = parent
    end 
    
    return Al, Ar
end

get_ancestors (generic function with 1 method)

In [129]:
# set up model
m = Model(solver=GurobiSolver())

# parameters 
D = 2  # depth
N_min = 1  # min number of data points in each leaf
alpha = 1  # complexity penalty weight

T = 2 ^ (D + 1) - 1  # number of nodes in the tree
Tb = convert(Int, floor(T/2))  # number of branch nodes
Tl = T - Tb  # number of leaf nodes

# variables
@variable(m, a[1:p, 1:Tb], Bin)  # vector a for each branch node
@variable(m, b[1:Tb])  # split value b
@variable(m, d[1:Tb], Bin)  # d indicates if a branch node applies split
@variable(m, z[1:n, Tb + 1:T], Bin)  # z indicates if xi is in leaf node t
@variable(m, l[Tb + 1:T], Bin)  # l indicates if leaf node t contains any points
@variable(m, c[1:k, Tb + 1:T], Bin)  # ???

# constraints
@constraint(m, [t = 1:Tb], sum(a[:, t]) == d[t])  # equation (2)
@constraint(m, [t = 1:Tb], b[t] <= d[t])  # equation (3)
@constraint(m, [t = 1:Tb], b[t] >= 0)  # equation (3)
@constraint(m, [t = 2:Tb], d[t] <= d[convert(Int, floor(t/2))])  # equation (5)
@constraint(m, [i = 1:n, t = Tb + 1:T], z[i, t] <= l[t])  # equation (6)
@constraint(m, [t = Tb + 1:T], sum(z[:, t]) >= N_min * l[t])  # equation (7)
@constraint(m, [i = 1:n], sum(z[i, :]) == 1)  # equation (8)

# equation (13), (14)
for t in Tb + 1:T
    Al, Ar = get_ancestors(t)
    # println(Al, Ar)
    
    for s in Ar, i in 1:n
        @constraint(m, x[i, :]' * a[:,  s] - b[s] - z[i, t] + 1 >= 0)
    end
    
end