In [12]:
%%file IBP/jit_functions.py
from numba import jit
import numpy as np

np.random.seed(123)

@jit

def full_X(Z,X,s_x,s_a):
    D = X.shape[1]
    N = Z.shape[0]
    K = Z.shape[1]
    """The constant part"""
    zz = np.dot(Z.T,Z)+((s_x**2)/(s_a**2))*np.eye(K) #zz -- K*K
    determ = np.linalg.det(zz)
    log_const = 0.5*N*D*np.log(2*np.pi)+(N-K)*D*np.log(s_x)+K*D*np.log(s_a)+0.5*D*np.log(determ)
    log_const = -log_const
    """The exponential part"""
    ii =  np.eye(N)-np.dot(np.dot(Z , np.linalg.inv(zz)) , Z.T)
    tr =  np.trace(np.dot(np.dot(X.T , ii) ,X))
    expon = -tr/(2*s_x**2)
    return(log_const+expon)

Overwriting IBP/jit_functions.py


In [13]:
import numpy as np
import math
import IBP.jit_functions as func

def new_K_jit(alpha,X,N,Z,s_x,s_a,obj):
    k_prob = np.zeros(5)
    for i in range(0,5):
        l = alpha/N
        new_zi = np.zeros((N,i))
        new_zi[obj,:] = np.ones((1,i))
        new_Z = np.hstack([Z,new_zi.reshape(N,i)])
        LH = func.full_X(new_Z,X,s_x, s_a)
        log_prior = i*np.log(l)-l-np.log(math.factorial(i))
        k_prob[i] = LH + log_prior#likelihood*prior = posterior
 
    k_prob = np.exp(k_prob-max(k_prob))
    k_prob = k_prob/sum(k_prob)
    if (abs(sum(k_prob)-1)>0.001):
        return(sum(k_prob),'wrong k sum')
    
    new_k = np.random.choice(5,1,p = k_prob)
    return (new_k)

def gibbs_sampler_jit(X,init_alpha,init_sig_x,init_sig_a,mcmc):
    
    N = X.shape[0]
    chain_alpha = np.zeros(mcmc)
    chain_sigma_a = np.zeros(mcmc)
    chain_sigma_x = np.zeros(mcmc)
    chain_K = np.zeros(mcmc)
    chain_Z = list()
    #initial matrix Z
    Z = np.array(np.random.choice(2,N,p = [0.5,0.5])).reshape(N,1)

    chain_alpha[0] = alpha = init_alpha 
    chain_sigma_a[0] = sigma_a = init_sig_a 
    chain_sigma_x[0] = sigma_x = init_sig_x
    chain_K[0] = K = 1
    chain_Z.append(Z)
    P = np.zeros(2)
    
    Hn = 0
    for i in range(1,mcmc):
        #gibbs
        alpha = np.random.gamma(1+K,1/(1+Hn))
        print(i,K)
        Hn = 0
        for im in range(0,N): #loop over images
            Hn = Hn + 1/(im+1)
            #sample new Z_i
            for k in range(0,K):#loop over features
                zk_sum = np.sum(Z[:,k])
                if zk_sum == 0:
                     lz = -10**5
                else:
                     lz = np.log(zk_sum)-np.log(N)
                if zk_sum == N:
                     lz0 = -10**5
                else:
                     lz0 = np.log(N-zk_sum)-np.log(N)
                Z[im,k] = 1
                P[0] = func.full_X(Z,X,sigma_x,sigma_a)+lz
                Z[im,k] = 0
                P[1] = func.full_X(Z,X,sigma_x,sigma_a)+lz0

                P=np.exp(P - max(P))
                P[0] = P[0]/(P[0]+P[1])
                if np.random.uniform(0,1,1)<P[0]:
                    Z[im,k] = 1
                else:
                    Z[im,k] = 0

            #sample K---num of new features
            new_k = new_K_jit(alpha,X,N,Z,sigma_x,sigma_a,im)[0]
            if Z.shape[1]>(K+new_k):
                Ztemp=Z
                Ztemp[im,K:(K+new_k)]=1       
            else:
                Ztemp=np.zeros((Z.shape[0],K+new_k))
                Ztemp[0:Z.shape[0],0:Z.shape[1]]=Z
                Ztemp[im,K:(K+new_k)] = 1

            Z=Ztemp
            K = K + new_k

            #sample a new sigma_x and sigma_a with MH,invgamma(2,2) prior/invgamma(1,1) proposal
            #for mh in range(0,5):
            '''propose new sigma_x'''
            current_LH = func.full_X(Z,X,sigma_x,sigma_a)
            #sig_x_str = sigma_x + (np.random.rand(1)[0]-0.5)
            sig_x_str = 1/np.random.gamma(3,2)#propose a new sigma_x from invgamma(3,2)
            pos_str = func.full_X(Z,X,sig_x_str,sigma_a)-3*np.log(sig_x_str)-1/(2*sig_x_str)
            pos = current_LH-3*np.log(sigma_x)-1/(2*sigma_x)
            if((pos_str-pos)>0):
                sigma_x = sig_x_str
            else:
                move = np.random.rand(1)
                if(np.log(move[0]) < (pos_str-pos)):
                    sigma_x = sig_x_str
                '''propose new sigma_a'''
            #sig_a_str = sigma_a + (np.random.rand(1)[0]-0.5)
            sig_a_str = 1/np.random.gamma(3,2)
            pos_str = func.full_X(Z,X,sigma_x,sig_a_str)-3*np.log(sig_a_str)-1/(2*sig_a_str)
            pos = current_LH-3*np.log(sigma_a)-1/(2*sigma_a)
            if((pos_str-pos) > 0):
                sigma_a = sig_a_str
            else:
                move = np.random.rand(1)
                if(np.log(move[0]) < (pos_str-pos)):
                    sigma_a = sig_a_str

        #remove features that have only 1 object
        index = np.sum(Z,0)>1
        Z = Z[:,index]
        K = Z.shape[1]

        #store chain values                
        chain_alpha[i] = alpha
        chain_sigma_a[i] = sigma_a
        chain_sigma_x[i] = sigma_x
        chain_K[i] = K
        chain_Z.append(Z)
        
    return(chain_alpha,chain_sigma_a,chain_sigma_x,chain_K,chain_Z)

