In [3]:
import scanpy as sc
import argparse
from evaluate import AnndataProcessor
from accelerate import Accelerator
import glob
import numpy as np

default_arguments = {
    # Anndata Processing Arguments
    'adata_path': '',
    'dir': "./",
    'species': "human",
    'filter': True,
    'skip': True,

    # Model Arguments
    'model_loc': None,
    'batch_size': 25,
    'pad_length': 1536,
    'pad_token_idx': 0,
    'chrom_token_left_idx': 1,
    'chrom_token_right_idx': 2,
    'cls_token_idx': 3,
    'CHROM_TOKEN_OFFSET': 143574,
    'sample_size': 1024,
    'CXG': True,
    'nlayers': 4,
    'output_dim': 1280,
    'd_hid': 5120,
    'token_dim': 5120,
    'multi_gpu': False,

    # Misc Arguments
    'spec_chrom_csv_path': "./model_files/species_chrom.csv",
    'token_file': "./model_files/all_tokens.torch",
    'protein_embeddings_dir': "./model_files/protein_embeddings/",
    'offset_pkl_path': "./model_files/species_offsets.pkl"
}
args = argparse.Namespace(**default_arguments)

accelerator = Accelerator(project_dir=args.dir)


In [None]:
for f in glob.glob('/dfs/project/virtual-cell-data/train_data/perturb/cleaned/*/*/*.h5ad'):
    if 'scaling' not in f and 'replogle' not in f and 'preprocessed' not in f:
        adata_name = '_'.join(f.split('/')[-3:]).split('_adata.h5ad')[0]
        args.adata_path = f
        processor = AnndataProcessor(args, accelerator, name = adata_name)
        processor.preprocess_anndata()
        processor.generate_idxs()


Using sample 4 layer model
Proccessing dixit_K562_33
/dfs/project/virtual-cell-data/train_data/perturb/cleaned/dixit/K562//33_adata.h5ad
3747.0
dixit_K562_33 (51898, 12953)
Wrote Shapes Dict
12952
Max Code: 613
Proccessing peipei_CD8+T_33
/dfs/project/virtual-cell-data/train_data/perturb/cleaned/peipei/CD8+T//33_adata.h5ad
8028.0
peipei_CD8+T_33 (116676, 13936)
Wrote Shapes Dict
13936
Max Code: 613
Proccessing ursu_A549_33_TPS3
/dfs/project/virtual-cell-data/train_data/perturb/cleaned/ursu/A549//33_TPS3_adata.h5ad
4757.0
ursu_A549_33_TPS3 (12702, 10854)
Wrote Shapes Dict
10854
Max Code: 613
Proccessing ursu_A549_33_KRAS
/dfs/project/virtual-cell-data/train_data/perturb/cleaned/ursu/A549//33_KRAS_adata.h5ad
3540.0
ursu_A549_33_KRAS (11633, 10774)
Wrote Shapes Dict
10774
Max Code: 613
Proccessing frangieh_TIL_33_perturb_CITE
/dfs/project/virtual-cell-data/train_data/perturb/cleaned/frangieh/TIL//33_perturb_CITE_adata.h5ad
4437.0
frangieh_TIL_33_perturb_CITE (218331, 16236)


### Replogle essentials

In [4]:
##replogle_datasets = ['/dfs/project/perturb-gnn/datasets/replogle2022_unpublished/jurkat_raw_singlecell_01_varnames_mod.h5ad',
##                     '/dfs/project/perturb-gnn/datasets/replogle2022_unpublished/hepg2_raw_singlecell_01_varnames_mod_fix.h5ad']


##replogle_datasets = ['/dfs/project/perturb-gnn/datasets/replogle/anndata/rpe1_raw_singlecell_01_varnames_mod.h5ad',
##                     '/dfs/project/perturb-gnn/datasets/replogle/anndata/K562_essential_raw_singlecell_01_varnames_mod.h5ad']



