In [None]:
%matplotlib inline
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import edward as ed
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import six
import tensorflow as tf
import pickle

from edward.models import (
    Categorical, Dirichlet, Empirical, InverseGamma,
    MultivariateNormalDiag, Normal, ParamMixture, Bernoulli, PointMass, Mixture)

from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import roc_auc_score
from sklearn.cluster import KMeans

plt.style.use('ggplot')

In [None]:
class dataset(object):
    def __init__(self, N, M, D1, D2, K=2):
        self.N = N
        self.M = M
        self.D1 = D1
        self.D2 = D2
        self.K = K
        self.Z = np.zeros((N,D1), dtype=np.float32)
        self.X = np.zeros((N,D2), dtype=np.float32)
        self.Y = np.zeros((N,1))
        self.C = np.zeros((N,1), dtype=np.float32)
        self.beta = None
        self.mus = None
        self.stds = None
        self.WX = None
        self.WY = None
        self.sigmaX = None
        
    def create(self):
#         beta = np.random.dirichlet([1]*self.M)
        beta = [1/self.M]*self.M
        mus = np.random.randn(self.M, self.D1)*4
        stds = [[1]*self.D1]*self.M
        WX = np.random.randn(self.D1, self.D2)
        if self.K == 2:
            WY = np.random.randn(self.D1)
        else:
            WY = np.random.randn(self.D1, self.K)
        sigmaX = [1]*self.D2
        
        for n in range(self.N):
            c = np.argmax(np.random.multinomial(1, beta))
            self.C[n,:] = c
            self.Z[n, :] = np.random.multivariate_normal(mus[c], np.diag(stds[c]))
            self.X[n, :] = np.random.multivariate_normal(np.matmul(self.Z[n],WX), np.diag(sigmaX))
            if self.K == 2:
                self.Y[n,:] = np.random.binomial(1,1/(1+np.exp(-np.matmul(self.Z[n],WY))))
            else:
                uprob = np.exp(np.matmul(self.Z[n],WY))
                self.Y[n,:] = np.argmax(np.random.multinomial(1,uprob/np.sum(uprob)))
                del uprob

        
        self.beta = beta
        self.mus = mus
        self.stds = stds
        self.WX = WX
        self.WY = WY
        self.sigmaX = sigmaX

    def print_params(self):
        print("Cluster Probabilities:", self.beta)
        print("Centers:")
        Xcenters = np.matmul(self.mus, self.WX)
        for i in range(self.M):
            print(i,self.mus[i,:],Xcenters[i,:])
        for i in range(self.M):
            classcounts = np.unique(self.Y[np.where(self.C==i)],return_counts=True)
#             print(classcounts[1], classcounts[1][1]/(classcounts[1][0]+classcounts[1][1]))
            print(i,classcounts)
            
        
    def visualize(self):
        color = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']
        marker = ['x','+','0']
        for i in [0,1]:
            classpoints = np.where(self.Y==i)
            for j in range(self.M):
                points = np.where(self.C[classpoints]==j)
                Z = self.Z[classpoints[0],:]
                plt.plot(Z[points, 0], Z[points, 1], color[j]+marker[i])
#                 plt.axis([-20,20,-20,20])
                plt.title("Simulated dataset")
        plt.show()
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        for i in [0,1]:
            classpoints = np.where(self.Y==i)
            for j in range(self.M):
                points = np.where(self.C[classpoints]==j)
                X = self.X[classpoints[0],:]
                plt.scatter(X[points, 0], X[points, 1], X[points,2],c=color[j], marker=marker[i])
        plt.show()
    

    def MAPdiagnostics(self, mode='print', qmu=None, qwx=None, qwy=None, Xtest=None, Ytest=None, K=2, k=2):
        sess = ed.get_session()
