In [1]:
%load_ext Cython

In [20]:
%%file IBP/cython_setup.py
from distutils.core import setup
from Cython.Build import cythonize
#ext_modules = cythonize("cython_functions.pyx")
setup(
  #name = 'MyProject',
  #ext_modules = cythonize(["*.pyx"]),
    ext_modules = cythonize("cython_functions.pyx"),
)

Overwriting IBP/cython_setup.py


In [21]:
%%file IBP/cython_functions.pyx

from __future__ import division
import numpy as np
import math

np.random.seed(123)

def new_K(alpha,X,N,Z,s_x,s_a,obj):
    k_prob = np.zeros(5)
    for i in range(0,5):
        l = alpha/N
        new_zi = np.zeros((N,i))
        new_zi[obj,:] = np.ones((1,i))
        new_Z = np.hstack([Z,new_zi.reshape(N,i)])
        LH = full_X(new_Z,X,s_x, s_a)
        log_prior = i*np.log(l)-l-np.log(math.factorial(i))
        k_prob[i] = LH + log_prior#likelihood*prior = posterior
 
    k_prob = np.exp(k_prob-max(k_prob))
    k_prob = k_prob/sum(k_prob)
    if (abs(sum(k_prob)-1)>0.001):
        return(sum(k_prob),'wrong k sum')
    
    new_k = np.random.choice(5,1,p = k_prob)
    return (new_k)

def full_X(Z,X,s_x,s_a):
    D = X.shape[1]
    N = Z.shape[0]
    K = Z.shape[1]
    """The constant part"""
    zz = np.dot(Z.T,Z)+((s_x**2)/(s_a**2))*np.eye(K) #zz -- K*K
    determ = np.linalg.det(zz)
    log_const = 0.5*N*D*np.log(2*np.pi)+(N-K)*D*np.log(s_x)+K*D*np.log(s_a)+0.5*D*np.log(determ)
    log_const = -log_const
    """The exponential part"""
    ii =  np.eye(N)-np.dot(np.dot(Z , np.linalg.inv(zz)) , Z.T)
    tr =  np.trace(np.dot(np.dot(X.T , ii) ,X))
    expon = -tr/(2*s_x**2)
    return(log_const+expon)




Overwriting IBP/cython_functions.pyx


In [16]:
%%file IBP/sampler_cython.py

from __future__ import division
import numpy as np
import math
from IBP.cython_functions import new_K, full_X

