In [27]:
%%file jit_functions.py
from numba import jit
import numpy as np

np.random.seed(123)

@jit

def full_X(Z,X,s_x,s_a):
    D = X.shape[1]
    N = Z.shape[0]
    K = Z.shape[1]
    """The constant part"""
    zz = Z.T@Z+np.diag([(s_x**2)/(s_a**2)]*K) #zz -- K*K
    determ = np.linalg.det(zz)
    log_const = 0.5*N*D*np.log(2*np.pi)+(N-K)*D*np.log(s_x)+K*D*np.log(s_a)+0.5*D*np.log(determ)
    log_const = -log_const
    """The exponential part"""
    L = np.linalg.cholesky(zz)
    inv_L = np.linalg.inv(L)
    ii =  np.eye(N)-(Z @ inv_L.T @ inv_L @ Z.T)
    tr =  np.trace(X.T @ ii @X)
    expon = -tr/(2*s_x**2)
    return(log_const+expon)

Overwriting jit_functions.py


In [28]:
import numpy as np
import math
import jit_functions as func

def new_K_jit(alpha,X,N,Z,s_x,s_a,obj):
    k_prob = np.zeros(5)
    for i in range(0,5):
        l = alpha/N
        new_zi = np.zeros((N,i))
        new_zi[obj,:] = np.ones((1,i))
        new_Z = np.hstack([Z,new_zi.reshape(N,i)])
        LH = func.full_X(new_Z,X,s_x, s_a)
        log_prior = i*np.log(l)-l-np.log(math.factorial(i))
        k_prob[i] = LH + log_prior#likelihood*prior = posterior
 
    k_prob = np.exp(k_prob-max(k_prob))
    k_prob = k_prob/sum(k_prob)
    if (abs(sum(k_prob)-1)>0.001):
        return(sum(k_prob),'wrong k sum')
    
    new_k = np.random.choice(5,1,p = k_prob)
    return (new_k)

def gibbs_sampler_jit(X,init_alpha,init_sig_x,init_sig_a,seed,mcmc):
    N = X.shape[0]
    chain_alpha = np.zeros(mcmc)
    chain_sigma_a = np.zeros(mcmc)
    chain_sigma_x = np.zeros(mcmc)
    chain_K = np.zeros(mcmc)
    chain_Z = list()
    #initial matrix Z
    Z = np.array(np.random.choice(2,N,p = [0.5,0.5])).reshape(N,1)

    chain_alpha[0] = alpha = init_alpha 
    chain_sigma_a[0] = sigma_a = init_sig_a 
    chain_sigma_x[0] = sigma_x = init_sig_x
    chain_K[0] = K = 1
    chain_Z.append(Z)
    P = np.zeros(2)
    
    Hn = 0
    for i in range(1,mcmc):
        #gibbs
        alpha = np.random.gamma(1+K,1/(1+Hn))
        print(i,K)
        Hn = 0
        for im in range(0,N): #loop over images
            Hn = Hn + 1/(im+1)
            #sample new Z_i
            for k in range(0,K):#loop over features
                zk_sum = np.sum(Z[:,k])
                if zk_sum == 0:
                     lz = -10**5
                else:
                     lz = np.log(zk_sum)-np.log(N)
                if zk_sum == N:
                     lz0 = -10**5
                else:
                     lz0 = np.log(N-zk_sum)-np.log(N)
                Z[im,k] = 1
                P[0] = func.full_X(Z,X,sigma_x,sigma_a)+lz
                Z[im,k] = 0
                P[1] = func.full_X(Z,X,sigma_x,sigma_a)+lz0

                P=np.exp(P - max(P))
                P[0] = P[0]/(P[0]+P[1])
                if np.random.uniform(0,1,1)<P[0]:
                    Z[im,k] = 1
                else:
                    Z[im,k] = 0

            #sample K---num of new features
            new_k = new_K_jit(alpha,X,N,Z,sigma_x,sigma_a,im)[0]
            if Z.shape[1]>(K+new_k):
                Ztemp=Z
                Ztemp[im,K:(K+new_k)]=1       
            else:
                Ztemp=np.zeros((Z.shape[0],K+new_k))
                Ztemp[0:Z.shape[0],0:Z.shape[1]]=Z
                Ztemp[im,K:(K+new_k)] = 1

            Z=Ztemp
            K = K + new_k

            #sample a new sigma_x and sigma_a with MH,invgamma(2,2) prior/invgamma(1,1) proposal
            #for mh in range(0,5):
            '''propose new sigma_x'''
            current_LH = func.full_X(Z,X,sigma_x,sigma_a)
            #sig_x_str = sigma_x + (np.random.rand(1)[0]-0.5)
            sig_x_str = 1/np.random.gamma(3,2)#propose a new sigma_x from invgamma(3,2)
            pos_str = func.full_X(Z,X,sig_x_str,sigma_a)-3*np.log(sig_x_str)-1/(2*sig_x_str)
            pos = current_LH-3*np.log(sigma_x)-1/(2*sigma_x)
            if((pos_str-pos)>0):
                sigma_x = sig_x_str
            else:
                move = np.random.rand(1)
                if(np.log(move[0]) < (pos_str-pos)):
                    sigma_x = sig_x_str
                '''propose new sigma_a'''
            #sig_a_str = sigma_a + (np.random.rand(1)[0]-0.5)
            sig_a_str = 1/np.random.gamma(3,2)
            pos_str = func.full_X(Z,X,sigma_x,sig_a_str)-3*np.log(sig_a_str)-1/(2*sig_a_str)
            pos = current_LH-3*np.log(sigma_a)-1/(2*sigma_a)
            if((pos_str-pos) > 0):
                sigma_a = sig_a_str
            else:
                move = np.random.rand(1)
                if(np.log(move[0]) < (pos_str-pos)):
                    sigma_a = sig_a_str

        #remove features that have only 1 object
        index = np.sum(Z,0)>1
        Z = Z[:,index]
        K = Z.shape[1]

        #store chain values                
        chain_alpha[i] = alpha
        chain_sigma_a[i] = sigma_a
        chain_sigma_x[i] = sigma_x
        chain_K[i] = K
        chain_Z.append(Z)
        
    return(chain_alpha,chain_sigma_a,chain_sigma_x,chain_K,chain_Z)

