In [1]:
import sys
import os

import pandas as pd
import scanpy as sc
import numpy as np
import warnings

import torch
import torch.nn.functional as F
import torch.nn as nn

from torch.utils.data import DataLoader, TensorDataset
from torch.nn import DataParallel
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context
import anndata
import seaborn as sns
import matplotlib.font_manager
from matplotlib import rcParams

font_list = []
fpaths = matplotlib.font_manager.findSystemFonts()
for i in fpaths:
    try:
        f = matplotlib.font_manager.get_font(i)
        font_list.append(f.family_name)
    except RuntimeError:
        pass

font_list = set(font_list)
plot_font = 'Helvetica' if 'Helvetica' in font_list else 'FreeSans'
rcParams['font.family'] = plot_font
rcParams.update({'font.size': 10})
rcParams.update({'figure.dpi': 300})
rcParams.update({'figure.figsize': (3,3)})
rcParams.update({'savefig.dpi': 500})
warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [2]:
adata = sc.read_h5ad('../Squidiff_project/PRnet/dataset/Sci_Plex.h5ad')

In [3]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=200)
# Filter the dataset to keep only these genes
adata = adata[:, adata.var.highly_variable]
adata_control = adata[np.array(adata.obs['paired_control_index'].values)].copy()

In [4]:
adata_raw = sc.read_h5ad('../Squidiff_project/PRnet/dataset/Sci_Plex.h5ad')


In [5]:
ind_all =[]
for i in adata.var_names:
    ind_all.append(np.where(adata_raw.var_names==i)[0][0])

In [6]:
np.save('index_all.npy',ind_all)

In [10]:
adata.obs['cell_type']

0         A549
1         MCF7
2         MCF7
3         K562
4         K562
          ... 
290883    A549
290884    A549
290885    MCF7
290886    MCF7
290887    MCF7
Name: cell_type, Length: 290888, dtype: category
Categories (3, object): ['A549', 'K562', 'MCF7']

In [12]:
adata.obs['cell_split_0'] = 'train'
adata.obs[adata.obs['cell_type']=='A549']['cell_split_0'] = 'test'

adata.obs['cell_split_1'] = 'train'
adata.obs[adata.obs['cell_type']=='K562']['cell_split_1'] = 'test'

adata.obs['cell_split_2'] = 'train'
adata.obs[adata.obs['cell_type']=='MCF7']['cell_split_2'] = 'test'

In [13]:
adata.write('../Squidiff_project/PRnet/dataset/Sci_Plex_new.h5ad')

In [14]:
adata_cut = adata[adata.obs['cell_split_0']=='train'].copy()
sci_plex_train= anndata.AnnData(adata_cut.X)
sci_plex_train.obs_names=adata_cut.obs_names
sci_plex_train.var_names=adata_cut.var_names
sci_plex_train.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_train.obs['dose']=adata_cut.obs['dose']
sci_plex_train.obs['SMILES']=adata_cut.obs['SMILES']

sci_plex_train.write('datasets/sci_plex_train_cell_split_0.h5ad')
sci_plex_train_control = adata_control[adata.obs['cell_split_0']=='train'].copy()
sci_plex_train_control.write('datasets/sci_plex_train_cell_split_0_control.h5ad')

In [15]:
adata_cut = adata[adata.obs['cell_split_1']=='train'].copy()
sci_plex_train= anndata.AnnData(adata_cut.X)
sci_plex_train.obs_names=adata_cut.obs_names
sci_plex_train.var_names=adata_cut.var_names
sci_plex_train.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_train.obs['dose']=adata_cut.obs['dose']
sci_plex_train.obs['SMILES']=adata_cut.obs['SMILES']

sci_plex_train.write('datasets/sci_plex_train_cell_split_1.h5ad')
sci_plex_train_control = adata_control[adata.obs['cell_split_1']=='train'].copy()
sci_plex_train_control.write('datasets/sci_plex_train_cell_split_1_control.h5ad')

