In [1]:
import sys
sys.path.insert(0, '..')

In [2]:
import numpy as np
import numpy.linalg as npl

from scipy.spatial import distance_matrix

from tqdm import tqdm
from functools import partial
from typing import Callable, Tuple

import holoviews as hv
hv.extension('matplotlib')

In [3]:
from solversuperres_v3.data_utils        import semi_gridded_init
from solversuperres_v3.g_utils           import gradient_g, g
from solversuperres_v3.descent_utilities import clip_domain, projection_X
from solversuperres_v3.opCOMP_init       import opCOMP

import solversuperres_v3.FISTA_restart_descent_v2 as FISTA
import solversuperres_v3.init_utils               as init

from solversuperres_v3.linop_Gaussian_MATIRF        import           Gaussian2D_MATIRF
from solversuperres_v3.linop_Gaussian_MATIRF_Sketch import Sketching_Gaussian2D_MATIRF

In [4]:
# sigma_x = sigma_y = 0.04
sigma_x = sigma_y = None
m_gaussian = 64
m_sketch = 18 * 20 # K * m_spike

linop_gauss  = Gaussian2D_MATIRF(sigma_x, sigma_y, N1=m_gaussian, N2=m_gaussian)
linop_sketch = Sketching_Gaussian2D_MATIRF(m_sketch, linop_gauss.sigma_x, linop_gauss.sigma_y)

In [5]:
d = 3

In [6]:
K = 18
density =np.array([3, 3, 2])
a = np.random.uniform(low=1, high=2, size=K)
t = np.random.uniform(low=0.1, high=0.9, size=(K, d)) * linop_gauss.bounds['max']
t = semi_gridded_init(density=density)
factor = .8
shift = linop_gauss.b1 * ((1 - factor) / 2)
t = t * np.array([factor, factor, 1]) + np.array([shift, shift, 0])
# t = np.array([[.2, .2, .5], [.8, .2, .1], [.2, .8, .9], [.5, .5, .5], [.8, .8, .4]]) * linop_gauss.bounds['max']

In [7]:
np.min(distance_matrix(t, t) + np.eye(K) * 100)

0.8327139935538276

In [8]:
y = linop_gauss.Ax(a, t)
y_image = linop_gauss.image(y)
hv.NdLayout({k: hv.Image(np.flipud(y_image[..., k]), bounds=(0, 0, 1, 1)).opts(colorbar=True) for k in range(linop_gauss.K)})

In [9]:
y_sketch = linop_sketch.Ax(a, t)
y_sketch_image = linop_sketch.image(linop_gauss.grid, y_sketch)

In [10]:
hv.NdLayout({k: hv.Image(np.flipud(y_sketch_image[..., k]), bounds=(0, 0, 1, 1)).opts(colorbar=True) for k in range(linop_gauss.K)})

In [11]:
t_grid = np.copy(linop_gauss.grid)
a_grid = np.copy(y)
t_grid.shape, a_grid.shape

((64, 64, 2), (16384,))

In [12]:
y_fourier = linop_sketch.DFT_DiracComb(a_grid, t_grid)
y_fourier_image = linop_sketch.image(linop_gauss.grid, y_fourier)
hv.NdLayout({k: hv.Image(np.flipud(y_fourier_image[..., k]), bounds=(0, 0, 1, 1)).opts(colorbar=True) for k in range(linop_gauss.K)})

In [13]:
hv.NdLayout({k: hv.Image(np.flipud(y_sketch_image[..., k] - y_fourier_image[..., k]), bounds=(0, 0, 1, 1)).opts(colorbar=True, symmetric=True, cmap='seismic') for k in range(linop_gauss.K)})

## With Sketching

In [14]:
functional = partial(
    g, y=y_fourier, linop=linop_sketch
)
gradient_functional = partial(
    gradient_g, y=y_fourier, linop=linop_sketch
)
#proj = lambda X: X
proj = partial(
    projection_X,
    eps_proj=0.01,
    cut_off=3e-2
)

clip = lambda X: X

update_residue = partial(
    FISTA.update_residue_from_y,
    y=y_fourier, linop=linop_sketch
)

exit_cond = partial(
    FISTA.exit_cond,
    abs_tol_cst=1e-8,
    rel_tol_cst=1e-10,
)

In [15]:
a_init, t_init, errors_opCOMP, residue = opCOMP(
    y=y_fourier, linop=linop_sketch,
    step=1e-4, nb_tests=1,
    min_iter=K, max_iter=3*K,
    descent_nit=1_000,
    tol_criterion=0.08,
    init_position=init.init_position_random,
    multicore=False
)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 54/54 [00:14<00:00,  3.64it/s]


In [16]:
X_init = np.concatenate((t_init, a_init[:, None]), axis=-1)

In [17]:
hv.Curve(errors_opCOMP).opts(logy=True)

In [18]:
hv.Curve((errors_opCOMP / errors_opCOMP[0])).opts(logy=True)

In [19]:
residue_image = linop_sketch.image(linop_gauss.grid, residue)
hv.NdLayout({k: hv.Image(np.flipud(residue_image[..., k]), bounds=(0, 0, 1, 1)).opts(colorbar=True) for k in range(linop_gauss.K)})

In [20]:
X_true = np.concatenate((t, a[:, None]), axis=-1)

