In [1]:
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report

Load every available dataset.

In [3]:
%%time
adatas = []
for i, row in pd.read_csv('personal.csv').iterrows():
    try:
        adata = sc.read(f'datasets/{row.Author}_{row.Year}.h5ad')
        adata.obs['dataset'] = f'{row.Author}_{row.Year}'
        adatas.append(adata)
    except FileNotFoundError:
        pass

CPU times: user 1min 37s, sys: 7.85 s, total: 1min 45s
Wall time: 1min 47s


In [4]:
%%time
adata = ad.concat(adatas)

CPU times: user 22.5 s, sys: 1min 14s, total: 1min 36s
Wall time: 1min 37s


Take the guide RNAs that exist across multiple datasets.

In [38]:
df = pd.crosstab(adata.obs.perturbation_name, adata.obs.dataset).T
df = df.T[np.count_nonzero(df, axis=0) > 1]
df

dataset,Dixit_2016,Frangieh_2021,Norman_2019,Srivatsan_2019
perturbation_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
EGR1,6181,0,321,0
IRF1,3837,0,427,0
NCL,0,121,237,0
SET,0,415,985,0


In [14]:
subset = adata[adata.obs.perturbation_name.isin(df.index)]
subset

View of AnnData object with n_obs × n_vars = 12524 × 12738
    obs: 'n_genes', 'leiden', 'perturbation_name', 'perturbation_type', 'perturbation_value', 'perturbation_unit', 'dataset'
    obsm: 'X_pca', 'X_umap'

Generate random train-test split.

In [27]:
test_idx = sc.pp.subsample(subset, .2, copy=True).obs.index
train_idx = list(set(subset.obs.index) - set(test_idx))

test = subset[test_idx]
train = subset[train_idx]

Train a model.

In [28]:
clf = LogisticRegression()
clf.fit(train.X, train.obs.perturbation_name.values)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression


LogisticRegression()

In [30]:
clf.score(train.X, train.obs.perturbation_name.values)

0.9612774451097804

In [29]:
clf.score(test.X, test.obs.perturbation_name.values)

0.5938498402555911

In [33]:
print(classification_report(test.obs.perturbation_name.values, clf.predict(test.X)))

              precision    recall  f1-score   support

        EGR1       0.63      0.64      0.64      1329
        IRF1       0.45      0.45      0.45       846
         NCL       0.59      0.36      0.45        75
         SET       0.84      0.92      0.88       254

    accuracy                           0.59      2504
   macro avg       0.63      0.59      0.60      2504
weighted avg       0.59      0.59      0.59      2504



Let's try leaving out one Norman et al. 2019 and only learn the treatments from the other two datasets.

In [34]:
test = subset[subset.obs.dataset == 'Norman_2019']
train = subset[subset.obs.dataset != 'Norman_2019']

In [35]:
clf = LogisticRegression()
clf.fit(train.X, train.obs.perturbation_name.values)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression


LogisticRegression()

In [36]:
print(classification_report(test.obs.perturbation_name.values, clf.predict(test.X)))

              precision    recall  f1-score   support

        EGR1       0.14      0.46      0.22       321
        IRF1       0.17      0.37      0.23       427
         NCL       0.00      0.00      0.00       237
         SET       0.00      0.00      0.00       985

    accuracy                           0.16      1970
   macro avg       0.08      0.21      0.11      1970
weighted avg       0.06      0.16      0.09      1970



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Looks like basicaly information can transfer from Dixit to Norman without batch correction, but not from Frangieh.