In [14]:
#load data
X=np.genfromtxt("data_files/true_X.csv", delimiter=",")

In [15]:
import time
t0 = time.time()
chain_alpha,chain_sigma_a,chain_sigma_x,chain_K,chain_Z = gibbs_sampler_jit(X,init_alpha=1,init_sig_x=0.5,init_sig_a=1.7,mcmc=1000)
t1 = time.time()
total=t1-t0

(1, 1)
(2, 3)


  del sys.path[0]
  del sys.path[0]


(3, 3)
(4, 3)
(5, 3)
(6, 3)
(7, 3)
(8, 3)
(9, 3)
(10, 3)
(11, 3)
(12, 3)
(13, 3)
(14, 3)
(15, 3)
(16, 3)
(17, 3)
(18, 3)
(19, 3)
(20, 3)
(21, 3)
(22, 3)
(23, 3)
(24, 3)
(25, 3)
(26, 3)
(27, 3)
(28, 3)
(29, 3)
(30, 3)
(31, 3)
(32, 3)
(33, 3)
(34, 3)
(35, 3)
(36, 3)
(37, 3)
(38, 3)
(39, 3)
(40, 3)
(41, 3)
(42, 3)
(43, 3)
(44, 3)
(45, 3)
(46, 3)
(47, 3)
(48, 3)
(49, 3)
(50, 3)
(51, 3)
(52, 3)
(53, 3)
(54, 3)
(55, 3)
(56, 3)
(57, 3)
(58, 3)
(59, 3)
(60, 3)
(61, 3)
(62, 3)
(63, 3)
(64, 3)
(65, 3)
(66, 3)
(67, 3)
(68, 3)
(69, 3)
(70, 3)
(71, 3)
(72, 3)
(73, 3)
(74, 3)
(75, 3)
(76, 3)
(77, 3)
(78, 3)
(79, 3)
(80, 3)
(81, 3)
(82, 3)
(83, 3)
(84, 3)
(85, 3)
(86, 3)
(87, 3)
(88, 3)
(89, 3)
(90, 3)
(91, 3)
(92, 3)
(93, 3)
(94, 3)
(95, 3)
(96, 3)
(97, 3)
(98, 3)
(99, 3)
(100, 3)
(101, 3)
(102, 3)
(103, 3)
(104, 3)
(105, 3)
(106, 3)
(107, 3)
(108, 3)
(109, 3)
(110, 3)
(111, 3)
(112, 3)
(113, 3)
(114, 3)
(115, 3)
(116, 3)
(117, 3)
(118, 3)
(119, 3)
(120, 3)
(121, 3)
(122, 3)
(123, 3)
(124, 3)
(125, 

(926, 3)
(927, 3)
(928, 3)
(929, 3)
(930, 3)
(931, 3)
(932, 3)
(933, 3)
(934, 3)
(935, 3)
(936, 3)
(937, 3)
(938, 3)
(939, 3)
(940, 3)
(941, 3)
(942, 3)
(943, 3)
(944, 3)
(945, 3)
(946, 3)
(947, 3)
(948, 3)
(949, 3)
(950, 3)
(951, 3)
(952, 3)
(953, 3)
(954, 3)
(955, 3)
(956, 3)
(957, 3)
(958, 3)
(959, 3)
(960, 3)
(961, 3)
(962, 3)
(963, 3)
(964, 3)
(965, 3)
(966, 3)
(967, 3)
(968, 3)
(969, 3)
(970, 3)
(971, 3)
(972, 3)
(973, 3)
(974, 3)
(975, 3)
(976, 3)
(977, 3)
(978, 3)
(979, 3)
(980, 3)
(981, 3)
(982, 3)
(983, 3)
(984, 3)
(985, 3)
(986, 3)
(987, 3)
(988, 3)
(989, 3)
(990, 3)
(991, 3)
(992, 3)
(993, 3)
(994, 3)
(995, 3)
(996, 3)
(997, 3)
(998, 3)
(999, 3)


In [16]:
total

154.85736322402954