# Empirical weak proxy analysis on UKBB data 

Method to identify potentially weak proxies X and Z that violate the dual and primal equations. 
For the following explanation, assume we are looking at the dual equation $E[X(D-\gamma'Z)]$. 

1. Compute the $Cov(DX)$. Non-zero covariance is identified by asserting the pearson correlation p-value between some feature $X_i$ and $D$ > $.05/dx.$ We only consider these $X_i \in X_{ss}$ features henceforth. 


2. Compute $Cov(XZ)$. From here we can identify potentially weak $X$ and $Z$.
    1. To identify weak $X$, we look at all $X_i \in X_{ss}$, and if there are no $Z$ features with non-zero covariance (determined by pearson p-value again) with $X_i$, we identify $X_i$ as a weak proxy. 
    2. Note we want the minimal subset of $Z$ such that its covariance with $X$ still span $Cov(DX)$. To do so, we pick the $Z$s that have the most associations with the $X_i \in  X_{ss}$. Let $Cov(XZ)_{ss}$ denote the binary $dx$ by $dz$ matrix where 1 denotes the Z and X feat had a p-value < .05/dx * dz. We have a hyperparameter $N$ (or `popular_z_thresh`) that denotes, in the case of $Cov(X_{ss}Z)_{ss}$, the column sum of how many $X_i \in  X_{ss}$ each $Z$ has association with. We take only the $Z$ feats that have > $N$ associations.
    

We repeat this for the primal, flipping D <-> Y and Z <-> X. The final X,Z features kept must pass both the identification tests for both the primal and dual. Varying $N$ controls how rigid we are. 

Note: I did not account for all the replications of these tests, at least # D_Y * 2 (dual + primal), so p-value correction might need to be adjusted. 


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from proximalde.ukbb_proximal import ProximalDE_UKBB
from proximalde.ukbb_data_utils import *
import seaborn as sns 
pd.options.display.max_columns = None
from tqdm import tqdm 

### Tools for loading data and visualization

In [None]:
from scipy.stats import t
import textwrap 
def wrap_labels(labels, max_characters):
    return [textwrap.fill(label, max_characters) for label in labels]

def get_cov(x, y, get_pvals=False):
    """
    Returns covariance matrix between the columns of x and y.
    If get_pvals, also returns the pvals associated with a pearson 
    correlation, and the pval threshold of significance = .05 / dy * dx
    """
    xc = x - np.mean(x, axis=0)
    yc = y - np.mean(y, axis=0)
    cov = np.dot(xc.T, yc) / (xc.shape[0] - 1)
    
    if get_pvals:
        dx, dy = xc.shape[1], yc.shape[1]
        std_x = np.std(xc, axis=0, ddof=1)
        std_y = np.std(yc, axis=0, ddof=1)
        corr_matrix = cov / np.outer(std_x, std_y)
        
        dof = xc.shape[0] - 2
        t_stat = corr_matrix * np.sqrt(dof / (1 - corr_matrix**2))
        pvals = 2 * t.sf(np.abs(t_stat), dof)
        pvals[np.isnan(pvals)] = 1 #replace NaNs
        
        return cov, pvals, .05 / (dx * dy)
    else:
        return cov
    
UKBB_DATA_DIR = '/oak/stanford/groups/rbaltman/karaliu/bias_detection/cohort_creation/data/'

def _load_data(fname: str):
    data = np.load(UKBB_DATA_DIR + f'{fname}_data_rd.npy', allow_pickle=False)    
    feats = np.load(UKBB_DATA_DIR + f'{fname}_feats_rd.npy', allow_pickle=False)
    assert np.isnan(data).sum() == 0, 'NaN values cannot exist in data'
    return data, feats
    
def load_XZ_data():
    Z, Z_feats = _load_data(fname = 'srMntSlp')
    X, X_feats = _load_data(fname = 'biomMed')
    return X, X_feats, Z, Z_feats

def load_DY_data(D_label, Y_label):
    D_df = pd.read_csv(UKBB_DATA_DIR + 'updated_sa_df_pp.csv')
    D = D_df[D_label].to_numpy()     
    Y = pd.read_csv(UKBB_DATA_DIR + 'updated_Y_labels.csv')[Y_label].to_numpy()[:,None] 
    return D, Y

def load_res_data(D_label, Y_label):
    _get_path = lambda fname: f'/oak/stanford/groups/rbaltman/karaliu/bias_detection/causal_analysis/data_hm/{fname}'
    D_label = D_label.replace('_', '')
    Winfo = f'_Wrm{D_label}'
    Yres = np.load(_get_path(f'Yres_{Y_label}{Winfo}.npy')) 
    Dres = np.load(_get_path(f'Dres_{D_label}.npy')) 
    Xres = np.load(_get_path(f'Xres{Winfo}.npy')) 
    Zres = np.load(_get_path(f'Zres{Winfo}.npy')) 
    return Xres, Zres, Yres, Dres

def XZ_hparam_plot(dual_or_primal='dual'):
    """
    Tool for visualizing covariance matrices 
    and how different thresholds for N affect the corresponding covariance.
    Could be for dual or primal. 
    """
    
    D_labels = ['Female', 'Obese','Black', 'Asian']
    if dual_or_primal=='dual':
        Y_labels = ['OA']
    else:
        Y_labels = ['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']
    for D_label in D_labels:
        print(D_label)
        for Y_label in Y_labels:
            Xres, Zres, _, Dres = load_res_data(D_label, Y_label=Y_label)
            if dual_or_primal=='dual':
                Xrpr, Zrpr, Drpr, label = 'X', 'Z', 'D', D_label
                XZres_cov, XZres_pvals, XZres_thresh = get_cov(Xres, Zres, get_pvals=True)
                DXres_cov, DXres_pvals, DXres_thresh = get_cov(Dres, Xres, get_pvals=True)
            else:
                Xrpr, Zrpr, Drpr, label = 'Z', 'X', 'Y', Y_label
                XZres_cov, XZres_pvals, XZres_thresh = get_cov(Zres, Xres, get_pvals=True)
                DXres_cov, DXres_pvals, DXres_thresh = get_cov(Yres, Zres, get_pvals=True)
            
            # We only care about X feats with st.sig. assn with D
            ss_DXidx = (DXres_pvals < DXres_thresh).squeeze()
            if dual_or_primal=='dual':
                fig, axs = plt.subplots(1, 2, figsize=(20, 5), dpi=70)
            else:
                fig, axs = plt.subplots(1, 2, figsize=(12, 5), dpi=70)

            im = axs[0].imshow((XZres_pvals[ss_DXidx] < XZres_thresh), aspect='auto', cmap='Blues', interpolation='nearest')
            axs[0].set_title(f"Nonzero Covariance({Xrpr},{Zrpr})\n(if spearman pvalue < .05/dz*dx)", fontsize=12)
            axs[0].set_ylabel(f"{Xrpr} feats", fontsize=10)
            axs[0].set_xlabel(f"{Zrpr} feats", fontsize=10)
            cbar = fig.colorbar(im, ax=axs[0], orientation='vertical')
            cbar.set_ticks([0, 1])
            cbar.set_ticklabels(['Zero', 'Nonzero'])

            
            nZfeats = []
            Xfeats_w_zero_Zfeats = []
            for i in range(40):
                keep = (XZres_pvals[ss_DXidx] < XZres_thresh).sum(axis=0) > i
                zero = ((XZres_pvals[ss_DXidx][:, keep] < XZres_thresh).sum(axis=1) == 0).sum()
                nZfeats.append(keep.sum())
                Xfeats_w_zero_Zfeats.append(zero)
                
            axs[1].set_title(f"Each {Zrpr} feat's # st.sig. correlations w/ all {Xrpr} feats, {Drpr}={label}", fontsize=12)
            ax3 = axs[1]
            ax3.plot(range(40), nZfeats, color='blue')
            ax3.set_ylabel(f"# {Zrpr} feats correlated with >N {Xrpr} feats", color='blue')
            ax3.set_xlabel(f"N\n(hparam for filtering {Xrpr} feats)")
            ax3.tick_params(axis='y', labelcolor='blue')

            ax3b = ax3.twinx()
            ax3b.plot(range(40), Xfeats_w_zero_Zfeats, color='green')
            ax3b.set_ylabel(f"# {Xrpr} feats w/ 0 st.sig. feat correlation\nafter rm {Zrpr} feats correlated with <N {Xrpr} feats", color='green')
            ax3b.tick_params(axis='y')

            ax3.set_title(f'How removing {Zrpr} feats affects Cov(X,Z)')
            ax3.grid(True, axis='y', linestyle='--', linewidth=0.5, color='lightgray')

            # Adjust layout to prevent overlapping
            plt.tight_layout()
            plt.suptitle(f"{dual_or_primal} violation plots for {Drpr}={label}\nusing Xres, Zres")
            # Show the combined plots
            plt.show()

def XZ_vis_cov(rmX_zeroZ_dual, popssZ_dual, rmZ_zeroX_primal, popssX_primal):


    for D_label in ['Female', 'Obese','Black', 'Asian']:
        for Y_label in ['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']:
            print(f"{D_label}->{Y_label}")
            W, W_feats, X, X_feats, Z, Z_feats, Y, D = load_ukbb_data(D_label=D_label, Y_label=Y_label)
            Xres, Zres, Yres, Dres = load_res_data(D_label, Y_label)

            # Filter X,Z based on bad proxies proxies 
            Xprm_idx = popssX_primal[Y_label][D_label]    
            Zprm_idx = ~rmZ_zeroX_primal[Y_label][D_label]
            Zdual_idx = popssZ_dual['OA'][D_label]   
            Xdual_idx = ~rmX_zeroZ_dual['OA'][D_label]

            Xres = Xres[:,(Xprm_idx & Xdual_idx)]
            Zres = Zres[:,(Zprm_idx & Zdual_idx)]

            XZres_cov, XZres_pvals, XZres_thresh = get_cov(Xres,Zres, get_pvals=True)
            DXres_cov, DXres_pvals, DXres_thresh = get_cov(Dres, Xres, get_pvals=True)
            YZres_cov, YZres_pvals, YZres_thresh = get_cov(Yres, Zres, get_pvals=True)

            XZres_cov1 = np.concatenate([XZres_cov, YZres_cov], axis=0)
            DXres_cov1 = np.vstack([np.array([[0]]), DXres_cov.reshape(-1, 1)])
            XZres_cov2 = np.concatenate([XZres_cov1, DXres_cov1], axis=1)
            plt.subplots(1,1,figsize=(12,8),dpi=60)
            sns.heatmap(np.abs(XZres_cov2), cmap='Blues')
            plt.axhline(y=Xres.shape[1], color='black', linewidth=2)  # Line between N and N+1 (where the vector is appended)
            plt.axvline(x=Zres.shape[1], color='black', linewidth=2)  # Line between N and N+1 (where the vector is appended)
            xtick_labels = list(Xint[(Xprm_idx & Xdual_idx)]) + [f'D={D_label}']  # First N are xi, last one is D
            plt.yticks(ticks=np.arange(Xres.shape[1]+1)+.5, labels=xtick_labels, rotation=0)
            ytick_labels = list(Zint[(Zprm_idx & Zdual_idx)]) + [f'Y={Y_label}']  # First N are xi, last one is D
            plt.xticks(ticks=np.arange(Zres.shape[1]+1)+.5, labels=wrap_labels(ytick_labels,50),rotation=90, fontsize=8)
            plt.xlabel('Z feats')
            plt.ylabel('X feats')
            plt.title(f'|Cov(X,Z)| after filtering X,Z\n{D_label}->{Y_label}')
            plt.show()

            assert ((XZres_pvals < XZres_thresh).sum(axis=0) > 0).all()
            assert ((XZres_pvals < XZres_thresh).sum(axis=1) > 0).all()
            XZres_cov1 = np.concatenate([XZres_pvals < XZres_thresh, YZres_pvals < YZres_thresh], axis=0)
            DXres_cov1 = np.vstack([np.array([[0]]), DXres_pvals.reshape(-1, 1) < DXres_thresh])
            XZres_cov2 = np.concatenate([XZres_cov1, DXres_cov1], axis=1)
            plt.subplots(1,1,figsize=(12,8),dpi=60)
            sns.heatmap(np.abs(XZres_cov2), cmap='Blues')
            plt.axhline(y=Xres.shape[1], color='black', linewidth=2)  # Line between N and N+1 (where the vector is appended)
            plt.axvline(x=Zres.shape[1], color='black', linewidth=2)  # Line between N and N+1 (where the vector is appended)
            xtick_labels = list(Xint[(Xprm_idx & Xdual_idx)]) + [f'D={D_label}']  # First N are xi, last one is D
            plt.yticks(ticks=np.arange(Xres.shape[1]+1)+.5, labels=xtick_labels, rotation=0)
            ytick_labels = list(Zint[(Zprm_idx & Zdual_idx)]) + [f'Y={Y_label}']  # First N are xi, last one is D
            plt.xticks(ticks=np.arange(Zres.shape[1]+1)+.5, labels=wrap_labels(ytick_labels,50),rotation=90, fontsize=8)
            plt.xlabel('Z feats')
            plt.ylabel('X feats')
            plt.title(f'Nonzero Cov(X,Z) after filtering X,Z\n{D_label}->{Y_label}')
            plt.show()

            DYXres = np.concatenate([Dres, Yres, Xres], axis=1)
            Zall_cov, Zall_pvals, Zall_thresh = get_cov(DYXres, Zres, get_pvals=True)
            plt.subplots(1,1,figsize=(12,8),dpi=60)
            sns.heatmap(np.abs(Zall_cov), cmap='Blues')
            xtick_labels =[f'D={D_label}', f'Y={Y_label}'] + list(Xint[(Xprm_idx & Xdual_idx)])  # First N are xi, last one is D
            plt.yticks(ticks=np.arange(Xres.shape[1]+2)+.5, labels=xtick_labels, rotation=0)
            ytick_labels = list(Zint[(Zprm_idx & Zdual_idx)])
            plt.xticks(ticks=np.arange(Zres.shape[1])+.5, labels=wrap_labels(ytick_labels,50),rotation=90, fontsize=10)
            plt.ylabel('X feats')
            plt.xlabel('Z feats')
            plt.title(f'|Cov(DYX,Z)| after filtering X,Z\n{D_label}->{Y_label}')
            plt.show()


            DYZres = np.concatenate([Dres, Yres, Zres], axis=1)
            Zall_cov, Zall_pvals, Zall_thresh = get_cov(DYZres, Xres, get_pvals=True)
            plt.subplots(1,1,figsize=(12,8),dpi=60)
            sns.heatmap(np.abs(Zall_cov), cmap='Blues')
            xtick_labels =[f'D={D_label}', f'Y={Y_label}'] + list(Zint[(Zprm_idx & Zdual_idx)])  # First N are xi, last one is D
            plt.yticks(ticks=np.arange(Zres.shape[1]+2)+.5, labels=xtick_labels, rotation=0)
            ytick_labels = list(Xint[(Xprm_idx & Xdual_idx)])
            plt.xticks(ticks=np.arange(Xres.shape[1])+.5, labels=wrap_labels(ytick_labels,50),rotation=90, fontsize=10)
            plt.ylabel('Z feats')
            plt.xlabel('X feats')
            plt.title(f'|Cov(DYZ,X| after filtering X,Z\n{D_label}->{Y_label}')
            plt.show()

In [None]:
# Interpretable X, Z features 
X, X_feats, Z, Z_feats = load_XZ_data()
Xint = get_int_feats(X_feats)
Zint = get_int_feats(Z_feats)

# Production 

### Step 1. We need to pick `popular_Z_thresh` and optionally `min_Z_assns_per_X`
`popular_Z_thresh`: we want to pick the minimal number of Z feats = Z' such that cov(Z', X) still linearly span cov(D, X). Thus we crudely only look at the features Z that have a nonzero covariance with at least `popular_Z_thresh` # X's. Choose this number to be the minimal set such that cov(D, Xi) still has some Zj for all Xi. This might be a different number for each D. 

<!-- IGNORE FOR NOW:
`min_Z_assns_per_X`: we trivially suggest removing all X feats that have no Z feature associations (an alternative I haven't looked into would be to add on Zs). We can also choose to remove X feats that have < `popular_Z_thresh` Z feats associated with it. Default is to only remove  < `popular_Z_thresh` = 1. 
 -->

In [None]:
XZ_hparam_plot(dual_or_primal='dual')

In [None]:
XZ_hparam_plot(dual_or_primal='primal')

### Step 2. Identify weak X and Z proxies

In [None]:
def id_weak_XZ_proxies(popular_Z_thresh: int, dual_or_primal='dual', verbose:bool=False):
    ''' 
    For dual, checks if cov(D,X) in span( cov(X,Z) ). 
    A nonzero covariance is converted into a binary outcome
    by a spearman association test where pval < .05/m for
    Bonferroni corrected m.
    
    For primal, cov(Y,Z) in span( cov(Z,X))
    
    Parameters
    ----------
    dual_or_primal: str
        Tests for violations in either dual or primal. 
    verbose : bool
        Print for interpretability
    popular_Z_thresh: int 
        Threshold to set for filtering what Z feats to keep,
        based on how "popular" it is in terms of # st. sig. 
        associations with X feats (in dual; for primal, is 
        X and Z feats respectively)

    Returns
    -------
    All of the below are a nested dictionary, first with Y_label, then with D_label.
        The final value is a boolean index denoting which X and Z's should be kept 
        based on how good of a proxy they are to eachother based on the violation. 
    
    ssDX_dual : Dict[str, Dict[str, np.array]] 
        if dual: np.array is of size X where each bool item denotes if the 
            ith X feat is st.sig. associated with D
        if primal: np.array is of size Z where each bool item denotes if the 
            ith Z feat is st.sig. associated with Y
        Not used for filtering X,Z as good proxies, but could be useful meta-
        data.
        
    rmX_zeroZ_dual : Dict[str, Dict[str, np.array]]
        if dual: np.array is of size X where each bool item denotes if the 
            ith X feat has 0 Z feats st.sig. associated with it, and thus 
            should be removed
        if primal: np.array is of size Z where each bool item denotes if the 
            ith Z feat has 0 X feats st.sig. associated with it, and thus 
            should be removed
            
    popssZ_dual : Dict[str, Dict[str, np.array]]
        Based on threshold set by popular_Z_thresh.
        if dual: np.array is of size Z where each bool item denotes if the 
            ith Z feat has > popular_Z_thresh st.sig. associations with 
            all X feats, and thus should be kept 
        if primal: np.array is of size X where each bool item denotes if the 
            ith X feat has > popular_Z_thresh st.sig. associations with 
            all Z feats, and thus should be kept 
    '''
    assert dual_or_primal in ['dual', 'primal']
    ssDX_dual = {}
    rmX_zeroZ_dual= {}
    popssZ_dual = {}
    
    D_labels = ['Female', 'Obese','Black', 'Asian']
    if dual_or_primal=='dual':
        Y_labels = ['OA']
    else:
        Y_labels = ['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']
    
    for Y_label in Y_labels:
        ssDX_dual[Y_label] = {}
        rmX_zeroZ_dual[Y_label] = {}
        popssZ_dual[Y_label] = {}
        
        for D_label in D_labels:
            print(f"{D_label}->{Y_label}")

            Xres, Zres, Yres, Dres = load_res_data(D_label, Y_label)
            if dual_or_primal=='dual':
                Xfeat_int, Zfeat_int = Xint, Zint
                Xrpr, Zrpr, Drpr, Yrpr, label = 'X', 'Z', 'D', 'Y', D_label
                XZres_cov, XZres_pvals, XZres_thresh = get_cov(Xres, Zres, get_pvals=True)
                DXres_cov, DXres_pvals, DXres_thresh = get_cov(Dres, Xres, get_pvals=True)
            else:
                Xfeat_int, Zfeat_int = Zint, Xint
                Xrpr, Zrpr, Drpr, Yrpr, label = 'Z', 'X', 'Y', 'D', Y_label
                XZres_cov, XZres_pvals, XZres_thresh = get_cov(Zres, Xres, get_pvals=True)
                DXres_cov, DXres_pvals, DXres_thresh = get_cov(Yres, Zres, get_pvals=True)
    
            # We only care about X feats with st.sig. assn with D (dual)
            #        Z feats with st.sig. assn with Y (primal)
            ss_DXidx = (DXres_pvals < DXres_thresh).squeeze()
            ssDX_dual[Y_label][D_label] = ss_DXidx
            
            print(f"Ignored {sum(~ss_DXidx)} {Xrpr} feats due to low assn to {Drpr}={label}")
            if verbose:
                for idx in np.where(~ss_DXidx)[0][:5]:
                    print(f"{Xrpr} feat: {Xfeat_int[idx]}")
                    print(f"\tCov w/ {Drpr}={label}: {round(DXres_cov.squeeze()[idx], 5)}")

        
            # We only keep the Z feats that are sufficiently "popular" 
            # ie. have > popular_Z_thresh st.sig. correlations with all X feats
            popular_Zidx = ((XZres_pvals < XZres_thresh).sum(axis=0) > popular_Z_thresh).squeeze()
            # Z covariance w/ X feats must be both st.sig. and Z feats must be "popular" 
            ss_popular_Zidx = (XZres_pvals < XZres_thresh).squeeze() & popular_Zidx[None, :]
            popssZ_dual[Y_label][D_label] = popular_Zidx
            print(f"Keeping {popular_Zidx.sum()} {Zrpr} feats w/ > {popular_Z_thresh} influences")
            
            # By default, all Xs are kept
            rmX_zeroZ_dual[Y_label][D_label] = np.zeros(ss_DXidx.shape).astype(bool)
            if not verbose:
                rmX_zeroZ_dual[Y_label][D_label][(ss_popular_Zidx.sum(axis=1) == 0) |  ~ss_DXidx] = True
            else:
                # Only look at X feats w/ nonzero D association
                for idx in np.where(ss_DXidx)[0]:
                    if verbose and idx < 10:
                        print("~"*10)
                        print(f"{Xrpr} feat: {Xfeat_int[idx]}")
                        print(f"\tCov w/ {Drpr}={label}: {round(DXres_cov.squeeze()[idx], 5)}")

                    # If this X feat has no st.sig. Z associations, put in rm list
                    if sum(ss_popular_Zidx[idx]) == 0:
                        rmX_zeroZ_dual[Y_label][D_label][idx] = True
                    else: 
                        if verbose and idx < 10:
                            print(f"{Zrpr} feats w/ highest cov. to {Xrpr} feat ({sum(ss_popular_Zidx[idx])} found):")
                            sorted_XZ_idx = np.argsort(np.abs(XZres_cov)[idx])[::-1]
                            sorted_XZ_idx = np.array([x for x in sorted_XZ_idx if ss_popular_Zidx[idx,x]])[:5]
                            for j in sorted_XZ_idx:
                                print(f'\t{Zfeat_int[j]}: {round(XZres_cov[idx, j], 5)}')
            print(f"{Xrpr} feats w/out any {Zrpr} assns: {Xfeat_int[rmX_zeroZ_dual[Y_label][D_label]]}")
            print()
    return ssDX_dual, rmX_zeroZ_dual, popssZ_dual



In [None]:
ssDX_dual, rmX_zeroZ_dual, popssZ_dual = id_weak_XZ_proxies(dual_or_primal="dual", verbose=False, popular_Z_thresh=25)


In [None]:
ssYZ_primal, rmZ_zeroX_primal, popssX_primal = id_weak_XZ_proxies(dual_or_primal="primal", verbose=False, popular_Z_thresh=15)

In [None]:
XZ_vis_cov(rmX_zeroZ_dual, popssZ_dual, rmZ_zeroX_primal, popssX_primal)

# Re-run

In [None]:
# When N = 35 
for D_label in ['Female', 'Obese','Black', 'Asian']:
#     for Y_label in ['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']:
    for Y_label in ['OA', 'myoc', 'deprs', 'back']:
        print(f"{D_label}->{Y_label}")
        W, W_feats, X, X_feats, Z, Z_feats, Y, D = load_ukbb_data(D_label=D_label, Y_label=Y_label)

#         np.random.seed(4)
#         est = ProximalDE_UKBB(cv=3, semi=True, dual_type='Z', ivreg_type='adv',
#                     multitask=False, n_jobs=-1, random_state=3, verbose=1)
#         est.fit(W, D, Z, X, Y, D_label=D_label, Y_label=Y_label)                
#         sm = est.summary(decimals=5)
#         print("OLD:")
#         print(sm.tables[0])
#         print(sm.tables[2])   
                
        Xprm_idx = popssX_primal[Y_label][D_label]    
        Zprm_idx = ~rmZ_zeroX_primal[Y_label][D_label]
        
        Zdual_idx = popssZ_dual['OA'][D_label]   
        Xdual_idx = ~rmX_zeroZ_dual['OA'][D_label]
        np.random.seed(4)
        est = ProximalDE(cv=3, semi=True, dual_type='Z', ivreg_type='adv',
                    multitask=False, n_jobs=-1, random_state=3, verbose=1)
        est.fit(W, D, Z, X, Y, D_label=D_label, Y_label=Y_label, 
                Zres_idx=(Zprm_idx & Zdual_idx), Xres_idx=(Xprm_idx & Xdual_idx))                
        sm = est.summary(decimals=5)
        print("New:")
        print(sm.tables[0])
        print(sm.tables[2])   
        print()
        print()
        print()

### Ignore

In [None]:
for D_label in ['Female', 'Obese','Black']:
    for Y_label in ['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']:
        print(f"{D_label}->{Y_label}")
        W, W_feats, X, X_feats, Z, Z_feats, Y, D = load_ukbb_data(D_label=D_label, Y_label=Y_label, W='lite')

#         np.random.seed(4)
#         est = ProximalDE_UKBB(cv=3, semi=True, dual_type='Z', ivreg_type='adv',
#                     multitask=False, n_jobs=-1, random_state=3, verbose=1)
#         est.fit(W, D, Z, X, Y, D_label=D_label, Y_label=Y_label)                
#         sm = est.summary(decimals=5)
#         print("OLD:")
#         print(sm.tables[0])
#         print(sm.tables[2])   
                
        Xprm_idx = popssX_primal[Y_label][D_label]    
        Zprm_idx = ~rmZ_zeroX_primal[Y_label][D_label]
        
        Zdual_idx = popssZ_dual['OA'][D_label]   
        Xdual_idx = ~rmX_zeroZ_dual['OA'][D_label]
        
        np.random.seed(4)
        est = ProximalDE(cv=3, semi=True, dual_type='Z', ivreg_type='adv',
                    multitask=False, n_jobs=-1, random_state=3, verbose=1)
        est.fit(W, D, Z, X, Y, D_label=D_label, 
                Y_label=Y_label, Zres_idx=(Zprm_idx & Zdual_idx), 
                Xres_idx=(Xprm_idx & Xdual_idx), save_fname_addn='_lite')                
        sm = est.summary(decimals=5)
        print("New:")
        print(sm.tables[0])
        print(sm.tables[2])   
        print()
        print()
        print()

In [None]:
for D_label in ['Black']:
    for Y_label in ['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']:
        try:
            print(f"{D_label}->{Y_label}")
            W, W_feats, X, X_feats, Z, Z_feats, Y, D = load_ukbb_data(D_label=D_label, Y_label=Y_label)

            Xprm_idx = popssX_primal[Y_label][D_label]    
            Zprm_idx = ~rmZ_zeroX_primal[Y_label][D_label]

            Zdual_idx = popssZ_dual['OA'][D_label]   
            Xdual_idx = ~rmX_zeroZ_dual['OA'][D_label]
            np.random.seed(4)
            est = ProximalDE(cv=3, semi=True, dual_type='Z', ivreg_type='adv',
                        multitask=False, n_jobs=-1, random_state=3, verbose=1)
            est.fit(W, D, Z, X, Y, D_label=D_label, 
                    Y_label=Y_label,try_split='Z')                
            sm = est.summary(decimals=5)
            print("New:")
            print(sm.tables[0])
            print(sm.tables[2])   
            print()
            print()
            print()
        except Exception as e:
            print(e)

In [None]:
svalues, svalues_crit = est.covariance_rank_test(calculate_critical=True)


In [None]:
plt.title(f"D={D_label}_Y={Y_label}\nNumber of singular values above threshold: {np.sum(svalues >= svalues_crit)}. "
          f"\nThreshold={svalues_crit:.3f}. Top singular value={svalues[0]:.3f}")
plt.scatter(np.arange(len(svalues)), svalues)
plt.axhline(svalues_crit)
plt.show()

In [None]:
for D_label in ['Female', 'Obese','Black', 'Asian']:
    for Y_label in ['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']:
        print(f"{D_label}->{Y_label}")

        Xprm_idx = popssX_primal[Y_label][D_label]    
        Zprm_idx = np.ones(Zres.shape[1]).astype(bool)
        Zprm_idx[rmZ_zeroX_primal[Y_label][D_label]] = False
        
        Zdual_idx = popssZ_dual[D_label]   
        Xdual_idx = np.ones(Xres.shape[1]).astype(bool)
        Xdual_idx[rmX_zeroZ_dual[D_label]] = False
        
#         print(Xint[~(Xprm_idx & Xdual_idx)])
        print(Zint[(Zprm_idx & Zdual_idx)])
                


In [None]:
for D_label in ['Female', 'Obese','Black', 'Asian']:
    for Y_label in ['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']:
        print(f"{D_label}->{Y_label}")
        W, W_feats, X, X_feats, Z, Z_feats, Y, D = load_ukbb_data(D_label=D_label, Y_label=Y_label)

#         np.random.seed(4)
#         est = ProximalDE(cv=3, semi=True, dual_type='Z', ivreg_type='adv',
#                     multitask=False, n_jobs=-1, random_state=3, verbose=1)
#         est.fit(W, D, Z, X, Y, D_label=D_label, Y_label=Y_label)                
#         sm = est.summary(decimals=5)
#         print("OLD:")
#         print(sm.tables[0])
#         print(sm.tables[2])   
                
        Xprm_idx = popssX_primal[Y_label][D_label]    
        Zprm_idx = np.ones(Zres.shape[1]).astype(bool)
        Zprm_idx[rmZ_zeroX_primal[Y_label][D_label]] = False
        
        Zdual_idx = popssZ_dual[D_label]   
        Xdual_idx = np.ones(Xres.shape[1]).astype(bool)
        Xdual_idx[rmX_zeroZ_dual[D_label]] = False
        
                
        np.random.seed(4)
        est = ProximalDE(cv=3, semi=True, dual_type='Z', ivreg_type='adv',
                    multitask=False, n_jobs=-1, random_state=3, verbose=1)
        est.fit(W, D, Z, X, Y, D_label=D_label, Y_label=Y_label, Zres_idx=(Zprm_idx & Zdual_idx))                
        sm = est.summary(decimals=5)
        print("Zonly:")
        print(sm.tables[0])
        print(sm.tables[2])   
        np.random.seed(4)
        est = ProximalDE(cv=3, semi=True, dual_type='Z', ivreg_type='adv',
                    multitask=False, n_jobs=-1, random_state=3, verbose=1)
        est.fit(W, D, Z, X, Y, D_label=D_label, Y_label=Y_label, Xres_idx=(Xprm_idx & Xdual_idx))                
        sm = est.summary(decimals=5)
        print("Xonly:")
        print(sm.tables[0])
        print(sm.tables[2])  
        
        print()
        print()
        print()

In [None]:
# # From this we choose 
# popular_Z_thresh = {'Female': 20, 'Obese': 30, 'Black': 20, "Asian": 20}
# popular_Z_thresh = {'Female': 5, 'Obese': 5, 'Black': 5, "Asian": 5}

In [None]:
Y_label = 'OA'
for D_label in ['Female', 'Obese','Black', 'Asian']:
    print("\n" + "#"*20)

    print(D_label)
    Xres, Zres, Yres, Dres = load_res_data(D_label, Y_label)
    DXres_cov, DXres_pvals, DXres_thresh = get_cov(Dres, Xres, get_pvals=True)
    
    # We only care about X feats with st.sig. assn with D
    ss_DXidx = (DXres_pvals < DXres_thresh).squeeze()
    
    ssDX_dual[D_label] = ss_DXidx
    rmX_zeroZ_dual[D_label] = []
    
    if verbose:
        print(f"Ignored {sum(~ss_DXidx)} X feats due to low assn to D={D_label}")
        for idx in np.where(~ss_DXidx)[0][:2]:
            print(f"X feat: {Xint[idx]}")
            print(f"\tCov w/ D={D_label}: {round(DXres_cov.squeeze()[idx], 5)}")
    
    # W the same for D in [Black, Asian] so can reuse XZres
    if D_label != 'Asian':
        XZres_cov, XZres_pvals, XZres_thresh = get_cov(Xres, Zres, get_pvals=True)

    popular_Zidx = ((XZres_pvals < XZres_thresh).sum(axis=0) > popular_Z_thresh[D_label])
    popssZ_dual[D_label] = popular_Zidx.squeeze()

    if verbose:
        print(f"Keeping {popular_Zidx.sum()} Z feats w/ > {popular_Z_thresh[D_label]} influences")

    # Only look at X feats w/ nonzero D association
    for idx in np.where(ss_DXidx)[0]:
        print("\n" + "~"*10)
        print(f"X feat: {Xint[idx]}")
        print(f"\tCov w/ D={D_label}: {round(DXres_cov.squeeze()[idx], 5)}")

        # Which Z feats have st.sig. assn for each X feat idx? 
        ss_Zidx = (XZres_pvals[idx].squeeze() < XZres_thresh) 
        # Filter to keep only the most "popular" Z feats
        ss_popular_Zidx = ss_Zidx & popular_Zidx
        
        
        if verbose and sum(ss_Zidx) != sum(ss_popular_Zidx):
            print(f"Keeping only popular Z's {sum(ss_Zidx) } -> {sum(ss_popular_Zidx)} ss Z feats w/ this X")

        if sum(ss_popular_Zidx) == 0:
            print("ERROR: X FEAT HAS NO ASSD Z FEATS")
            rmX_zeroZ_dual[D_label].append(idx)
        else: 
            if verbose:
                print(f"Z feats w/ highest cov. to X feat ({sum(ss_popular_Zidx)} found):")
                sorted_XZ_idx = np.argsort(np.abs(XZres_cov)[idx])[::-1]
                sorted_XZ_idx = np.array([x for x in sorted_XZ_idx if ss_popular_Zidx[x]])[:2]
                for j in sorted_XZ_idx:
                    print(f'\t{Zint[j]}: {round(XZres_cov[idx, j], 5)}')
    print(f"Consider removing these X feats, which don't have any Z assns: {Xint[np.array(rmX_zeroZ_dual[D_label])]}")


In [None]:
plt.hist(XZres_cov.flatten(), label='before', alpha=.4, bins=50, density=True)
plt.hist(XZres_cov[XZres_pvals < XZres_thresh].flatten(),alpha=.4, bins=50, label='after', density=True)
plt.legend()
plt.title("XZ covariances before and after pvalue filtering (residual XZ)")
plt.show()

# OLD: Primal 

In [None]:
# X feats that don't really influence any Z feats
get_int_feats(X_feats[((XZres_pvals < XZres_thresh).sum(axis=1) < 3)])

In [None]:
# X feats that really influence any Z feats
get_int_feats(X_feats[((XZres_pvals < XZres_thresh).sum(axis=1) > 10)])

In [None]:
verbose = True
popular_X_thresh = 5
ssYZ_primal = {}
rmZ_zeroX_primal = {}
popssX_primal = {}
for Y_label in ['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']:
    ssYZ_primal[Y_label] = {}
    rmZ_zeroX_primal[Y_label] = {}
    popssX_primal[Y_label] = {}
    
    for D_label in ['Female', 'Obese','Black', 'Asian']:

        print("\n" + "#"*20)

        print(Y_label, D_label)
        Xres, Zres, Yres, Dres = load_res_data(D_label, Y_label)
        YZres_cov, YZres_pvals, YZres_thresh = get_cov(Yres, Zres, get_pvals=True)

        # We only care about Z feats with st.sig. assn with Y
        ss_YZidx = (YZres_pvals < YZres_thresh).squeeze()

        ssYZ_primal[Y_label][D_label] = ss_YZidx
        rmZ_zeroX_primal[Y_label][D_label] = []

        if verbose:
            print(f"Ignored {sum(~ss_YZidx)} Z feats due to low assn to Y={Y_label}")
#             for idx in np.where(~ss_Zidx)[0][:3]:
#                 print("\n" + "~"*10)
#                 print(f"Z feat: {Zint[idx]}")
#                 print(f"\tCov w/ Y={Y_label}: {round(YZres_cov.squeeze()[idx], 5)}")
#                 print(f"\tCov w/ D={D_label}: {round(DZres_cov.squeeze()[idx], 5)}")

        ZXres_cov, ZXres_pvals, ZXres_thresh = get_cov(Zres, Xres, get_pvals=True)

        # Filter X feats based on how many Z feats the X feat is associated w/
        popular_Xidx = ((ZXres_pvals < ZXres_thresh).sum(axis=0) > popular_X_thresh)
        popssX_primal[Y_label][D_label] = popular_Xidx.squeeze()

        if verbose:
            print(f"Keeping {popular_Xidx.sum()} X feats w/ > {popular_X_thresh} influences")

        # Only look at Z feats w/ nonzero Y association
        for idx in np.where(ss_YZidx)[0]:
#             print("\n" + "~"*10)
#             print(f"Z feat: {Zint[idx]}")
#             print(f"\tCov w/ Y={Y_label}: {round(YZres_cov.squeeze()[idx], 5)}")

            # Which X feats have st.sig. assn for each Z feat idx? 
            ss_Xidx = (ZXres_pvals[idx].squeeze() < ZXres_thresh) 
            # Filter to keep only the most "popular" X feats
            ss_popular_Xidx = ss_Xidx & popular_Xidx


#             if verbose and sum(ss_Xidx) != sum(ss_popular_Xidx):
#                 print(f"Keeping only popular X's {sum(ss_Xidx) } -> {sum(ss_popular_Xidx)} ss X feats w/ this Z")

            if sum(ss_popular_Xidx) == 0:
                print("ERROR: Z FEAT HAS NO ASSD X FEATS")
                rmZ_zeroX_primal[Y_label][D_label].append(idx)
#             else: 
#                 if sum(ss_popular_Xidx) < min_X_assns_per_Z:
#                     rmZ_lowX_primal[Y_label].append(idx)

#                 if verbose:
#                     print(f"X feats w/ highest cov. to Z feat ({sum(ss_popular_Xidx)} found):")
#                     sorted_ZX_idx = np.argsort(np.abs(ZXres_cov)[idx])[::-1]
#                     sorted_ZX_idx = np.array([x for x in sorted_ZX_idx if ss_popular_Xidx[x]])[:3]
#                     for j in sorted_ZX_idx:
#                         print(f'\t{Xint[j]}: {round(ZXres_cov[idx, j], 5)}')
        print(f"Consider removing these Z feats, which don't have any X assns: {Zint[np.array(rmZ_zeroX_primal[Y_label][D_label])]}")

In [None]:
# for D_label in ['Black', 'Female', 'Obese','Asian']:
#     print(D_label)
#     for Y_label in ['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']:
D_label = 'Black'
Y_label = 'deprs'
D, Y = load_DY_data(D_label, Y_label)
D, Y = D.reshape(-1, 1), Y.reshape(-1, 1)
YZ_cov, YZ_pvals, YZ_thresh = get_cov(Y, Z, get_pvals=True)
YZres_cov, YZres_pvals, YZres_thresh = get_cov(Yres, Zres, get_pvals=True)
(YZres_pvals < YZres_thresh).mean(), (YZ_pvals < YZ_thresh).mean()

In [None]:
DZ_cov, DZ_pvals, DZ_thresh = get_cov(D, Z, get_pvals=True)
DZres_cov, DZres_pvals, DZres_thresh = get_cov(Dres, Zres, get_pvals=True)
(DZres_pvals < DZres_thresh).mean(), (DZ_pvals < DZ_thresh).mean()

In [None]:
# Z feats with st.sig. assn with Y
ss_Zidx = (YZres_pvals < YZres_thresh).squeeze()
get_int_feats(Z_feats[~ss_Zidx]), get_int_feats(Z_feats[ss_Zidx])

In [None]:
for idx in np.where(~ss_Zidx)[0]:
    print("\n" + "#"*20)
    print(f"Z feat: {Zint[idx]}")
    print(f"\tCov w/ Y={Y_label}: {round(YZres_cov.squeeze()[idx], 5)}")
    print(f"\tCov w/ D={D_label}: {round(DZres_cov.squeeze()[idx], 5)}")


In [None]:
plt.imshow(XZres_pvals[:, ss_Zidx] < XZres_thresh)
plt.show()
plt.hist((XZres_pvals[:, ss_Zidx] < XZres_thresh).sum(axis=1), bins=20)
plt.title("X features # st.sig. correlation with each feat of X\nAfter removing Z feats w/ no Y assn")

In [None]:
plt.hist((XZres_pvals < XZres_thresh).sum(axis=1),bins=20)

In [None]:

keep_Xidx_thresh = 20
keep_Zidx_thresh = 10
zero_Z = []
low_Z = []
keep_Xidx = ((XZres_pvals < XZres_thresh).sum(axis=1) > keep_Xidx_thresh)
print(f"Only {keep_Xidx.sum()} X feats kept!")
for idx in np.where(ss_Zidx)[0]:
    print("\n" + "#"*20)
    print(f"Z feat: {Zint[idx]}")
    print(f"\tCov w/ Y={Y_label}: {round(YZres_cov.squeeze()[idx], 5)}")
    print(f"\tCov w/ D={D_label}: {round(DZres_cov.squeeze()[idx], 5)}")
    ss_Xidx = (XZres_pvals[:, idx].squeeze() < XZres_thresh) 
    ss_fltr_Xidx = ss_Xidx & keep_Xidx

    if sum(ss_Xidx) != sum(ss_fltr_Xidx):
        print(f"Rm noninfluential X's rd ss # from {sum(ss_Xidx) } -> {sum(ss_fltr_Xidx)}")

    if sum(ss_fltr_Xidx) == 0:
        print("ERROR: No ss X found!!!")
        zero_Z.append(idx)
    else: 
        if sum(ss_fltr_Xidx) < keep_Zidx_thresh:
            low_Z.append(idx)
        print(f"X feats w/ highest cov. to Z feat ({sum(ss_fltr_Xidx)} found):")
        sorted_XZ_idx = np.argsort(np.abs(XZres_cov)[:, idx])[::-1]
        sorted_XZ_idx = np.array([x for x in sorted_XZ_idx if ss_fltr_Xidx[x]])[:10]
        for j in sorted_XZ_idx:
            print(f'\t{Xint[j]}: {round(XZres_cov[j, idx], 5)}')
print(f"Consider removing these Z feats, which don't have any X assns: {Zint[np.array(zero_Z)]}")
print(f"\twhich have low X assns: {Zint[np.array(low_Z)]}")

In [None]:
plt.imshow(popssZ_dual[D_label])
(popssZ_dual[D_label].sum(axis=0) > 0).sum(), popular_Zidx.sum(), (popssZ_dual[D_label].sum(axis=1) == 0).sum()