This notebook is primarily to make sure your environment installed properly. If anything in here does not run, certainly nothing else will.

In [1]:
1

1

In [2]:
import torch
import warnings
import scvi
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
sc.set_figure_params(dpi=100, frameon=False, color_map='Reds', facecolor=None)
sc.logging.print_header()
assert(scvi.__version__=='0.16.3')

Global seed set to 0


scanpy==1.9.1 anndata==0.8.0 umap==0.5.3 numpy==1.21.5 scipy==1.9.1 pandas==1.4.4 scikit-learn==1.1.2 statsmodels==0.13.2 pynndescent==0.5.7


In [3]:
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']

In [4]:
import lime
from lime import lime_tabular

In [5]:
base_path = '/home/icb/yuge.ji/projects/feature-attribution-sc'  # should be changed to shared dir when I can find one

In [6]:
torch.cuda.is_available()

True

## scANVI (HLCA)

In [7]:
hlca_path = '/lustre/groups/ml01/workspace/hlca_lisa.sikkema_malte.luecken/HLCA_reproducibility/data/HLCA_core_h5ads/HLCA_v1_integration/HLCA_v1_scANVI_input.h5ad'
adata = sc.read(hlca_path)


In [8]:
model = scvi.model.SCANVI.load('/home/icb/yuge.ji/projects/HLCA_reproducibility/notebooks/3_atlas_extension/scanvi_model/', adata)

