In [117]:
def rank_genes_groups_by_cov(
    adata,
    groupby,
    control_group,
    covariate,
    n_genes=50,
    rankby_abs=True,
    key_added="rank_genes_groups_cov",
    return_dict=False,
):
    gene_dict = {}
    cov_categories = adata.obs[covariate].unique()
    for cov_cat in cov_categories:
        # name of the control group in the groupby obs column
        control_group_cov = control_group  # "_".join([cov_cat, control_group])
        # subset adata to cells belonging to a covariate category
        adata_cov = adata[adata.obs[covariate] == cov_cat]
        # compute DEGs
        sc.tl.rank_genes_groups(
            adata_cov,
            groupby=groupby,
            reference=control_group_cov,
            rankby_abs=rankby_abs,
            n_genes=n_genes,
            use_raw=False,
        )
        # add entries to dictionary of gene sets
        de_genes = pd.DataFrame(adata_cov.uns["rank_genes_groups"]["names"])
        for group in de_genes:
            gene_dict[group] = de_genes[group].tolist()
    adata.uns[key_added] = gene_dict
    if return_dict:
        return gene_dict


def get_DE_genes(adata):
    adata.obs.loc[:, "control"] = adata.obs.condition.apply(lambda x: 1 if x == "control" else 0)
    adata.obs = adata.obs.astype("category")
    rank_genes_groups_by_cov(
        adata,
        groupby="condition",
        covariate="cell_line",
        control_group="ctrl",
        n_genes=50,
        key_added="rank_genes_groups_cov_all",
    )
    return adata

### First split.

In [118]:
import scanpy as sc
import pickle
import os
import anndata
import numpy as np
import pandas as pd

output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/norman"

In [119]:
train = sc.read_h5ad(os.path.join(output_dir, "adata_train_2_seen_genes.h5ad"))
ood = sc.read_h5ad(os.path.join(output_dir, "adata_ood_2_seen_genes.h5ad"))
test = sc.read_h5ad(os.path.join(output_dir, "adata_test_2_seen_genes.h5ad"))

In [120]:
custom_split = {}

In [121]:
train.obs.condition = train.obs.condition.str.replace('control', 'ctrl')
ood.obs.condition = ood.obs.condition.str.replace('control', 'ctrl')
test.obs.condition = test.obs.condition.str.replace('control', 'ctrl')

In [122]:
def modify_condition(condition):
    if '+' not in condition:
        if condition == 'ctrl':
            return condition
        else:
            return condition + '+ctrl'
    else:
        return condition

def modify_condition_name(condition):
    if condition != 'ctrl':
        return 'K562_'+ condition + '_1+1'
    else:
        return 'K562_'+ condition + '+1'

# Apply the function to the condition column
train.obs['condition'] = train.obs['condition'].apply(modify_condition)
ood.obs['condition'] = ood.obs['condition'].apply(modify_condition)
test.obs['condition'] = test.obs['condition'].apply(modify_condition)

train.obs['condition_name'] = train.obs['condition']
train.obs['condition_name'] = train.obs['condition'].apply(modify_condition_name)
ood.obs['condition_name'] = ood.obs['condition']
ood.obs['condition_name'] = ood.obs['condition'].apply(modify_condition_name)
test.obs['condition_name'] = test.obs['condition']
test.obs['condition_name'] = test.obs['condition'].apply(modify_condition_name)

In [123]:
ood.obs.condition

index
AAACCTGAGGCCCTTG-1     KLF1+MAP2K6
AAACCTGGTTCACCTC-1     MAP2K6+SPI1
AAACCTGTCCATTCTA-1     FOXA1+FOXA3
AAACGGGAGCGATTCT-1    IKZF3+MAP2K6
AAACGGGGTAGCAAAT-1      FOXA1+KLF1
                          ...     
TGTATTCTCACAAACC-8            ctrl
TTAGGACAGGCTCATT-8            ctrl
TTCCCAGCACGAAACG-8            ctrl
TTGAACGTCACTTACT-8            ctrl
TTGACTTGTATCAGTC-8            ctrl
Name: condition, Length: 15112, dtype: object

