In [1]:
%%capture output
!pip install --upgrade pip
!pip install --upgrade pandas
!pip install tables   
# necessary for pd.read_hdf()

!pip install ipywidgets
!pip install --upgrade jupyter
!pip install IProgress
!pip install catboost
!pip install shap
!pip install anndata

In [15]:
import os
import numpy as np
import pandas as pd
import shap

import anndata as ad

In [3]:
%matplotlib inline
from tqdm.notebook import tqdm
import pickle

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda") if torch.cuda.is_available() else 'cpu'
# pd.set_option('display.max_rows', 500)
# pd.set_option('display.max_columns', 500)

## data load

In [4]:
lrz_path = '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zoj/open-problems-multimodal-3rd-solution/'

# model_path_for_now = '/dss/dsshome1/02/di93zoj/valentina/open-problems-multimodal-3rd-solution/'

raw_path =  lrz_path + 'input/raw/'  # '../../../input/raw/'

cite_target_path = lrz_path + 'input/target/cite/'   # '../../../input/target/cite/'
cite_feature_path = lrz_path + 'input/features/cite/'   # '../../../input/features/cite/'
cite_mlp_path = lrz_path + 'model/cite/mlp/'   # '../../../model/cite/mlp/'   # '../../../model/cite/mlp/'
cite_cb_path = lrz_path + 'model/cite/cb/'   # '../../../model/cite/cb/'

# multi_target_path = lrz_path + 'input/target/multi/'   # '../../../input/target/multi/'
# multi_feature_path = lrz_path + 'input/features/multi/'   # '../../../input/features/multi/'
# multi_mlp_path = lrz_path + 'model/multi/mlp/'   # '../../../model/multi/mlp/'
# multi_cb_path = lrz_path + 'model/multi/cb/'   # '../../../model/multi/cb/'

index_path = lrz_path + 'input/preprocess/cite/'

output_path = lrz_path + 'output/'   # '../../../output/'

## Cite

In [5]:
# get model name
#mlp_model_path = os.listdir(cite_mlp_path)

markdown

In [None]:
# check model names and lists/dict/...

In [5]:
mlp_model_name = [
    'corr_add_con_imp',
    'corr_last_v3', 
    'corr_c_add_w2v_v1_mish_flg',
    'corr_c_add_w2v_v1_flg',
    'corr_c_add_84_v1',
    'corr_c_add_120_v1',
    'corr_w2v_cell_flg',
    'corr_best_cell_120',
    'corr_cluster_cell',
    'corr_w2v_128',
    'corr_imp_w2v_128',
    'corr_snorm',
    'corr_best_128',
    'corr_best_64',
    'corr_cluster_128',
    'corr_cluster_64',
    'corr_svd_128',
    'corr_svd_64',
             ]

In [6]:
model_name_list = []

for i in mlp_model_name:
    for num, j in enumerate(os.listdir(cite_mlp_path)):
        if i in j:
            model_name_list.append(j)

len(model_name_list)
model_name_list

['cite_mlp_corr_add_con_imp_flg_donor_val_50',
 'cite_mlp_corr_last_v3_flg_donor_val_55',
 'cite_mlp_corr_c_add_w2v_v1_mish_flg_donor_val_66',
 'cite_mlp_corr_c_add_w2v_v1_flg_donor_val_66',
 'cite_mlp_corr_c_add_84_v1_flg_donor_val_47',
 'cite_mlp_corr_c_add_120_v1_flg_donor_val_63',
 'cite_mlp_corr_w2v_cell_flg_donor_val_51',
 'cite_mlp_corr_best_cell_120_flg_donor_val_51',
 'cite_mlp_corr_cluster_cell_flg_donor_val_64',
 'cite_mlp_corr_w2v_128_flg_donor_val_42',
 'cite_mlp_corr_imp_w2v_128_flg_donor_val_38',
 'cite_mlp_corr_snorm_flg_donor_val_39',
 'cite_mlp_corr_best_128_flg_donor_val_45',
 'cite_mlp_corr_best_64_flg_donor_val_50',
 'cite_mlp_corr_cluster_128_flg_donor_val_51',
 'cite_mlp_corr_cluster_64_flg_donor_val_57',
 'cite_mlp_corr_svd_128_flg_donor_val_30',
 'cite_mlp_corr_svd_64_flg_donor_val_38']