[34mINFO    [0m File [35m/home/icb/yuge.ji/projects/HLCA_reproducibility/notebooks/3_atlas_extension/sca[0m
         [35mnvi_model/[0m[95mmodel.pt[0m already downloaded                                               




### Calculate

In [9]:
num_samples = 5000

In [10]:
cell_count = adata.n_obs
indices = np.random.choice(adata.n_obs, size=cell_count, replace=False)
scdl = model._make_data_loader(adata=adata, indices=indices, batch_size=cell_count)
batch = next(scdl.__iter__())
x = batch["X"]
x.requires_grad = True
batch_labels = batch["batch"]

In [11]:
pred = model.module.classify(x, batch_index=batch_labels)

In [12]:
len(np.unique(batch["labels"]))

29

In [13]:
explainer = lime.lime_tabular.LimeTabularExplainer(x.to('cpu').detach().numpy(),  
                                                   mode='classification',
                                                   training_labels=np.array(batch["labels"]),
                                                   feature_names=np.array(adata.var["gene_symbols"])
                                                  )

In [14]:
# Let's try whether the function/method works for a random row

In [15]:
cell_count = 1
indices = np.random.choice(adata.n_obs, size=cell_count, replace=False)
scdl = model._make_data_loader(adata=adata, indices=indices, batch_size=cell_count)
batch = next(scdl.__iter__())
x = batch["X"]
x.requires_grad = True
batch_labels = batch["batch"]

In [16]:
def predictor(bo):
    if isinstance(bo, np.ndarray):
        bo = torch.Tensor(bo).to("cuda:0")
    else:
        raise ValueError    
    a = np.ones((bo.shape[0], 1)) * batch_labels.to('cpu').detach().numpy()[0][0]
    batch_index = torch.Tensor(a).to(bo.device)
    result = model.module.classify(bo, batch_index=batch_index)
    return result.to('cpu').detach().numpy()

predictor(x.to('cpu').detach().numpy())

array([[3.8930426e-13, 1.1191378e-11, 4.7614289e-13, 6.3525785e-13,
        1.0000000e+00, 4.1044598e-13, 3.9094650e-13, 6.8458471e-13,
        8.7400112e-15, 6.5282182e-13, 9.3185507e-13, 7.7077472e-13,
        3.7356641e-13, 9.3867073e-14, 4.1554092e-13, 5.0942377e-13,
        3.8201262e-13, 3.6396182e-13, 2.2587525e-13, 3.9645671e-13,
        3.5576262e-13, 3.9430818e-13, 4.7298865e-16, 2.6510597e-13,
        3.5919087e-13, 4.0717264e-13, 1.6444196e-14, 3.8440708e-13]],
      dtype=float32)

In [17]:
exp = explainer.explain_instance(data_row=x.to('cpu').detach().numpy()[0], 
                           predict_fn=predictor,
                           num_features = 2000,
                           num_samples=num_samples)

In [18]:
## Run experiments

In [19]:
number_of_cells_in_a_class = 50

In [20]:
dff = adata.obs._scvi_labels
indice_dict = dict()
df = pd.DataFrame()

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for i in np.unique(dff):
        adatata = adata[dff==i]
        cell_count = number_of_cells_in_a_class if len(adatata) >= number_of_cells_in_a_class else len(adatata)
        indices = np.random.choice(adatata.n_obs, size=cell_count, replace=False)

        for ind, indice in enumerate(indices):
            print(ind, end=',')
            scdl = model._make_data_loader(adata=adatata, indices=np.array([indice]), batch_size=cell_count)
            batch = next(scdl.__iter__())
            x = batch["X"]
            x.requires_grad = True
            batch_labels = batch["batch"]

            def predictor(bo):
                if isinstance(bo, np.ndarray):
                    bo = torch.Tensor(bo).to("cuda:0")
                else:
                    raise ValueError    
                a = np.ones((bo.shape[0], 1)) * batch_labels.to('cpu').detach().numpy()[0][0]
                batch_index = torch.Tensor(a).to(bo.device)
                result = model.module.classify(bo, batch_index=batch_index)
                return result.to('cpu').detach().numpy()

            exp = explainer.explain_instance(data_row=x.to('cpu').detach().numpy()[0], 
                               predict_fn=predictor,
                               num_features = 2000,
                               num_samples=num_samples)

            df_temp = pd.DataFrame(exp.as_list(), columns=['features', 'values'])
            df_temp['cell_type_label'] = i
            df_temp['indice'] = indice
            df = pd.concat([df, df_temp])

0,[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                  
1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,0,[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                  
1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,0,[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                  
1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,0,[34mINFO    [0m AnnData object appears to be a copy. Attempting to transfer setup.                  
1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,0,[34mINFO    [0m AnnData 

In [21]:
df.to_pickle("/lustre/groups/ml01/workspace/kemal.inecik/feature_attribution/df.pk")

### Prepare

In [102]:
df = pd.read_pickle("/lustre/groups/ml01/workspace/kemal.inecik/feature_attribution/df.pk")

In [103]:
def isfloat(x):
    try:
        float(x)
        return True
    except ValueError:
        return False
def issign(x):
    return "<" in x or ">" in x or "=" in x

In [107]:
las = adata.obs[["scanvi_label", "_scvi_labels"]].drop_duplicates()
las.set_index('_scvi_labels', inplace=True, verify_integrity=True)
conversion_dict = las.to_dict()['scanvi_label']

In [104]:
feat = []
for i in df['features']:
    isplit = i.split()
    res = [j for j in isplit if not isfloat(j) and not issign(j)]
    assert len(res) == 1
    feat.append(res[0])

In [105]:
df["features"] = feat
len(np.unique(df["features"])) == adata.n_vars

True

In [132]:
df1 = df.groupby(['features', 'cell_type_label'], as_index=False).mean()
del df1['indice']
df1.sort_values(by=["features", "cell_type_label"], inplace=True)
pt1 = pd.pivot_table(df1, index=['features'], columns=['cell_type_label'])
pt1 = pd.DataFrame(np.array(pt1), index=pt.index, columns=[i[1] for i in pt1.columns])
pt1.rename(columns=conversion_dict, inplace=True)

In [142]:
df_temp = df.copy()
df_temp["values"] = df_temp["values"].abs()
df2 = df_temp.groupby(['features', 'cell_type_label'], as_index=False).mean()
del df2['indice']
df2.sort_values(by=["features", "cell_type_label"], inplace=True)
pt2 = pd.pivot_table(df2, index=['features'], columns=['cell_type_label'])
pt2 = pd.DataFrame(np.array(pt2), index=pt.index, columns=[i[1] for i in pt2.columns])
pt2.rename(columns=conversion_dict, inplace=True)

In [151]:
pt1.to_csv("/lustre/groups/ml01/workspace/kemal.inecik/feature_attribution/lime.csv")
pt2.to_csv("/lustre/groups/ml01/workspace/kemal.inecik/feature_attribution/lime_absolute.csv")

In [152]:
pt1

Unnamed: 0_level_0,AT1,AT2,Arterial EC,B cell lineage,Basal,Bronchial Vessel 1,Bronchial Vessel 2,Capillary,Ciliated,Dendritic cells,...,Non-T/B cells,Proliferating cells,Rare,Secretory,Smooth Muscle,Squamous,Submucosal Secretory,T cell lineage,Venous,unlabeled
features,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
A1BG,0.001779,0.002617,0.004959,-0.001271,0.002954,0.007373,0.002989,-0.003094,-0.003829,-0.001330,...,0.004007,0.003400,0.004137,-0.000164,0.002330,0.001171,0.004033,-0.002133,0.002923,0.002780
A2M,-0.021970,-0.013727,0.014143,-0.012696,-0.023292,0.011786,0.013547,0.008324,-0.020707,-0.007230,...,-0.015676,-0.012459,-0.023263,-0.021737,0.014450,-0.012662,-0.019499,-0.021884,0.010869,-0.011714
ABCA1,-0.002339,0.000736,0.001854,0.000277,0.000767,0.004644,-0.001174,0.004156,0.003801,-0.000620,...,-0.000763,0.000186,0.010743,0.003737,0.001815,0.008389,0.001305,0.002435,-0.002428,0.002997
ABCA3,-0.009904,0.006089,-0.016655,-0.015035,-0.024026,-0.008725,-0.014750,-0.009145,-0.017162,-0.010948,...,-0.010206,-0.010856,-0.019492,-0.018550,-0.013265,-0.017285,-0.019883,-0.010900,-0.016944,-0.015394
ABCA6,-0.017477,-0.002396,-0.005593,-0.014316,-0.005585,-0.007866,-0.010361,-0.005096,-0.017290,-0.015864,...,-0.019770,-0.010024,-0.005901,-0.011114,-0.008757,-0.016117,-0.006705,-0.005875,-0.008419,-0.008569
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZNF267,0.001891,0.000300,0.007168,0.001514,-0.002904,0.002917,0.005295,0.001639,0.003599,0.000671,...,-0.000079,0.001202,0.007166,0.002902,0.001657,0.000205,0.006255,-0.000978,0.002031,0.002961
ZNF331,-0.000394,0.000981,0.000741,-0.002825,0.000740,0.001139,-0.001152,-0.002960,0.000276,0.003097,...,0.000200,0.001403,-0.000335,0.000537,0.003227,0.002996,0.000020,-0.001755,-0.002710,-0.001847
ZNF683,0.000863,-0.004453,0.012403,0.012902,0.008936,0.003110,-0.005724,0.003230,-0.006797,0.012642,...,-0.000188,-0.002333,0.007184,0.002266,0.009754,-0.003493,-0.005571,0.007981,0.002499,-0.001810
ZNF80,0.000966,0.006352,0.005982,-0.003240,-0.003573,0.011932,0.011614,0.003823,-0.022872,-0.001759,...,-0.000797,0.000740,0.017529,-0.007500,-0.022535,-0.029523,0.029265,-0.023081,-0.041336,0.013932


In [153]:
pt1.columns

Index(['AT1', 'AT2', 'Arterial EC', 'B cell lineage', 'Basal',
       'Bronchial Vessel 1', 'Bronchial Vessel 2', 'Capillary', 'Ciliated',
       'Dendritic cells', 'Fibroblast lineage', 'KRT5- KRT17+ epithelial',
       'Lymphatic EC', 'Macrophages', 'Mast cells', 'Megakaryocytes',
       'Mesothelium', 'Monocytes', 'Neutrophilic', 'Non-T/B cells',
       'Proliferating cells', 'Rare', 'Secretory', 'Smooth Muscle', 'Squamous',
       'Submucosal Secretory', 'T cell lineage', 'Venous', 'unlabeled'],
      dtype='object')