In [18]:
import scanpy as sc
import numpy as np
import torch
from omnicell.config.config import Config
from pathlib import Path
import yaml
from omnicell.models.VAE.predictor import VAEPredictor
from omnicell.models.nearest_neighbor.predictor import NearestNeighborPredictor
from omnicell.data.preprocessing import preprocess
from omnicell.constants import *

import logging

In [19]:
logger = logging.getLogger(__name__)    

logging.basicConfig(filename= 'output.log', filemode= 'w', level='INFO', format='%(asctime)s - %(levelname)s - %(message)s')
logger.info("Application started")


In [21]:
task_config_path = '/orcd/archive/abugoot/001/Projects/opitcho/sandbox/omnicell/configs/tasks/test.yaml'
model_vae_config_path = '/orcd/archive/abugoot/001/Projects/opitcho/sandbox/omnicell/configs/models/vae.yaml'
model_nn_config_path = '/orcd/archive/abugoot/001/Projects/opitcho/sandbox/omnicell/configs/models/nearest-neighbor.yaml'

model_config_path = model_nn_config_path


model_path = Path(model_config_path).resolve()
task_path = Path(task_config_path).resolve()


config_model = yaml.load(open(model_path), Loader=yaml.UnsafeLoader)
config_task = yaml.load(open(task_path), Loader=yaml.UnsafeLoader)

config = Config.empty().add_model_config(config_model).add_task_config(config_task).add_train_args({'test_mode': True})


In [22]:
adata = sc.read(config.get_data_path())
adata



AnnData object with n_obs × n_vars = 328542 × 2054
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'bc1_well', 'bc2_well', 'bc3_well', 'percent.mito', 'cell_type', 'pathway', 'RNA_snn_res.0.9', 'seurat_clusters', 'sample_ID', 'Batch_info', 'guide', 'gene', 'mixscale_score'
    var: 'gene', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg', 'log1p'

In [23]:

adata = preprocess(adata, config)
adata

adta control View of AnnData object with n_obs × n_vars = 14582 × 2054
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'bc1_well', 'bc2_well', 'bc3_well', 'percent.mito', 'cell_type', 'pathway', 'RNA_snn_res.0.9', 'seurat_clusters', 'sample_ID', 'Batch_info', 'guide', 'gene', 'mixscale_score', 'pert', 'cell'
    var: 'gene', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg', 'log1p'


  adata = adata_first.concatenate(adata_ctrl)


AnnData object with n_obs × n_vars = 24582 × 2054
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'bc1_well', 'bc2_well', 'bc3_well', 'percent.mito', 'cell_type', 'pathway', 'RNA_snn_res.0.9', 'seurat_clusters', 'sample_ID', 'Batch_info', 'guide', 'gene', 'mixscale_score', 'pert', 'cell', 'batch'
    var: 'gene', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'

In [24]:



input_dim = adata.shape[1]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pert_ids = adata.obs[PERT_KEY].unique()


In [27]:



#model = VAEPredictor(config.get_model_config(), input_dim, device, pert_ids)
model = NearestNeighborPredictor(config.get_model_config())


data = adata.X.toarray().astype(np.float32)
data = torch.tensor(data).to(device)

res = model.model.encode(data)
res[0].shape   

AttributeError: 'NearestNeighborPredictor' object has no attribute 'model'

In [28]:
model.train(adata)

adata.obs