In [7]:
weight = [1, 0.3, 1, 1, 1, 1, 1, 1, 1, 0.8, 0.8, 0.8, 0.8, 0.5, 0.5, 0.5, 1, 1, 2, 2]
weight_sum = np.array(weight).sum()
weight_sum

model_feat_dict = {model_name_list[0]:['X_test_add_con_imp.pickle', 1],
                   model_name_list[1]:['X_test_last_v3.pickle', 0.3],
                   model_name_list[2]:['X_test_c_add_w2v_v1.pickle', 1],
                   model_name_list[3]:['X_test_c_add_w2v_v1.pickle', 1],
                   model_name_list[4]:['X_test_c_add_84_v1.pickle', 1],
                   model_name_list[5]:['X_test_c_add_v1.pickle', 1],
                   
                   model_name_list[6]:['X_test_feature_w2v_cell.pickle', 1],
                   model_name_list[7]:['X_test_best_cell_128_120.pickle', 1],
                   model_name_list[8]:['X_test_cluster_cell_128.pickle', 1],
                   
                   model_name_list[9]:['X_test_feature_w2v.pickle', 0.8],
                   model_name_list[10]:['X_test_feature_imp_w2v.pickle',0.8],
                   model_name_list[11]:['X_test_feature_snorm.pickle', 0.8],
                   model_name_list[12]:['X_test_best_128.pickle', 0.8],
                   model_name_list[13]:['X_test_best_64.pickle', 0.5],
                   model_name_list[14]:['X_test_cluster_128.pickle', 0.5],
                   model_name_list[15]:['X_test_cluster_64.pickle', 0.5],
                   model_name_list[16]:['X_test_svd_128.pickle', 1],
                   model_name_list[17]:['X_test_svd_64.pickle', 1],
                   
                   'best_128':['X_test_best_128.pickle', 2],
                   'best_64':['X_test_best_64.pickle', 2],
                  }

### cite model

In [8]:
def std(x):
    x = np.array(x)
    return (x - x.mean(1).reshape(-1, 1)) / x.std(1).reshape(-1, 1)

In [9]:
class CiteDataset_test(Dataset):
    
    def __init__(self, feature):
        self.feature = feature
        
    def __len__(self):
        return len(self.feature)
    
    def __getitem__(self, index):
                
        d = {
            "X": self.feature[index]
        }
        return d

In [10]:
class CiteModel(nn.Module):
    
    def __init__(self, feature_num):
        super(CiteModel, self).__init__()
        
        self.layer_seq_256 = nn.Sequential(nn.Linear(feature_num, 256),
                                           nn.Linear(256, 128),
                                       nn.LayerNorm(128),
                                       nn.ReLU(),
                                      )
        self.layer_seq_64 = nn.Sequential(nn.Linear(128, 64),
                                       nn.Linear(64, 32),
                                       nn.LayerNorm(32),
                                       nn.ReLU(),
                                      )
        self.layer_seq_8 = nn.Sequential(nn.Linear(32, 16),
                                         nn.Linear(16, 8),
                                       nn.LayerNorm(8),
                                       nn.ReLU(),
                                      )
        
        self.head = nn.Linear(128 + 32 + 8, 140)
                   
    def forward(self, X, y=None):
        
        from_numpy = False
        
      ##
        if isinstance(X, np.ndarray):
            X = torch.from_numpy(X)
            from_numpy = True
        X = X.to(device)  # Move the input to the appropriate device if necessary
        ##
        X_256 = self.layer_seq_256(X)
        X_64 = self.layer_seq_64(X_256)
        X_8 = self.layer_seq_8(X_64)
        
        X = torch.cat([X_256, X_64, X_8], axis = 1)
        out = self.head(X)
        
        if from_numpy:
            out = out.cpu().detach().numpy()
            
        return out

