- get feature attributions for all genes instead of svd components
- load in normalized svd components svd_comp_norm
- get_attr_all_features() creates attr_all_22085_genes

In [1]:
import numpy as np
import pandas as pd
import os
import pickle

import anndata as ad

In [2]:
os.chdir('../..')

In [3]:
lrz_path = '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zoj/'

path_for_attr_all_genes = lrz_path + 'large_preprocessed_files/attr_all_genes/'

### svd contributions:

Load in svd components from TruncatedSVD fitted in 2.preprocess_to_feature/cite/make_base_feature.py

In [4]:
# normalized svd components
svd_comp_norm = np.loadtxt('2.preprocess_to_feature/cite/svd_comp_norm.txt', delimiter=',')
print(svd_comp_norm.shape)

(128, 22001)


In [5]:
# shap plots: base_svd_2 important feature
svd_comp_norm[2]

array([-1.80979350e-05, -1.49000311e-06,  3.09174447e-06, ...,
        1.26464845e-04,  2.65770446e-04,  1.75062174e-04])

=> This means that the contribution can be computed as follows: \
contribution x_2 = -0.00001809 * geneA - 0.00000149 * geneB + 0.0000030917 * geneC + ... + 0.0001264 * geneX + 0.00026577 * geneY + 0.000175 * geneZ

### column names:

In [6]:
# all_22001_genes: 22001 features, handselected_gene_ids: 84 handselected genes
all_22001_genes = np.loadtxt('2.preprocess_to_feature/cite/all_22001_genes_names.txt', dtype=str)
handselected_gene_ids = np.loadtxt('2.preprocess_to_feature/cite/handselected_84_gene_ids.txt', dtype=str)

### propagate shap values back through svd to get original genes -> store in attr_all_22085_genes
#### currently only first class (shap_values[0]) -> TODO consider all 140 classes

steps:\
multiply SHAP(svd_n) * contribution of gene A to component n -> then sum \
each dot in summary_plot is attribution for one cell -> loop over all cells

dimensions:\
212 features: 128 svd and 84 genes
cells: number_of_samples_per_cell_type * 7  (there are 7 unique cell types)
predicted "classes": 140

contribution of gene A to component n: svd_comp_norm
SHAP(svd_n) for the 128 svd (=first 128 columns)

In [7]:
def get_attr_all_features(xtest, shap_values, svd_comp_norm): # TODO how to save attr_all_22085_genes? Huge files!

    # backpropagation of svd feature attributions
    attr_genes_only = np.zeros((shap_values.shape[0], len(xtest), 22001))  # Initialize the output array, 140x350x22001  # 350==50*7==samples_cell_type * 7

    for pred in range(shap_values.shape[0]):
        for cell in range(len(xtest)):
            attr_genes_only[pred, cell] = np.sum(shap_values[pred, cell, :128, None] * svd_comp_norm[:128], axis=0)
    
    # combine attributions of handselected and other genes

    cols = list(all_22001_genes)+list(handselected_gene_ids)
    
    # attr_all_22085_genes = np.zeros((shap_values.shape[0], len(xtest), 22085))    # dict of 350x22085 dataframes with column names -> 140 entries = classes 
    attr_all_22085_genes = {}
    
    for classes in range(shap_values.shape[0]):
        # hstack:  first 22001 genes are backpropagated through svd, the other 84 genes are handselected and are considered separately -> stack left and right acc. to cols
        attr_all_22085_genes[classes] = pd.DataFrame(np.hstack((attr_genes_only[classes], shap_values[classes][:,-84:])), columns=cols)  # dict of dataframes with column names
        # attr_all_22085_genes[classes] = np.hstack((attr_genes_only[classes], shap_values[classes][:,-84:]))    # 3D array without column names, not much smaller
        
    return attr_all_22085_genes

Get attr_all_22085_genes for data corresponding to model #16.\
First, use 5 samples per cell type:

In [8]:
# # testing on 5 samples:
# shap_values = np.load('4.model/pred/shap_values_16_restructured.npy', allow_pickle=True).astype(float)

# xtest = ad.read_h5ad('4.model/pred/X_test_shap_16_5_samples.h5ad')

# attr_all_22085_genes = get_attr_all_features(xtest, shap_values, svd_comp_norm)   # first 22001 columns are sorted alphabetically, then next 84 are sorted
# attr_all_22085_genes[0].head(2)

In [9]:
# # with open('4.model/pred/attr_all_22085_genes_16_5_samples.pkl', 'wb') as f:
# with open(path_for_attr_all_genes + 'attr_all_22085_genes_16_5_samples.pkl', 'wb') as f:
#     pickle.dump(attr_all_22085_genes, f)

same for 50 samples:

In [10]:
# shap_values = np.load('4.model/pred/shap_values_16_50_samples.npy', allow_pickle=True).astype(float)
shap_values = np.load('4.model/pred/shap_values_16_50_samples_med.npy', allow_pickle=True).astype(float)

xtest = ad.read_h5ad('4.model/pred/X_test_shap_16_50_samples.h5ad')

# attr_all_22085_genes = get_attr_all_features(xtest, shap_values, svd_comp_norm)     # first 22001 columns are sorted alphabetically, then next 84 are sorted
# attr_all_22085_genes[0].head(2)

In [11]:
# with open(path_for_attr_all_genes + 'attr_all_22085_genes_16_50_samples_med.pkl', 'wb') as f:
#     pickle.dump(attr_all_22085_genes, f)

Get attr_all_22085_genes for data corresponding to model #17.

In [12]:
# shap_values = np.load('4.model/pred/shap_values_16_50_samples.npy', allow_pickle=True).astype(float)
shap_values = np.load('4.model/pred/shap_values_17_50_samples_med.npy', allow_pickle=True).astype(float)

xtest = ad.read_h5ad('4.model/pred/X_test_shap_17_50_samples.h5ad')

attr_all_22085_genes = get_attr_all_features(xtest, shap_values, svd_comp_norm)     # first 22001 columns are sorted alphabetically, then next 84 are sorted
attr_all_22085_genes[0].head(2)

Unnamed: 0,ENSG00000121410_A1BG,ENSG00000268895_A1BG-AS1,ENSG00000175899_A2M,ENSG00000245105_A2M-AS1,ENSG00000166535_A2ML1,ENSG00000128274_A4GALT,ENSG00000094914_AAAS,ENSG00000081760_AACS,ENSG00000109576_AADAT,ENSG00000103591_AAGAB,...,ENSG00000188404_SELL,ENSG00000124570_SERPINB6,ENSG00000235169_SMIM1,ENSG00000095932_SMIM24,ENSG00000137642_SORL1,ENSG00000128040_SPINK2,ENSG00000072274_TFRC,ENSG00000205542_TMSB4X,ENSG00000133112_TPT1,ENSG00000026025_VIM
0,0.000413,0.000556,0.000846,0.0011,4.6e-05,0.000108,-0.006612,-0.001009,0.000792,0.000163,...,0.0,0.0,0.0,0.02081,-0.02712,0.0,0.0,0.0,0.0,0.045797
1,1.7e-05,-3.8e-05,0.000252,2.3e-05,-2e-06,-0.000174,-0.000361,-0.00032,-0.000231,0.000216,...,0.0,0.010838,0.0,0.0,0.0,0.026258,0.0,0.0,0.0,-0.013764


In [13]:
with open(path_for_attr_all_genes + 'attr_all_22085_genes_17_50_samples_med.pkl', 'wb') as f:
    pickle.dump(attr_all_22085_genes, f)

### same for shap values from private data

model #16:

In [8]:
shap_values = np.load('4.model/pred/shap_values_16_50_samples_p_ct_distr.npy', allow_pickle=True).astype(float)