In [29]:
#load data
X=np.genfromtxt("data_files/true_X.csv", delimiter=",")

In [30]:
import time
t0 = time.time()
chain_alpha,chain_sigma_a,chain_sigma_x,chain_K,chain_Z = gibbs_sampler_jit(X,init_alpha=1,init_sig_x=0.5,init_sig_a=1.7,seed=123,mcmc=1000)
t1 = time.time()
total=t1-t0
total

1 1
2 3
3 4
4 5
5 5
6 5
7 5
8 5
9 6
10 5
11 5
12 5
13 5
14 5
15 5
16 5
17 5
18 5
19 5
20 5
21 5
22 6
23 6
24 8
25 6
26 6
27 6
28 6
29 5
30 5
31 6
32 6
33 7
34 6
35 5
36 5
37 5
38 5
39 5
40 5
41 6
42 6
43 6
44 5
45 5
46 6
47 7
48 6
49 6
50 6
51 6
52 5
53 5
54 5
55 5
56 5
57 6
58 6
59 6
60 7
61 6
62 6
63 6
64 6
65 5
66 5
67 5
68 5
69 5
70 6
71 6
72 6
73 6
74 7
75 6
76 5
77 5
78 5
79 5
80 6
81 6
82 5
83 5
84 6
85 6
86 5
87 6
88 6
89 6
90 6
91 7
92 6
93 6
94 5
95 5
96 5
97 5
98 5
99 5
100 5
101 5
102 5
103 5
104 5
105 6
106 5
107 5
108 5
109 5
110 5
111 6
112 5
113 6
114 5
115 5
116 5
117 5
118 5
119 5
120 5
121 6
122 5
123 5
124 5
125 5
126 5
127 5
128 5
129 5
130 5
131 6
132 6
133 5
134 6
135 5
136 6
137 5
138 5
139 5
140 6
141 6
142 5
143 6
144 6
145 5
146 5
147 5
148 5
149 6
150 6
151 6
152 7
153 7
154 6
155 6
156 8
157 7
158 7
159 6
160 6
161 6
162 6
163 6
164 6
165 6
166 7
167 7
168 6
169 8
170 6
171 6
172 6
173 6
174 6
175 6
176 6
177 7
178 6
179 6
180 6
181 7
182 6
183 5
184 5
185 

886.8754632472992