In [124]:
train_list = []
ood_list = []

for pert in train.obs.condition.unique():
    if pert != 'ctrl':
        train_list.append(pert)

for pert in ood.obs.condition.unique():
    if pert != 'ctrl':
        ood_list.append(pert)

custom_split['train'] = train_list
custom_split['test'] = ood_list
custom_split['val'] = [train_list[1]]


In [125]:
custom_split['train'] = [pert for pert in custom_split['train'] if not pert in (['KIAA1804+ctrl', 'IER5L+LYL1', 'IER5L+ctrl'])]
custom_split['test'] = [pert for pert in custom_split['train'] if not pert in (['KIAA1804+ctrl', 'IER5L+LYL1', 'IER5L+ctrl'])]

In [126]:
gears_2seen = anndata.concat([train,ood,test])
gears_2seen.var['gene_name'] = gears_2seen.var.index.values

  utils.warn_names_duplicates("obs")


In [127]:
gears_2seen

AnnData object with n_obs × n_vars = 111755 × 2000
    obs: 'guide_identity', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'guide_AHR', 'guide_ARID1A', 'guide_ARRDC3', 'guide_ATL1', 'guide_BAK1', 'guide_BCL2L11', 'guide_BCORL1', 'guide_BPGM', 'guide_C19orf26', 'guide_C3orf72', 'guide_CBFA2T3', 'guide_CBL', 'guide_CDKN1A', 'guide_CDKN1B', 'guide_CDKN1C', 'guide_CEBPA', 'guide_CEBPB', 'guide_CEBPE', 'guide_CELF2', 'guide_CITED1', 'guide_CKS1B', 'guide_CLDN6', 'guide_CNN1', 'guide_CNNM4', 'guide_COL1A1', 'guide_COL2A1', 'guide_CSRNP1', 'guide_DLX2', 'guide_DUSP9', 'guide_EGR1', 'guide_ELMSAN1', 'guide_ETS2', 'guide_FEV', 'guide_FOSB', 'guide_FOXA1', 'guide_FOXA3', 'guide_FOXF1', 'guide_FOXL2', 'guide_FOXO4', 'guide_GLB1L2', 'guide_HES7', 'guide_HK2', 'guide_HNF4A', 'guide_HOXA13', 'guide_HOXB9', 'guide_HOXC13', 'guide_IER5L', 'guide_IGDCC3', 'guide_IKZF3', 'guide_IRF1', 'guide_ISL2', 'guide_JUN', 'guide_KIAA1804', 'guide_KIF18B', 'guide_KIF2C', 'g

In [None]:
get_DE_genes(gears_2seen)

In [129]:
gears_2seen_final = anndata.concat([train,ood])

In [130]:
gears_2seen_final = anndata.concat([train,ood])
gears_2seen_final.uns = gears_2seen.uns
gears_2seen_final.var = gears_2seen.var

In [131]:
gears_2seen_final = gears_2seen_final[~gears_2seen_final.obs.condition.isin(['KIAA1804+ctrl', 'IER5L+LYL1', 'IER5L+ctrl'])]

In [132]:
train.obs