#         clusterprobs = sess.run(qbeta.params)
        zproto = sess.run(qmu.params)
        weightx = sess.run(qwx.params)
        weighty = sess.run(qwy.params)
        xcenters = np.matmul(zproto,weightx.transpose())
        if K == 2:
            ycenters = 1/(1+np.exp(-np.matmul(zproto,weighty.transpose())))
        else:
            uprob = np.exp(np.matmul(zproto,weighty.transpose()))
            ycenters = uprob/np.expand_dims(np.sum(uprob, axis=1), axis=1)
            del uprob
        
        if mode=='print':
            print("Inferred prototypes axes:")
            data.print_params()
            print("__________________________________________\n")
            for i in range(zproto.shape[0]):
                print(i,zproto[i,:],ycenters[i,:],xcenters[i,:])
        
        elif mode=='evaluate':
            print(k)
            invWx = np.linalg.pinv(weightx)
            ztest = np.matmul(Xtest, invWx.transpose())
            if K == 2:
                ymat = 1/(1+np.exp(-np.matmul(ztest,weighty.transpose())))
            else:
                uprob = np.exp(np.matmul(ztest,weighty.transpose()))
                ymat = uprob/np.expand_dims(np.sum(uprob, axis=1), axis=1)
                del uprob
                
            nbrs = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(zproto)
            distances, indices = nbrs.kneighbors(ztest)
            invdist = np.reciprocal(distances)
            probdist = invdist/np.expand_dims(invdist.sum(axis=1), axis=1)
            
            if K==2:
                ynn = np.zeros([Xtest.shape[0],1])
            else:
                ynn = np.zeros([Xtest.shape[0],K])
                countnn = 0
                countmat = 0
            
            for i in range(Xtest.shape[0]):
                for j in range(k):
                    ynn[i] = ynn[i]+ycenters[indices[i,j]]*probdist[i,j]
                if K > 2:
                    if np.argmax(ynn[i])==Ytest[i]:
                        countnn = countnn + 1
                    if np.argmax(ymat[i])==Ytest[i]:
                        countmat = countmat + 1
#             for i in range(Xtest.shape[0]):
#                 print(Ytest[i],ymat[i],ynn[i])
            total = ynn.shape[0]
            if K == 2:
                print(np.sum((ynn>0.5)==Ytest)/total,np.sum((ymat>0.5)==Ytest)/total)
                print(roc_auc_score(Ytest,ynn), roc_auc_score(Ytest,ymat),"\n")
            else:
                print(countnn/total, countmat/total)
            

    def EMdiagnostics(self, qmu=None, qwx=None, qwy=None, qz=None):
        sess = ed.get_session()
        probs = sess.run(qc.probs)
        cluster = np.argmax(probs, axis=1)
        clusterlabels = np.zeros([self.M, 10])
        for i in range(self.M):
            temp = Ytrain[np.where(cluster==i)]
            elem, count = np.unique(temp, return_counts=True)
            elem = elem.astype(int)
            for j in range(elem.shape[0]):
                clusterlabels[i,elem[j]] = count[j]

        zproto = sess.run(qmu.mean())
        dictionary = sess.run(qwx.mean())
        dictionary = np.matmul(zproto,dictionary.transpose())*Xscale+Xmean
        # np.place(dictionary, dictionary<0, 0)
        for i in range(dictionary.shape[0]):
            print(clusterlabels[i,:].astype(int))
            utils.show(dictionary[i,:].reshape((28,28)))
        

In [None]:
N = 1000
M = 2
D1 = 2
D2 = 3
K = 3
inference = 'EM'
model = 'collapsed'
initialization = 'kmeans'
experiment = 'new'

data = dataset(N, M, D1, D2, K)

if experiment == 'new':
    data.create()
else:
    with open('dataset.pkl','rb') as infile:
        data = pickle.load(infile)
# data.visualize()
N = 750
Xtrain = data.X[:N]
Ytrain = data.Y[:N]
Xtest = data.X[N:]
Ytest = data.Y[N:]
Ytrain = Ytrain[:,0]

In [None]:
if model != "collapsed":
    beta = Dirichlet(tf.ones(M))
    mu = Normal(tf.zeros(D1), tf.ones(D1), sample_shape=M)
    sigmasq = InverseGamma(tf.ones(D1), tf.ones(D1), sample_shape=M)
    z = ParamMixture(beta, {'loc': mu, 'scale_diag': tf.sqrt(sigmasq)},
                     MultivariateNormalDiag,
                     sample_shape=N)
    c = z.cat
    wx = Normal(loc=tf.zeros([D2, D1]), scale=tf.ones([D2, D1]))
    x = Normal(loc=tf.matmul(z, wx, transpose_b=True), scale=tf.ones([N, D2]))
    if K == 2:
        wy = Normal(loc=tf.zeros([1, D1]), scale=tf.ones([1, D1]))
        y = Bernoulli(logits=tf.matmul(z, wy, transpose_b=True))
    else:
        wy = Normal(loc=tf.zeros([K, D1]), scale=tf.ones([K, D1]))
        y = Categorical(logits=tf.matmul(z, wy, transpose_b=True))
