In [66]:
import anndata
import scanpy as sc
import numpy as np
import os
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader

In [67]:
file_dir='/home/manu/chemCPA/chemCPA'

## Construct dataset loader

In [68]:
class GenePerturbationDataloader(Dataset):
    def __init__(self, file_dir, file_name):
        adata=sc.read_h5ad(os.path.join(file_dir, file_name))
        self.X = adata.X.A
        self.cell_line_id = torch.tensor(adata.obs['cell_line_id'])
        self.treatment_type = list(adata.obs['treatment_type'])
        self.treatment_ids = list(adata.obs['treatment_id'])
        self.treatment_dose_uM = torch.tensor(adata.obs['treatment_dose_uM'], dtype=torch.float)
        #self.treatment_time_h = list(adata.obs['treatment_time_h'])
    
    def __len__(self):
        return len(self.cell_line_id)
    
    def __getitem__(self, idx):
        return(self.X[idx,:], self.cell_line_id[idx], self.treatment_type[idx],
               self.treatment_ids[idx], self.treatment_dose_uM[idx])

In [69]:
GPD=GenePerturbationDataloader(file_dir='/home/manu/chemCPA/chemCPA', file_name='Comb.h5ad')

In [70]:
train_dataloader = DataLoader(GPD, batch_size=5, shuffle=True)

In [72]:
one_batch = next(iter(train_dataloader))
one_batch

# expression, cell_line_id, treatment_type, treatment_id, treatment_dose_uM

[tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.4922, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.6071,  ..., 0.0000, 0.6071, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.8616, 0.5209]]),
 tensor([0, 0, 0, 0, 0]),
 ('gene_knockout',
  'gene_knockout',
  'gene_knockout',
  'gene_knockout',
  'gene_knockout'),
 ('32+48', '18', '12', '49', '83'),
 tensor([nan, nan, nan, nan, nan])]

In [None]:
# I had problem with loading a list of tensors of different length directly using the data loader because
# torch wanted to collate them
treatment_id=[torch.tensor([int(a) for a in a.split('+')]) if '+' in a else torch.tensor([int(a)]) for a in treatment_id]

## Map treatments to vectors

In [73]:
# For drugs this will be SMILES-to-drug-embedding, for gene-knockouts this is the ESM embedding

In [74]:
# ToDo: Take the acutall matrices, right now I use some random embedding vectors

In [75]:
M={}
M['drug_perturbation']=torch.tensor(pd.read_pickle(f'{file_dir}/drug_perturbation_matrix.pkl').values)
M['gene_knockout']=torch.tensor(pd.read_pickle(f'{file_dir}/gene_knockout_matrix.pkl').values)

In [76]:
def map_index_to_embedding_vector(one_batch, M):
    treatment_types=one_batch[2]
    treatment_ids=one_batch[3]

    embedding_vectors=[]
    for i in range(len(treatment_types)):
        # if there are several treatments happening at the same time I sum up their embedding vectors
        embedding_vectors.append(M[treatment_types[i]][:,[int(a) for a in treatment_ids[i].split('+')]].sum(1))
        
    one_batch[3]=embedding_vectors
    return(one_batch)

In [77]:
one_batch=map_index_to_embedding_vector(one_batch, M)

# expression, cell_line_id, treatment_type, treatment_embedding, treatment_dose_uM
one_batch

[tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.4922, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.6071,  ..., 0.0000, 0.6071, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.8616, 0.5209]]),
 tensor([0, 0, 0, 0, 0]),
 ('gene_knockout',
  'gene_knockout',
  'gene_knockout',
  'gene_knockout',
  'gene_knockout'),
 [tensor([-0.6943,  2.2602,  1.1828,  2.5050,  0.2024, -4.0970, -1.9896,  0.5536,
          -0.5383, -1.0787], dtype=torch.float64),
  tensor([-0.0374, -1.1477,  0.2969, -1.6982, -0.0310,  0.7577,  1.5921,  1.1752,
           0.4533,  1.0757], dtype=torch.float64),
  tensor([-0.2641, -0.0939,  0.0409, -0.8417, -1.4903, -1.2633,  1.4893, -1.1421,
           0.9217, -0.4295], dtype=torch.float64),
  tensor([ 0.3832,  0.4902, -1.5907,  0.8304, -0.2842, -1.0488, -1.8955, -0.4898,
          -0.6559,  0.2390], dtype=torch.float64),
  tensor([ 0.5296, 

In [None]:
# This can then be used to train the model, which will need to differentiate between 
# 'drug_perturbations' and 'gene_knockout' vectors when feeding it into the MLPs and other layers