In [None]:
from functools import partial
from multiprocessing import cpu_count
from multiprocessing.pool import Pool

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from distance_determination import estimate_dist, get_current_freq
from interpolation.interpolate import interpolate, sklearn_interpolations
from interpolation.metrics import get_RMSE
from interpolation.wsinterp import wsinterp
from simul.signals.augment import signal_add_noise
from simul.utilities.data import load_experiment
from simul.vis.dist_probs import vis_dist_probs
from simul.vis.signals import get_vis_df, vis_signals, vis_signals2d, get_fft_df
from tqdm.auto import tqdm, trange

%load_ext autoreload
%autoreload 2

In [None]:
from distance_determination import estimate_dist, simulate_signals
from run_experiment import experiments

# import plotly.io as pio
# pio.renderers.default = "notebook_connected"
exp_start = 1000
exp_size = 5000
# exp_name = "default_full"
exp_name = "default_full"
params = experiments[exp_name]
dist, signals_data_full = simulate_signals(params)
signals_data = signals_data_full[:, exp_start:exp_start+exp_size]
# exp_name = "default"
exp_name = "default_random"
params = experiments[exp_name]
dist, signals_data_pruned_full = simulate_signals(params)
signals_data_pruned = signals_data_pruned_full[:, exp_start:exp_start+exp_size]


exp_name = "default"
params = experiments[exp_name]
dist, signals_data_pruned_reg_full = simulate_signals(params)
signals_data_pruned_reg = signals_data_pruned_reg_full[:, exp_start:exp_start+exp_size]


In [None]:
df = get_vis_df(
    params.tss,
    signals_data,
    signals_data_pruned,
#     *interp_signals,
#     n=20000,
    freqs=[0],
)
vis_signals2d(df)

In [None]:

def randomized_sinc_interp(x:np.ndarray, xp:np.ndarray, fp:np.ndarray, sigma_coeff=0.8, left=None, right=None)->np.ndarray:

    Tn = (xp[-1] - xp[0])/(xp.shape[0]-1)
#     print(xp.shape, xp[0], xp[1], Tn)
    xp_regular = np.arange(xp[0], xp[-1]+Tn*0.1, Tn)
    
    xp_deltas = xp - xp_regular

    xp_result = xp_regular + xp_deltas * sigma_coeff 

    # shape = (nxp, nx), nxp copies of x data span axis 1
    u = np.resize(x, (len(xp), len(x)))
    # Must take transpose of u for proper broadcasting with xp.
    # shape = (nx, nxp), v(xp) data spans axis 1
    # v = (xp - u.T) / (Tn)
#     v = (xp_result - u.T) / (Tn)
    v = (u.T - xp_result) / (Tn)
    # shape = (nx, nxp), m(v) data spans axis 1
    m =   fp * np.sinc(v)
    # Sum over m(v) (axis 1)
    fp_at_x = np.sum(m, axis=1)

    # Enforce left and right
    if left is None:
        left = fp[0]
    fp_at_x[x < xp[0]] = left
    if right is None:
        right = fp[-1]
    fp_at_x[x > xp[-1]] = right

    return fp_at_x

def interpolate_rand(signals: np.ndarray, sigma:float = 0.8):
    interp_signal = []

    x = np.arange(signals.shape[1])
    for signal in tqdm(signals):
        idx = np.where(~np.isnan(signal))[0]
        interp_signal.append(randomized_sinc_interp(x, x[idx], signal[idx], sigma))

    return np.array(interp_signal)



In [None]:
interp_signals = [
    (interpolate_rand(signals_data_pruned, sigma), f"Sigma:{sigma}") for sigma in [0, 0.1, 0.2, 0.5, 0.7, 0.8, 1]
]


In [None]:

vis_signals2d(
    get_vis_df(
        params.tss,
        signals_data,
        signals_data_pruned,
        *interp_signals,
#         n=20000,
        freqs=[0],
))

In [None]:
vis_signals2d(
    get_fft_df(
        params.tss,
        signals_data,
        signals_data_pruned,
        *interp_signals,
#         n=20000,
        freqs=[0],
))

In [None]:
interp_signals_reg = [
    (interpolate_rand(signals_data_pruned_reg, sigma), f"Sigma:{sigma}") for sigma in [0, 0.1, 0.2, 0.5, 0.7, 0.8, 1]
]+[(interpolate(signals_data_pruned_reg, "Whittaker–Shannon"), "WS")]

In [None]:
vis_signals2d(
    get_vis_df(
        params.tss,
        signals_data,
        signals_data_pruned_reg,
        *interp_signals_reg,
#         n=20000,
        freqs=[0],
))

In [None]:
vis_signals2d(
    get_fft_df(
        params.tss,
        signals_data,
        signals_data_pruned_reg,
        *interp_signals_reg,
#         n=20000,
        freqs=[0],
))