In [1]:
import sys
sys.path.insert(0, '..')

In [2]:
import numpy as np
import numpy.linalg as npl

from functools import partial

import holoviews as hv
hv.extension('matplotlib')

In [3]:
from solversuperres_v3.linop_Gaussian_MATIRF        import           Gaussian2D_MATIRF
from solversuperres_v3.linop_Gaussian_MATIRF_Sketch import Sketching_Gaussian2D_MATIRF

from solversuperres_v3.analyse_utils import RMSE, pair_GT_estimation

In [4]:
def load_data(file_name, allow_pickle=False, ignore_traj_X=True):
    dict_data = dict()
    object_data = np.load(file_name, allow_pickle=allow_pickle)

    for key in object_data.keys():
        if key == 'traj_X' and ignore_traj_X:
            continue

        dict_data[key] = object_data[key]

    return dict_data

# Results : comparison between sketched and classical linear operator

## Saving wrapper

In [5]:
save = False

def save_fig(fig, name):
    hv.save(fig, 'figs/' + f'fig_' + name, fmt='svg')

## Parameters

In [6]:
N = 50
rep = 2
m = 12

In [7]:
name_prefix_sketch  = f"data_k{N:03d}_m{m:02d}_rep{rep:02d}_"
name_prefix_classic = f"data_classical_k{N:03d}_rep{rep:02d}_"
name_suffix = ".npz"
directory = "data/"

file_name_true_data           = directory + name_prefix_sketch  + "true"    + name_suffix

file_name_opCOMP_data_sketch  = directory + name_prefix_sketch  + "opCOMP"  + name_suffix
file_name_PGD_data_sketch     = directory + name_prefix_sketch  + "PGD"     + name_suffix

file_name_opCOMP_data_classic = directory + name_prefix_classic + "opCOMP"  + name_suffix
file_name_PGD_data_classic    = directory + name_prefix_classic + "PGD"     + name_suffix

## Linear operators

In [8]:
sigma_x = sigma_y = None
m_gaussian = 64
m_sketch = m * N

linop_gauss  = Gaussian2D_MATIRF(sigma_x, sigma_y, N1=m_gaussian, N2=m_gaussian)
linop_sketch = Sketching_Gaussian2D_MATIRF(m_sketch, linop_gauss.sigma_x, linop_gauss.sigma_y)

## Display Functions

In [9]:
opts_true = dict(marker = 'x', color = 'r')
opts_init = dict(marker = '1', color = 'g')
opts_esti = dict(marker = '+', color = 'b')
fig_latex = True

In [10]:
def plot_observation(y_image, **kwargs):
    bounds = (0, 0, 6.4, 6.4)
    opts_image = dict(
        colorbar=True,
        fig_latex=fig_latex,
    )
    opts_image.update(kwargs)
    options = dict(
        fig_latex = fig_latex
    )
    dict_images = {k + 1: hv.Image(np.flipud(y_image[..., k]), bounds=bounds).opts(**opts_image) for k in range(linop_gauss.K)}
    plot = hv.NdLayout(dict_images, kdims='2D Grid number').cols(2)
    return plot.opts(**options)

In [11]:
def plot_curve_error(
    error, method='', sketching=False,
    normalized=False, logy=True
):
    xlabel = 'Number of iterations'
    xlabel += f' by {method}' if method else ''
    xlabel += '\n' + (' with' if sketching else ' without') + ' sketching'
    ylabel = 'Norm of residue'
    ylabel += '\nnormalized' if normalized else ''
    options = dict(
        logy=logy,
        show_grid=True,
        xlabel=xlabel,
        ylabel=ylabel,
        aspect=.6,
        fig_latex = fig_latex,
    )
    plot = hv.Curve(error)

    return plot.opts(**options)

In [12]:
def plot_functional_wrt_time(error, time_dict, method=''):
    options = dict(
        logy=True,
        xlabel='Time elapsed (in seconds)',
        ylabel='Norm of residue normalized',
        show_grid=True,
        aspect=.6,
        fig_latex=fig_latex,
    )

    time_keys = sorted(time_dict.keys())
    time_values = [time_dict[key] for key in time_keys]
    time_values = np.cumsum(time_values) / 10**9

    time_values_interp = np.interp(np.arange(error.size), time_keys, time_values)

    curve = hv.Curve({'x': time_values_interp, 'y': error})

    return curve.opts(**options)

