In [1]:
%%HTML
<style>
    div#notebook-container    { width: 95%; }
    div#menubar-container     { width: 65%; }
    div#maintoolbar-container { width: 99%; }
</style>


In [2]:
import sys
sys.path.append("../")

import numpy as np
import pickle
from sklearn import linear_model
from tqdm import tqdm_notebook as tqdm

from ampy import utils
from ampy import StabilitySelectionVAMPSolver
%matplotlib inline

In [3]:
for α in tqdm([0.05, 0.075, 0.09, 0.1, 0.15], desc="alpha"):
    
    # cross validation
    np.random.seed(0)
    # α = 0.15
    N = 4096
    M = int(N * α)
    ρ = 0.05

    A = utils.utils.make_random_dct_matrix(M, N)
    x_0 = np.random.normal(0.0, 1.0, N) * np.random.binomial(1, ρ, N)
    y = A @ x_0 + np.random.normal(0.0, 1e-1, M)
    A -= A.mean(axis=0)
    A /= A.std(axis=0)
    y -= y.mean()

    lasso_cv = linear_model.LassoCV(n_jobs=3, n_alphas=50, cv=10, tol=1e-3, max_iter=1e7, fit_intercept=False)
    lasso_cv.fit(A, y)


    '------------------------------------------------------------------------------------------------------------------------------------'
    # VAMP
    μ = 0.5

    dumping = 1.0
    tol = 1e-200
    message=False

    history_list = []
    n_sample = 10
    for seed in tqdm(range(n_sample)):
        np.random.seed(seed)
        A = utils.utils.make_random_dct_matrix(M, N)
        x_0 = np.random.normal(0.0, 1.0, N) * np.random.binomial(1, ρ, N)
        y = A @ x_0 + np.random.normal(0.0, 1e-1, M)
        A -= A.mean(axis=0)
        A /= A.std(axis=0)
        y -= y.mean()

        vamp_solver = StabilitySelectionVAMPSolver(A, y, regularization_strength=lasso_cv.alpha_ * M * np.ones(N), 
                                                           dumping_coefficient=dumping, mu=μ, clip_min=1e-12, clip_max=1e12)  # solver

        _ = vamp_solver.solve(max_iteration=200, tolerance=tol, message=message)  # fit
        history_list.append(vamp_solver.diff_history)

    history_array = np.array(history_list)

    alpha_string = "{0:.3f}".format(α)[:].replace(".", "p")
    with open("time_evolution_random_dct_manytimes_alpha" + alpha_string + ".pickle", "wb") as f:
        pickle.dump(history_array, f, pickle.HIGHEST_PROTOCOL)

