# Testing PCA for Causal Inference
I need to get a baseline for what causal inference I can do using genotype PCs. So to start I'll just run a small simulation under three scenarios

1. Null case (no association between variables)
1. Full causal mediation
1. Independent association of genotype PC with mediator and response

First I'll just import some of the code I've written to generate datasets and run inference, then I'll construct the PCs and run a causal inference test on varying numbers of PCs for each set.

In [21]:
import cit
import joblib 
import data_model as dm
import multiprocessing as mp
import numpy as np
import scipy as sp
import pandas as pd

In [25]:
import warnings
warnings.filterwarnings(action='once')
def to_df(results):
    '''Merge result dictionary to a pd.DataFrame'''
    out_rows = []
    for res in results:
        cur_row = {}
        for j in range(1,5):
            cur_test = 'test{}'.format(j)
            cur_p = 'p{}'.format(j)
            for key in res[cur_test]:
                if key in ['rss', 'r2']: #single value for test
                    cur_key = '{}_{}'.format(cur_test,key)
                    cur_row[cur_key] = res[cur_test][key]
                else:
                    for k in range(len(res[cur_test][key])):
                        cur_key = '{}_{}{}'.format(cur_test,key,k)
                        cur_row[cur_key] = res[cur_test][key][k]
            cur_row[cur_p] = res[cur_p]
        cur_row['omni_p'] = res['omni_p']
        out_rows.append(cur_row)
    return pd.DataFrame(out_rows)

## Null Case

In [23]:
NUM_SAMPLES = 100
null_data = [dm.generate_null(n=500) for i in range(NUM_SAMPLES)]
SVDs = [np.linalg.svd(genotype, full_matrices=False, compute_uv=True) for (_,_,genotype) in null_data]
PCs = [U@np.diag(D) for (U,D,Vh) in SVDs]

In [28]:
import importlib as imp
imp.reload(cit)
null_results = []
for i in range(1,6):
    args = [(trait,gene_exp, Z[:, 0:i], 100) for ((trait, gene_exp, _), Z) in zip(null_data, PCs)]
    with mp.Pool(processes=4) as pool:
        %time null_result = pool.starmap(cit.cit, args)
    null_results.append(null_result)

(array([62.45983416]), 1879065.2317911393, 1879065.2317911389, array([61.30359258]), array([1.01886091])) (array([11.73429944, 99.46183278]), 49418.47117749621, 1879065.2317911389, array([12.97441424, 16.34622547]), array([0.90441844, 0.71785988, 7.66599793, 6.08469723])) 18437.723083601035
(array([-0.2041362 ,  0.10006083]), 4.623143275211819, 18818.141596747635, array([0.13727581, 0.00156855]), array([  -1.48705145, -130.14343755,    0.72890353,   63.79201721])) (array([-0.1965866 ,  0.0993324 ,  0.07440783]), 4.596213110667739, 18818.141596747635, array([0.16870846, 0.00964396, 0.97207368]), array([ -1.16524446, -20.38442603,  -0.20223426,   0.58878135,
        10.29995879,   0.10218608,   0.44104383,   7.71548419,
         0.07654546])) 2.9120259344248804
(array([2.05496466, 9.99146585]), 461.63898519824994, 1879065.2317911389, array([1.34903809, 0.15662565]), array([ 1.52328142, 13.12023109,  7.40636304, 63.79201721])) (array([2.0701728 , 9.97320131, 0.18669177]), 461.470378974890

KeyboardInterrupt: 

In [None]:
import pickle
with open('3_5_2019_PCA_null.dat','wb') as f:
    pickle.dump(PCs, f)
    pickle.dump(null_data, f)
    pickle.dump(null_results,f)

In [27]:
test = to_df(null_results)
test.filter(regex = 'p')

Unnamed: 0,omni_p,p1,p2,p3,p4,test1_p0,test1_p1,test2_p0,test2_p1,test2_p2,test3_p0,test3_p1,test3_p2,test4_p0,test4_p1,test4_p2
0,1.0,1.110223e-16,0.010995,0.709518,1.0,0.191287,3.811862e-11,0.854053,0.0,0.454447,0.116443,0.0,0.506666,0.116443,0.0,0.506666
1,1.0,1.110223e-16,0.088546,0.670197,1.0,0.183105,1.164912e-09,0.877761,0.0,0.469508,0.093392,0.0,0.492379,0.093392,0.0,0.492379
2,1.0,1.110223e-16,0.331121,0.184564,1.0,0.173182,8.445725e-10,0.868151,0.0,0.482606,0.099658,0.0,0.476249,0.099658,0.0,0.476249
3,1.0,1.110223e-16,0.918301,0.032265,1.0,0.175917,9.554985e-10,0.853707,0.0,0.498164,0.114547,0.0,0.461656,0.114547,0.0,0.461656
4,1.0,1.110223e-16,0.005545,0.557278,1.0,0.170181,5.372109e-10,0.852141,0.0,0.450304,0.116436,0.0,0.510503,0.116436,0.0,0.510503
5,1.0,1.110223e-16,0.115837,0.537427,1.0,0.167269,1.420746e-09,0.859509,0.0,0.471849,0.108959,0.0,0.488963,0.108959,0.0,0.488963
6,1.0,1.110223e-16,0.004645,0.517948,1.0,0.186132,5.313568e-09,0.846459,0.0,0.449278,0.123206,0.0,0.51157,0.123206,0.0,0.51157
7,1.0,1.110223e-16,0.047777,0.909885,1.0,0.175768,2.497311e-09,0.860579,0.0,0.464556,0.109838,0.0,0.497975,0.109838,0.0,0.497975
8,1.0,1.110223e-16,0.01126,0.943617,1.0,0.163724,4.054744e-10,0.855244,0.0,0.454596,0.109167,0.0,0.501266,0.109167,0.0,0.501266
9,1.0,1.110223e-16,0.016817,0.89908,1.0,0.172542,2.193259e-09,0.840494,0.0,0.457178,0.126928,0.0,0.50227,0.126928,0.0,0.50227