In [16]:
adata_cut = adata[adata.obs['cell_split_2']=='train'].copy()
sci_plex_train= anndata.AnnData(adata_cut.X)
sci_plex_train.obs_names=adata_cut.obs_names
sci_plex_train.var_names=adata_cut.var_names
sci_plex_train.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_train.obs['dose']=adata_cut.obs['dose']
sci_plex_train.obs['SMILES']=adata_cut.obs['SMILES']

sci_plex_train.write('datasets/sci_plex_train_cell_split_2.h5ad')
sci_plex_train_control = adata_control[adata.obs['cell_split_2']=='train'].copy()
sci_plex_train_control.write('datasets/sci_plex_train_cell_split_2_control.h5ad')

In [16]:
adata_cut = adata[adata.obs['random_split_0']=='train'].copy()
sci_plex_train= anndata.AnnData(adata_cut.X)
sci_plex_train.obs_names=adata_cut.obs_names
sci_plex_train.var_names=adata_cut.var_names
sci_plex_train.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_train.obs['dose']=adata_cut.obs['dose']
sci_plex_train.obs['SMILES']=adata_cut.obs['SMILES']

sci_plex_train.write('datasets/sci_plex_train_random_split_0.h5ad')
sci_plex_train_control = adata_control[adata.obs['random_split_0']=='train'].copy()
sci_plex_train_control.write('datasets/sci_plex_train_random_split_0_control.h5ad')

In [17]:
adata_cut = adata[adata.obs['random_split_1']=='train'].copy()
sci_plex_train= anndata.AnnData(adata_cut.X)
sci_plex_train.obs_names=adata_cut.obs_names
sci_plex_train.var_names=adata_cut.var_names
sci_plex_train.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_train.obs['dose']=adata_cut.obs['dose']
sci_plex_train.obs['SMILES']=adata_cut.obs['SMILES']

sci_plex_train.write('datasets/sci_plex_train_random_split_1.h5ad')
sci_plex_train_control = adata_control[adata.obs['random_split_1']=='train'].copy()
sci_plex_train_control.write('datasets/sci_plex_train_random_split_1_control.h5ad')

In [27]:
adata_cut = adata[adata.obs['random_split_2']=='train'].copy()
sci_plex_train= anndata.AnnData(adata_cut.X)
sci_plex_train.obs_names=adata_cut.obs_names
sci_plex_train.var_names=adata_cut.var_names
sci_plex_train.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_train.obs['dose']=adata_cut.obs['dose']
sci_plex_train.obs['SMILES']=adata_cut.obs['SMILES']

sci_plex_train.write('datasets/sci_plex_train_random_split_2.h5ad')
sci_plex_train_control = adata_control[adata.obs['random_split_2']=='train'].copy()
sci_plex_train_control.write('datasets/sci_plex_train_random_split_2_control.h5ad')

In [24]:
adata_cut = adata[adata.obs['random_split_4']=='train'].copy()
sci_plex_train= anndata.AnnData(adata_cut.X)
sci_plex_train.obs_names=adata_cut.obs_names
sci_plex_train.var_names=adata_cut.var_names
sci_plex_train.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_train.obs['dose']=adata_cut.obs['dose']
sci_plex_train.obs['SMILES']=adata_cut.obs['SMILES']

sci_plex_train.write('datasets/sci_plex_train_random_split_4.h5ad')
sci_plex_train_control = adata_control[adata.obs['random_split_4']=='train'].copy()
sci_plex_train_control.write('datasets/sci_plex_train_random_split_4_control.h5ad')

In [23]:
sci_plex_train

AnnData object with n_obs × n_vars = 170654 × 200
    obs: 'Group', 'dose', 'SMILES'