HBox(children=(IntProgress(value=0, description='alpha', max=5, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

doesn't converged
abs_diff= 1.1377412487886083e-20 
abs_diff_v= 4.8394373874942393e-23 
iteration num=199
### x1 ###
self.chi1x_hat 1.363101792554156
self.q1x_hat 8.214603043992078
self.chi1x 2.40213780828496e-06
self.v1x 3.2539365260477016e-08
### u1 ###
self.chi1u_hat 0.0001343389406214994
self.q1u_hat 0.009836014878485818
self.chi1u 0.4927537825532099
self.v1u 0.02915380655191879
### x2 ###
self.chi2x_hat 927490234413.7625
self.q2x_hat 927490238374.3833
self.chi2x 2.40213780828496e-06
self.v2x 3.2539365260477016e-08
### u2 ###
self.chi2u_hat 0.12005183502767872
self.q2u_hat 2.019577157422516
self.chi2u 0.49275378255320984
self.v2u 0.02915380655191879


doesn't converged
abs_diff= 5.246144043606088e-21 
abs_diff_v= 2.973633229557357e-23 
iteration num=199
### x1 ###
self.chi1x_hat 1.2747695854996433
self.q1x_hat 7.0277621101648675
self.chi1x 1.6432805829210322e-06
self.v1x 2.0083343056733134e-08
### u1 ###
self.chi1u_hat 8.271201503974586e-05
self.q1u_hat 0.0067287187950506215
self.c

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

doesn't converged
abs_diff= 2.6563159091782224e-20 
abs_diff_v= 1.5642190974246596e-22 
iteration num=199
### x1 ###
self.chi1x_hat 1.5907240585286426
self.q1x_hat 12.917570828488406
self.chi1x 5.070656735320288e-06
self.v1x 8.093314730109086e-08
### u1 ###
self.chi1u_hat 0.0003329594486729715
self.q1u_hat 0.020770671509440265
self.chi1u 0.48499119161672827
self.v1u 0.027101542513017392
### x2 ###
self.chi2x_hat 919433593772.0525
self.q2x_hat 919433595721.7145
self.chi2x 5.070656735320288e-06
self.v2x 8.093314730109086e-08
### u2 ###
self.chi2u_hat 0.11501372494281446
self.q2u_hat 2.0411296111401733
self.chi2u 0.4849911916167284
self.v2u 0.027101542513017392


doesn't converged
abs_diff= 2.1518986608985326e-20 
abs_diff_v= 7.317305575437925e-23 
iteration num=199
### x1 ###
self.chi1x_hat 1.4516789148742935
self.q1x_hat 11.199697943656364
self.chi1x 3.5828276934644318e-06
self.v1x 4.8862538209745895e-08
### u1 ###
self.chi1u_hat 0.0002011956422034363
self.q1u_hat 0.014672731563711438
s

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

doesn't converged
abs_diff= 9.237322411277126e-20 
abs_diff_v= 1.2957258427603934e-21 
iteration num=199
### x1 ###
self.chi1x_hat 3.1595593797387087
self.q1x_hat 50.863352501102334
self.chi1x 1.9707775323351947e-05
self.v1x 3.2849352851748005e-07
### u1 ###
self.chi1u_hat 0.0013489678841405075
self.q1u_hat 0.0808760487675954
self.chi1u 0.4470028941647283
self.v1u 0.021918337740930337
### x2 ###
self.chi2x_hat 695800781340.1963
self.q2x_hat 695800790349.9121
self.chi2x 1.9707775323351947e-05
self.v2x 3.2849352851748e-07
### u2 ###
self.chi2u_hat 0.10864119299034576
self.q2u_hat 2.156282163259726
self.chi2u 0.44700289416472827
self.v2u 0.021918337740930337


doesn't converged
abs_diff= 5.3244312032006487e-20 
abs_diff_v= 2.538912280890978e-22 
iteration num=199
### x1 ###
self.chi1x_hat 2.869401355136803
self.q1x_hat 47.64467749613441
self.chi1x 1.5996288518884542e-05
self.v1x 2.2047660685312443e-07
### u1 ###
self.chi1u_hat 0.0009052725067451679
self.q1u_hat 0.06563007860449563
self.ch

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

doesn't converged
abs_diff= 1.8358992324777173e-19 
abs_diff_v= 2.1880950452701536e-21 
iteration num=199
### x1 ###
self.chi1x_hat 4.855790527660249
self.q1x_hat 104.63240348294784
self.chi1x 3.865414209324561e-05
self.v1x 5.671074893920053e-07
### u1 ###
self.chi1u_hat 0.0023209577497001962
self.q1u_hat 0.15920147529804632
self.chi1u 0.40663773540210923
self.v1u 0.017409938154419217
### x2 ###
self.chi2x_hat 376953125193.00214
self.q2x_hat 376953148609.13403
self.chi2x 3.865414209324561e-05
self.v2x 5.671074893920053e-07
### u2 ###
self.chi2u_hat 0.1032589311712595
self.q2u_hat 2.3000494145788566
self.chi2u 0.40663773540210923
self.v2u 0.017409938154419217


doesn't converged
abs_diff= 1.313278206157606e-19 
abs_diff_v= 3.675871537602102e-22 
iteration num=199
### x1 ###
self.chi1x_hat 4.584337028252424
self.q1x_hat 104.64858761248536
self.chi1x 3.3200903545668387e-05
self.v1x 4.2836515515636214e-07
### u1 ###
self.chi1u_hat 0.0017517531772853238
self.q1u_hat 0.13661674813474295
self

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

doesn't converged
abs_diff= 1.481378133704481e-18 
abs_diff_v= 6.483863884780695e-21 
iteration num=199
### x1 ###
self.chi1x_hat 5.827733111619295
self.q1x_hat 179.54758020800733
self.chi1x 0.0001298097070498788
self.v1x 1.7443446642876153e-06
### u1 ###
self.chi1u_hat 0.007087707346905317
self.q1u_hat 0.5438396753660281
self.chi1u 0.2857082222165579
self.v1u 0.009146029192737135
### x2 ###
self.chi2x_hat 103.11348457895991
self.q2x_hat 14220.459971648568
self.chi2x 0.0001298097070498788
self.v2x 1.7443446642876147e-06
### u2 ###
self.chi2u_hat 0.1051315419632193
self.q2u_hat 2.9563844023780717
self.chi2u 0.28570822221655734
self.v2u 0.009146029192737111


doesn't converged
abs_diff= 3.5159879600162694e-19 
abs_diff_v= 1.1339822230514577e-21 
iteration num=199
### x1 ###
self.chi1x_hat 5.551024957365945
self.q1x_hat 192.16145565216817
self.chi1x 0.00010870102020805854
self.v1x 1.2592901540282219e-06
### u1 ###
self.chi1u_hat 0.005107187171868497
self.q1u_hat 0.4544081029360075
self.ch