In [None]:
import numpy as np
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import trange
import os
from scipy.sparse import issparse
import matplotlib.gridspec as gridspec
from sklearn.neighbors import NearestNeighbors
import warnings

In [None]:
saved_data = np.load('covariances.npz', allow_pickle=True)
dict_cov_stem = saved_data['dict_cov_stem'].item()
dict_cov_prog = saved_data['dict_cov_prog'].item()
dict_cov_diff = saved_data['dict_cov_diff'].item()

In [None]:
def project_W_to_PD(W, D, margin=1e-3):
    # D is diagonal (array or diag matrix)
    d = np.diag(D) if D.ndim == 2 else np.asarray(D)
    Dm12 = 1.0/np.sqrt(d)
    # compute spectral norm of D^{-1/2} W^T
    Z = (W.T * Dm12[:,None])   # columns scaled by D^{-1/2}
    # spectral norm via largest svd value
    smax = np.linalg.svd(Z, compute_uv=False)[0]
    cap = 1.0 - margin
    if smax > cap:
        W *= cap / smax
    return W

def fit_W_l1_l2(
    X_cov,
    inverse_variances,
    w_diag_val,
    l1_reg_strength=1e-2,
    l2_reg_strength=1e-2,
    n_iters=10000,
    learning_rate=0.001,
    tolerance=1e-16,          # for eigvals, as before
    early_stopping=True,
    es_patience=500,          # number of steps with no improvement
    es_min_delta=1e-6         # minimum improvement to reset patience
):
    vals, vecs = np.linalg.eigh(X_cov)
    vals = np.maximum(vals, tolerance)
    X_fit = torch.tensor((vecs * vals) @ vecs.T, dtype=torch.float32)
    if(inverse_variances is float):
        fixed_X_matr = X_fit * inverse_variances - torch.eye(X_fit.size(0))
    else:
        fixed_X_matr = X_fit @ torch.diag(torch.Tensor(inverse_variances)) - torch.eye(X_fit.size(0))

    W = torch.randn(size=(X_cov.shape[0], X_cov.shape[0]))/np.sqrt(X_cov.shape[0])
    W = W.requires_grad_()
    optimizer = optim.Adam([W], lr=learning_rate, weight_decay=l2_reg_strength)

    indices = torch.arange(W.shape[0])
    loss_history = []
    std_history = []

    best_loss = None
    steps_no_improve = 0

    for step in trange(n_iters):
        optimizer.zero_grad()

        WTW = W.T @ W
        loss = torch.norm(fixed_X_matr - WTW @ X_fit, p='fro')**2

        if l1_reg_strength > 0:
            loss += l1_reg_strength * torch.norm(W, p=1)

        loss.backward()
        optimizer.step()

        current_loss = float(loss.item())
        loss_history.append(current_loss)

        # --- early stopping bookkeeping ---
        if early_stopping:
            if best_loss is None or (best_loss - current_loss) > es_min_delta:
                best_loss = current_loss
                steps_no_improve = 0
            else:
                steps_no_improve += 1
                if steps_no_improve >= es_patience:
                    # Optional: uncomment if you want a message
                    # print(f"Early stopping at step {step+1}, best loss = {best_loss:.6g}")
                    break

        # --- diagnostics on W (as you had) ---
        with torch.no_grad():

            if(w_diag_val is not None):
                W[indices, indices] = w_diag_val
            W_np = W.detach().cpu().numpy()
            W_upper = W_np[np.triu_indices(W_np.shape[0], k=1)]
            W_lower = W_np[np.tril_indices(W_np.shape[0], k=-1)]
            std_history.append(np.concatenate((W_upper, W_lower)).std())

    loss_history = np.array(loss_history)
    std_history = np.array(std_history)
    return W.detach().cpu().numpy(), loss_history, std_history




def fit_W_bhatt(
    X_cov,
    inverse_variances,
    w_diag_val,
    n_iters=10000,
    learning_rate=0.001,
    tolerance=1e-16,          # for eigvals, as before
    early_stopping=True,
    es_patience=500,          # number of steps with no improvement
    es_min_delta=1e-6         # minimum improvement to reset patience
):
    vals, vecs = np.linalg.eigh(X_cov)
    vals = np.maximum(vals, tolerance)
    X_fit = (vecs * vals) @ vecs.T

    W = np.random.randn(X_cov.shape[0], X_cov.shape[0])/np.sqrt(X_cov.shape[0])
    G = np.eye(X_cov.shape[0])*inverse_variances

    indices = np.arange(W.shape[0])
    loss_history = []
    std_history = []

    best_loss = None
    steps_no_improve = 0

    for step in trange(n_iters):

        WTW = W.T @ W
        Y = G - WTW

        id_plus_C_Y = np.eye(X_cov.shape[0]) + np.matmul(X_fit, Y)

        loss = 0.5*np.log(np.linalg.det(id_plus_C_Y)/np.linalg.det(Y)**0.5)

        W += learning_rate * W @ (np.linalg.inv(id_plus_C_Y) @ X_fit -0.5*np.linalg.inv(Y))

        project_W_to_PD(W, G)
        current_loss = float(loss.item())
        loss_history.append(current_loss)

        # --- early stopping bookkeeping ---
        if early_stopping:
            if best_loss is None or (best_loss - current_loss) > es_min_delta:
                best_loss = current_loss
                steps_no_improve = 0
            else:
                steps_no_improve += 1
                if steps_no_improve >= es_patience:
                    # Optional: uncomment if you want a message
                    # print(f"Early stopping at step {step+1}, best loss = {best_loss:.6g}")
                    break

        # --- diagnostics on W (as you had) ---
        if(w_diag_val is not None):
            W[indices, indices] = w_diag_val
        W_upper = W[np.triu_indices(W.shape[0], k=1)]
        W_lower = W[np.tril_indices(W.shape[0], k=-1)]
        std_history.append(np.concatenate((W_upper, W_lower)).std())

    loss_history = np.array(loss_history)
    std_history = np.array(std_history)
    return W, loss_history, std_history