Unnamed: 0_level_0,guide_identity,read_count,UMI_count,coverage,gemgroup,good_coverage,number_of_cells,guide_AHR,guide_ARID1A,guide_ARRDC3,...,perturbation_value,perturbation_unit,gene_1,gene_2,cell_line,ood,is_ood,split,control,condition_name
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCTGAGAAGAAGC-1,NegCtrl0_NegCtrl0__NegCtrl0_NegCtrl0,1252,67,18.686567,1,True,2,0,0,0,...,,,control,control,K562,not ood,False,train,1,K562_ctrl+1
AAACCTGCACGAAGCA-1,NegCtrl10_NegCtrl0__NegCtrl10_NegCtrl0,958,39,24.564103,1,True,1,0,0,0,...,,,control,control,K562,not ood,False,train,1,K562_ctrl+1
AAACCTGCAGACGTAG-1,CEBPE_RUNX1T1__CEBPE_RUNX1T1,244,14,17.428571,1,True,1,0,0,0,...,,,CEBPE,RUNX1T1,K562,not ood,False,train,0,K562_CEBPE+RUNX1T1_1+1
AAACCTGCAGCCTTGG-1,MAML2_NegCtrl0__MAML2_NegCtrl0,1525,66,23.106061,1,True,1,0,0,0,...,,,MAML2,control,K562,not ood,False,train,0,K562_MAML2+ctrl_1+1
AAACCTGCATCTCCCA-1,NegCtrl0_CEBPE__NegCtrl0_CEBPE,499,30,16.633333,1,True,1,0,0,0,...,,,CEBPE,control,K562,not ood,False,train,0,K562_CEBPE+ctrl_1+1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTCAGTCCATGAT-8,TGFBR2_C19orf26__TGFBR2_C19orf26,1306,66,19.787879,8,True,2,0,0,0,...,,,C19orf26,TGFBR2,K562,not ood,False,train,0,K562_C19orf26+TGFBR2_1+1
TTTGTCATCAGTACGT-8,FOXA3_NegCtrl0__FOXA3_NegCtrl0,2068,95,21.768421,8,True,1,0,0,0,...,,,FOXA3,control,K562,not ood,False,train,0,K562_FOXA3+ctrl_1+1
TTTGTCATCCACTCCA-8,CELF2_NegCtrl0__CELF2_NegCtrl0,829,33,25.121212,8,True,1,0,0,0,...,,,CELF2,control,K562,not ood,False,train,0,K562_CELF2+ctrl_1+1
TTTGTCATCCCAACGG-8,BCORL1_NegCtrl0__BCORL1_NegCtrl0,136,9,15.111111,8,True,1,0,0,0,...,,,BCORL1,control,K562,not ood,False,train,0,K562_BCORL1+ctrl_1+1


In [133]:
len(gears_2seen_final.obs_names)

97652

In [134]:
len(set(gears_2seen_final.obs_names))

97652

In [17]:
gears_2seen_final.write(os.path.join(output_dir, "gears", "2seen", "perturb_processed.h5ad"))


  df[key] = c
  df[key] = c


In [18]:
train.write(os.path.join(output_dir, "gears", "2seen", "train_processed.h5ad"))
ood.write(os.path.join(output_dir, "gears", "2seen", "ood_processed.h5ad"))
test.write(os.path.join(output_dir, "gears", "2seen", "test_processed.h5ad"))


In [19]:
with open(os.path.join(output_dir, "gears", "2seen", "custom_split_2seen.pkl"), 'wb') as fp:
    pickle.dump(custom_split, fp)

In [20]:
train.obs

Unnamed: 0_level_0,guide_identity,read_count,UMI_count,coverage,gemgroup,good_coverage,number_of_cells,guide_AHR,guide_ARID1A,guide_ARRDC3,...,perturbation_value,perturbation_unit,gene_1,gene_2,cell_line,ood,is_ood,split,control,condition_name
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCTGAGAAGAAGC-1,NegCtrl0_NegCtrl0__NegCtrl0_NegCtrl0,1252,67,18.686567,1,True,2,0,0,0,...,,,control,control,K562,not ood,False,train,1,K562_ctrl+1
AAACCTGCACGAAGCA-1,NegCtrl10_NegCtrl0__NegCtrl10_NegCtrl0,958,39,24.564103,1,True,1,0,0,0,...,,,control,control,K562,not ood,False,train,1,K562_ctrl+1
AAACCTGCAGACGTAG-1,CEBPE_RUNX1T1__CEBPE_RUNX1T1,244,14,17.428571,1,True,1,0,0,0,...,,,CEBPE,RUNX1T1,K562,not ood,False,train,0,K562_CEBPE+RUNX1T1_1+1
AAACCTGCAGCCTTGG-1,MAML2_NegCtrl0__MAML2_NegCtrl0,1525,66,23.106061,1,True,1,0,0,0,...,,,MAML2,control,K562,not ood,False,train,0,K562_MAML2+ctrl_1+1
AAACCTGCATCTCCCA-1,NegCtrl0_CEBPE__NegCtrl0_CEBPE,499,30,16.633333,1,True,1,0,0,0,...,,,CEBPE,control,K562,not ood,False,train,0,K562_CEBPE+ctrl_1+1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTCAGTCCATGAT-8,TGFBR2_C19orf26__TGFBR2_C19orf26,1306,66,19.787879,8,True,2,0,0,0,...,,,C19orf26,TGFBR2,K562,not ood,False,train,0,K562_C19orf26+TGFBR2_1+1
TTTGTCATCAGTACGT-8,FOXA3_NegCtrl0__FOXA3_NegCtrl0,2068,95,21.768421,8,True,1,0,0,0,...,,,FOXA3,control,K562,not ood,False,train,0,K562_FOXA3+ctrl_1+1
TTTGTCATCCACTCCA-8,CELF2_NegCtrl0__CELF2_NegCtrl0,829,33,25.121212,8,True,1,0,0,0,...,,,CELF2,control,K562,not ood,False,train,0,K562_CELF2+ctrl_1+1
TTTGTCATCCCAACGG-8,BCORL1_NegCtrl0__BCORL1_NegCtrl0,136,9,15.111111,8,True,1,0,0,0,...,,,BCORL1,control,K562,not ood,False,train,0,K562_BCORL1+ctrl_1+1


