In [1]:
import sys, os
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import numpy as np
import pandas as pd
import scanpy as sc

We load the single-cell data using scanpy. The single-cell data is stored in a special data structure called AnnData (short: adata).

In [2]:
adata = sc.read('L008_rna_counts.h5ad')
adata

AnnData object with n_obs × n_vars = 316138 × 36604
    obs: 'condition', 'target_gene'

The counts are stored in adata.X


In [3]:
adata.X

<316138x36604 sparse matrix of type '<class 'numpy.float32'>'
	with 1217701057 stored elements in Compressed Sparse Row format>

adata.obs is a dataframe containing annotation for each cell, such as e.g. batch, cell type, perturbation. Or other technical annotations at the cell level.

In [4]:
adata.obs

Unnamed: 0,condition,target_gene
0,"2, Donor#2 Cas9+ IFNg+",ZNF331
1,"2, Donor#2 Cas9+ IFNg+",MAPK11
2,"2, Donor#2 Cas9+ IFNg+",OAS3
3,"2, Donor#2 Cas9+ IFNg+",IL7R
4,"2, Donor#2 Cas9+ IFNg+",SOX4
...,...,...
316133,"3, Donor#3 Cas9+ restim",
316134,"3, Donor#3 Cas9+ restim",
316135,"3, Donor#3 Cas9+ restim",
316136,"3, Donor#3 Cas9+ restim",


In [5]:
adata = adata[adata.obs['target_gene'] != 'NA']

In [6]:
adata.obs['donor'] = adata.obs['condition'].apply(lambda x: x.split(' ')[1])
adata.obs['crispr'] = adata.obs['condition'].apply(lambda x: x.split(' ')[2])
adata.obs['cell_type'] = adata.obs['condition'].apply(lambda x: x.split(' ')[3])

  adata.obs['donor'] = adata.obs['condition'].apply(lambda x: x.split(' ')[1])


adata.var is a dataframe containing annotation for each gene, usually some statistics such as dispersion, or gene names, pathways etc.

In [7]:
adata.var

MIR1302-2HG
FAM138A
OR4F5
AL627309.1
AL627309.3
...
AC007325.4
AC007325.2
GFP
RFP670
REF-polyA


## Quality control
Check the quality of the data and see if it's necessary to remove some cells.

In [8]:
adata.obs['n_counts'] = np.ravel(adata.X.sum(1)) #number of counts in the cell
adata.obs['n_genes'] = np.ravel(np.sum(adata.X > 0, axis=1)) #number of genes with at least 1 count per cell
adata.var['mito'] = adata.var_names.str.contains("MT-") #flag for mitochondrial genes
adata.obs['mt_frac'] = np.ravel(adata.X[:, adata.var.mito].sum(1)) / adata.obs['n_counts'].values #fraction of mitochondrial gene exp, high values mean dead or bad quality cells

In [9]:
# filtering
adata = adata[adata.obs['n_counts'] > 500]
adata = adata[adata.obs['n_genes'] > 750]
adata = adata[adata.obs['mt_frac'] < 0.2]
adata

View of AnnData object with n_obs × n_vars = 139564 × 36604
    obs: 'condition', 'target_gene', 'donor', 'crispr', 'cell_type', 'n_counts', 'n_genes', 'mt_frac'
    var: 'mito'

## Normalization

In [10]:
adata.X.max() #check it's an int, to make sure it's count data and not preprocessed data

3709.0

In [11]:
sc.pp.normalize_total(adata)
#adata.layers['counts'] = adata.X.copy()
sc.pp.log1p(adata)

  view_to_actual(adata)


## Feature (gene) selection
We select only the top N most variable genes.

In [12]:
sc.pp.filter_genes(adata, min_cells=1)
sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor='cell_ranger')

In [13]:
gene_targeted = [n in adata.obs['target_gene'].cat.categories for n in adata.var_names]
adata = adata[:, [a or b for a, b in zip(gene_targeted, adata.var.highly_variable)]]

In [14]:
adata

View of AnnData object with n_obs × n_vars = 139564 × 2048
    obs: 'condition', 'target_gene', 'donor', 'crispr', 'cell_type', 'n_counts', 'n_genes', 'mt_frac'
    var: 'mito', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'log1p', 'hvg'

## Out-of-distribution selection

In [15]:
results = []
for cond1 in adata.obs.target_gene.unique():
    ad1 = adata[adata.obs.target_gene == cond1]
    ad2 = adata[adata.obs.target_gene != cond1]
    mean1 = ad1.X.mean(0)
    mean2 = ad2.X.mean(0)
    l2 = np.linalg.norm(mean1-mean2)
    results.append({
        'cond1': cond1,
        'L2': l2
    })
df_vs_rest = pd.DataFrame(results)

Pick biggest signals

In [16]:
df_vs_rest.sort_values(by='L2').tail(20)