else:
    beta = Dirichlet(tf.ones(M))
    mu = Normal(tf.zeros(D1), tf.ones(D1), sample_shape=M)
    sigmasq = InverseGamma(tf.ones(D1), tf.ones(D1), sample_shape=M)
    cat = Categorical(probs=beta, sample_shape=N)
    components = [
    MultivariateNormalDiag(mu[k], sigmasq[k], sample_shape=N)
    for k in range(M)]
    z = Mixture(cat=cat, components=components,sample_shape=N)
    wx = Normal(loc=tf.zeros([D2, D1]), scale=tf.ones([D2, D1]))
    x = Normal(loc=tf.matmul(z, wx, transpose_b=True), scale=tf.ones([N, D2]))
    if K == 2:
        wy = Normal(loc=tf.zeros([1, D1]), scale=tf.ones([1, D1]))
        y = Bernoulli(logits=tf.matmul(z, wy, transpose_b=True))
    else:
        wy = Normal(loc=tf.zeros([K, D1]), scale=tf.ones([K, D1]))
        y = Categorical(logits=tf.matmul(z, wy, transpose_b=True))
    

In [None]:
if inference == 'VI':
    qz = Normal(loc=tf.Variable(tf.random_normal([N, D1])),
                scale=tf.nn.softplus(tf.Variable(tf.random_normal([N, D1]))))
    qmu = Normal(loc=tf.Variable(tf.random_normal([M, D1])),
                scale=tf.nn.softplus(tf.Variable(tf.random_normal([M, D1]))))
    qwx = Normal(loc=tf.Variable(tf.random_normal([D2, D1])),
                scale=tf.nn.softplus(tf.Variable(tf.random_normal([D2, D1]))))
    qwy = Normal(loc=tf.Variable(tf.random_normal([1, D1])),
                scale=tf.nn.softplus(tf.Variable(tf.random_normal([1, D1]))))
#     qc = Categorical(logits=tf.Variable(tf.zeros([N,M])))

    # inference = ed.MAP([mu,c], data={x: data.X, y: data.Y})
    inference = ed.KLqp({mu: qmu}, data={x: data.X, y: data.Y})
    #  , z: qz, wy: qwy, wx: qwx
    optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
    inference.run(n_iter=500, n_print=100, n_samples=50,optimizer=optimizer)

    
if inference == 'MAP':
    qz = PointMass(params=tf.Variable(tf.random_normal([N, D1])))
    qmu = PointMass(params=tf.Variable(tf.random_normal([M, D1])))
    qwx = PointMass(params=tf.Variable(tf.random_normal([D2, D1])))
    
    if K == 2:
        qwy = PointMass(params=tf.Variable(tf.random_normal([1, D1])))
    else:
        qwy = PointMass(params=tf.Variable(tf.random_normal([K, D1])))
        
    qsigmasq = PointMass(params=tf.Variable(tf.ones([M,D1])))
    qc = PointMass(params=tf.Variable(tf.zeros(N)))
    inference = ed.MAP({mu:qmu, wx:qwx, wy:qwy, sigmasq:qsigmasq}, data={x: data.X, y: data.Y})
    optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
    inference.run(n_iter=5000, n_print=100, optimizer=optimizer)


if inference == 'EM':
    
    qz = Normal(loc=tf.Variable(tf.random_normal([N, D1])),
                scale=tf.nn.softplus(tf.Variable(tf.random_normal([N, D1]))))