### Second split.

In [151]:
import scanpy as sc
import pickle
import os
import anndata
import numpy as np
import pandas as pd

output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/norman"

In [152]:
train = sc.read_h5ad(os.path.join(output_dir, "adata_train_1_seen_genes.h5ad"))
ood = sc.read_h5ad(os.path.join(output_dir, "adata_ood_1_seen_genes.h5ad"))
test = sc.read_h5ad(os.path.join(output_dir, "adata_test_1_seen_genes.h5ad"))

In [153]:
custom_split = {}

In [154]:
train.obs.condition = train.obs.condition.str.replace('control', 'ctrl')
ood.obs.condition = ood.obs.condition.str.replace('control', 'ctrl')
test.obs.condition = test.obs.condition.str.replace('control', 'ctrl')

In [155]:
def modify_condition(condition):
    if '+' not in condition:
        if condition == 'ctrl':
            return condition
        else:
            return condition + '+ctrl'
    else:
        return condition

def modify_condition_name(condition):
    if condition != 'ctrl':
        return 'K562_'+ condition + '_1+1'
    else:
        return 'K562_'+ condition + '+1'

# Apply the function to the condition column
train.obs['condition'] = train.obs['condition'].apply(modify_condition)
ood.obs['condition'] = ood.obs['condition'].apply(modify_condition)
test.obs['condition'] = test.obs['condition'].apply(modify_condition)

train.obs['condition_name'] = train.obs['condition']
train.obs['condition_name'] = train.obs['condition'].apply(modify_condition_name)
ood.obs['condition_name'] = ood.obs['condition']
ood.obs['condition_name'] = ood.obs['condition'].apply(modify_condition_name)
test.obs['condition_name'] = test.obs['condition']
test.obs['condition_name'] = test.obs['condition'].apply(modify_condition_name)

In [156]:
ood.obs.condition

index
AAACCTGAGGCCCTTG-1      KLF1+MAP2K6
AAACGGGAGCGATTCT-1     IKZF3+MAP2K6
AAAGATGAGAGTACAT-1    PTPN9+UBASH3B
AAAGATGAGCCTCGTG-1    IGDCC3+ZBTB25
AAAGCAAAGCTAGTCT-1    MAP2K3+MAP2K6
                          ...      
TGTATTCTCACAAACC-8             ctrl
TTAGGACAGGCTCATT-8             ctrl
TTCCCAGCACGAAACG-8             ctrl
TTGAACGTCACTTACT-8             ctrl
TTGACTTGTATCAGTC-8             ctrl
Name: condition, Length: 7772, dtype: object

In [None]:
train.obs.condition_name.unique()

In [158]:
train_list = []
ood_list = []

for pert in train.obs.condition.unique():
    if pert != 'ctrl':
        train_list.append(pert)

for pert in ood.obs.condition.unique():
    if pert != 'ctrl':
        ood_list.append(pert)

custom_split['train'] = train_list
custom_split['test'] = ood_list
custom_split['val'] = [train_list[1]]


