In [1]:
import numpy as np
from sklearn.datasets import make_sparse_spd_matrix
from scipy import linalg as LA

from infoband.band_info import InfoCorrBand
from wlpy.covariance import Covariance
from utils.adpt_correlation_threshold import AdptCorrThreshold

In [2]:
def cov2cor(S: np.ndarray):
    D = np.diag(np.sqrt(np.diag(S)))
    D_inv = np.linalg.inv(D)
    return D_inv @ S @ D_inv

In [3]:
def gen_S_AR1(rho = 0.8,N = 500) -> np.ndarray:
    # generate the covariance matrix of AR(1) process
    S_block = np.zeros(shape=[N, N])
    for j in range(0, N):
        S_block = S_block + np.diag(np.ones(N-j)*(rho**j), -j) + \
        np.diag(np.ones(N-j)*(rho**j), j)
    S = S_block - np.eye(N)
    return S

In [5]:
rng = np.random.RandomState(100)
N = 100
T = 80

In [29]:
# S = gen_S_AR1(N = N)
S = make_sparse_spd_matrix(N, alpha = 0.95, random_state = 100)
R = cov2cor(S)
# print(S[:5, :5])
X = rng.multivariate_normal(mean =np.zeros(N), cov = S, size = T)

In [30]:
c = InfoCorrBand(X)
# c.sample_cov()[:3, :3]
# c.sample_corr()[:3, :3]



In [31]:
L = abs(R)
c.feed_info(L)
# print(L[:5, :5])

In [32]:
# c.find_biggest_k_for_pd()
# c.plot_k_pd(range(N-50, N+1))

In [33]:
k = c.k_by_cv()
print(k)

15


In [34]:
R_est = c.fit_info_corr_band(k)
S_est = c.fit_info_cov_band(k)

In [35]:
def show_rs(S: np.ndarray, 
            c: InfoCorrBand, m: Covariance, 
            ord = 'fro'):
    # ord: norm type
    R = cov2cor(S)
    print('Correlation itself', LA.norm(R, ord))
    print('Error:')
    print('Sample', LA.norm(c.sample_corr() - R, ord))
    print('Linear Shrinkage', LA.norm(cov2cor(m.lw_lin_shrink()) - R, ord))
    print('Nonlinear Shrinkage', LA.norm(cov2cor(m.nonlin_shrink()) - R, ord))
    print()
    print('Covariance itself', LA.norm(S, ord))
    print('Error:')
    print('Sample', LA.norm(c.sample_cov() - S, ord))
    print('Linear Shrinkage', LA.norm(m.lw_lin_shrink() - S, ord))
    print('Nonlinear Shrinkage', LA.norm(m.nonlin_shrink() - S, ord))
    return

In [36]:
m = Covariance(X)

In [37]:
show_rs(S, c, m, 'fro')
print(LA.norm(R - R_est))
print(LA.norm(S - S_est))

Correlation itself 13.063598940290655
Error:
Sample 10.90843326315147
Linear Shrinkage 6.67466190622478
Nonlinear Shrinkage 6.550023379802649

Covariance itself 24.474579742722565
Error:
Sample 19.750055312281983
Linear Shrinkage 12.961697294614458
Nonlinear Shrinkage 12.718546057157571
3.3325967451589675
7.673033587446809


In [38]:
show_rs(S, c, m, 2)
print(LA.norm(R - R_est, 2))
print(LA.norm(S - S_est, 2))

Correlation itself 3.521618568406586
Error:
Sample 3.9268388174047457
Linear Shrinkage 1.9705457115334815
Nonlinear Shrinkage 1.9122338325807837

Covariance itself 7.296498981115177
Error:
Sample 6.75868433127621
Linear Shrinkage 4.358202636782313
Nonlinear Shrinkage 4.231550286041183
0.8487290134093068
2.2463442512338916


In [39]:
show_rs(S, c, m, 1)
print(LA.norm(R - R_est))
print(LA.norm(S - S_est))

Correlation itself 5.842057143391176
Error:
Sample 11.17148784996043
Linear Shrinkage 6.327856331904511
Nonlinear Shrinkage 6.370594558945804

Covariance itself 14.073948569692456
Error:
Sample 22.733167485618015
Linear Shrinkage 14.345945481031292
Nonlinear Shrinkage 14.422168561841556
3.3325967451589675
7.673033587446809


In [44]:
# def gen_L(S: np.ndarray, eta = 0.8):
R = cov2cor(S)
L = abs(L)
c = InfoCorrBand(X = np.eye(N), L = L) # you can ignore the 'X' parameter, I create this object solely to get 'self.rowSort'
rowSort = c.rowSort

In [47]:
x = rowSort[0]


array([  1.,  32.,  51.,  50.,  48.,  14.,  38.,  44.,  15.,   6.,  36.,
        37.,  39.,  40.,  41.,   3.,  12.,  42.,  35.,  43.,  45.,  46.,
        52.,  10.,  49.,   7.,  33.,  34.,  25.,  17.,  13.,  18.,  19.,
        20.,  21.,  22.,  16.,  24.,  23.,  26.,  27.,  28.,  29.,  30.,
        31.,  11.,  47.,  54.,  78., 100.,  81.,  82.,  83.,  84.,  85.,
        86.,  87.,  88.,  79.,  89.,  91.,  92.,  93.,  94.,  95.,   4.,
        96.,  97.,  98.,   8.,  99.,  90.,  77.,  66.,  76.,  56.,  57.,
        58.,  59.,  60.,  61.,  62.,  63.,  64.,   5.,  55.,  65.,  67.,
        68.,  69.,  70.,  71.,  72.,   9.,  73.,  74.,  75.,   2.,  80.,
        53.])