In [13]:
def plot_points_3D(
    X, marker, color, label='',
    xlim=(0, 6.4), ylim=(0, 6.4), zlim=(0, 0.8),
    padding=-2,
):
    def hook_pad_labels(plot, element):
        plot.handles['axis'].tick_params(axis='x', pad=-1)
        plot.handles['axis'].tick_params(axis='y', pad=-1)
        plot.handles['axis'].tick_params(axis='z', pad=0)
        plot.handles['axis'].set_xlabel('x', labelpad=-4)
        plot.handles['axis'].set_ylabel('y', labelpad=-4)
        plot.handles['axis'].set_zlabel('z', labelpad=-4)

    options = dict(
        s=abs(hv.dim('size')) * 200,
        marker=marker, c=color,
        xlim=xlim, ylim=ylim, zlim=zlim,
        #xlabel="x", ylabel="y", zlabel="z",
        legend_position='top_left',
        hooks=[hook_pad_labels],
        fig_latex = fig_latex,
    )
    kdims = ['x', 'y', 'z'] # ['$x_1$','$x_2$','$x_3$']
    plot = hv.Scatter3D(X, kdims=kdims, vdims=['size'], label=label)

    return plot.opts(**options)

In [14]:
def plot_points_2D(
    X, marker, color, label='',
    xlim=(0, 6.4), ylim=(0, 6.4)
):
    options = dict(
        s=abs(hv.dim('size')) * 100,
        marker=marker, color=color,
        xlim=xlim, ylim=ylim,
        xlabel="$x$", ylabel="$y$",
        legend_position='top_left',
        fig_latex = fig_latex,
    )
    plot = hv.Scatter(X, vdims=['y', 'z', 'size'], label=label)

    return plot.opts(**options)

## Classical linop

### Data to recover

In [15]:
object_true_data = load_data(file_name_true_data)
X_true = object_true_data['X_true']

In [16]:
points_true = plot_points_3D(X_true, **opts_true)
points_true

In [17]:
X_true.shape

(50, 4)

In [18]:
points_true_2D = plot_points_2D(X_true, **opts_true)
points_true_2D

  artist = plot_fn(*plot_args, **plot_kwargs)


In [19]:
y_gauss = linop_gauss.Ax(X_true[:, -1], X_true[:, :-1])
y_gauss_image = linop_gauss.image(y_gauss)

In [20]:
fig_y_A_classic = plot_observation(y_gauss_image) * points_true_2D
fig_y_A_classic_k1 = fig_y_A_classic[1]

if save:
    save_fig(fig_y_A_classic,    f'y_classic_x_true_rep{rep:02d}')
    save_fig(fig_y_A_classic_k1, f'y_classic_x_true_k1_rep{rep:02d}')

fig_y_A_classic

  artist = plot_fn(*plot_args, **plot_kwargs)
  artist = plot_fn(*plot_args, **plot_kwargs)
  artist = plot_fn(*plot_args, **plot_kwargs)
  artist = plot_fn(*plot_args, **plot_kwargs)


In [21]:
fig_y_A_classic_k1

  artist = plot_fn(*plot_args, **plot_kwargs)


### Data init : OP-COMP

In [22]:
object_opCOMP_data_classic = load_data(file_name_opCOMP_data_classic)
X_init_classic       = object_opCOMP_data_classic['X_init']
error_opCOMP_classic = object_opCOMP_data_classic['error_opCOMP']
residue_classic      = object_opCOMP_data_classic['residue']
time_opCOMP_classic  = object_opCOMP_data_classic['time_opCOMP'][0] / 10**9 / 60

In [23]:
error_opCOMP_classic_norm = error_opCOMP_classic**.5 / npl.norm(y_gauss)

In [24]:
plot_error_opCOMP_classic = plot_curve_error(
    error_opCOMP_classic, method='OP-COMP'
)
plot_error_opCOMP_classic_norm = plot_curve_error(
    error_opCOMP_classic_norm, method='OP-COMP',
    normalized=True
)

if save:
    save_fig(plot_error_opCOMP_classic     , f'residue_opCOMP_classic_rep{rep:02d}')
    save_fig(plot_error_opCOMP_classic_norm, f'residue_opCOMP_classic_norm_rep{rep:02d}')

(plot_error_opCOMP_classic + plot_error_opCOMP_classic_norm).opts(shared_axes=False)

In [25]:
error_opCOMP_classic[0], error_opCOMP_classic[-1]

(33079.75663271473, 209.68353843175697)

In [26]:
error_opCOMP_classic_norm[0], error_opCOMP_classic_norm[-1]

(1.0, 0.079616121203283)

In [27]:
points_init_classic = plot_points_3D(X_init_classic, **opts_init)
points_init_classic

In [28]:
fig_x_true_init_classic = (points_true * points_init_classic).opts(fig_latex=fig_latex)

if save:
    save_fig(fig_x_true_init_classic, f'x_true_init_classic_rep{rep:02d}')

fig_x_true_init_classic

### Data esti : PGD (FISTA Restart)