#     qc = Categorical(logits=tf.Variable(tf.zeros([N,M])))
#     qbeta = PointMass(params=tf.Variable(tf.ones(M)/M))
    wxinit = np.random.normal(size=[D2, D1]).astype(np.float32)
    qwx = PointMass(params=tf.Variable(wxinit))
    
    if K == 2:
        qwy = PointMass(params=tf.Variable(tf.random_normal([1, D1])))
    else:
        qwy = PointMass(params=tf.Variable(tf.random_normal([K, D1])))
        
    qsigmasq = PointMass(params=tf.Variable(tf.ones([M,D1])))
    if initialization == 'random':
        qmu = PointMass(params=tf.Variable(tf.random_normal([M, D1])))
    elif initialization == 'kmeans':
        kmeans = KMeans(n_clusters=M, random_state=0, n_init=5, n_jobs=-2).fit(Xtrain)
        xinit = kmeans.cluster_centers_
        zinit = np.matmul(xinit, np.linalg.pinv(wxinit).transpose()).astype(np.float32)
        qmu = PointMass(params=tf.Variable(zinit))
            
    
    inference_e = ed.KLqp({z:qz}, data={x: Xtrain, y: Ytrain, mu:qmu, wx:qwx, wy:qwy, sigmasq:qsigmasq})
    inference_m = ed.MAP({mu:qmu, wx:qwx, wy:qwy, sigmasq:qsigmasq}, data={x: Xtrain, y: Ytrain, z:qz})
    inference_e.initialize(optimizer = tf.train.AdamOptimizer(learning_rate=1e-3))
    inference_m.initialize()

    init = tf.global_variables_initializer()
    init.run()
    
    for i in range(500):
        for j in range(5):
            info_dict_e = inference_e.update()
        info_dict_m = inference_m.update()
        inference_m.print_progress(info_dict_m)

        
if inference == 'MCMC':
    T = 2000  # number of MCMC samples
    qz = Empirical(tf.Variable(tf.zeros([T, N, D1])))
    qmu = Empirical(tf.Variable(tf.zeros([T, M, D1])))
    qsigmasq = Empirical(tf.Variable(tf.ones([T, M, D1])))
    qwx = Empirical(tf.Variable(tf.random_normal([T, D2, D1])))
    qwy = Empirical(tf.Variable(tf.random_normal([T, 1, D1])))
    inference = ed.Gibbs({mu: qmu, sigmasq: qsigmasq}, data={x: data.X, y:data.Y})
    inference.initialize()

    sess = ed.get_session()
    tf.global_variables_initializer().run()

    t_ph = tf.placeholder(tf.int32, [])
    running_cluster_means = tf.reduce_mean(qmu.params[:t_ph], 0)
    running_weight_means = tf.reduce_mean(qwx.params[:t_ph], 0)

    for _ in range(inference.n_iter):
      info_dict = inference.update()
      inference.print_progress(info_dict)
      t = info_dict['t']
      if t % inference.n_print == 0:
        print("\nInferred cluster means:")
        print(sess.run(tf.matmul(running_cluster_means, running_weight_means, transpose_b=True), {t_ph: t - 1}))
    data.print_params()

In [None]:
for k in [1]:
    data.MAPdiagnostics('print', qmu, qwx, qwy, Xtest, Ytest, K, k)
# print(np.sum(np.argmax(sess.run(qc.probs), axis=1)==data.C[:,0]))

In [None]:
# z_post = ed.copy(z, {c: qc, mu: qmu})
# z_gen = sess.run(z_post)
# plt.scatter(z_gen[:,0], z_gen[:, 1])
# plt.show()

# x_post = ed.copy(x, {wx: qwx, z: qz, c: qc, wy: qwy, mu: qmu})
# x_gen = sess.run(x_post)
# fig = plt.figure()
# ax = fig.add_subplot(111, projection='3d')
# plt.scatter(x_gen[:,0], x_gen[:, 1], x_gen[:,2])
# plt.show()

# print(ed.evaluate('log_likelihood', data={x_post: data.X}))

In [None]:
with open('dataset1.pkl', 'wb') as infile:
    pickle.dump(data, infile)

In [None]:
from sklearn.neighbors import KNeighborsClassifier
for k in [1,2,4,8,16]:
    neigh = KNeighborsClassifier(n_neighbors=k)
    neigh.fit(Xtrain, Ytrain[:,0])
    print(k,np.sum(neigh.predict(Xtest)==Ytest[:,0])/Ytest.shape[0],roc_auc_score(Ytest, neigh.predict_proba(Xtest)[:,1]))