In [205]:
using JuMP, Gurobi

### Create the OCT in the paper with a toy dataset

In [213]:
# generate XOR-like data
function f(x1::Float64, x2::Float64)
    if (x1 <= 0.5 && x2 <= 0.5) || (x1 > 0.5 && x2 > 0.5)
        return 1
    end
    return 2     
end

K = 2  # number of classes
p = 2  # number of features
n = 20  # number of data points
x = rand(n, p)
y = diag([f(x1, x2) for x1 = x[:, 1], x2 = x[:, 2]])
Y = reshape(vcat(-(y-1)*2+1, (y-1)*2-1), (n, p));  # needed in equation (15)

In [214]:
# find epsilon, needed in equation (13)
epsilon = Float64[]

for j in 1:p
    min_j = typemax(Int32)
    for i in 1:n-1
        diff = abs(x[i, j] - x[i+1, j])
        if diff < min_j
            min_j = diff
        end
    end
    append!(epsilon, min_j)
end

In [215]:
# find left and right ancestors, needed in equation (13), (14)
function get_ancestors(t)
    current_node = t
    Al = Int[]
    Ar = Int[]
    
    while current_node != 1
        parent = convert(Int, floor(current_node/2))
        if current_node % 2 == 0
            append!(Al, parent)
        else
            append!(Ar, parent)
        end
        current_node = parent
    end 
    
    return Al, Ar
end

get_ancestors (generic function with 1 method)

In [228]:
# set up model
m = Model(solver=GurobiSolver())

# parameters 
D = 2  # depth
N_min = 1  # min number of data points in each leaf
alpha = 0  # complexity penalty weight

T = 2^(D+1)-1  # number of nodes in the tree
Tb = convert(Int, floor(T/2))  # number of branch nodes

# variables
@variable(m, a[1:p, 1:Tb], Bin)  # vector a for each branch node
@variable(m, b[1:Tb])  # split value b
@variable(m, d[1:Tb], Bin)  # d indicates if a branch node applies split
@variable(m, z[1:n, Tb+1:T], Bin)  # z indicates if xi is in leaf node t
@variable(m, l[Tb+1:T], Bin)  # l indicates if leaf node t contains any points
@variable(m, c[1:K, Tb+1:T], Bin)  # ???
@variable(m, Nt[Tb+1:T])  # total number of points in leaf node t
@variable(m, Nkt[1:K, Tb+1:T])  # total number of points of label k in leaf node t
@variable(m, L[Tb+1:T])  # loss at each leaf node

# constraints
@constraint(m, [t = 1:Tb], sum(a[:, t]) == d[t])  # equation (2)
@constraint(m, [t = 1:Tb], b[t] <= d[t])  # equation (3)
@constraint(m, [t = 1:Tb], b[t] >= 0)  # equation (3)
@constraint(m, [t = 2:Tb], d[t] <= d[convert(Int, floor(t/2))])  # equation (5)
@constraint(m, [i = 1:n, t = Tb + 1:T], z[i, t] <= l[t])  # equation (6)
@constraint(m, [t = Tb+1:T], sum(z[:, t]) >= N_min*l[t])  # equation (7)
@constraint(m, [i = 1:n], sum(z[i, :]) == 1)  # equation (8)

# equation (13), (14)
for t in Tb+1:T
    Al, Ar = get_ancestors(t)    
    for s in Ar, i in 1:n
        @constraint(m, x[i, :]'*a[:, s]-b[s]-z[i, t]+1 >= 0)  # equation (14)
    end
    for s in Al, i in 1:n
        @constraint(m, (x[i, :]+epsilon)'*a[:, s] <= b[s]+(1+maximum(epsilon))*(1-z[i, t]))  # equation (13)
    end
end

@constraint(m, [k = 1:K, t = Tb+1:T], Nkt[k, t] == 0.5*sum((1+Y[:, k])'*z[:, t]))  # equation (15)
@constraint(m, [t = Tb+1:T], Nt[t] == sum(z[:, t]))  # equation (16)
@constraint(m, [t = Tb+1:T], sum(c[:, t]) == l[t])
@constraint(m, [t = Tb+1:T], L[t] >= 0)  # equation (22)
@constraint(m, [k = 1:K, t = Tb+1:T], L[t] >= Nt[t]-Nkt[k, t]-n*(1-c[k, t]))  # equation (20)
@constraint(m, [k = 1:K, t = Tb+1:T], L[t] <= Nt[t]-Nkt[k, t]+n*c[k, t])  # equation (21)

# set up objective
@objective(m, Min, sum(L)+alpha*sum(d));

In [229]:
# print the problem
#print(m)

In [230]:
status = solve(m)
println("Objective value: ", getobjectivevalue(m))

Academic license - for non-commercial use only
Optimize a model with 311 rows, 120 columns and 1238 nonzeros
Variable types: 19 continuous, 101 integer (101 binary)
Coefficient statistics:
  Matrix range     [3e-03, 2e+01]
  Objective range  [1e+00, 1e+00]
  Bounds range     [1e+00, 1e+00]
  RHS range        [1e+00, 2e+01]
Found heuristic solution: objective 7
Presolve removed 9 rows and 2 columns
Presolve time: 0.00s
Presolved: 302 rows, 118 columns, 1185 nonzeros
Variable types: 3 continuous, 115 integer (99 binary)

Root relaxation: objective 0.000000e+00, 65 iterations, 0.00 seconds

    Nodes    |    Current Node    |     Objective Bounds      |     Work
 Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time

     0     0    0.00000    0    2    7.00000    0.00000   100%     -    0s
H    0     0                       6.0000000    0.00000   100%     -    0s
H    0     0                       5.0000000    0.00000   100%     -    0s
     0     0    0.00000    0 