In [159]:
custom_split['train'] = [pert for pert in custom_split['train'] if not pert in (['KIAA1804+ctrl', 'IER5L+LYL1', 'IER5L+ctrl'])]
custom_split['test'] = [pert for pert in custom_split['train'] if not pert in (['KIAA1804+ctrl', 'IER5L+LYL1', 'IER5L+ctrl'])]

In [160]:
gears_2seen = anndata.concat([train,ood,test])
gears_2seen.var['gene_name'] = gears_2seen.var.index.values

  utils.warn_names_duplicates("obs")


In [161]:
gears_2seen

AnnData object with n_obs × n_vars = 100988 × 2000
    obs: 'guide_identity', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'guide_AHR', 'guide_ARID1A', 'guide_ARRDC3', 'guide_ATL1', 'guide_BAK1', 'guide_BCL2L11', 'guide_BCORL1', 'guide_BPGM', 'guide_C19orf26', 'guide_C3orf72', 'guide_CBFA2T3', 'guide_CBL', 'guide_CDKN1A', 'guide_CDKN1B', 'guide_CDKN1C', 'guide_CEBPA', 'guide_CEBPB', 'guide_CEBPE', 'guide_CELF2', 'guide_CITED1', 'guide_CKS1B', 'guide_CLDN6', 'guide_CNN1', 'guide_CNNM4', 'guide_COL1A1', 'guide_COL2A1', 'guide_CSRNP1', 'guide_DLX2', 'guide_DUSP9', 'guide_EGR1', 'guide_ELMSAN1', 'guide_ETS2', 'guide_FEV', 'guide_FOSB', 'guide_FOXA1', 'guide_FOXA3', 'guide_FOXF1', 'guide_FOXL2', 'guide_FOXO4', 'guide_GLB1L2', 'guide_HES7', 'guide_HK2', 'guide_HNF4A', 'guide_HOXA13', 'guide_HOXB9', 'guide_HOXC13', 'guide_IER5L', 'guide_IGDCC3', 'guide_IKZF3', 'guide_IRF1', 'guide_ISL2', 'guide_JUN', 'guide_KIAA1804', 'guide_KIF18B', 'guide_KIF2C', 'g

In [None]:
get_DE_genes(gears_2seen)

In [163]:
gears_2seen_final = anndata.concat([train,ood])
gears_2seen_final.uns = gears_2seen.uns
gears_2seen_final.var = gears_2seen.var

In [164]:
gears_2seen_final = gears_2seen_final[~gears_2seen_final.obs.condition.isin(['KIAA1804+ctrl', 'IER5L+LYL1', 'IER5L+ctrl'])]

In [165]:
train.obs

Unnamed: 0_level_0,guide_identity,read_count,UMI_count,coverage,gemgroup,good_coverage,number_of_cells,guide_AHR,guide_ARID1A,guide_ARRDC3,...,perturbation_value,perturbation_unit,gene_1,gene_2,cell_line,ood,is_ood,split,control,condition_name
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCTGAGAAGAAGC-1,NegCtrl0_NegCtrl0__NegCtrl0_NegCtrl0,1252,67,18.686567,1,True,2,0,0,0,...,,,control,control,K562,not ood,False,train,1,K562_ctrl+1
AAACCTGCACGAAGCA-1,NegCtrl10_NegCtrl0__NegCtrl10_NegCtrl0,958,39,24.564103,1,True,1,0,0,0,...,,,control,control,K562,not ood,False,train,1,K562_ctrl+1
AAACCTGCAGACGTAG-1,CEBPE_RUNX1T1__CEBPE_RUNX1T1,244,14,17.428571,1,True,1,0,0,0,...,,,CEBPE,RUNX1T1,K562,not ood,False,train,0,K562_CEBPE+RUNX1T1_1+1
AAACCTGCAGCCTTGG-1,MAML2_NegCtrl0__MAML2_NegCtrl0,1525,66,23.106061,1,True,1,0,0,0,...,,,MAML2,control,K562,not ood,False,train,0,K562_MAML2+ctrl_1+1
AAACCTGCATTACCTT-1,ETS2_MAP7D1__ETS2_MAP7D1,4,1,4.000000,1,False,0,0,0,0,...,,,ETS2,MAP7D1,K562,not ood,False,train,0,K562_ETS2+MAP7D1_1+1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTCAGTCATGCAT-8,RHOXF2_SET__RHOXF2_SET,1138,45,25.288889,8,True,1,0,0,0,...,,,RHOXF2,SET,K562,not ood,False,train,0,K562_RHOXF2+SET_1+1
TTTGTCATCCACTCCA-8,CELF2_NegCtrl0__CELF2_NegCtrl0,829,33,25.121212,8,True,1,0,0,0,...,,,CELF2,control,K562,not ood,False,train,0,K562_CELF2+ctrl_1+1
TTTGTCATCCCAACGG-8,BCORL1_NegCtrl0__BCORL1_NegCtrl0,136,9,15.111111,8,True,1,0,0,0,...,,,BCORL1,control,K562,not ood,False,train,0,K562_BCORL1+ctrl_1+1
TTTGTCATCCTCCTAG-8,ZBTB10_PTPN12__ZBTB10_PTPN12,1254,59,21.254237,8,True,3,0,0,0,...,,,PTPN12,ZBTB10,K562,not ood,False,train,0,K562_PTPN12+ZBTB10_1+1


