In [1]:
import anndata
import scanpy as sc
import numpy as np
import os

import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
file_dir='/home/manu/chemCPA/chemCPA'
file_name='Norman_pp.h5ad'

In [3]:
# Potentially load several anndatas here and concatenate them
# ToDo: Figure out what to do if adata.var's are different
adata=sc.read_h5ad(os.path.join(file_dir, file_name))

## Assign cell line ids

In [4]:
cell_lines=sorted(set(adata.obs['cell_line']))
D_cell_lines={cell_lines[i]:i for i in range(len(cell_lines))}
adata.obs['cell_line_id']=[D_cell_lines[a] for a in adata.obs['cell_line']]

## Assign treatment ids

In [None]:
# ToDo: Filter out drugs/genes for which I do not have a SMILES/ESM-embedding

In [5]:
# I will give gene-knockouts and durg-treatments different ID-numbering since their SMILES/ESM-embeddings will
# be different matrices

In [6]:
df0=adata.obs[adata.obs['treatment_type']=='drug_perturbation']
df1=adata.obs[adata.obs['treatment_type']=='gene_knockout']

In [7]:
# Index the treatments by alphabetical ordering, where 0 corresponds to control
def find_indexes(df):
    singularized_treatments=sorted(set([a for b in [a.split('+') if '+' in a else [a] for a in list(df['treatment'])] for a in b]))
    singularized_treatments=[a for a in singularized_treatments if a!='control']

    D_treatments={singularized_treatments[i]:str(i+1) for i in range(len(singularized_treatments))}
    D_treatments['control']='0'
    return(D_treatments)

In [8]:
D_treatments={}
D_treatments['drug_perturbation']=find_indexes(df0)
D_treatments['gene_knockout']=find_indexes(df1)

In [11]:
treatment_ids=[]
treatment_types=list(adata.obs['treatment_type'])
treatments=list(adata.obs['treatment'])
for i in range(len(adata.obs)):
    treatment_type=treatment_types[i]
    treatment=treatments[i]
    treatment_ids.append('+'.join(D_treatments[treatment_type][a] for a in treatment.split('+')))
adata.obs['treatment_id']=treatment_ids

In [12]:
adata.uns['D_treatments']=D_treatments

In [13]:
adata.obs['treatment_id']

index
AAACCTGAGAAGAAGC-1         0
AAACCTGAGGCATGTG-1        98
AAACCTGAGGCCCTTG-1     56+62
AAACCTGCACGAAGCA-1         0
AAACCTGCAGACGTAG-1     18+82
                       ...  
TTTGTCATCAGTACGT-8        36
TTTGTCATCCACTCCA-8        19
TTTGTCATCCCAACGG-8         7
TTTGTCATCCTCCTAG-8    77+102
TTTGTCATCTGGCGAC-8        63
Name: treatment_id, Length: 111255, dtype: object

In [14]:
# For now I'm only using a subsampled anndata to set everything up more efficiently
sc.pp.subsample(adata, fraction=0.01)

In [15]:
adata.write(f'{file_dir}/Comb.h5ad')

In [16]:
adata.obs

Unnamed: 0_level_0,cell_line,treatment,treatment_dose_uM,treatment_time_h,treatment_type,cell_line_id,treatment_id
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
TGACGGCAGACCTAGG-6,K562,FOXA1+FOXA3,,168.0,gene_knockout,0,35+36
AGGTCATTCATAAAGG-2,K562,CBL,,168.0,gene_knockout,0,12
CTGCCTAGTTATTCTC-8,K562,CNN1,,168.0,gene_knockout,0,23
TAAGCGTCACCATCCT-4,K562,BCORL1,,168.0,gene_knockout,0,7
CCAATCCAGTCCCACG-6,K562,CLDN6+KLF1,,168.0,gene_knockout,0,22+56
...,...,...,...,...,...,...,...
ATGTGTGAGAATTGTG-2,K562,CBL+PTPN12,,168.0,gene_knockout,0,12+77
TAGACCAGTGGCGAAT-2,K562,BAK1,,168.0,gene_knockout,0,5
GCATGTAAGAGTAATC-8,K562,control,,168.0,gene_knockout,0,0
GACCAATCACAACGCC-4,K562,BCORL1,,168.0,gene_knockout,0,7