In [11]:
class CiteModel_mish(nn.Module):
    
    def __init__(self, feature_num):
        super(CiteModel_mish, self).__init__()
        
        self.layer_seq_256 = nn.Sequential(nn.Linear(feature_num, 256),
                                           nn.Linear(256, 128),
                                       nn.LayerNorm(128),
                                       nn.Mish(),
                                      )
        self.layer_seq_64 = nn.Sequential(nn.Linear(128, 64),
                                       nn.Linear(64, 32),
                                       nn.LayerNorm(32),
                                       nn.Mish(),
                                      )
        self.layer_seq_8 = nn.Sequential(nn.Linear(32, 16),
                                         nn.Linear(16, 8),
                                       nn.LayerNorm(8),
                                       nn.Mish(),
                                      )
        
        self.head = nn.Linear(128 + 32 + 8, 140)
                   
    def forward(self, X, y=None):
    
        X_256 = self.layer_seq_256(X)
        X_64 = self.layer_seq_64(X_256)
        X_8 = self.layer_seq_8(X_64)
        
        X = torch.cat([X_256, X_64, X_8], axis = 1)
        out = self.head(X)
        
        return out

In [12]:
def test_loop(model, loader):
    
    model.eval()
    predicts=[]

    for d in tqdm(loader):
        with torch.no_grad():
            X = d['X'].to(device)
            logits = model(X)
            predicts.append(logits.detach().cpu().numpy())
            
    return np.concatenate(predicts)

### pred

In [16]:
# only need model, not whole prediction

# model #16: cite_mlp_corr_svd_128_flg_donor_val_30

model_name = 'cite_mlp_corr_svd_128_flg_donor_val_30'
        
test_file = model_feat_dict[model_name][0]
X_test = pd.read_pickle(cite_feature_path  + test_file)   # do we actually need this? Not for model but for feature_dims (212 in model #16 cases)
X_test = np.array(X_test)
feature_dims = X_test.shape[1]

# test_ds = CiteDataset_test(X_test)
# test_dataloader = DataLoader(test_ds, batch_size=128, pin_memory=True, 
#                               shuffle=False, drop_last=False, num_workers=4)

if 'mish' in i:
    model = CiteModel_mish(feature_dims)
else:
    model = CiteModel(feature_dims)
    
model = model.to(device)
model.load_state_dict(torch.load(f'{cite_mlp_path}/{model_name}'))

<All keys matched successfully>

### add cell_ids and cell_type to train and test data

In [30]:
# train_ids = np.load(index_path + "train_cite_raw_inputs_idxcol.npz", allow_pickle=True)
# test_ids = np.load(index_path + "test_cite_raw_inputs_idxcol.npz", allow_pickle=True)

train_index = np.load(index_path + "train_cite_raw_inputs_idxcol.npz", allow_pickle=True)["index"]
train_column = np.load(index_path + "train_cite_raw_inputs_idxcol.npz", allow_pickle=True)["columns"]
test_index = np.load(index_path + "test_cite_raw_inputs_idxcol.npz", allow_pickle=True)["index"]
print(len(list(train_index)))
print(len(list(test_index)))

70988
48203


In [110]:
metadata = pd.read_csv('/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zoj/neurips_competition_data/' + 'metadata.csv')
metadata_filtered = metadata[metadata['cell_id'].isin(test_index)]
metadata_filtered.index=metadata_filtered['cell_id']
metadata_filtered = metadata_filtered.drop('cell_id', axis=1)
metadata_filtered.shape   # somehow only 41187 matching cell_ids instead of 48203

(41187, 4)

In [112]:
X_train_cell_ids = pd.read_pickle(cite_feature_path  + 'X_svd_128.pickle')   # = X_svd_128 in make-features second to last cell
X_train_cell_ids = ad.AnnData(X=X_train_cell_ids)
X_train_cell_ids.obs_names = train_index
# X_train_cell_ids.to_df().head()

# cell type from metadata
X_test_cell_ids = pd.read_pickle(cite_feature_path  + test_file)   # X_test_svd_128.pickle
X_test_cell_ids = ad.AnnData(X=X_test_cell_ids)
X_test_cell_ids.obs_names = test_index
X_test_cell_ids = X_test_cell_ids[X_test_cell_ids.obs_names.isin(metadata_filtered.index), :]
X_test_cell_ids.obs = metadata_filtered
X_test_cell_ids

Transforming to str index.
Transforming to str index.


AnnData object with n_obs × n_vars = 41187 × 212
    obs: 'day', 'donor', 'cell_type', 'technology'

In [113]:
X_test_cell_ids.obs['cell_type'].value_counts()

cell_type
HSC     9451
MasP    9064
EryP    8788
NeuP    7719
MkP     4844
MoP     1215
BP       106
Name: count, dtype: int64

### create dataset of n samples per cell type for SHAP beeswarm plot

