- get feature attributions for all genes instead of svd components
- load in normalized svd components svd_comp_norm
- get_attr_all_features() creates attr_all_22085_genes

In [1]:
import numpy as np
import pandas as pd
import os
import pickle

import anndata as ad

In [2]:
os.chdir('../..')

In [3]:
lrz_path = '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zoj/'

path_for_attr_all_genes = lrz_path + 'large_preprocessed_files/attr_all_genes/'

### svd contributions:

Load in svd components from TruncatedSVD fitted in 2.preprocess_to_feature/cite/make_base_feature.py

In [4]:
# normalized svd components
svd_comp_norm = np.loadtxt('2.preprocess_to_feature/cite/svd_comp_norm.txt', delimiter=',')
print(svd_comp_norm.shape)

(128, 22001)


In [5]:
# shap plots: base_svd_2 important feature
svd_comp_norm[2]

array([-1.80979350e-05, -1.49000311e-06,  3.09174447e-06, ...,
        1.26464845e-04,  2.65770446e-04,  1.75062174e-04])

=> This means that the contribution can be computed as follows: \
contribution x_2 = -0.00001809 * geneA - 0.00000149 * geneB + 0.0000030917 * geneC + ... + 0.0001264 * geneX + 0.00026577 * geneY + 0.000175 * geneZ

### column names:

In [6]:
# all_22001_genes: 22001 features, handselected_gene_ids: 84 handselected genes
all_22001_genes = np.loadtxt('2.preprocess_to_feature/cite/all_22001_genes_names.txt', dtype=str)
handselected_gene_ids = np.loadtxt('2.preprocess_to_feature/cite/handselected_84_gene_ids.txt', dtype=str)

### propagate shap values back through svd to get original genes -> store in attr_all_22085_genes
#### currently only first class (shap_values[0]) -> TODO consider all 140 classes

steps:\
multiply SHAP(svd_n) * contribution of gene A to component n -> then sum \
each dot in summary_plot is attribution for one cell -> loop over all cells

dimensions:\
212 features: 128 svd and 84 genes
cells: number_of_samples_per_cell_type * 7  (there are 7 unique cell types)
predicted "classes": 140

contribution of gene A to component n: svd_comp_norm
SHAP(svd_n) for the 128 svd (=first 128 columns)

In [7]:
def get_attr_all_features(xtest, shap_values, svd_comp_norm): # TODO how to save attr_all_22085_genes? Huge files!

    # backpropagation of svd feature attributions
    attr_genes_only = np.zeros((shap_values.shape[0], len(xtest), 22001))  # Initialize the output array, 140x350x22001  # 350==50*7==samples_cell_type * 7

    for pred in range(shap_values.shape[0]):
        for cell in range(len(xtest)):
            attr_genes_only[pred, cell] = np.sum(shap_values[pred, cell, :128, None] * svd_comp_norm[:128], axis=0)
    
    # combine attributions of handselected and other genes

    cols = list(all_22001_genes)+list(handselected_gene_ids)
    
    # attr_all_22085_genes = np.zeros((shap_values.shape[0], len(xtest), 22085))    # dict of 350x22085 dataframes with column names -> 140 entries = classes 
    attr_all_22085_genes = {}
    
    for classes in range(shap_values.shape[0]):
        # hstack:  first 22001 genes are backpropagated through svd, the other 84 genes are handselected and are considered separately -> stack left and right acc. to cols
        attr_all_22085_genes[classes] = pd.DataFrame(np.hstack((attr_genes_only[classes], shap_values[classes][:,-84:])), columns=cols)  # dict of dataframes with column names
        # attr_all_22085_genes[classes] = np.hstack((attr_genes_only[classes], shap_values[classes][:,-84:]))    # 3D array without column names, not much smaller
        
    return attr_all_22085_genes

Get attr_all_22085_genes for data corresponding to model #16.\
First, use 5 samples per cell type:

In [13]:
# testing on 5 samples:
shap_values = np.load('4.model/pred/shap_values_16_restructured.npy', allow_pickle=True).astype(float)

xtest = ad.read_h5ad('4.model/pred/X_test_shap_16_5_samples.h5ad')

attr_all_22085_genes = get_attr_all_features(xtest, shap_values, svd_comp_norm)   # first 22001 columns are sorted alphabetically, then next 84 are sorted
attr_all_22085_genes[0].head(2)