def gibbs_sampler_cython(X,init_alpha,init_sig_x,init_sig_a,mcmc):
    N = X.shape[0]
    chain_alpha = np.zeros(mcmc)
    chain_sigma_a = np.zeros(mcmc)
    chain_sigma_x = np.zeros(mcmc)
    chain_K = np.zeros(mcmc)
    chain_Z = list()
    #initial matrix Z
    Z = np.array(np.random.choice(2,N,p = [0.5,0.5])).reshape(N,1)

    chain_alpha[0] = alpha = init_alpha 
    chain_sigma_a[0] = sigma_a = init_sig_a 
    chain_sigma_x[0] = sigma_x = init_sig_x
    chain_K[0] = K = 1
    chain_Z.append(Z)
    P = np.zeros(2)
    
    Hn = 0
    for i in range(1,mcmc):
        #gibbs
        alpha = np.random.gamma(1+K,1/(1+Hn))
        print(i,K)
        Hn = 0
        for im in range(0,N): #loop over images
            Hn = Hn + 1/(im+1)
            #sample new Z_i
            for k in range(0,K):#loop over features
                zk_sum = np.sum(Z[:,k])
                if zk_sum == 0:
                     lz = -10**5
                else:
                     lz = np.log(zk_sum)-np.log(N)
                if zk_sum == N:
                     lz0 = -10**5
                else:
                     lz0 = np.log(N-zk_sum)-np.log(N)
                Z[im,k] = 1
                P[0] = full_X(Z,X,sigma_x,sigma_a)+lz
                Z[im,k] = 0
                P[1] = full_X(Z,X,sigma_x,sigma_a)+lz0

                P=np.exp(P - max(P))
                P[0] = P[0]/(P[0]+P[1])
                if np.random.uniform(0,1,1)<P[0]:
                    Z[im,k] = 1
                else:
                    Z[im,k] = 0

            #sample K---num of new features
            new_k = new_K(alpha,X,N,Z,sigma_x,sigma_a,im)[0]
            if Z.shape[1]>(K+new_k):
                Ztemp=Z
                Ztemp[im,K:(K+new_k)]=1       
            else:
                Ztemp=np.zeros((Z.shape[0],K+new_k))
                Ztemp[0:Z.shape[0],0:Z.shape[1]]=Z
                Ztemp[im,K:(K+new_k)] = 1

            Z=Ztemp
            K = K + new_k

            #sample a new sigma_x and sigma_a with MH,invgamma(2,2) prior/invgamma(1,1) proposal
            #for mh in range(0,5):
            '''propose new sigma_x'''
            current_LH = full_X(Z,X,sigma_x,sigma_a)
            #sig_x_str = sigma_x + (np.random.rand(1)[0]-0.5)
            sig_x_str = 1/np.random.gamma(3,2)#propose a new sigma_x from invgamma(3,2)
            pos_str = full_X(Z,X,sig_x_str,sigma_a)-3*np.log(sig_x_str)-1/(2*sig_x_str)
            pos = current_LH-3*np.log(sigma_x)-1/(2*sigma_x)
            if((pos_str-pos)>0):
                sigma_x = sig_x_str
            else:
                move = np.random.rand(1)
                if(np.log(move[0]) < (pos_str-pos)):
                    sigma_x = sig_x_str
                '''propose new sigma_a'''
            #sig_a_str = sigma_a + (np.random.rand(1)[0]-0.5)
            sig_a_str = 1/np.random.gamma(3,2)
            pos_str = full_X(Z,X,sigma_x,sig_a_str)-3*np.log(sig_a_str)-1/(2*sig_a_str)
            pos = current_LH-3*np.log(sigma_a)-1/(2*sigma_a)
            if((pos_str-pos) > 0):
                sigma_a = sig_a_str
            else:
                move = np.random.rand(1)
                if(np.log(move[0]) < (pos_str-pos)):
                    sigma_a = sig_a_str

        #remove features that have only 1 object
        index = np.sum(Z,0)>1
        Z = Z[:,index]
        K = Z.shape[1]

        #store chain values                
        chain_alpha[i] = alpha
        chain_sigma_a[i] = sigma_a
        chain_sigma_x[i] = sigma_x
        chain_K[i] = K
        chain_Z.append(Z)
        
    return(chain_alpha,chain_sigma_a,chain_sigma_x,chain_K,chain_Z)

Overwriting IBP/sampler_cython.py


In [17]:
! python IBP/cython_setup.py build_ext --inplace

Traceback (most recent call last):
  File "IBP/cython_setup.py", line 7, in <module>
    ext_modules = cythonize("cython_functions.py"),
  File "/Users/zhangkuazhuo/anaconda2/lib/python2.7/site-packages/Cython/Build/Dependencies.py", line 920, in cythonize
    aliases=aliases)
  File "/Users/zhangkuazhuo/anaconda2/lib/python2.7/site-packages/Cython/Build/Dependencies.py", line 800, in create_extension_list
    for file in nonempty(sorted(extended_iglob(filepattern)), "'%s' doesn't match any files" % filepattern):
  File "/Users/zhangkuazhuo/anaconda2/lib/python2.7/site-packages/Cython/Build/Dependencies.py", line 125, in nonempty
    raise ValueError(error_msg)
ValueError: 'cython_functions.py' doesn't match any files


In [24]:
from __future__ import division
import numpy as np
import math
from IBP.sampler_cython import gibbs_sampler_cython
import time