In [19]:
adata_cut = adata[adata.obs['drug_split_0']=='train'].copy()
sci_plex_train= anndata.AnnData(adata_cut.X)
sci_plex_train.obs_names=adata_cut.obs_names
sci_plex_train.var_names=adata_cut.var_names
sci_plex_train.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_train.obs['dose']=adata_cut.obs['dose']
sci_plex_train.obs['SMILES']=adata_cut.obs['SMILES']

sci_plex_train.write('datasets/sci_plex_train_drug_split_0.h5ad')
sci_plex_train_control = adata_control[adata.obs['drug_split_0']=='train'].copy()
sci_plex_train_control.write('datasets/sci_plex_train_drug_split_0_control.h5ad')

In [20]:
adata_cut = adata[adata.obs['drug_split_1']=='train'].copy()
sci_plex_train= anndata.AnnData(adata_cut.X)
sci_plex_train.obs_names=adata_cut.obs_names
sci_plex_train.var_names=adata_cut.var_names
sci_plex_train.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_train.obs['dose']=adata_cut.obs['dose']
sci_plex_train.obs['SMILES']=adata_cut.obs['SMILES']

sci_plex_train.write('datasets/sci_plex_train_drug_split_1.h5ad')
sci_plex_train_control = adata_control[adata.obs['drug_split_1']=='train'].copy()
sci_plex_train_control.write('datasets/sci_plex_train_drug_split_1_control.h5ad')

In [21]:
adata_cut = adata[adata.obs['drug_split_2']=='train'].copy()
sci_plex_train= anndata.AnnData(adata_cut.X)
sci_plex_train.obs_names=adata_cut.obs_names
sci_plex_train.var_names=adata_cut.var_names
sci_plex_train.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_train.obs['dose']=adata_cut.obs['dose']
sci_plex_train.obs['SMILES']=adata_cut.obs['SMILES']

sci_plex_train.write('datasets/sci_plex_train_drug_split_2.h5ad')
sci_plex_train_control = adata_control[adata.obs['drug_split_2']=='train'].copy()
sci_plex_train_control.write('datasets/sci_plex_train_drug_split_2_control.h5ad')

In [22]:
adata_cut = adata[adata.obs['drug_split_0']=='test'].copy()
sci_plex_test= anndata.AnnData(adata_cut.X)
sci_plex_test.obs_names=adata_cut.obs_names
sci_plex_test.var_names=adata_cut.var_names
sci_plex_test.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_test.obs['dose']=adata_cut.obs['dose']
sci_plex_test.obs['SMILES']=adata_cut.obs['SMILES']
sci_plex_test.write('datasets/sci_plex_test_drug_split_0.h5ad')
sci_plex_test_control = adata_control[adata.obs['drug_split_0']=='test'].copy()
sci_plex_test_control.write('datasets/sci_plex_test_drug_split_0_control.h5ad')

In [23]:
adata_cut = adata[adata.obs['drug_split_1']=='test'].copy()
sci_plex_test= anndata.AnnData(adata_cut.X)
sci_plex_test.obs_names=adata_cut.obs_names
sci_plex_test.var_names=adata_cut.var_names
sci_plex_test.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_test.obs['dose']=adata_cut.obs['dose']
sci_plex_test.obs['SMILES']=adata_cut.obs['SMILES']
sci_plex_test.write('datasets/sci_plex_test_drug_split_1.h5ad')
sci_plex_test_control = adata_control[adata.obs['drug_split_1']=='test'].copy()
sci_plex_test_control.write('datasets/sci_plex_test_drug_split_1_control.h5ad')

In [24]:
adata_cut = adata[adata.obs['drug_split_2']=='test'].copy()
sci_plex_test= anndata.AnnData(adata_cut.X)
sci_plex_test.obs_names=adata_cut.obs_names
sci_plex_test.var_names=adata_cut.var_names
sci_plex_test.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_test.obs['dose']=adata_cut.obs['dose']
sci_plex_test.obs['SMILES']=adata_cut.obs['SMILES']
sci_plex_test.write('datasets/sci_plex_test_drug_split_2.h5ad')
sci_plex_test_control = adata_control[adata.obs['drug_split_2']=='test'].copy()
sci_plex_test_control.write('datasets/sci_plex_test_drug_split_2_control.h5ad')