Unnamed: 0,ENSG00000121410_A1BG,ENSG00000268895_A1BG-AS1,ENSG00000175899_A2M,ENSG00000245105_A2M-AS1,ENSG00000166535_A2ML1,ENSG00000128274_A4GALT,ENSG00000094914_AAAS,ENSG00000081760_AACS,ENSG00000109576_AADAT,ENSG00000103591_AAGAB,...,ENSG00000188404_SELL,ENSG00000124570_SERPINB6,ENSG00000235169_SMIM1,ENSG00000095932_SMIM24,ENSG00000137642_SORL1,ENSG00000128040_SPINK2,ENSG00000072274_TFRC,ENSG00000205542_TMSB4X,ENSG00000133112_TPT1,ENSG00000026025_VIM
0,-0.0002805793,-0.00035,0.001712,-0.000411,-6e-05,-0.001241,0.001953,-0.006214,-0.000883,-0.000264,...,-0.005195,0.0,0.0,0.0,-0.022825,0.0,0.024688,0.0,0.017971,0.0
1,4.067557e-07,-2.7e-05,0.000132,-4.8e-05,9e-06,0.000188,-0.000248,0.001903,-0.000227,-0.000268,...,0.0,0.0,0.0,-0.017512,0.004803,0.0,0.0,0.0,0.0,0.0


In [14]:
# with open('4.model/pred/attr_all_22085_genes_16_5_samples.pkl', 'wb') as f:
with open(path_for_attr_all_genes + 'attr_all_22085_genes_16_5_samples.pkl', 'wb') as f:
    pickle.dump(attr_all_22085_genes, f)

same for 50 samples:

In [15]:
# shap_values = np.load('4.model/pred/shap_values_16_50_samples.npy', allow_pickle=True).astype(float)
shap_values = np.load('4.model/pred/shap_values_16_50_samples_restructured.npy', allow_pickle=True).astype(float)

xtest = ad.read_h5ad('4.model/pred/X_test_shap_16_50_samples.h5ad')

attr_all_22085_genes = get_attr_all_features(xtest, shap_values, svd_comp_norm)     # first 22001 columns are sorted alphabetically, then next 84 are sorted
attr_all_22085_genes[0].head(2)

Unnamed: 0,ENSG00000121410_A1BG,ENSG00000268895_A1BG-AS1,ENSG00000175899_A2M,ENSG00000245105_A2M-AS1,ENSG00000166535_A2ML1,ENSG00000128274_A4GALT,ENSG00000094914_AAAS,ENSG00000081760_AACS,ENSG00000109576_AADAT,ENSG00000103591_AAGAB,...,ENSG00000188404_SELL,ENSG00000124570_SERPINB6,ENSG00000235169_SMIM1,ENSG00000095932_SMIM24,ENSG00000137642_SORL1,ENSG00000128040_SPINK2,ENSG00000072274_TFRC,ENSG00000205542_TMSB4X,ENSG00000133112_TPT1,ENSG00000026025_VIM
0,-3.9e-05,-0.000108,0.001539,0.000235,-1.5e-05,-0.001444,0.000596,-0.005306,-0.000497,0.000329,...,-0.01017,-0.002227,0.0,0.0,-0.013633,0.006514,0.020438,0.0,0.0,0.0
1,6.4e-05,0.000346,-0.000896,-0.000224,1.6e-05,0.000908,-0.001883,0.005648,9.6e-05,-4.1e-05,...,0.0,0.0,0.0,0.0,0.0,0.020803,0.037623,0.0,0.0,0.0


In [16]:
with open(path_for_attr_all_genes + 'attr_all_22085_genes_16_50_samples.pkl', 'wb') as f:
    pickle.dump(attr_all_22085_genes, f)

Get attr_all_22085_genes for data corresponding to model #17.

In [9]:
# shap_values = np.load('4.model/pred/shap_values_16_50_samples.npy', allow_pickle=True).astype(float)
shap_values = np.load('4.model/pred/shap_values_17_50_samples_restructured.npy', allow_pickle=True).astype(float)

xtest = ad.read_h5ad('4.model/pred/X_test_shap_17_50_samples.h5ad')

attr_all_22085_genes = get_attr_all_features(xtest, shap_values, svd_comp_norm)     # first 22001 columns are sorted alphabetically, then next 84 are sorted
attr_all_22085_genes[0].head(2)

Unnamed: 0,ENSG00000121410_A1BG,ENSG00000268895_A1BG-AS1,ENSG00000175899_A2M,ENSG00000245105_A2M-AS1,ENSG00000166535_A2ML1,ENSG00000128274_A4GALT,ENSG00000094914_AAAS,ENSG00000081760_AACS,ENSG00000109576_AADAT,ENSG00000103591_AAGAB,...,ENSG00000188404_SELL,ENSG00000124570_SERPINB6,ENSG00000235169_SMIM1,ENSG00000095932_SMIM24,ENSG00000137642_SORL1,ENSG00000128040_SPINK2,ENSG00000072274_TFRC,ENSG00000205542_TMSB4X,ENSG00000133112_TPT1,ENSG00000026025_VIM
0,0.000224,0.000277,0.000304,0.001039,2.1e-05,-0.001013,-0.004067,-0.003513,-0.000175,0.000285,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,-2.1e-05,0.000105,-8e-06,-8.6e-05,-8e-06,-0.000221,0.000356,-6e-06,-0.00015,-0.000447,...,0.0,0.006332,0.0,0.002259,-0.001898,0.0,0.0,0.0,0.0,0.0