In [29]:
object_PGD_data_classic = load_data(file_name_PGD_data_classic, allow_pickle=True)
X_esti_classic          = object_PGD_data_classic['X_esti']
error_PGD_classic       = object_PGD_data_classic['error_PGD']
time_dict_classic       = object_PGD_data_classic['time_dict'].tolist()

In [30]:
error_PGD_classic_norm = error_PGD_classic**.5 / npl.norm(y_gauss)

In [31]:
error_opCOMP_classic[-1], error_PGD_classic[0]

(209.68353843175697, 209.68353843175697)

In [32]:
error_opCOMP_classic_norm[-1], error_PGD_classic_norm[0]

(0.079616121203283, 0.079616121203283)

In [33]:
plot_error_PGD_classic = plot_curve_error(
    error_PGD_classic , method='PGD',
)
plot_error_PGD_classic_norm = plot_curve_error(
    error_PGD_classic_norm, method='PGD',
    normalized=True,
)

if save:
    save_fig(plot_error_PGD_classic     , f'residue_PGD_classic_rep{rep:02d}')
    save_fig(plot_error_PGD_classic_norm, f'residue_PGD_classic_norm_rep{rep:02d}')

(plot_error_PGD_classic + plot_error_PGD_classic_norm).opts(shared_axes=False)

In [34]:
points_esti_classic = plot_points_3D(X_esti_classic, **opts_esti)
points_esti_classic

In [35]:
fig_x_true_esti_classic = (points_true * points_esti_classic).opts(fig_latex=fig_latex)

if save:
    save_fig(fig_x_true_esti_classic, f'x_true_esti_classic_rep{rep:02d}')

fig_x_true_esti_classic

In [36]:
time_PGD_classic = np.sum(list(time_dict_classic.values())) / 10**9 / 60
time_opCOMP_classic, time_PGD_classic, time_opCOMP_classic + time_PGD_classic

(0.7953103953333333, 123.30726114875, 124.10257154408333)

## Sketched linop

### Data to recover

In [37]:
object_true_data = load_data(file_name_true_data)
X_true = object_true_data['X_true']
w      = object_true_data['w']

In [38]:
linop_sketch.w = w

In [39]:
points_true = plot_points_3D(X_true, **opts_true)
points_true

In [40]:
X_true.shape

(50, 4)

In [41]:
points_true_2D = plot_points_2D(X_true, **opts_true)
points_true_2D

  artist = plot_fn(*plot_args, **plot_kwargs)


In [42]:
y_sketch = linop_sketch.Ax(X_true[:, -1], X_true[:, :-1])
y_sketch_image = linop_sketch.image(linop_gauss.grid, y_sketch)

In [43]:
fig_y_A_sketch = plot_observation(y_sketch_image, clim=(0, y_sketch_image.max())).cols(2) * points_true_2D
fig_y_A_sketch_k1 = fig_y_A_sketch[1].opts(fig_latex=fig_latex)

if save:
    save_fig(fig_y_A_sketch,    f'y_sketch_x_true_m{m:02d}_rep{rep:02d}')
    save_fig(fig_y_A_sketch_k1, f'y_sketch_x_true_k1_m{m:02d}_rep{rep:02d}')

fig_y_A_sketch

  artist = plot_fn(*plot_args, **plot_kwargs)
  artist = plot_fn(*plot_args, **plot_kwargs)
  artist = plot_fn(*plot_args, **plot_kwargs)
  artist = plot_fn(*plot_args, **plot_kwargs)


In [44]:
fig_y_A_sketch_k1

  artist = plot_fn(*plot_args, **plot_kwargs)


In [45]:
t_grid = np.copy(linop_gauss.grid)
a_grid = np.copy(y_gauss)
t_grid.shape, a_grid.shape

((64, 64, 2), (16384,))

In [46]:
y_fourier = linop_sketch.DFT_DiracComb(a_grid, t_grid)
y_fourier_image = linop_sketch.image(linop_gauss.grid, y_fourier)

In [47]:
fig_y_A_sketch_approx = plot_observation(y_fourier_image, clim=(0, y_sketch_image.max())).cols(2) * points_true_2D
fig_y_A_sketch_approx_k1 = fig_y_A_sketch_approx[1].opts(fig_latex=fig_latex)

if save:
    save_fig(fig_y_A_sketch_approx,    f'y_sketch_approx_x_true_m{m:02d}_rep{rep:02d}')
    save_fig(fig_y_A_sketch_approx_k1, f'y_sketch_approx_k1_x_true_m{m:02d}_rep{rep:02d}')

fig_y_A_sketch_approx

  artist = plot_fn(*plot_args, **plot_kwargs)
  artist = plot_fn(*plot_args, **plot_kwargs)
  artist = plot_fn(*plot_args, **plot_kwargs)
  artist = plot_fn(*plot_args, **plot_kwargs)


