In [3]:
%load_ext Cython

The Cython extension is already loaded. To reload it, use:
  %reload_ext Cython


In [4]:
import numpy as np
alpha=1
N=100
A=np.genfromtxt("data_files/true_A.csv", delimiter=",")
X=np.genfromtxt("data_files/true_X.csv", delimiter=",")
Z=np.genfromtxt("data_files/true_Z.csv", delimiter=",")
s_x=0.5
s_a=1.7

In [14]:
%%cython --annotate

import cython
cimport cython
import numpy as np
cimport numpy as np
import math

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)

def new_K_cy(double alpha, X, int N, Z,double s_x, double s_a, int obj):
    cdef double[5] k_prob
    cdef double l,LH,log_prior
    cdef int i
#     with cython.nogil:
    for i in range(5):
        l = alpha/N
        new_zi = np.zeros((N,i))
        new_zi[obj,:] = np.ones((1,i))
        new_Z = np.hstack([Z,new_zi.reshape(N,i)])
        LH = full_X_cython(new_Z,X,s_x, s_a)
        log_prior = i*np.log(l)-l-np.log(math.factorial(i))
        k_prob[i] = LH + log_prior#likelihood*prior = posterior

    k_prob = np.exp(k_prob-np.max(k_prob))
    k_prob = k_prob/np.sum(k_prob)
    if (abs(np.sum(k_prob)-1)>0.001):
        return(np.sum(k_prob),'wrong k sum')

    new_k = np.random.choice(5,1,p = k_prob)
    return (new_k)


def full_X_cython( Z, X,double s_x,double s_a):
    cdef double[:,:] zz
    cdef double determ 
    cdef double log_const
    cdef double[:,:] ii
    cdef double tr
    cdef double expon
    
    cdef int D = X.shape[1]
    cdef int N = Z.shape[0]
    cdef int K = Z.shape[1]
    """The constant part"""
    zz = np.dot(Z.T,Z)+((s_x**2)/(s_a**2))*np.eye(K) #zz -- K*K
    determ = np.linalg.det(zz)
    log_const = 0.5*N*D*np.log(2*np.pi)+(N-K)*D*np.log(s_x)+K*D*np.log(s_a)+0.5*D*np.log(determ)
    log_const = -log_const
    """The exponential part"""
    ii =  np.eye(N)-np.dot(np.dot(Z , np.linalg.inv(zz)) , Z.T)
    tr =  np.trace(np.dot(np.dot(X.T , ii) ,X))
    expon = -tr/(2*s_x**2)
    return(log_const+expon)



In [9]:
np.random.seed(123)
def gibbs_sampler_cython(X,init_alpha,init_sig_x,init_sig_a,mcmc):
    N = X.shape[0]
    chain_alpha = np.zeros(mcmc)
    chain_sigma_a = np.zeros(mcmc)
    chain_sigma_x = np.zeros(mcmc)
    chain_K = np.zeros(mcmc)
    chain_Z = list()
    #initial matrix Z
    Z = np.array(np.random.choice(2,N,p = [0.5,0.5])).reshape(N,1)

    chain_alpha[0] = alpha = init_alpha 
    chain_sigma_a[0] = sigma_a = init_sig_a 
    chain_sigma_x[0] = sigma_x = init_sig_x
    chain_K[0] = K = 1
    chain_Z.append(Z)
    P = np.zeros(2)
    
    Hn = 0
    for i in range(1,mcmc):
        #gibbs
        alpha = np.random.gamma(1+K,1/(1+Hn))
        print(i,K)
        Hn = 0
        for im in range(0,N): #loop over images
            Hn = Hn + 1/(im+1)
            #sample new Z_i
            for k in range(0,K):#loop over features
                zk_sum = np.sum(Z[:,k])
                if zk_sum == 0:
                     lz = -10**5
                else:
                     lz = np.log(zk_sum)-np.log(N)
                if zk_sum == N:
                     lz0 = -10**5
                else:
                     lz0 = np.log(N-zk_sum)-np.log(N)
                Z[im,k] = 1
                P[0] = full_X_cython(Z,X,sigma_x,sigma_a)+lz
                Z[im,k] = 0
                P[1] = full_X_cython(Z,X,sigma_x,sigma_a)+lz0

                P=np.exp(P - max(P))
                P[0] = P[0]/(P[0]+P[1])
                if np.random.uniform(0,1,1)<P[0]:
                    Z[im,k] = 1
                else:
                    Z[im,k] = 0

            #sample K---num of new features
            new_k = new_K_cy(alpha,X,N,Z,sigma_x,sigma_a,im)[0]
            if Z.shape[1]>(K+new_k):
                Ztemp=Z
                Ztemp[im,K:(K+new_k)]=1       
            else:
                Ztemp=np.zeros((Z.shape[0],K+new_k))
                Ztemp[0:Z.shape[0],0:Z.shape[1]]=Z
                Ztemp[im,K:(K+new_k)] = 1

            Z=Ztemp
            K = K + new_k

            #sample a new sigma_x and sigma_a with MH,invgamma(2,2) prior/invgamma(1,1) proposal
            #for mh in range(0,5):
            '''propose new sigma_x'''
            current_LH = full_X_cython(Z,X,sigma_x,sigma_a)
            #sig_x_str = sigma_x + (np.random.rand(1)[0]-0.5)
            sig_x_str = 1/np.random.gamma(3,2)#propose a new sigma_x from invgamma(3,2)
            pos_str = full_X_cython(Z,X,sig_x_str,sigma_a)-3*np.log(sig_x_str)-1/(2*sig_x_str)
            pos = current_LH-3*np.log(sigma_x)-1/(2*sigma_x)
            if((pos_str-pos)>0):
                sigma_x = sig_x_str
            else:
                move = np.random.rand(1)
                if(np.log(move[0]) < (pos_str-pos)):
                    sigma_x = sig_x_str
                '''propose new sigma_a'''
            #sig_a_str = sigma_a + (np.random.rand(1)[0]-0.5)
            sig_a_str = 1/np.random.gamma(3,2)
            pos_str = full_X_cython(Z,X,sigma_x,sig_a_str)-3*np.log(sig_a_str)-1/(2*sig_a_str)
            pos = current_LH-3*np.log(sigma_a)-1/(2*sigma_a)
            if((pos_str-pos) > 0):
                sigma_a = sig_a_str
            else:
                move = np.random.rand(1)
                if(np.log(move[0]) < (pos_str-pos)):
                    sigma_a = sig_a_str

        #remove features that have only 1 object
        index = np.sum(Z,0)>1
        Z = Z[:,index]
        K = Z.shape[1]

        #store chain values                
        chain_alpha[i] = alpha
        chain_sigma_a[i] = sigma_a
        chain_sigma_x[i] = sigma_x
        chain_K[i] = K
        chain_Z.append(Z)
        
    return(chain_alpha,chain_sigma_a,chain_sigma_x,chain_K,chain_Z)