Unnamed: 0,orig.ident,nCount_RNA,nFeature_RNA,sample,bc1_well,bc2_well,bc3_well,percent.mito,cell_type,pathway,RNA_snn_res.0.9,seurat_clusters,sample_ID,Batch_info,guide,gene,mixscale_score,pert,cell,batch
07_48_88_1_1_1_1_1_1_1_1_1-0,7,9816,4122,A549_IFNB,A7,D12,H4,1.161369,A549,IFNB,15.0,15.0,sample_1,Rep1,TRAFD1g3,TRAFD1,-0.290358,TRAFD1,A549,0
06_04_63_1_1_1_1_1_1_1_1_1-0,6,9359,4112,A549_IFNB,A6,A4,F3,3.835880,A549,IFNB,15.0,15.0,sample_1,Rep1,HES4g3,HES4,0.121449,HES4,A549,0
06_28_67_1_1_1_1_1_1_1_1_1-0,6,8999,3854,A549_IFNB,A6,C4,F7,9.189910,A549,IFNB,15.0,15.0,sample_1,Rep1,NTg8,NT,0.000000,ctrl,A549,0
06_27_93_1_1_1_1_1_1_1_1_1-0,6,8384,3600,A549_IFNB,A6,C3,H9,3.268130,A549,IFNB,15.0,15.0,sample_1,Rep1,STAT5Ag1,STAT5A,0.377627,STAT5A,A549,0
06_81_38_1_1_1_1_1_1_1_1_1-0,6,7925,3580,A549_IFNB,A6,G9,D2,3.798107,A549,IFNB,15.0,15.0,sample_1,Rep1,STAT4g2,STAT4,1.000000,STAT4,A549,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11_60_02_2_2-1,11,3379,2005,MCF7_IFNB,A11,E12,A2,5.622965,MCF7,IFNB,,,sample_16,Rep2,NTg4,NT,0.000000,ctrl,MCF7,1
12_37_18_2_2-1,12,3195,1809,MCF7_IFNB,A12,D1,B6,4.538341,MCF7,IFNB,,,sample_16,Rep2,NTg13,NT,0.000000,ctrl,MCF7,1
10_31_02_2_2-1,10,3130,1736,MCF7_IFNB,A10,C7,A2,4.984026,MCF7,IFNB,,,sample_16,Rep2,NTg13,NT,0.000000,ctrl,MCF7,1
12_18_26_2_2-1,12,3017,1892,MCF7_IFNB,A12,B6,C2,2.850514,MCF7,IFNB,,,sample_16,Rep2,NTg6,NT,0.000000,ctrl,MCF7,1


In [29]:

adata_test = adata[(adata.obs[PERT_KEY] == CONTROL_PERT) & (adata.obs[CELL_KEY] == 'A549')]

adata_test.obs

Unnamed: 0,orig.ident,nCount_RNA,nFeature_RNA,sample,bc1_well,bc2_well,bc3_well,percent.mito,cell_type,pathway,RNA_snn_res.0.9,seurat_clusters,sample_ID,Batch_info,guide,gene,mixscale_score,pert,cell,batch
06_28_67_1_1_1_1_1_1_1_1_1-0,6,8999,3854,A549_IFNB,A6,C4,F7,9.189910,A549,IFNB,15.0,15.0,sample_1,Rep1,NTg8,NT,0.0,ctrl,A549,0
05_69_22_1_1_1_1_1_1_1_1_1-0,5,6291,3106,A549_IFNB,A5,F9,B10,3.163249,A549,IFNB,15.0,15.0,sample_1,Rep1,NTg9,NT,0.0,ctrl,A549,0
08_42_41_1_1_1_1_1_1_1_1_1-0,8,5107,2567,A549_IFNB,A8,D6,D5,7.538672,A549,IFNB,15.0,15.0,sample_1,Rep1,NTg11,NT,0.0,ctrl,A549,0
06_01_93_1_1_1_1_1_1_1_1_1-0,6,4723,2552,A549_IFNB,A6,A1,H9,3.726445,A549,IFNB,15.0,15.0,sample_1,Rep1,NTg1,NT,0.0,ctrl,A549,0
07_27_61_1_1_1_1_1_1_1_1_1-0,7,4683,2617,A549_IFNB,A7,C3,F1,1.281230,A549,IFNB,15.0,15.0,sample_1,Rep1,NTg11,NT,0.0,ctrl,A549,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
05_15_40_2_2-1,5,3747,2112,A549_IFNB,A5,B3,D4,0.693888,A549,IFNB,,,sample_16,Rep2,NTg1,NT,0.0,ctrl,A549,1
07_88_53_2_2-1,7,3580,2168,A549_IFNB,A7,H4,E5,0.865922,A549,IFNB,,,sample_16,Rep2,NTg9,NT,0.0,ctrl,A549,1
05_35_76_2_2-1,5,3476,2083,A549_IFNB,A5,C11,G4,1.323360,A549,IFNB,,,sample_16,Rep2,NTg14,NT,0.0,ctrl,A549,1
05_70_76_2_2-1,5,3367,2015,A549_IFNB,A5,F10,G4,4.959905,A549,IFNB,,,sample_16,Rep2,NTg9,NT,0.0,ctrl,A549,1


In [30]:

preds = model.make_predict(adata_test, 'IFNAR1', 'A549')

NotImplementedError: Both cell type and perturbation are in the training data, in distribution prediction not implemented yet