In [10]:
with open(path_for_attr_all_genes + 'attr_all_22085_genes_17_50_samples.pkl', 'wb') as f:
    pickle.dump(attr_all_22085_genes, f)

### same for shap values from private data

model #16:

In [9]:
shap_values = np.load('4.model/pred/shap_values_16_50_samples_p.npy', allow_pickle=True).astype(float)

xtest = ad.read_h5ad('4.model/pred/private_test_input_128_svd_50_samples.h5ad')

attr_all_22085_genes = get_attr_all_features(xtest, shap_values, svd_comp_norm)     # first 22001 columns are sorted alphabetically, then next 84 are sorted
attr_all_22085_genes[0].head(2)

Unnamed: 0,ENSG00000121410_A1BG,ENSG00000268895_A1BG-AS1,ENSG00000175899_A2M,ENSG00000245105_A2M-AS1,ENSG00000166535_A2ML1,ENSG00000128274_A4GALT,ENSG00000094914_AAAS,ENSG00000081760_AACS,ENSG00000109576_AADAT,ENSG00000103591_AAGAB,...,ENSG00000188404_SELL,ENSG00000124570_SERPINB6,ENSG00000235169_SMIM1,ENSG00000095932_SMIM24,ENSG00000137642_SORL1,ENSG00000128040_SPINK2,ENSG00000072274_TFRC,ENSG00000205542_TMSB4X,ENSG00000133112_TPT1,ENSG00000026025_VIM
0,0.000298,0.000451,-0.000898,0.000341,3.2e-05,0.001201,-0.002951,0.00755,0.000339,-9e-05,...,0.0,0.0,-0.000284,0.0,0.0,0.009688,0.0,0.0,-0.010834,0.0
1,-1.7e-05,0.00016,-0.00039,-0.000453,-5e-06,0.000746,0.000328,0.00182,-0.000185,0.000238,...,0.0,0.0012,0.0,0.0,-0.00411,0.01252,0.0,0.0,0.0,0.0


In [10]:
with open(path_for_attr_all_genes + 'attr_all_22085_genes_16_50_samples_p.pkl', 'wb') as f:
    pickle.dump(attr_all_22085_genes, f)

model #17:

In [8]:
shap_values = np.load('4.model/pred/shap_values_17_50_samples_p.npy', allow_pickle=True).astype(float)

xtest = ad.read_h5ad('4.model/pred/private_test_input_64_svd_50_samples.h5ad')

attr_all_22085_genes = get_attr_all_features(xtest, shap_values, svd_comp_norm)     # first 22001 columns are sorted alphabetically, then next 84 are sorted
attr_all_22085_genes[0].head(2)

Unnamed: 0,ENSG00000121410_A1BG,ENSG00000268895_A1BG-AS1,ENSG00000175899_A2M,ENSG00000245105_A2M-AS1,ENSG00000166535_A2ML1,ENSG00000128274_A4GALT,ENSG00000094914_AAAS,ENSG00000081760_AACS,ENSG00000109576_AADAT,ENSG00000103591_AAGAB,...,ENSG00000188404_SELL,ENSG00000124570_SERPINB6,ENSG00000235169_SMIM1,ENSG00000095932_SMIM24,ENSG00000137642_SORL1,ENSG00000128040_SPINK2,ENSG00000072274_TFRC,ENSG00000205542_TMSB4X,ENSG00000133112_TPT1,ENSG00000026025_VIM
0,9.2e-05,0.000246,0.000361,0.000903,2.4e-05,-0.000299,-0.002838,-0.001823,0.000231,0.00014,...,0.0,0.007116,0.0,0.0,0.0,0.000979,0.0,0.0,0.0,0.0
1,-7.4e-05,0.000388,-0.00024,0.001097,6.2e-05,-0.001001,0.005089,0.006846,0.001264,0.000882,...,0.001284,0.0,0.0,-0.022162,-0.019837,0.0,0.0,-0.010808,0.0,0.0


In [9]:
with open(path_for_attr_all_genes + 'attr_all_22085_genes_17_50_samples_p.pkl', 'wb') as f:
    pickle.dump(attr_all_22085_genes, f)