In [18]:
import time

t0 = time.time()
chain_alpha1,chain_sigma_a1,chain_sigma_x1,chain_K1,chain_Z1 = gibbs_sampler_cython(X,init_alpha=1,init_sig_x=0.5,init_sig_a=1.7,mcmc=1000)
t1 = time.time()
total_cy=t1-t0


(1, 1)
(2, 6)




(3, 6)
(4, 6)
(5, 6)
(6, 6)
(7, 6)
(8, 6)
(9, 6)
(10, 6)
(11, 6)
(12, 6)
(13, 6)
(14, 6)
(15, 6)
(16, 6)
(17, 6)
(18, 6)
(19, 6)
(20, 6)
(21, 6)
(22, 6)
(23, 6)
(24, 6)
(25, 6)
(26, 6)
(27, 6)
(28, 6)
(29, 6)
(30, 6)
(31, 6)
(32, 6)
(33, 6)
(34, 6)
(35, 6)
(36, 6)
(37, 6)
(38, 6)
(39, 6)
(40, 6)
(41, 6)
(42, 6)
(43, 6)
(44, 6)
(45, 6)
(46, 6)
(47, 6)
(48, 6)
(49, 6)
(50, 6)
(51, 6)
(52, 6)
(53, 6)
(54, 6)
(55, 6)
(56, 5)
(57, 5)
(58, 5)
(59, 5)
(60, 5)
(61, 5)
(62, 5)
(63, 5)
(64, 5)
(65, 5)
(66, 5)
(67, 5)
(68, 5)
(69, 5)
(70, 5)
(71, 4)
(72, 4)
(73, 4)
(74, 4)
(75, 4)
(76, 4)
(77, 4)
(78, 4)
(79, 4)
(80, 4)
(81, 4)
(82, 4)
(83, 4)
(84, 4)
(85, 4)
(86, 4)
(87, 4)
(88, 4)
(89, 4)
(90, 4)
(91, 4)
(92, 4)
(93, 4)
(94, 4)
(95, 4)
(96, 4)
(97, 4)
(98, 4)
(99, 4)
(100, 4)
(101, 4)
(102, 4)
(103, 4)
(104, 4)
(105, 4)
(106, 4)
(107, 4)
(108, 4)
(109, 4)
(110, 4)
(111, 4)
(112, 4)
(113, 4)
(114, 4)
(115, 4)
(116, 4)
(117, 4)
(118, 4)
(119, 4)
(120, 4)
(121, 4)
(122, 4)
(123, 4)
(124, 4)
(125, 

(925, 4)
(926, 4)
(927, 4)
(928, 4)
(929, 4)
(930, 4)
(931, 4)
(932, 4)
(933, 4)
(934, 4)
(935, 4)
(936, 4)
(937, 4)
(938, 4)
(939, 4)
(940, 4)
(941, 4)
(942, 4)
(943, 4)
(944, 4)
(945, 4)
(946, 4)
(947, 4)
(948, 4)
(949, 4)
(950, 4)
(951, 4)
(952, 4)
(953, 4)
(954, 4)
(955, 4)
(956, 4)
(957, 4)
(958, 4)
(959, 4)
(960, 4)
(961, 4)
(962, 4)
(963, 4)
(964, 4)
(965, 4)
(966, 4)
(967, 4)
(968, 4)
(969, 4)
(970, 4)
(971, 4)
(972, 4)
(973, 4)
(974, 4)
(975, 4)
(976, 4)
(977, 4)
(978, 4)
(979, 4)
(980, 4)
(981, 4)
(982, 4)
(983, 4)
(984, 4)
(985, 4)
(986, 4)
(987, 4)
(988, 4)
(989, 4)
(990, 4)
(991, 4)
(992, 4)
(993, 4)
(994, 4)
(995, 4)
(996, 4)
(997, 4)
(998, 4)
(999, 4)


In [19]:
total_cy

457.5609200000763