In [119]:
def get_samples(samples_per_cell_type, data):
    ''' data: AnnData with obs 'cell_type' '''
    
    unique_types = np.unique(data.obs['cell_type'])

    sample_obs = []

    # get indices of n samples per cell type
    for t in unique_types:
        sampled_rows = data[data.obs.cell_type == t].obs.sample(n=samples_per_cell_type, random_state=42)
        sample_obs.append(sampled_rows.index) 

    # select rows in data with matching indices
    X_test_shap = data[[elem for index_obj in sample_obs for elem in index_obj.tolist()]]

    return X_test_shap

X_test_shap = get_samples(5, X_test_cell_ids)

In [121]:
# rename imp_ columns to gene ids:
gene_ids = ['ENSG00000075340_ADD2', 'ENSG00000233968_AL157895.1',
        'ENSG00000029534_ANK1', 'ENSG00000135046_ANXA1',
        'ENSG00000130208_APOC1', 'ENSG00000047648_ARHGAP6',
        'ENSG00000101200_AVP', 'ENSG00000166710_B2M',
        'ENSG00000130303_BST2', 'ENSG00000172247_C1QTNF4',
        'ENSG00000170458_CD14', 'ENSG00000134061_CD180',
        'ENSG00000177455_CD19', 'ENSG00000116824_CD2',
        'ENSG00000206531_CD200R1L', 'ENSG00000012124_CD22',
        'ENSG00000272398_CD24', 'ENSG00000139193_CD27',
        'ENSG00000105383_CD33', 'ENSG00000174059_CD34',
        'ENSG00000135218_CD36', 'ENSG00000004468_CD38',
        'ENSG00000010610_CD4', 'ENSG00000026508_CD44',
        'ENSG00000117091_CD48', 'ENSG00000169442_CD52',
        'ENSG00000135404_CD63', 'ENSG00000173762_CD7',
        'ENSG00000137101_CD72', 'ENSG00000019582_CD74',
        'ENSG00000105369_CD79A', 'ENSG00000085117_CD82',
        'ENSG00000114013_CD86', 'ENSG00000010278_CD9',
        'ENSG00000002586_CD99', 'ENSG00000166091_CMTM5',
        'ENSG00000119865_CNRIP1', 'ENSG00000100368_CSF2RB',
        'ENSG00000100448_CTSG', 'ENSG00000051523_CYBA',
        'ENSG00000116675_DNAJC6', 'ENSG00000142227_EMP3',
        'ENSG00000143226_FCGR2A', 'ENSG00000167996_FTH1',
        'ENSG00000139278_GLIPR1', 'ENSG00000130755_GMFG',
        'ENSG00000169567_HINT1', 'ENSG00000206503_HLA-A',
        'ENSG00000234745_HLA-B', 'ENSG00000204287_HLA-DRA',
        'ENSG00000196126_HLA-DRB1', 'ENSG00000204592_HLA-E',
        'ENSG00000171476_HOPX', 'ENSG00000076662_ICAM3',
        'ENSG00000163565_IFI16', 'ENSG00000142089_IFITM3',
        'ENSG00000160593_JAML', 'ENSG00000055118_KCNH2',
        'ENSG00000105610_KLF1', 'ENSG00000139187_KLRG1',
        'ENSG00000133816_MICAL2', 'ENSG00000198938_MT-CO3',
        'ENSG00000107130_NCS1', 'ENSG00000090470_PDCD7',
        'ENSG00000143627_PKLR', 'ENSG00000109099_PMP22',
        'ENSG00000117450_PRDX1', 'ENSG00000112077_RHAG',
        'ENSG00000108107_RPL28', 'ENSG00000198918_RPL39',
        'ENSG00000145425_RPS3A', 'ENSG00000198034_RPS4X',
        'ENSG00000196154_S100A4', 'ENSG00000197956_S100A6',
        'ENSG00000188404_SELL', 'ENSG00000124570_SERPINB6',
        'ENSG00000235169_SMIM1', 'ENSG00000095932_SMIM24',
        'ENSG00000137642_SORL1', 'ENSG00000128040_SPINK2',
        'ENSG00000072274_TFRC', 'ENSG00000205542_TMSB4X',
        'ENSG00000133112_TPT1', 'ENSG00000026025_VIM']

In [127]:
# X_test_shap currently has base_svd and imp columns -> rename imp (handselected genes)
X_test_shap.var_names = X_test_shap.var_names[:-84].tolist() + gene_ids
X_test_shap.to_df().head()

