### Import librairies

In [None]:
import numpy as np
from tqdm import tqdm

from functools import partial
from scipy.spatial import distance_matrix

import holoviews as hv
hv.extension('matplotlib')

In [None]:
from linop_MA_TIRF import MA_TIRF

from g_utils import gradient_g, g

from opCOMP_init import opCOMP
from sCOMP_init import sCOMP
from descent_utilities import clip_domain, project_theta_eps

from FISTA_restart_descent import FISTA_restart
from init_utils import FISTA_restart_single_spike, init_position_max_val
from data_utils import create_batches

from analyse_utils import pair_GT_estimation, compute_metrics, RMSE

In [None]:
def projection(X, eps_proj, cut_off=1e-2):
    a, t = X[:, -1], X[:, :-1]
    a, t = project_theta_eps(a, t, eps_proj, cut_off)
    return np.concatenate((t, a[:, None]), axis=-1)

### Settings variables

In [None]:
seed = 10
np.random.seed(seed)

N = 10
k_min = N
k_max = 3 * N
nb_batch = 1

single_batch = nb_batch == 1

lambda_l = 0.66
N1, N2 = 64, 64
K_angles = 4

In [None]:
linop = MA_TIRF(lambda_l=lambda_l, N1=N1, N2=N2, K=K_angles)
batches = create_batches(nbatch=N, plafrim_path=False)

T0 = np.zeros([0, 3])
A0 = np.zeros([0])

T_INIT = np.zeros([0, 3])
A_INIT = np.zeros([0])

T_ESTI = np.zeros([0, 3])
A_ESTI = np.zeros([0])

In [None]:
t0 = batches[0]

In [None]:
T0 = np.concatenate((T0, t0), axis=0)
a0 = np.random.uniform(1, 1.5, N)
A0 = np.concatenate((A0, a0))

min_dist = np.min(distance_matrix(t0, t0) + np.eye(N))
eps_proj = min_dist * 0.75

y = linop.Ax(a0, t0)

### Initialization and estimation : Sliding COMP

In [None]:
clip = partial(
    clip_domain,
    linop=linop
)

a_init, t_init, errors, r = sCOMP(
    y, linop=linop,
    step=.2, nb_tests=1,
    max_iter=k_min,
    min_iter=k_min + 1,
    clip=clip,
    disable_tqdm_init=not single_batch
)

In [None]:
T_INIT = np.concatenate((T_INIT, t_init), axis=0)
A_INIT = np.concatenate((A_INIT, a_init))

X_init = np.concatenate((t_init, a_init[:, None]), axis=-1)

### Visualization

In [None]:
X_true = np.concatenate((t0, a0[:, None]), axis=-1)
plot_true = hv.Scatter3D(X_true, label='GT').opts(marker='+', s=100)
plot_init = hv.Scatter3D(X_init, label='Esti Sliding COMP').opts(marker='1', s=100)

In [None]:
(plot_true * plot_init).opts(fig_size=300)

In [None]:
plot_error_init = hv.Curve(errors, label='Error Sliding COMP').opts(logy=True)
plot_error_init