Unnamed: 0,cond1,L2
77,SRSF7,3.169078
79,IL2RG,3.196164
56,MX1,3.21841
51,RAC2,3.303963
36,NR4A2,3.317241
71,SARAF,3.317464
7,F8,3.447231
70,CD247,3.4644
64,IRF9,3.514952
63,THY1,3.525789


In [17]:
KO_OOD = df_vs_rest.sort_values(by='L2').tail(20).cond1.values
KO_OOD

array(['SRSF7', 'IL2RG', 'MX1', 'RAC2', 'NR4A2', 'SARAF', 'F8', 'CD247',
       'IRF9', 'THY1', 'OAS3', 'JUNB', 'STAG3', 'PDCD1', 'ENTPD1',
       'FURIN', 'MAPK11', 'CD44', 'SLC2A3', 'ZNF331'], dtype=object)

## Prepare for the model

In [18]:
adata.uns['fields'] = {}

  self.data[key] = value


In [19]:
adata.obs['target_gene'] = adata.obs['target_gene'].astype(str)
adata.uns['fields']['perturbation'] = 'target_gene'

In [20]:
adata.obs['control'] = [1 if x == 'Cas9-' else 0 for x in adata.obs['crispr'].values]
adata.obs.loc[adata.obs['control'] == 1, 'target_gene'] = 'ctrl'
adata.uns['fields']['control'] = 'control'

In [21]:
adata.obs['dose'] = 1.0
adata.uns['fields']['dose'] = 'dose'

In [22]:
adata.uns['fields']['covariates'] = ['cell_type', 'donor']

In [23]:
del adata.obs['condition']
del adata.uns['log1p']

In [24]:
# split dataset

from sklearn.model_selection import train_test_split

adata.obs['split'] = 'NA'
adata.uns['fields']['split'] = 'split'

adata.obs.loc[
    (adata.obs['cell_type'] == 'restim') & (adata.obs['target_gene'].isin(KO_OOD)),
    'split'
] = 'ood'

idx = np.where(adata.obs['split']=='NA')[0]
idx_train, idx_test = train_test_split(idx, test_size=0.2, random_state=42)

adata.obs.iloc[idx_train, adata.obs.columns.get_loc('split')] = 'train'
adata.obs.iloc[idx_test, adata.obs.columns.get_loc('split')] = 'test'

Rank DE genes (optional)

In [25]:
# this will be done in main script if it's not done here

cov_names = []
for cov in adata.uns['fields']['covariates']:
    cov_names.append(np.array(adata.obs[cov].values))
cov_names = ["_".join(c) for c in zip(*cov_names)]
adata.obs["cov_name"] = cov_names

cov_pert_names = []
for i in range(len(adata)):
    comb_name = (
        f"{adata.obs['cov_name'].values[i]}"
        f"_{adata.obs[adata.uns['fields']['perturbation']].values[i]}"
    )
    cov_pert_names.append(comb_name)
adata.obs["cov_pert_name"] = cov_pert_names

import warnings

from vci.utils.data_utils import rank_genes_groups

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    rank_genes_groups(adata,
        groupby="cov_pert_name",
        reference="cov_name",
        control_key="control"
    )

  from .autonotebook import tqdm as notebook_tqdm


In [26]:
adata.obs

Unnamed: 0,target_gene,donor,crispr,cell_type,n_counts,n_genes,mt_frac,control,dose,split,cov_name,cov_pert_name
0,ZNF331,Donor#2,Cas9+,IFNg+,4055.0,1881,0.123551,0,1.0,test,IFNg+_Donor#2,IFNg+_Donor#2_ZNF331
1,MAPK11,Donor#2,Cas9+,IFNg+,1997.0,1073,0.019529,0,1.0,train,IFNg+_Donor#2,IFNg+_Donor#2_MAPK11
2,OAS3,Donor#2,Cas9+,IFNg+,11044.0,3155,0.045817,0,1.0,train,IFNg+_Donor#2,IFNg+_Donor#2_OAS3
3,IL7R,Donor#2,Cas9+,IFNg+,38119.0,6660,0.040295,0,1.0,train,IFNg+_Donor#2,IFNg+_Donor#2_IL7R
4,SOX4,Donor#2,Cas9+,IFNg+,2585.0,1561,0.042553,0,1.0,test,IFNg+_Donor#2,IFNg+_Donor#2_SOX4
...,...,...,...,...,...,...,...,...,...,...,...,...
316126,TOX,Donor#3,Cas9+,restim,19578.0,4508,0.069670,0,1.0,test,restim_Donor#3,restim_Donor#3_TOX
316127,SLC2A3,Donor#3,Cas9+,restim,3733.0,1689,0.091347,0,1.0,ood,restim_Donor#3,restim_Donor#3_SLC2A3
316128,BTG2,Donor#3,Cas9+,restim,32049.0,5401,0.074043,0,1.0,train,restim_Donor#3,restim_Donor#3_BTG2
316131,NTC,Donor#3,Cas9+,restim,30929.0,5601,0.070258,0,1.0,train,restim_Donor#3,restim_Donor#3_NTC


In [27]:
adata.write('L008_prepped.h5ad')