In [48]:
fig_y_A_sketch_approx_k1

  artist = plot_fn(*plot_args, **plot_kwargs)


In [49]:
diff_y_sketch_fourier = y_sketch_image - y_fourier_image
fig_diff_y_sketch_fourier = plot_observation(diff_y_sketch_fourier, cmap='seismic', symmetric=True)
fig_diff_y_sketch_fourier_k1 = fig_diff_y_sketch_fourier[1].opts(fig_latex=fig_latex)

if save:
    save_fig(fig_diff_y_sketch_fourier,    f'diff_y_sketch_fourier_m{m:02d}_rep{rep:02d}')
    save_fig(fig_diff_y_sketch_fourier_k1, f'diff_y_sketch_fourier_k1_m{m:02d}_rep{rep:02d}')

fig_diff_y_sketch_fourier

In [50]:
fig_diff_y_sketch_fourier_k1

In [53]:
norm_y_gauss = npl.norm(y_gauss)
norm_diff_y_sketch_fourier = npl.norm(y_sketch - y_fourier)
norm_y_gauss, norm_diff_y_sketch_fourier

(181.87841167305902, 0.005746600451108866)

In [54]:
0.006 / 180

3.3333333333333335e-05

In [55]:
norm_diff_y_sketch_fourier / norm_y_gauss

3.1595835911734476e-05

### Data init : OP-COMP

In [56]:
object_opCOMP_data_sketch = load_data(file_name_opCOMP_data_sketch)
X_init_sketch       = object_opCOMP_data_sketch['X_init']
error_opCOMP_sketch = object_opCOMP_data_sketch['error_opCOMP']
residue_sketch      = object_opCOMP_data_sketch['residue']
time_opCOMP_sketch  = object_opCOMP_data_sketch['time_opCOMP'][0] / 10**9 / 60

In [57]:
error_opCOMP_sketch_norm = error_opCOMP_sketch**.5 / npl.norm(y_fourier)

In [58]:
plot_error_opCOMP_sketch = plot_curve_error(
    error_opCOMP_sketch , method='OP-COMP',
)
plot_error_opCOMP_sketch_norm = plot_curve_error(
    error_opCOMP_sketch_norm, method='OP-COMP',
    sketching=True, normalized=True,
)

if save:
    save_fig(plot_error_opCOMP_sketch     , f'residue_opCOMP_sketch_m{m:02d}_rep{rep:02d}')
    save_fig(plot_error_opCOMP_sketch_norm, f'residue_opCOMP_sketch_norm_m{m:02d}_rep{rep:02d}')

(plot_error_opCOMP_sketch + plot_error_opCOMP_sketch_norm).opts(shared_axes=False, fig_latex=fig_latex)

In [59]:
error_opCOMP_sketch[-1], error_opCOMP_sketch_norm[-1]

(621.2696004186798, 0.14519087054685786)

In [60]:
points_init_sketch = plot_points_3D(X_init_sketch, **opts_init)
points_init_sketch

In [61]:
fig_x_true_init_sketch = (points_true * points_init_sketch).opts(fig_latex=fig_latex)

if save:
    save_fig(fig_x_true_init_sketch, f'x_true_init_sketch_m{m:02d}_rep{rep:02d}')

fig_x_true_init_sketch

### Data esti : PGD (FISTA Restart)

In [62]:
object_PGD_data_sketch = load_data(file_name_PGD_data_sketch, allow_pickle=True)
X_esti_sketch          = object_PGD_data_sketch['X_esti']
error_PGD_sketch       = object_PGD_data_sketch['error_PGD']
time_dict_sketch       = object_PGD_data_sketch['time_dict'].tolist()

In [63]:
error_PGD_sketch_norm = error_PGD_sketch**.5 / npl.norm(y_fourier)

In [64]:
error_opCOMP_sketch[-1], error_PGD_sketch[0]

(621.2696004186798, 621.2696004186798)

In [65]:
error_opCOMP_sketch_norm[-1], error_PGD_sketch_norm[0]

(0.14519087054685786, 0.14519087054685786)

In [66]:
plot_error_PGD_sketch = plot_curve_error(
    error_PGD_sketch , method='PGD',
)
plot_error_PGD_sketch_norm = plot_curve_error(
    error_PGD_sketch_norm, method='PGD',
    sketching=True, normalized=True
)

if save:
    save_fig(plot_error_PGD_sketch     , f'residue_PGD_sketch_m{m:02d}_rep{rep:02d}')
    save_fig(plot_error_PGD_sketch_norm, f'residue_PGD_sketch_norm_m{m:02d}_rep{rep:02d}')