In [166]:
gears_2seen_final

View of AnnData object with n_obs × n_vars = 86985 × 2000
    obs: 'guide_identity', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'guide_AHR', 'guide_ARID1A', 'guide_ARRDC3', 'guide_ATL1', 'guide_BAK1', 'guide_BCL2L11', 'guide_BCORL1', 'guide_BPGM', 'guide_C19orf26', 'guide_C3orf72', 'guide_CBFA2T3', 'guide_CBL', 'guide_CDKN1A', 'guide_CDKN1B', 'guide_CDKN1C', 'guide_CEBPA', 'guide_CEBPB', 'guide_CEBPE', 'guide_CELF2', 'guide_CITED1', 'guide_CKS1B', 'guide_CLDN6', 'guide_CNN1', 'guide_CNNM4', 'guide_COL1A1', 'guide_COL2A1', 'guide_CSRNP1', 'guide_DLX2', 'guide_DUSP9', 'guide_EGR1', 'guide_ELMSAN1', 'guide_ETS2', 'guide_FEV', 'guide_FOSB', 'guide_FOXA1', 'guide_FOXA3', 'guide_FOXF1', 'guide_FOXL2', 'guide_FOXO4', 'guide_GLB1L2', 'guide_HES7', 'guide_HK2', 'guide_HNF4A', 'guide_HOXA13', 'guide_HOXB9', 'guide_HOXC13', 'guide_IER5L', 'guide_IGDCC3', 'guide_IKZF3', 'guide_IRF1', 'guide_ISL2', 'guide_JUN', 'guide_KIAA1804', 'guide_KIF18B', 'guide_KIF

In [167]:
gears_2seen_final.write(os.path.join(output_dir, "gears", "1seen", "perturb_processed.h5ad"))


  df[key] = c
  df[key] = c


In [168]:
train.write(os.path.join(output_dir, "gears", "1seen", "train_processed.h5ad"))
ood.write(os.path.join(output_dir, "gears", "1seen", "ood_processed.h5ad"))
test.write(os.path.join(output_dir, "gears", "1seen", "test_processed.h5ad"))


In [169]:
with open(os.path.join(output_dir, "gears", "1seen", "custom_split_1seen.pkl"), 'wb') as fp:
    pickle.dump(custom_split, fp)

In [170]:
len(gears_2seen_final.obs_names)

86985

### Third split.

In [172]:
import scanpy as sc
import pickle
import os
import anndata
import numpy as np
import pandas as pd

output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/norman"

In [173]:
train = sc.read_h5ad(os.path.join(output_dir, "adata_train_0_seen_genes.h5ad"))
ood = sc.read_h5ad(os.path.join(output_dir, "adata_ood_0_seen_genes.h5ad"))
test = sc.read_h5ad(os.path.join(output_dir, "adata_test_0_seen_genes.h5ad"))