In [25]:
adata_cut = adata[adata.obs['random_split_0']=='test'].copy()
sci_plex_test= anndata.AnnData(adata_cut.X)
sci_plex_test.obs_names=adata_cut.obs_names
sci_plex_test.var_names=adata_cut.var_names
sci_plex_test.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_test.obs['dose']=adata_cut.obs['dose']
sci_plex_test.obs['SMILES']=adata_cut.obs['SMILES']
sci_plex_test.write('datasets/sci_plex_test_random_split_0.h5ad')
sci_plex_test_control = adata_control[adata.obs['random_split_0']=='test'].copy()
sci_plex_test_control.write('datasets/sci_plex_test_random_split_0_control.h5ad')

In [26]:
adata_cut = adata[adata.obs['random_split_2']=='test'].copy()
sci_plex_test= anndata.AnnData(adata_cut.X)
sci_plex_test.obs_names=adata_cut.obs_names
sci_plex_test.var_names=adata_cut.var_names
sci_plex_test.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_test.obs['dose']=adata_cut.obs['dose']
sci_plex_test.obs['SMILES']=adata_cut.obs['SMILES']
sci_plex_test.write('datasets/sci_plex_test_random_split_2.h5ad')
sci_plex_test_control = adata_control[adata.obs['random_split_2']=='test'].copy()
sci_plex_test_control.write('datasets/sci_plex_test_random_split_2_control.h5ad')

In [25]:
adata_cut = adata[adata.obs['random_split_4']=='test'].copy()
sci_plex_test= anndata.AnnData(adata_cut.X)
sci_plex_test.obs_names=adata_cut.obs_names
sci_plex_test.var_names=adata_cut.var_names
sci_plex_test.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_test.obs['dose']=adata_cut.obs['dose']
sci_plex_test.obs['SMILES']=adata_cut.obs['SMILES']
sci_plex_test.write('datasets/sci_plex_test_random_split_4.h5ad')
sci_plex_test_control = adata_control[adata.obs['random_split_4']=='test'].copy()
sci_plex_test_control.write('datasets/sci_plex_test_random_split_4_control.h5ad')

In [26]:
adata_cut = adata[adata.obs['random_split_1']=='test'].copy()
sci_plex_test= anndata.AnnData(adata_cut.X)
sci_plex_test.obs_names=adata_cut.obs_names
sci_plex_test.var_names=adata_cut.var_names
sci_plex_test.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_test.obs['dose']=adata_cut.obs['dose']
sci_plex_test.obs['SMILES']=adata_cut.obs['SMILES']
sci_plex_test.write('datasets/sci_plex_test_random_split_1.h5ad')
sci_plex_test_control = adata_control[adata.obs['random_split_1']=='test'].copy()
sci_plex_test_control.write('datasets/sci_plex_test_random_split_1_control.h5ad')

In [None]:
adata_cut = adata[adata.obs['random_split_4']=='test'].copy()
sci_plex_test= anndata.AnnData(adata_cut.X)
sci_plex_test.obs_names=adata_cut.obs_names
sci_plex_test.var_names=adata_cut.var_names
sci_plex_test.obs['Group']=adata_cut.obs['cov_drug_dose_name']
sci_plex_test.obs['dose']=adata_cut.obs['dose']
sci_plex_test.obs['SMILES']=adata_cut.obs['SMILES']
sci_plex_test = sci_plex_test[:5,:]
sci_plex_test.write('datasets/sci_plex_test_random_split_4.h5ad')
sci_plex_test_control = adata_control[adata.obs['random_split_4']=='test'].copy()
sci_plex_test_control = sci_plex_test_control[:5,:]
sci_plex_test_control.write('datasets/sci_plex_test_random_split_4_control.h5ad')