#load data
X=np.genfromtxt("data_files/true_X.csv", delimiter=",")

t0 = time.time()
chain_alpha2,chain_sigma_a2,chain_sigma_x2,chain_K2,chain_Z2 = gibbs_sampler_cython(X,init_alpha=1,init_sig_x=0.5,init_sig_a=1.7,mcmc=1000)
t1 = time.time()
total=t1-t0

(1, 1)
(2, 1)
(3, 1)
(4, 2)
(5, 2)
(6, 2)
(7, 2)
(8, 2)
(9, 3)
(10, 3)
(11, 4)
(12, 4)
(13, 4)
(14, 4)
(15, 4)
(16, 4)
(17, 4)
(18, 4)
(19, 4)
(20, 4)
(21, 4)
(22, 5)
(23, 4)
(24, 4)
(25, 4)
(26, 4)
(27, 4)
(28, 4)
(29, 4)
(30, 5)
(31, 6)
(32, 5)
(33, 5)
(34, 5)
(35, 5)
(36, 5)
(37, 4)
(38, 4)
(39, 4)
(40, 4)
(41, 6)
(42, 6)
(43, 6)
(44, 6)
(45, 6)
(46, 7)
(47, 7)
(48, 6)
(49, 6)
(50, 4)
(51, 4)
(52, 4)
(53, 4)
(54, 5)
(55, 4)
(56, 4)
(57, 4)
(58, 4)
(59, 5)
(60, 5)
(61, 5)
(62, 4)
(63, 4)
(64, 4)
(65, 4)
(66, 4)
(67, 6)
(68, 4)
(69, 4)
(70, 4)
(71, 5)
(72, 5)
(73, 5)
(74, 4)
(75, 4)
(76, 5)
(77, 4)
(78, 4)
(79, 4)
(80, 4)
(81, 4)
(82, 4)
(83, 4)
(84, 4)
(85, 4)
(86, 5)
(87, 6)
(88, 5)
(89, 6)
(90, 6)
(91, 5)
(92, 5)
(93, 5)
(94, 5)
(95, 5)
(96, 5)
(97, 4)
(98, 4)
(99, 4)
(100, 4)
(101, 4)
(102, 4)
(103, 4)
(104, 4)
(105, 4)
(106, 4)
(107, 4)
(108, 4)
(109, 4)
(110, 4)
(111, 4)
(112, 4)
(113, 4)
(114, 4)
(115, 4)
(116, 5)
(117, 5)
(118, 4)
(119, 4)
(120, 4)
(121, 4)
(122, 4)
(123, 4)
(

(924, 4)
(925, 4)
(926, 4)
(927, 5)
(928, 4)
(929, 4)
(930, 4)
(931, 4)
(932, 4)
(933, 4)
(934, 4)
(935, 4)
(936, 4)
(937, 5)
(938, 4)
(939, 5)
(940, 6)
(941, 5)
(942, 5)
(943, 5)
(944, 5)
(945, 5)
(946, 5)
(947, 4)
(948, 6)
(949, 4)
(950, 4)
(951, 5)
(952, 5)
(953, 5)
(954, 5)
(955, 5)
(956, 5)
(957, 4)
(958, 5)
(959, 5)
(960, 5)
(961, 5)
(962, 4)
(963, 4)
(964, 4)
(965, 4)
(966, 5)
(967, 5)
(968, 6)
(969, 5)
(970, 4)
(971, 4)
(972, 4)
(973, 4)
(974, 4)
(975, 4)
(976, 4)
(977, 4)
(978, 4)
(979, 4)
(980, 5)
(981, 7)
(982, 6)
(983, 5)
(984, 5)
(985, 4)
(986, 5)
(987, 5)
(988, 5)
(989, 5)
(990, 6)
(991, 5)
(992, 5)
(993, 5)
(994, 5)
(995, 5)
(996, 6)
(997, 6)
(998, 6)
(999, 5)


In [25]:
total

411.7674868106842