- get feature attributions for all genes instead of svd components
- load in normalized svd components svd_comp_norm
- get_attr_all_features() creates attr_all_22085_genes

In [1]:
import numpy as np
import pandas as pd
import os
import pickle

import anndata as ad

In [2]:
os.chdir('../..')

### svd contributions:

In [3]:
# normalized svd components
svd_comp_norm = np.loadtxt('2.preprocess_to_feature/cite/svd_comp_norm.txt', delimiter=',')
print(svd_comp_norm.shape)

(128, 22001)


In [4]:
# shap plots: base_svd_2 important feature
svd_comp_norm[2]

array([-1.80979350e-05, -1.49000311e-06,  3.09174447e-06, ...,
        1.26464845e-04,  2.65770446e-04,  1.75062174e-04])

=> This means that the contribution can be computed as follows: \
contribution x_2 = -0.00001809 * geneA - 0.00000149 * geneB + 0.0000030917 * geneC + ... + 0.0001264 * geneX + 0.00026577 * geneY + 0.000175 * geneZ

### column names:

In [5]:
# all_22001_genes: 22001 features, handselected_gene_ids: 84 handselected genes
all_22001_genes = np.loadtxt('2.preprocess_to_feature/cite/all_22001_genes_names.txt', dtype=str)
handselected_gene_ids = np.loadtxt('2.preprocess_to_feature/cite/handselected_84_gene_ids.txt', dtype=str)

### propagate shap values back through svd to get original genes -> store in attr_all_22085_genes
#### currently only first class (shap_values[0]) -> TODO consider all 140 classes

steps:\
multiply SHAP(svd_n) * contribution of gene A to component n -> then sum \
each dot in summary_plot is attribution for one cell -> loop over all cells

dimensions:\
212 features: 128 svd and 84 genes
cells: number_of_samples_per_cell_type * 7  (there are 7 unique cell types)
predicted "classes": 140

contribution of gene A to component n: svd_comp_norm
SHAP(svd_n) for the 128 svd (=first 128 columns)

In [6]:
def get_attr_all_features(xtest, shap_values, svd_comp_norm):

    # backpropagation of svd feature attributions
    attr_genes_only = np.zeros((shap_values.shape[0], len(xtest), 22001))  # Initialize the output array, 140x350x22001  # 350==50*7==samples_cell_type * 7

    for pred in range(shap_values.shape[0]):
        for cell in range(len(xtest)):
            attr_genes_only[pred, cell] = np.sum(shap_values[pred, cell, :128, None] * svd_comp_norm[:128], axis=0)
    
    # combine attributions of handselected and other genes
    attr_all_22085_genes = np.hstack((attr_genes_only[0], shap_values[0][:,-84:]))    # first 22001 genes are backpropagated through svd, the other 84 genes are handselected and are considered separately
    print(attr_all_22085_genes.shape)
    return pd.DataFrame(attr_all_22085_genes, columns=list(all_22001_genes)+list(handselected_gene_ids))

In [7]:
# shap_values = np.load('4.model/pred/shap_values_16_50_samples.npy', allow_pickle=True).astype(float)
shap_values = np.load('4.model/pred/shap_values_16_restructured.npy', allow_pickle=True).astype(float)

xtest = ad.read_h5ad('4.model/pred/X_test_shap_16.h5ad')

attr_all_22085_genes = get_attr_all_features(xtest, shap_values, svd_comp_norm)     # TODO in original dataset all 22085 columns are sorted alphabetically
                                                                                    # here: first 22001 are sorted, then next 84 are sorted... change order afterwards?
attr_all_22085_genes.head(2)

(35, 22085)


Unnamed: 0,ENSG00000121410_A1BG,ENSG00000268895_A1BG-AS1,ENSG00000175899_A2M,ENSG00000245105_A2M-AS1,ENSG00000166535_A2ML1,ENSG00000128274_A4GALT,ENSG00000094914_AAAS,ENSG00000081760_AACS,ENSG00000109576_AADAT,ENSG00000103591_AAGAB,...,ENSG00000188404_SELL,ENSG00000124570_SERPINB6,ENSG00000235169_SMIM1,ENSG00000095932_SMIM24,ENSG00000137642_SORL1,ENSG00000128040_SPINK2,ENSG00000072274_TFRC,ENSG00000205542_TMSB4X,ENSG00000133112_TPT1,ENSG00000026025_VIM
0,-0.0002805793,-0.00035,0.001712,-0.000411,-6e-05,-0.001241,0.001953,-0.006214,-0.000883,-0.000264,...,-0.005195,0.0,0.0,0.0,-0.022825,0.0,0.024688,0.0,0.017971,0.0
1,4.067557e-07,-2.7e-05,0.000132,-4.8e-05,9e-06,0.000188,-0.000248,0.001903,-0.000227,-0.000268,...,0.0,0.0,0.0,-0.017512,0.004803,0.0,0.0,0.0,0.0,0.0


In [8]:
with open('4.model/pred/attr_all_22085_genes_16_5_samples.pkl', 'wb') as f:
    pickle.dump(attr_all_22085_genes, f)

In [None]:
# TODO same for shap_values_17