In [174]:
custom_split = {}

In [175]:
train.obs.condition = train.obs.condition.str.replace('control', 'ctrl')
ood.obs.condition = ood.obs.condition.str.replace('control', 'ctrl')
test.obs.condition = test.obs.condition.str.replace('control', 'ctrl')

In [176]:
def modify_condition(condition):
    if '+' not in condition:
        if condition == 'ctrl':
            return condition
        else:
            return condition + '+ctrl'
    else:
        return condition

def modify_condition_name(condition):
    if condition != 'ctrl':
        return 'K562_'+ condition + '_1+1'
    else:
        return 'K562_'+ condition + '+1'

# Apply the function to the condition column
train.obs['condition'] = train.obs['condition'].apply(modify_condition)
ood.obs['condition'] = ood.obs['condition'].apply(modify_condition)
test.obs['condition'] = test.obs['condition'].apply(modify_condition)

train.obs['condition_name'] = train.obs['condition']
train.obs['condition_name'] = train.obs['condition'].apply(modify_condition_name)
ood.obs['condition_name'] = ood.obs['condition']
ood.obs['condition_name'] = ood.obs['condition'].apply(modify_condition_name)
test.obs['condition_name'] = test.obs['condition']
test.obs['condition_name'] = test.obs['condition'].apply(modify_condition_name)

In [177]:
ood.obs.condition

index
AAACCTGGTTCACCTC-1    MAP2K6+SPI1
AAACGGGGTAGCAAAT-1     FOXA1+KLF1
AAAGATGTCCACGAAT-1    BAK1+TMSB4X
AAAGCAAAGGCGATAC-1     MAPK1+PRTG
AAAGCAAGTCTCTCTG-1    MAP2K6+SPI1
                         ...     
TGTATTCTCACAAACC-8           ctrl
TTAGGACAGGCTCATT-8           ctrl
TTCCCAGCACGAAACG-8           ctrl
TTGAACGTCACTTACT-8           ctrl
TTGACTTGTATCAGTC-8           ctrl
Name: condition, Length: 4987, dtype: object

In [179]:
train_list = []
ood_list = []

for pert in train.obs.condition.unique():
    if pert != 'ctrl':
        train_list.append(pert)

for pert in ood.obs.condition.unique():
    if pert != 'ctrl':
        ood_list.append(pert)

custom_split['train'] = train_list
custom_split['test'] = ood_list
custom_split['val'] = [train_list[1]]


In [180]:
custom_split['train'] = [pert for pert in custom_split['train'] if not pert in (['KIAA1804+ctrl', 'IER5L+LYL1', 'IER5L+ctrl'])]
custom_split['test'] = [pert for pert in custom_split['train'] if not pert in (['KIAA1804+ctrl', 'IER5L+LYL1', 'IER5L+ctrl'])]

In [181]:
gears_2seen = anndata.concat([train,ood,test])
gears_2seen.var['gene_name'] = gears_2seen.var.index.values

  utils.warn_names_duplicates("obs")


In [None]:
get_DE_genes(gears_2seen)

In [183]:
gears_2seen_final = anndata.concat([train,ood])
gears_2seen_final.uns = gears_2seen.uns
gears_2seen_final.var = gears_2seen.var

In [184]:
gears_2seen_final = gears_2seen_final[~gears_2seen_final.obs.condition.isin(['KIAA1804+ctrl', 'IER5L+LYL1', 'IER5L+ctrl'])]

In [185]:
gears_2seen_final.write(os.path.join(output_dir, "gears", "0seen", "perturb_processed.h5ad"))


  df[key] = c
  df[key] = c


In [188]:
train.write(os.path.join(output_dir, "gears", "0seen", "train_processed.h5ad"))
ood.write(os.path.join(output_dir, "gears", "0seen", "ood_processed.h5ad"))
test.write(os.path.join(output_dir, "gears", "0seen", "test_processed.h5ad"))


In [189]:
with open(os.path.join(output_dir, "gears", "0seen", "custom_split_0seen.pkl"), 'wb') as fp:
    pickle.dump(custom_split, fp)