(plot_error_PGD_sketch + plot_error_PGD_sketch_norm).opts(shared_axes=False)

In [67]:
error_PGD_sketch_norm.shape

(37685,)

In [68]:
points_esti_sketch = plot_points_3D(X_esti_sketch, **opts_esti)
points_esti_sketch

In [69]:
fig_x_true_esti_sketch = (points_true * points_esti_sketch).opts(fig_latex=fig_latex)

if save:
    save_fig(fig_x_true_esti_sketch, f'x_true_esti_sketch_m{m:02d}_rep{rep:02d}')

fig_x_true_esti_sketch

In [70]:
time_PGD_sketch = np.sum(list(time_dict_sketch.values())) / 10**9 / 60
time_opCOMP_sketch, time_PGD_sketch, time_opCOMP_sketch + time_PGD_sketch

(1.9512943945833332, 24.048190606933336, 25.999485001516668)

In [71]:
error_PGD_classic_norm[-1], error_PGD_sketch_norm[-1]

(1.722924097046133e-05, 3.107068572002797e-05)

In [72]:
(26 - 127) / 127

-0.7952755905511811

## Merging figs

In [70]:
plot_error_opCOMP_both_norm = hv.NdOverlay(
    {
        'Classical': plot_error_opCOMP_classic_norm,
        'Sketching': plot_error_opCOMP_sketch_norm,
    }, kdims='Method'
).opts(xlabel='Norm of iterations by OP-COMP', fig_latex=fig_latex)

if save:
    save_fig(plot_error_opCOMP_both_norm, f'residue_opCOMP_both_norm_m{m:02d}_rep{rep:02d}')

plot_error_opCOMP_both_norm

In [71]:
plot_error_PGD_classic_norm_wrt_time = plot_functional_wrt_time(error_PGD_classic_norm, time_dict_classic)
plot_error_PGD_sketch__norm_wrt_time  = plot_functional_wrt_time(error_PGD_sketch_norm, time_dict_sketch)

plot_error_PGD_both_norm_wrt_time = hv.NdOverlay(
    {
        'Classical': plot_error_PGD_classic_norm_wrt_time,
        'Sketching': plot_error_PGD_sketch__norm_wrt_time,
    }, kdims='Method'
).opts(fig_latex=fig_latex)

if save:
    save_fig(plot_error_PGD_both_norm_wrt_time, f'residue_PGD_both_norm_wrt_time_m{m:02d}_rep{rep:02d}')

plot_error_PGD_both_norm_wrt_time

In [72]:
plot_error_PGD_both_norm = hv.NdOverlay(
    {
        'Classical': plot_error_PGD_classic_norm,
        'Sketching': plot_error_PGD_sketch_norm,
    }, kdims='Method'
).opts(xlabel='Norm of iterations by PGD', fig_latex=fig_latex)

if save:
    save_fig(plot_error_PGD_both_norm, f'residue_PGD_both_norm_m{m:02d}_rep{rep:02d}')

plot_error_PGD_both_norm

# Evolution of time and energy with `m`

In [74]:
N = 50
range_rep = list(range(0, 16, 1)) # 1
range_m = list(range(3, 51, 3))
range_rep, range_m