In [None]:
std_per_type = {}
W_per_type = {}
for tp, dict_cov in [('stem', dict_cov_stem), ('prog', dict_cov_prog), ('diff', dict_cov_diff)]:
    std_per_type[tp] = {}
    W_per_type[tp] = {}
    for header in dict_cov.keys():
        X_cov = dict_cov[header]['Cov']
        X_cov_to_fit = X_cov/(np.outer(np.diag(X_cov)**0.5, np.diag(X_cov)**0.5))
        max_ev = np.max(np.linalg.eigvalsh(X_cov_to_fit))
        init_inv_variances = 1/np.diag(X_cov_to_fit)
        #W, loss_history, std_history = fit_W_l1_l2(X_cov, init_inv_variances, 1/np.sqrt(X_cov.shape[0]), l1_reg_strength=0.0, l2_reg_strength=0.0, n_iters=20000, learning_rate=0.001)
        W, loss_history, std_history = fit_W_l1_l2(X_cov_to_fit,
                                                   init_inv_variances,
                                                   w_diag_val= None, #1/np.sqrt(X_cov.shape[0]),
                                                    #l1_reg_strength=0,
                                                    #l2_reg_strength=0,
                                                    n_iters=5000,
                                                    learning_rate=0.001,
                                                    tolerance=1e-16,          # for eigvals, as before
                                                    early_stopping=False,
                                                    es_patience=3000,          # number of steps with no improvement
                                                    es_min_delta=1e-6         # minimum improvement to reset patience
                                                    )
        #W, loss_history, std_history = fit_W_bhatt(X_cov,
        #                                           init_inv_variances,
        #                                           w_diag_val= 1/np.sqrt(X_cov.shape[0]),
        #                                            n_iters=50000,
        #                                            learning_rate=0.001,
        #                                            tolerance=1e-16,          # for eigvals, as before
        #                                            early_stopping=False,
        #                                            es_patience=1000,          # number of steps with no improvement
        #                                            es_min_delta=1e-6          # minimum improvement to reset patience
        #                                            )
        
        W_per_type[tp][header] = (W, init_inv_variances, X_cov_to_fit)
        W_upper = W[np.triu_indices(W.shape[0], k=1)]
        W_lower = W[np.tril_indices(W.shape[0], k=-1)]
        std_per_type[tp][header] = {'std' : np.concatenate((W_upper, W_lower)).std(), 'max_ev' : max_ev}
        std = std_per_type[tp][header]['std']
        fig, axs = plt.subplots(1,3, figsize=(10,4))
        axs[0].plot(std_history, label=header)
        axs[1].plot(loss_history, label=header)
        axs[2].imshow(W, vmin=-0.5, vmax=0.5, cmap='seismic')
        axs[0].set_title(f'History for {header}')
        axs[2].set_title(f'Std: {std:.3g}, max_ev: {max_ev:.3g}')
        axs[0].set_xlabel('Iteration')
        axs[0].set_ylabel('Std')
        axs[0].legend()
        axs[1].set_yscale('log')
        axs[1].set_xscale('log')
        axs[1].set_xlabel('Iteration')
        axs[1].set_ylabel('Loss')
        axs[1].legend()
        plt.tight_layout()
        plt.show()

In [None]:
from scipy.stats import mannwhitneyu
stem_stds = []
stem_stds.extend([std_per_type['stem'][header]['std'] for header in std_per_type['stem'].keys()])
stem_stds.extend([std_per_type['prog'][header]['std'] for header in std_per_type['prog'].keys()])
stem_stds = np.array(stem_stds)
diff_stds = np.array([std_per_type['diff'][header]['std'] for header in std_per_type['diff'].keys()])

In [None]:
rescaler = 0.5/np.mean(np.concatenate((stem_stds, diff_stds)))
np.mean(stem_stds)/np.mean(diff_stds)

In [None]:
mannwhitneyu(stem_stds, diff_stds, alternative='two-sided')

In [None]:
for tp in W_per_type.keys():
    if tp == 'stem':
        dict_cov = dict_cov_stem
    elif tp == 'prog':
        dict_cov = dict_cov_prog
    elif tp == 'diff':
        dict_cov = dict_cov_diff
        
    for header in dict_cov.keys():
        W,init_inv_variances,X_cov_to_fit = W_per_type[tp][header]
        X_fit = np.linalg.inv(np.eye(X_cov_to_fit.shape[0])*init_inv_variances - W.T @ W)
        print(np.linalg.eigvalsh(X_fit))
        #plt.hist()
        #plt.hist(np.linalg.eigvalsh(X_fit), alpha=0.7, bins=10)
        #plt.yscale('log')
        #plt.show()
        fig, axs = plt.subplots(1,3, figsize=(14,4))
        axs[0].imshow(X_cov_to_fit, vmin=-1.0, vmax=1.0, cmap='seismic')
        axs[1].imshow(X_fit, vmin=-1.0, vmax=1.0, cmap='seismic')
        axs[2].hist(np.linalg.eigvalsh(X_cov_to_fit), bins=10, label='Original')
        axs[2].hist(np.linalg.eigvalsh(X_fit), alpha=0.7, bins=10, label='Fitted')
        axs[2].legend()
        plt.show()