In [None]:
import multiA as ma
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from matplotlib import pyplot as plt
from Node import BranchingProgram
from Node import Node
import scipy.io

In [None]:
L = 30
n_cells = np.zeros(2*L)
n_cells[L] = 1
lam = 1.3
for i in np.arange(1,L):
    n_cells[L + i] = n_cells[L+i-1] * lam
    n_cells[L - i] = n_cells[L - i + 1] / lam
print(n_cells)

In [None]:
def normal_kl(mu0,mu1,cov0,cov1):
    cov1inv = np.linalg.inv(cov1)
    det1 = np.linalg.det(cov1)
    det0 = np.linalg.det(cov0)
    x = np.trace(np.matmul(cov1inv,cov0)) + np.matmul(np.matmul(np.transpose(mu1 - mu0),cov1inv), (mu1-mu0)) - len(mu0) + np.log(det1/det0)
    x = x/(2*np.log(2))
    return x

def return_p_sample(N, k, means_p, sigma):
    per_comp = int(N/k)
    data_p = np.random.multivariate_normal(means_p[0,:], sigma, per_comp)
    for i in np.arange(1,k):
        add_data = np.random.multivariate_normal(means_p[i,:], sigma, per_comp)
        data_p = np.concatenate((data_p,add_data), axis = 0)
    data_p = np.random.permutation(data_p)
    return data_p

def return_r_sample(N, k, means_r, sigma):
    per_comp = int(N/k)
    data_r = np.random.multivariate_normal(means_r[0,:], sigma, per_comp)
    for i in np.arange(1,k):
        add_data = np.random.multivariate_normal(means_r[i,:], sigma, per_comp)
        data_r = np.concatenate((data_r,add_data), axis = 0)
    data_r = np.random.permutation(data_r)
    return data_r

def return_p_sample_label(N, k, means_p, sigma):
    per_comp = int(N/k)
    data_p = np.random.multivariate_normal(means_p[0,:], sigma, per_comp)
    y = np.zeros(per_comp)
    for i in np.arange(1,k):
        add_data = np.random.multivariate_normal(means_p[i,:], sigma, per_comp)
        add_y = (2*i)*np.ones(per_comp)
        data_p = np.concatenate((data_p,add_data), axis = 0)
        y = np.concatenate((y,add_y), axis = 0)
    return data_p, y

def return_r_sample_label(N, k, means_r, sigma):
    per_comp = int(N/k)
    data_r = np.random.multivariate_normal(means_r[0,:], sigma, per_comp)
    y = np.ones(per_comp)
    for i in np.arange(1,k):
        add_data = np.random.multivariate_normal(means_r[i,:], sigma, per_comp)
        add_y = (2*i+1)*np.ones(per_comp)
        data_r = np.concatenate((data_r,add_data), axis = 0)
        y = np.concatenate((y,add_y), axis = 0)
    return data_r, y

In [None]:
K = [5,10,15]
repeats = 5

In [None]:
def classifier():
    return RandomForestClassifier(max_depth=5, n_estimators=10)
#     return LogisticRegression(max_iter=500, C=1)

for i,k in enumerate(K):
    mc_kl = np.zeros(repeats)
    me_kl = np.zeros(repeats)
    mc_kl_std = np.zeros(repeats)
    me_kl_std = np.zeros(repeats)
    for count in range(repeats):
        d = 2
        delta = 2.5*np.ones(d)

        sigma = np.eye(d)
        means_p = np.random.multivariate_normal(np.zeros(d), k*10000*sigma, k)
        means_r = means_p + delta
        kl_approx = 10000*np.ones((k,k))
        for ii in range(k):
            for jj in np.arange(ii+1,k):
                kl_approx[ii,jj] = normal_kl(means_r[ii,:],means_p[jj,:],sigma,sigma) 
        print(kl_approx.min())

        true_kl = normal_kl(means_r[1,:],means_p[1,:],sigma,sigma) 
        print(true_kl)

        N = 500000
        data_p = return_p_sample(N, k, means_p, sigma)
        data_r = return_r_sample(N, k, means_r, sigma)

        A = BranchingProgram(n_cells, 25, 0.02, classifier, reweigh = 0)
        A.fit(data_p,data_r)
        A.compute_kl()

        data_p = return_p_sample(N, k, means_p, sigma)
        data_r = return_r_sample(N, k, means_r, sigma)

        mc_kl[count] = A.predict_kl_pair(data_p, data_r)

        data_p = return_p_sample(N, k, means_p, sigma)
        data_r = return_r_sample(N, k, means_r, sigma)

        me = ma.MaxEnt()
        me.fit(data_p, data_r, iter=1000, clr_maker=classifier, eps=0.02, eta=0.02)

        data_p = return_p_sample(N, k, means_p, sigma)
        data_r = return_r_sample(N, k, means_r, sigma)

        me_kl[count] = me.compute_KL(data_r)

        mc_subgroup = np.zeros(k)
        me_subgroup = np.zeros(k)
        for ii in range(k):
            per_comp = 100000
            data_p = np.random.multivariate_normal(means_p[ii,:], sigma, per_comp)
            data_r = np.random.multivariate_normal(means_r[ii,:], sigma, per_comp)
            mc_subgroup[ii] = A.predict_kl_pair(data_p, data_r)
            me_subgroup[ii] = me.compute_KL(data_r)
            
        mc_kl_std[count] = np.std(mc_subgroup)
        me_kl_std[count] = np.std(me_subgroup)

    mdic = {"N": N, "means_p": means_p, "means_r": means_r, "mc_subgroup":mc_subgroup, "me_subgroup":me_subgroup, "delta":delta, "me_kl":me_kl,
       "mc_kl":mc_kl,"d": d, "k": k, "true_kl": true_kl,"me_kl": me_kl, "mc_kl": mc_kl,
           "me_kl_std": me_kl_std, "mc_kl_std": mc_kl_std}
    name = 'gaussian_d'+str(d)+'_k'+str(k)+'.mat'
    scipy.io.savemat(name, mdic)