In [21]:
point_opts = dict(
    xlim=(0, linop_sketch.b1),
    ylim=(0, linop_sketch.b2),
    zlim=(0, linop_sketch.b3), 
    s=abs(hv.dim('size') * 50)
)
points_true = hv.Scatter3D(X_true, vdims=['size']).opts(**point_opts, c='g', marker='+')
points_init = hv.Scatter3D(X_init, vdims=['size']).opts(**point_opts, c='b', marker='*')
points_true + points_init

In [22]:
(points_true * points_init).opts(fig_size=500)

In [23]:
X_esti, error_PGD, error_PGD_norm, time_diffs = FISTA.FISTA_restart(
    X=X_init, nit=3_000, step=1e-4, 
    functional=functional,
    gradient_functional=gradient_functional,
    exit_cond=exit_cond, #exit_cond=FISTA.exit_cond,
    restart_cond=FISTA.restart_cond,
    project=proj, clip=clip, 
    functional_y=npl.norm(y_fourier)**2
)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3000/3000 [00:20<00:00, 143.01it/s]


In [24]:
hv.Curve(error_PGD).opts(logy=True)

In [25]:
hv.Curve(error_PGD_norm**2).opts(logy=True)

In [26]:
errors_opCOMP[-2], error_PGD[-1]

(70.59217478310457, 1.09170147473188e-06)

In [27]:
(abs(error_PGD[-1] - error_PGD[-2]) /  abs(error_PGD[-2]))

0.0031686912855449782

In [28]:
point_opts = dict(
    xlim=(0, linop_sketch.b1),
    ylim=(0, linop_sketch.b2),
    zlim=(0, linop_sketch.b3), 
    s=abs(hv.dim('size') * 50)
)
points_true = hv.Scatter3D(X_true, vdims=['size']).opts(**point_opts, c='g', marker='+')
points_init = hv.Scatter3D(X_init, vdims=['size']).opts(**point_opts, c='b', marker='*')
points_esti = hv.Scatter3D(X_esti, vdims=['size']).opts(**point_opts, c='r', marker='x')
points_true + points_init + points_esti

In [29]:
(points_true * points_esti).opts(fig_size=500)

In [30]:
residue_PGD = y_fourier - linop_sketch.Ax(X_esti[:, -1], X_esti[:, :-1])
residue_PGD_image = linop_sketch.image(linop_gauss.grid, residue_PGD)
hv.NdLayout({k: hv.Image(np.flipud(residue_PGD_image[..., k]), bounds=(0, 0, 1, 1)).opts(colorbar=True) for k in range(linop_gauss.K)})

## No sketching

In [31]:
functional = partial(
        g, y=y, linop=linop_gauss
)
gradient_functional = partial(
    gradient_g, y=y, linop=linop_gauss
)
#proj = lambda X: X
proj = partial(
    projection_X,
    eps_proj=0.01,
    cut_off=3e-2
)

clip = lambda X: X

update_residue = partial(
    FISTA.update_residue_from_y,
    y=y, linop=linop_gauss
)

exit_cond = partial(
    FISTA.exit_cond,
    abs_tol_cst=1e-8, 
    rel_tol_cst=1e-10,
)

In [32]:
a_init, t_init, errors_opCOMP, residue = opCOMP(
    y=y, linop=linop_gauss,
    step=1e-4, nb_tests=1,
    min_iter=K, max_iter=3*K,
    descent_nit=1_000,
    tol_criterion=0.08,
    init_position=init.init_position_max_val,
    multicore=False,
)

 31%|█████████████████████████████████████████████▉                                                                                                    | 17/54 [00:01<00:02, 12.87it/s]


In [33]:
X_init = np.concatenate((t_init, a_init[:, None]), axis=-1)

In [34]:
hv.Curve(errors_opCOMP).opts(logy=True)

In [35]:
residue_image = linop_gauss.image(residue)
hv.Image(np.flipud(residue_image[..., 0]), bounds=(0, 0, 1, 1)).opts(colorbar=True)

In [36]:
X_esti, error_PGD, error_PGD_norm, time_diffs = FISTA.FISTA_restart(
    X=X_init, nit=3_000, step=1e-5, 
    functional=functional,
    gradient_functional=gradient_functional,
    exit_cond=exit_cond, #FISTA.exit_cond,
    restart_cond=FISTA.restart_cond,
    project=proj, clip=clip,
    functional_y=npl.norm(y)**2
)

  9%|████████████▋                                                                                                                                  | 266/3000 [00:09<01:34, 29.02it/s]


In [37]:
hv.Curve(error_PGD).opts(logy=True)

In [38]:
hv.Curve(error_PGD_norm).opts(logy=True)

In [39]:
error_PGD[-2], error_PGD[-1]

(1.009484364294197e-12, 9.858300113687312e-13)

In [40]:
error_PGD_norm[-2], error_PGD_norm[-1]

(1.0003150127816274e-08, 9.885257940094786e-09)

In [41]:
X_true = np.concatenate((t, a[:, None]), axis=-1)

In [42]:
point_opts = dict(
    xlim=(0, linop_gauss.b1),
    ylim=(0, linop_gauss.b2),
    zlim=(0, linop_gauss.b3), 
    s=abs(hv.dim('size') * 50)
)
points_true = hv.Scatter3D(X_true, vdims=['size']).opts(**point_opts, c='g', marker='+')
points_init = hv.Scatter3D(X_init, vdims=['size']).opts(**point_opts, c='b', marker='*')
points_esti = hv.Scatter3D(X_esti, vdims=['size']).opts(**point_opts, c='r', marker='x')
points_true + points_init + points_esti

In [43]:
(points_true * points_esti).opts(fig_size=500)