xtest = ad.read_h5ad('4.model/pred/private_test_input_128_svd_50_samples.h5ad')

attr_all_22085_genes = get_attr_all_features(xtest, shap_values, svd_comp_norm)     # first 22001 columns are sorted alphabetically, then next 84 are sorted
attr_all_22085_genes[0].head(2)

Unnamed: 0,ENSG00000121410_A1BG,ENSG00000268895_A1BG-AS1,ENSG00000175899_A2M,ENSG00000245105_A2M-AS1,ENSG00000166535_A2ML1,ENSG00000128274_A4GALT,ENSG00000094914_AAAS,ENSG00000081760_AACS,ENSG00000109576_AADAT,ENSG00000103591_AAGAB,...,ENSG00000188404_SELL,ENSG00000124570_SERPINB6,ENSG00000235169_SMIM1,ENSG00000095932_SMIM24,ENSG00000137642_SORL1,ENSG00000128040_SPINK2,ENSG00000072274_TFRC,ENSG00000205542_TMSB4X,ENSG00000133112_TPT1,ENSG00000026025_VIM
0,5.6e-05,-0.000113,0.000543,0.000505,3e-06,-6.4e-05,-0.002281,0.000939,-5.8e-05,0.000303,...,-0.018111,0.0,0.0,0.024404,0.0,0.0,-0.032505,0.008612,0.0,-0.019831
1,6.3e-05,0.000227,-0.000275,0.000597,2.3e-05,-0.000329,0.001407,0.002708,0.00028,0.000883,...,-0.009935,0.004501,0.0,0.012049,0.0,0.0,0.000894,0.0,0.0,0.002347


In [10]:
with open(path_for_attr_all_genes + 'attr_all_22085_genes_16_50_samples_p_ct_distr.pkl', 'wb') as f:
    pickle.dump(attr_all_22085_genes, f)

model #17:

In [11]:
shap_values = np.load('4.model/pred/shap_values_17_50_samples_p_ct_distr.npy', allow_pickle=True).astype(float)

xtest = ad.read_h5ad('4.model/pred/private_test_input_64_svd_50_samples.h5ad')

attr_all_22085_genes = get_attr_all_features(xtest, shap_values, svd_comp_norm)     # first 22001 columns are sorted alphabetically, then next 84 are sorted
attr_all_22085_genes[0].head(2)

Unnamed: 0,ENSG00000121410_A1BG,ENSG00000268895_A1BG-AS1,ENSG00000175899_A2M,ENSG00000245105_A2M-AS1,ENSG00000166535_A2ML1,ENSG00000128274_A4GALT,ENSG00000094914_AAAS,ENSG00000081760_AACS,ENSG00000109576_AADAT,ENSG00000103591_AAGAB,...,ENSG00000188404_SELL,ENSG00000124570_SERPINB6,ENSG00000235169_SMIM1,ENSG00000095932_SMIM24,ENSG00000137642_SORL1,ENSG00000128040_SPINK2,ENSG00000072274_TFRC,ENSG00000205542_TMSB4X,ENSG00000133112_TPT1,ENSG00000026025_VIM
0,0.000136,0.000431,0.000498,0.000939,7e-06,-0.000766,-0.003348,-4.8e-05,-0.00022,-0.000737,...,-0.006766,-0.020778,0.0,0.0,0.0,0.0,-0.011902,0.0,-0.024068,0.0
1,-0.000138,0.000377,0.001151,0.001888,6.6e-05,-0.002372,0.006538,0.001002,0.0011,0.002942,...,0.0,-0.014653,0.0,0.011124,-0.026157,0.012642,-0.013345,-0.009018,0.0,0.0


In [12]:
with open(path_for_attr_all_genes + 'attr_all_22085_genes_17_50_samples_p_ct_distr.pkl', 'wb') as f:
    pickle.dump(attr_all_22085_genes, f)