([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
 [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48])

In [75]:
sigma_x = sigma_y = None
m_gaussian = 64
linop_gauss  = Gaussian2D_MATIRF(sigma_x, sigma_y, N1=m_gaussian, N2=m_gaussian)

### Load data

In [76]:
name_suffix = ".npz"
directory = "data/"
dict_m = dict()

for m in range_m:
    dict_m[m] = dict()
    for rep in range_rep:
        dict_m[m][rep] = dict()
        name_prefix_sketch  = f"data_k{N:03d}_m{m:02d}_rep{rep:02d}_"
        file_name_true_data    = directory + name_prefix_sketch + "true"    + name_suffix
        file_name_opCOMP_data  = directory + name_prefix_sketch + "opCOMP"  + name_suffix
        file_name_PGD_data     = directory + name_prefix_sketch + "PGD"     + name_suffix

        dict_m[m][rep]['true'] = load_data(file_name_true_data)
        dict_m[m][rep]['init'] = load_data(file_name_opCOMP_data)
        dict_m[m][rep]['esti'] = load_data(file_name_PGD_data, allow_pickle=True)

        m_sketch = m * N
        linop_sketch = Sketching_Gaussian2D_MATIRF(m_sketch, linop_gauss.sigma_x, linop_gauss.sigma_y)

        # Normalize data
        X_true = dict_m[m][rep]['true']['X_true']
        w      = dict_m[m][rep]['true']['w']
        linop_sketch.w = w
        y_sketch = linop_sketch.Ax(X_true[:, -1], X_true[:, :-1])
        dict_m[m][rep]['init']['error_opCOMP_norm'] = dict_m[m][rep]['init']['error_opCOMP']**.5 / npl.norm(y_sketch)
        dict_m[m][rep]['esti']['error_PGD_norm']    = dict_m[m][rep]['esti']['error_PGD']**.5    / npl.norm(y_sketch)

In [77]:
dict_classic = dict()
dict_classic[0] = dict()

for rep in range_rep:
    dict_classic[0][rep] = dict()
    name_prefix = f"data_classical_k{N:03d}_rep{rep:02d}_"
    file_name_true_data    = directory + name_prefix + "true"    + name_suffix
    file_name_opCOMP_data  = directory + name_prefix + "opCOMP"  + name_suffix
    file_name_PGD_data     = directory + name_prefix + "PGD"     + name_suffix

    dict_classic[0][rep]['true'] = load_data(file_name_true_data)
    dict_classic[0][rep]['init'] = load_data(file_name_opCOMP_data)
    dict_classic[0][rep]['esti'] = load_data(file_name_PGD_data, allow_pickle=True)

    # Normalize data
    X_true = dict_classic[0][rep]['true']['X_true']
    y_gauss = linop_gauss.Ax(X_true[:, -1], X_true[:, :-1])
    dict_classic[0][rep]['init']['error_opCOMP_norm'] = dict_classic[0][rep]['init']['error_opCOMP']**.5 / npl.norm(y_gauss)
    dict_classic[0][rep]['esti']['error_PGD_norm']    = dict_classic[0][rep]['esti']['error_PGD']**.5    / npl.norm(y_gauss)

In [78]:
save = False
fig_latex = True

def save_fig(fig, name):
    hv.save(fig, 'figs/' + f'fig_' + name, fmt='svg')

### Display function

In [79]:
def plot_spread_percentile(dict_m, dict_classic, func, q_spread=90, **opts_output):
    opts_spread = dict(
        alpha = .4,
        linewidth = 0
    )
    opts_curve = dict(
        linewidth = 4,
        linestyle = 'dotted',
    )
    opts = dict(
        show_legend = True,
        xlabel = 'm', aspect = 2,
        fig_size = 300, fontscale = 2,
        legend_cols = 2, # xformatter = PercentFormatter(xmax=1),
        fig_latex = fig_latex,
    )
    opts.update(opts_output)
    range_m = sorted(dict_m.keys())
    range_rep = sorted(dict_m[range_m[0]].keys())
    array_step_key = np.zeros((len(range_m), len(range_rep)))

    for idx_m, m in enumerate(sorted(dict_m.keys())):
        for idx_rep, rep in enumerate(sorted(dict_m[m].keys())):
            array_step_key[idx_m, idx_rep] = func(dict_m, m, rep)

    # classic
    array_step_key_classic = np.zeros(len(range_rep))
    for idx_rep, rep in enumerate(sorted(dict_classic[0].keys())):
        array_step_key_classic[idx_rep] = func(dict_classic, 0, rep)

    margin = round((100 - q_spread) / 2)
    q_low  = margin
    q_high = 100 - margin

    percentile_low  = np.percentile(array_step_key, q=q_low , axis=-1)
    percentile_high = np.percentile(array_step_key, q=q_high, axis=-1)
    median = np.median(array_step_key, axis=-1)
    median_classic = np.median(array_step_key_classic)
    data = dict(
        classic = array_step_key_classic,
        sketch  = array_step_key
    )

    yerrneg = np.abs(median - percentile_low)
    yerrpos = np.abs(percentile_high - median)

    m_classic = linop_gauss.N1 * linop_gauss.N2 * linop_gauss.K
    range_m_sketch = np.array(range_m) * N * linop_sketch.K
    x_range = range_m_sketch / m_classic

    def percent_format(x): return f'${round(100*x)}\%$'
    kdim = hv.Dimension('m', value_format=percent_format, range=(x_range.min(), x_range.max()))

    plot_spread = hv.Spread((x_range, median, yerrneg, yerrpos), kdims=kdim, vdims=['y', 'yerrneg', 'yerrpos']).opts(**opts_spread)
    plot_q_low  = hv.Curve((x_range, percentile_low),  kdims=kdim, label=f"{margin}-th percentile").opts(linewidth=2, linestyle='dashed', color='k')
    plot_q_high = hv.Curve((x_range, percentile_high), kdims=kdim, label=f"{100-margin}-th percentile").opts(linewidth=2, linestyle='dashdot', color='k')
    plot_median = hv.Curve((x_range, median), kdims=kdim, label='Median value').opts(**opts_curve)
    width = (plot_median.data["m"].min(), plot_median.data["m"].max())

    plot_median_classic = hv.Curve((width, [median_classic]*2), kdims=kdim, label="Median without sketching").opts(linewidth=2, color='r')

    plot_output = (plot_spread * plot_q_low * plot_q_high * plot_median * plot_median_classic).opts(**opts)
    return plot_output, data

## Spread of error

In [80]:
func_sum_time_seconds = lambda dict_m, m, rep: (dict_m[m][rep]['init']['time_opCOMP'][0] + np.sum(list(dict_m[m][rep]['esti']['time_dict'].tolist().values()))) / 10**9 / 60
output_time_error_bars, data_time_seconds = plot_spread_percentile(
    dict_m, dict_classic,
    func=func_sum_time_seconds,
    q_spread=90,
    xlabel='Proportion of number of measurements by $\mathcal{S}A$ when compared to $A$',
    ylabel='Time elapsed by OP-COMP + PGD (in minutes)',
    logy=False, legend_position='top',
)

if save:
    save_fig(output_time_error_bars, 'aggregate_rep16_time')

output_time_error_bars

In [82]:
sorted(dict_m.keys())

[3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48]

In [85]:
np.median(data_time_seconds['sketch'][3]), np.median(data_time_seconds['classic'][3]),

(27.169801057725, 143.33232419910001)

In [86]:
(np.median(data_time_seconds['sketch'][3]) - np.median(data_time_seconds['classic'][3])) / np.median(data_time_seconds['classic'][3])

-0.8104419138561936

In [81]:
for func in [np.mean, np.median]:
    print(f'Classic {func.__name__} = {func(data_time_seconds["classic"])}')
    print(f'Sketch  {func.__name__} = {func(data_time_seconds["sketch"][0])}')

Classic mean = 105.68478895995312
Sketch  mean = 8.960433733028125
Classic median = 116.167207138525
Sketch  median = 8.013910597483333


In [82]:
func_error_PGD = lambda dict_m, m, rep: dict_m[m][rep]['esti']['error_PGD_norm'][-1]
output_error_PGD_error_bars, data_error_PGD = plot_spread_percentile(
    dict_m, dict_classic,
    func=func_error_PGD,
    q_spread=80,
    xlabel='Proportion of number of measurements by $\mathcal{S}A$ when compared to $A$',
    ylabel='Norm residue by PGD normalized',
    logy=True, aspect=1,
    legend_position='top',
)
output_error_PGD_error_bars.opts(legend_cols=2)

if save:
    save_fig(output_error_PGD_error_bars, 'aggregate_rep16_residue_PGD')

output_error_PGD_error_bars

In [83]:
func_error_opCOMP = lambda dict_m, m, rep: dict_m[m][rep]['init']['error_opCOMP_norm'][-1]
output_error_opCOMP_error_bars, data_error_opCOMP = plot_spread_percentile(
    dict_m, dict_classic,
    func=func_error_opCOMP,
    q_spread=90,
    xlabel='Proportion of number of measurements by $\mathcal{S}A$ when compared to $A$',
    ylabel='Norm residue by OP-COMP normalized',
    logy=False, aspect=1,
    legend_position='top',
)

if save:
    save_fig(output_error_opCOMP_error_bars, 'aggregate_rep16_residue_opCOMP')

output_error_opCOMP_error_bars

In [84]:
# func_nit_opCOMP = lambda dict_m, m, rep: dict_m[m][rep]['init']['error_opCOMP_norm'].size - 1
# output_nit_opCOMP_error_bars, data_nit_opCOMP = plot_spread_percentile(
#     dict_m, dict_classic,
#     func=func_nit_opCOMP,
#     q_spread=90,
#     xlabel='Proportion of number of measurements by $\mathcal{S}A$ when compared to $A$',
#     ylabel='Number of spikes initialized with OP-COMP',
#     logy=False, aspect=1,
# )

# if save and False:
#     save_fig(output_nit_opCOMP_error_bars, 'aggregate_rep16_nit_opCOMP')

# output_nit_opCOMP_error_bars

In [85]:
def compute_RMSE(dict_exp, m, rep, dim=0):
    t_true_paired, t_true_not_paired, t_esti_paired, t_esti_not_paired = pair_GT_estimation(
        dict_exp[m][rep]['true']['X_true'][:, :-1],
        dict_exp[m][rep]['esti']['X_esti'][:, :-1],
    )
    return RMSE(t_true_paired, t_esti_paired)[dim]

In [86]:
RMSE_x = partial(compute_RMSE, dim=0)
RMSE_y = partial(compute_RMSE, dim=1)
RMSE_z = partial(compute_RMSE, dim=2)

In [87]:
output_RMSE_x_error_bars, data_RMSE_x = plot_spread_percentile(
    dict_m, dict_classic,
    func=RMSE_x,
    q_spread=90,
    xlabel='Proportion of number of measurements by $\mathcal{S}A$ when compared to $A$',
    ylabel='RMSE along the x-axis (in mm)',
    logy=True, legend_position='top',
)

if save:
    save_fig(output_RMSE_x_error_bars, 'aggregate_rep16_RMSE_x')

output_RMSE_x_error_bars

In [88]:
output_RMSE_y_error_bars, data_RMSE_y = plot_spread_percentile(
    dict_m, dict_classic,
    func=RMSE_y,
    q_spread=90,
    xlabel='Proportion of number of measurements by $\mathcal{S}A$ when compared to $A$',
    ylabel='RMSE along the y-axis (in mm)',
    logy=True, legend_position='top',
)

if save:
    save_fig(output_RMSE_y_error_bars, 'aggregate_rep16_RMSE_y')

output_RMSE_y_error_bars

In [89]:
output_RMSE_z_error_bars, data_RMSE_z = plot_spread_percentile(
    dict_m, dict_classic,
    func=RMSE_z,
    q_spread=90,
    xlabel='Proportion of number of measurements by $\mathcal{S}A$ when compared to $A$',
    ylabel='RMSE along the z-axis (in mm)',
    logy=True, legend_position='top',
)

if save:
    save_fig(output_RMSE_z_error_bars, 'aggregate_rep16_RMSE_z')

output_RMSE_z_error_bars

In [90]:
def plot_RMSE_all_exps(data_RMSE_x, data_RMSE_y, data_RMSE_z, range_m, **opts_output):
    opts = dict(
        show_legend = True,
        legend_position='top',
        xlabel = 'm',
        fig_size = 300, fontscale = 2,
        legend_cols = 2,
        fig_latex = fig_latex,
        logy=True, show_title = False,
    )
    opts.update(opts_output)

    median = dict(
        sketch = dict(),
        classic = dict(),
    )
    opts_sketch = dict(
        x = dict(marker='^', ms=13, color='r', linestyle='--', linewidth=2),
        y = dict(marker='o', ms=13, color='g', linestyle='--', linewidth=2),
        z = dict(marker='s', ms=13, color='b', linestyle='--', linewidth=2),
    )
    opts_classic = dict(
        x = dict(color='r', linestyle='-', linewidth=2),
        y = dict(color='g', linestyle='-', linewidth=2),
        z = dict(color='b', linestyle='-', linewidth=2),
    )
    plots_sketch_RMSE = dict()
    plots_classic_RMSE = dict()

    range_m = sorted(range_m)
    m_classic = linop_gauss.N1 * linop_gauss.N2 * linop_gauss.K
    range_m_sketch = np.array(range_m) * N * linop_sketch.K
    x_range = range_m_sketch / m_classic

    def percent_format(x): return f'${round(100*x)}\%$'
    kdim = hv.Dimension('m', value_format=percent_format, range=(x_range.min(), x_range.max()))

    method_sketch = 'Sketching'
    method_classic = 'Without Sketching'
    for dim, data in zip(['x', 'y', 'z'], [data_RMSE_x, data_RMSE_y, data_RMSE_z]):
        median['sketch'][dim] = np.median(data['sketch'], axis=-1)
        median['classic'][dim] = np.median(data['classic'])

        sketch_curve_RMSE = hv.Curve((x_range, median['sketch'][dim]), kdims=kdim, group=method_sketch, label=f'RMSE {dim}')
        plots_sketch_RMSE[(f'RMSE {dim}', method_sketch)] = sketch_curve_RMSE.opts(**opts_sketch[dim])

        width = (sketch_curve_RMSE.data["m"].min(), sketch_curve_RMSE.data["m"].max())
        classic_curve_RMSE = hv.Curve((width, [median['classic'][dim]]*2), kdims=kdim, group=method_classic, label=f'RMSE {dim}')
        plots_classic_RMSE[(f'RMSE {dim}', method_classic)] = classic_curve_RMSE.opts(**opts_classic[dim])

    sketch_overlay_RMSE = hv.NdOverlay(plots_sketch_RMSE, kdims=['RMSE', 'Method'])
    classic_overlay_RMSE = hv.NdOverlay(plots_classic_RMSE, kdims=['RMSE', 'Method'])

    return (sketch_overlay_RMSE * classic_overlay_RMSE).opts(**opts)

In [92]:
output_RMSE_all = plot_RMSE_all_exps(
    data_RMSE_x, data_RMSE_y, data_RMSE_z, range_m,
    xlabel='Proportion of number of measurements by $\mathcal{S}A$ when compared to $A$',
    ylabel='RMSE along each axis (in mm)',
)

if save:
    save_fig(output_RMSE_all, 'aggregate_rep16_RMSE_all')

output_RMSE_all