In [None]:
using JuMP, Gurobi, GLM, DataFrames

In [None]:
using Plots

In [None]:
using Clustering

In [1]:
function D2ClusterTree(X)
    c = 2 # number of clusters
    d = 2 # Tree depth
    T = 2^(d-1) # Number of leaves
    N = 2^d - 1 # Number of Nodes total (including leaves)
    H = 2^(d-1) - 1 # Number of hyperplanes
    Ti = N-T+1  # first enumerated leaf
    Tf = N      # last enumerated leaf
    
    Msplit = 2
    Ms = maximum(abs.(X),1)
    n,p = size(X)
    
    m = Model(solver = GurobiSolver(OutputFlag=0))

    @variable(m, L_ijt[i=1:n,j=1:p,t=Ti:Tf] >=0) # Contribution loss of obs i to leaf t, feature j
    @variable(m, μ_jt[j=1:p,t=Ti:Tf]) # Mean of leaf t, feature j
    @variable(m, z[i=1:n,t=Ti:Tf],Bin) # = 1 if obs. i ends in leaf t
    @variable(m, a[j=1:p]) # Hyperplane values
    @variable(m, b >= 0) # Hyperplane bias



    @constraint(m, LeafLoss1[i=1:n, j=1:p, t = Ti:Tf], 
        L_ijt[i,j,t] >= z[i,t]*X[i,j]-μ_jt[j,t]-Ms[j]*(1-z[i,t]))
    @constraint(m, LeafLoss2[i=1:n, j=1:p, t = Ti:Tf], 
        L_ijt[i,j,t] >= μ_jt[j,t] - z[i,t]*X[i,j]-Ms[j]*(1-z[i,t]))
    
    @constraint(m, bConst, b <= 1)

    @constraint(m, Split1[i=1:n], 
        sum(a[j]*X[i,j] for j = 1:p) <= b + Msplit*(1-z[i,2]))

    @constraint(m, Split2[i=1:n], 
        sum(a[j]*X[i,j] for j = 1:p) >= b - Msplit*(1-z[i,3]))

    @constraint(m, norm(a) <= 1)
    @constraint(m, zConst[i=1:n], 
        sum(z[i,t] for t = Ti:Tf) == 1)
    
    
    @objective(m, Min, sum(L_ijt[i,j,t] for i=1:n, j=1:p, t=Ti:Tf));

    status = solve(m)
    println("Status = ", status)
    z_soln = getvalue(z)
    μ_soln = getvalue(μ_jt)
    a_soln = getvalue(a)
    b_soln = getvalue(b)
    obj = getobjectivevalue(m)
    z_soln, μ_soln, a_soln, b_soln, obj
end;

LoadError: [91mUndefVarError: @variable not defined[39m

In [None]:
n1 = 30
μ1 = [10,10]
σ1 = (1/0.9)
x1 = (μ1.+σ1*randn(n1,2)')'
y1 = ones(Int, n1,1)

n2 = 30
μ2 = [0,0]
σ2 = (1/0.15)
x2 = (μ2.+σ2*randn(n2,2)')'
y2 = zeros(Int,n2,1)

X = [x1;x2]
Y = [y1;y2];

In [None]:
Xm = X .- mean(X,1)
Xnor = Xm./std(Xm,1);

In [None]:
mcols = [:red, :blue]
scatter(Xnor[:,1],Xnor[:,2], markercolor=mcols[Y+1])

In [None]:
@time z_soln, μ_soln, a_soln, b_soln, obj = D2ClusterTree(Xnor);

In [None]:
mcols = [:red, :blue]
scatter(Xnor[:,1],Xnor[:,2], markercolor=mcols[convert(Vector{Int},z_soln[:,2]+1)])

In [None]:
wstart = kmeans(Xnor',2).assignments-1;
mcols = [:red, :blue]
scatter(Xnor[:,1],Xnor[:,2], markercolor=mcols[wstart+1])