In [5]:
for f in replogle_datasets:
    args.adata_path = f
    processor = AnndataProcessor(args, accelerator)
    processor.preprocess_anndata()
    processor.generate_idxs()


Using sample 4 layer model
Proccessing rpe1_raw_singlecell_01_varnames_mod
/dfs/project/perturb-gnn/datasets/replogle/anndata//rpe1_raw_singlecell_01_varnames_mod.h5ad
3061.0
rpe1_raw_singlecell_01_varnames_mod (247914, 8427)
Wrote Shapes Dict
8426
Max Code: 612
Proccessing K562_essential_raw_singlecell_01_varnames_mod
/dfs/project/perturb-gnn/datasets/replogle/anndata//K562_essential_raw_singlecell_01_varnames_mod.h5ad
2277.0
K562_essential_raw_singlecell_01_varnames_mod (310385, 8105)
Wrote Shapes Dict
8103
Max Code: 612
Proccessing K562_gwps_raw_singlecell_01
/dfs/project/perturb-gnn/datasets/replogle/anndata//K562_gwps_raw_singlecell_01.h5ad


RuntimeError: stack expects a non-empty TensorList

### Replogle genome-wide

In [9]:
replogle_datasets = ['/dfs/project/virtual-cell-data/train_data/perturb/cleaned/replogle/K562/33_part1_adata.h5ad',
                    '/dfs/project/virtual-cell-data/train_data/perturb/cleaned/replogle/K562/33_part2_adata.h5ad'
                    ]

In [None]:
for f in replogle_datasets:
    adata_name = '_'.join(f.split('/')[-3:]).split('_adata.h5ad')[0]
    args.adata_path = f
    processor = AnndataProcessor(args, accelerator, name = adata_name)
    processor.preprocess_anndata()
    processor.generate_idxs()


Proccessing replogle_K562_33_part1
/dfs/project/virtual-cell-data/train_data/perturb/cleaned/replogle/K562//33_part1_adata.h5ad


### McFaline

In [4]:
f = '/dfs/project/virtual-cell-data/train_data/perturb/mcfaline-figueroa/mcfaline23_gxe_processed.h5ad'

In [5]:
args.adata_path = f
processor = AnndataProcessor(args, accelerator)
processor.preprocess_anndata()
processor.generate_idxs()

Using sample 4 layer model
Proccessing mcfaline23_gxe_processed
/dfs/project/virtual-cell-data/train_data/perturb/mcfaline-figueroa//mcfaline23_gxe_processed.h5ad
17362.0
mcfaline23_gxe_processed (878226, 17809)
Wrote Shapes Dict
PE Idx, Chrom and Starts files already created


## Setting up training data

In [24]:
import glob
import pickle
import torch

shape_dicts_files = glob.glob('/dfs/user/yhr/Arc/uce_edit/*_shapes_dict.pkl')
chroms_files = glob.glob('/dfs/user/yhr/Arc/uce_edit/*_chroms.pkl')
starts_files = glob.glob('/dfs/user/yhr/Arc/uce_edit/*_starts.pkl')
pe_idxs_files = glob.glob('/dfs/user/yhr/Arc/uce_edit/*_pe_idx.torch')

shape_dicts = {}
chroms = {}
starts = {}
pe_idxs = {}

for f in shape_dicts_files:
    with open(f, 'rb') as file:
        read_in = pickle.load(file)
    shape_dicts.update(read_in)
    
for f in chroms_files:
    with open(f, 'rb') as file:
        read_in = pickle.load(file)
    chroms.update(read_in)
    
for f in starts_files:
    with open(f, 'rb') as file:
        read_in = pickle.load(file)
    starts.update(read_in)
    
for f in pe_idxs_files:
    read_in = torch.load(f)
    pe_idxs.update(read_in)
    

In [20]:
## Total number of additional training cells

sum([x[0] for x in shape_dicts.values()])

6214929

