In [None]:
import numpy as np
import scipy
from scipy import stats

from init import init_Y
from metrics import X_in_A, rho, eta, distance, ESS, prop_alive_func
from stats import sampling_params, random_walk_on_params, pi_n, X_update, metropolis_X_theta

In [None]:
# Step 0 : initialisation
alpha = 0.9
N = 1000
Nt = 500
Y = init_Y()
Xs = [[1]]*N
epsilon_final = 0.00045
epsilon_t = np.inf
random_walk_stds = (0.05, 0.05, 0.7)
Ws = 1/N * np.ones(N)
prop_alive = prop_alive_func(Ws)
thetas = np.array([sampling_params() for _ in range(N)])

for _ in range(2): # We do updates so the Xs are not all the same at the beginning
    for index, X in enumerate(Xs): 
        phi, tau, xi = thetas[index]
        new_X, new_phi, new_tau, new_xi = metropolis_X_theta(X, Y, phi, tau, xi,
                                                             epsilon_t, random_walk_stds)
        Xs[index], thetas[index] = new_X, (new_phi, new_tau, new_xi)
    
results = {'Xs': [Xs], 'Ws': [Ws], 'thetas': [thetas], 'epsilon_list': [epsilon_t]}

In [None]:
t = 0
while epsilon_t > epsilon_final:

    # Step 1 : sampling
    print(f'Step 1 : t={t}') 
    distance_list =  np.sort(list(map(distance, Xs, [Y]*N))) # Compute all distance to Y
    epsilon_t = distance_list[int(alpha*len(distance_list))] # find the alpha qantile of the distance list
    Ws *= np.array([int(X_in_A(Y, X, epsilon_t)) for X in Xs]) # Compute the new weights

    # Step 2 : resampling
    if prop_alive_func(Ws) < 0.5: 
        print('Step 2 : Resampling')
        indices = np.random.choice(range(N), size=N, replace=True, p=(Ws/np.sum(Ws)))
        Xs = [Xs[i] for i in indices]
        thetas = [thetas[i] for i in indices]
        Ws = 1/N * np.ones(N)
        
    # Step 3 : random walk
    print('Step 3')
    for index, X in enumerate(Xs):
        if Ws[index] > 0:
            phi, tau, xi = thetas[index]
            new_X, new_phi, new_tau, new_xi = metropolis_X_theta(X, Y, phi, tau, xi,
                                                                 epsilon_t, random_walk_stds)
            Xs[index], thetas[index] = new_X, (new_phi, new_tau, new_xi)
    t+=1
    print(f'espilon_t = {epsilon_t}, epsilon_final = {epsilon_final}')

    results['Xs'].append(Xs)
    results['Ws'].append(Ws)
    results['thetas'].append(thetas)
    results['epsilon_list'].append(epsilon_t)
    # save results to file
    np.save('results.npy', results)