Unnamed: 0_level_0,base_svd_0,base_svd_1,base_svd_2,base_svd_3,base_svd_4,base_svd_5,base_svd_6,base_svd_7,base_svd_8,base_svd_9,...,ENSG00000188404_SELL,ENSG00000124570_SERPINB6,ENSG00000235169_SMIM1,ENSG00000095932_SMIM24,ENSG00000137642_SORL1,ENSG00000128040_SPINK2,ENSG00000072274_TFRC,ENSG00000205542_TMSB4X,ENSG00000133112_TPT1,ENSG00000026025_VIM
cell_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
5cb9daaca7ac,71.003983,13.2008,-6.168917,12.436371,-0.578827,0.227373,1.382157,0.969637,-1.477851,6.248293,...,0.0,0.0,0.0,0.0,0.0,0.947536,0.0,5.313975,4.545222,3.069516
57b5bce6a192,86.716537,5.783796,15.985765,4.867427,-5.404188,-4.447663,-6.89554,-5.07482,6.342696,2.834425,...,0.0,0.62413,0.62413,1.005488,0.0,0.0,0.62413,5.135746,4.22824,3.97024
11e72851054b,70.254166,15.571994,-5.599965,5.991638,8.331441,-4.175835,-1.716176,-2.858795,-2.759818,7.807173,...,0.881844,0.881844,0.0,0.881844,0.0,0.0,2.620177,4.532575,4.919058,1.657475
dd7e07a59b0c,65.834938,17.32021,-10.891003,8.597704,5.161943,-5.098966,-2.595727,-1.467793,-3.764605,4.425799,...,0.0,1.276917,0.0,1.819883,0.0,1.276917,1.276917,3.964831,5.36137,1.276917
02a718f16011,84.209236,11.387871,10.027047,13.640661,-6.556067,-1.784269,1.632261,-4.823346,-3.590804,0.736319,...,0.0,0.0,0.0,0.458628,0.0,0.0,2.001605,4.49367,4.538202,2.38806


In [131]:
X_test_shap.write("X_test_shap_16.h5ad")

### prepare data for SHAP plot 
### => shap.KernelExplainer, explainer.shap_values, shap.summary_plot(shap_values)

In [128]:
# X_train for model #16: 'X_svd_128.pickle'
X_train = pd.read_pickle(cite_feature_path  + 'X_svd_128.pickle')
X_train = np.array(X_train)
print('X_train: ', X_train.shape)
print('X_test: ', X_test.shape)

explainer = shap.KernelExplainer(model, shap.sample(X_train, 1000))
explainer

X_train:  (70988, 212)
X_test:  (48203, 212)


Using 1000 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.


<shap.explainers._kernel.Kernel at 0x7efc11aa3640>

In [130]:
# xtest = X_test_shap.to_df()#.drop(['cell_id', 'cell_type'], axis=1)

In [47]:
# features: genes and svd -> omnipath: genes
# model: mostly relying on genes or svd? -> later

In [132]:
# don't need to run again: np.load('shap_values.npy', allow_pickle=True)
%timeit
shap_values = explainer.shap_values(X_test_shap.to_df(), nsamples=300)  #500? 
print(len(shap_values)) # -> 140 genes
print(len(shap_values[0])) # -> number of samples in xtest
print(shap_values[0].shape)

# TODO rename files once double checked that everything works after restructuring
np.save('shap_values_16_restructured.npy', np.array(shap_values, dtype=object), allow_pickle=True)

  0%|          | 0/35 [00:00<?, ?it/s]

140
35
(35, 212)


In [144]:
shap_values[0]

array([[ 0.        ,  0.17936803, -0.0179155 , ...,  0.        ,
         0.01797059,  0.        ],
       [-0.03197865,  0.052598  ,  0.20977261, ...,  0.        ,
         0.        ,  0.        ],
       [ 0.        ,  0.22075023,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       ...,
       [ 0.03531018,  0.13639785,  0.        , ...,  0.        ,
         0.00061229,  0.        ],
       [ 0.00919145,  0.0482322 , -0.05356169, ...,  0.        ,
         0.        ,  0.        ],
       [ 0.00214199,  0.09497041, -0.03869506, ...,  0.00678889,
         0.        ,  0.00769504]])