In [77]:
with open('pert_shapes_dict.pkl', 'wb') as f:
    pickle.dump(shape_dicts, f)
    
with open('pert_chroms.pkl', 'wb') as f:
    pickle.dump(chroms, f)
    
with open('pert_starts.pkl', 'wb') as f:
    pickle.dump(starts, f)
    
torch.save(pe_idxs, 'pert_pe_idx.torch')

In [54]:
df = pd.DataFrame.from_dict(shape_dicts, orient='index').reset_index()
df = df.rename(columns={0:'num_cells', 1:'num_genes', 'index':'names'})
df['path'] = [x +'.h5ad' for x in df['names']]
df['census'] = 'no'
df['species'] = 'human'
df['species'] = df['names'].map(species_map)

In [62]:
## Add counts data to dataset_file

import pandas as pd
all_train_datasets = pd.read_csv('/dfs/project/uce/response/full_train_datasets_pert.csv',  index_col = 0)
all_train_datasets = pd.concat([df, all_train_datasets])
all_train_datasets.to_csv('/dfs/project/uce/response/full_train_datasets_pert.csv')

In [73]:
## Add other params: starts

with open('/lfs/local/0/yhr/dataset_to_chroms_new.pkl', 'rb') as f:
    loaded = pickle.load(f)
    loaded.update(chroms)
    
with open('/lfs/local/0/yhr/dataset_to_chroms_new.pkl', 'wb') as f:
    pickle.dump(loaded, f)    

In [75]:
## Add other params: chroms

with open('/lfs/local/0/yhr/dataset_to_starts_new.pkl', 'rb') as f:
    loaded = pickle.load(f)
    loaded.update(starts)
    
with open('/lfs/local/0/yhr/dataset_to_starts_new.pkl', 'wb') as f:
    pickle.dump(loaded, f)    

In [84]:
## Add other params: protein idxs

loaded = torch.load('/lfs/local/0/yhr/reduced_datasets_to_pe_chrom_5120_new.torch')
loaded.update(pe_idxs)

torch.save(loaded, '/lfs/local/0/yhr/reduced_datasets_to_pe_chrom_5120_new.torch')

In [88]:
np.sort(list(loaded.keys()))

array(['00476f9f-ebc1-4b72-b541-32f912ce36ea',
       '01209dce-3575-4bed-b1df-129f57fbc031',
       '0129dbd9-a7d3-4f6b-96b9-1da155a93748',
       '019c7af2-c827-4454-9970-44d5e39ce068',
       '01ad3cd7-3929-4654-84c0-6db05bd5fd59',
       '03d38670-1444-4001-bc53-9936e61d9b20',
       '047d57f2-4d14-45de-aa98-336c6f583750',
       '07854d9c-5375-4a9b-ac34-fa919d3c3686',
       '07b1d7c8-5c2e-42f7-9246-26f746cd6013',
       '090da8ea-46e8-40df-bffc-1f78e1538d27',
       '095940cb-7422-4510-96e2-cbafd961eb88',
       '0b4a15a7-4e9e-4555-9733-2423e5c66469',
       '0b75c598-0893-4216-afe8-5414cab7739d',
       '0ba636a1-4754-4786-a8be-7ab3cf760fd6',
       '0c774045-26a7-40f8-9b07-6742d3c771c0',
       '0ee5ae70-c3f5-473f-bd1c-287f4690ffc5',
       '0f4865d5-8000-4f68-8ac7-f5efea9e5e70',
       '1185a7d3-a9c1-4280-9ba5-d61895b15cac',
       '11ff73e8-d3e4-4445-9309-477a2c5be6f6',
       '124744b8-4681-474a-9894-683896122708',
       '12967895-3d58-4e93-be2c-4e1bcf4388d5',
       '182f6

In [89]:
loaded['K562_essential_raw_singlecell_01_varnames_mod']

tensor([24381, 21976, 20529,  ..., 23680, 23673, 22831])