## Add our splits to lincs_adata.h5ad

In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
from collections import Counter
lincs_adata = sc.read('data/lincs_adata.h5ad')
my_split = pd.read_pickle('data/Lincs_mysplit.pkl')
lincs_adata.obs['dose_val_4f'] = round(lincs_adata.obs.dose,4)

In [None]:
lincs_adata.obs['my_split'] = my_split.copy()
lincs_adata.obs['my_split'] = lincs_adata.obs['my_split'].apply(lambda x: 'valid' if x == 'test' else x)
lincs_adata.obs['Both_unseen'] = lincs_adata.obs['my_split'].apply(lambda x: 'test' if x == 'val_both_unseen' else x)
lincs_adata.obs['Drug_unseen'] = lincs_adata.obs['my_split'].apply(lambda x: 'test' if x == 'val_drug_unseen' else x)
lincs_adata.obs['Cell_line_unseen'] = lincs_adata.obs['my_split'].apply(lambda x: 'test' if x == 'val_cell_line_unseen' else x)

In [None]:
Counter(lincs_adata.obs['my_split'])

In [None]:
lincs_adata.write_h5ad('data/lincs_adata.h5ad')

## Calculate text embeddings for drugs

In [None]:
all_smiles = list(lincs_adata.obs['SMILES'].unique())

In [None]:
# Load pretrained MolT5
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("laituan245/molt5-large-smiles2caption", model_max_length=512)
model = T5ForConditionalGeneration.from_pretrained('laituan245/molt5-large-smiles2caption')

In [None]:
# Generate SMILES caption using MolT5
from tqdm import tqdm
import torch
smiles_caption = {}
device=torch.device('cuda:0')
model=model.to(device)
with torch.no_grad():
    for smile in tqdm(all_smiles):
        input_ids = tokenizer(smile, return_tensors="pt").input_ids.to(device)
        outputs = model.generate(input_ids, num_beams=5, max_length=512)
        smiles_caption[smile]=tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
# Load pretrained BioLinkBERT
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained('michiyasunaga/BioLinkBERT-large')
model = AutoModel.from_pretrained('michiyasunaga/BioLinkBERT-large')

In [None]:
# Generate text embedding using BioLinkBERT
smiles_text_emb = {}
device=torch.device('cpu')
model=model.to(device)
with torch.no_grad():
    for k,v in tqdm(smiles_caption.items()):
        inputs = tokenizer(v, return_tensors="pt").to(device)
        outputs = model(**inputs)
        smiles_text_emb[k]=outputs.last_hidden_state.detach().squeeze(0).cpu()
torch.save(smiles_text_emb, 'data/pert_smiles_emb.pkl')

In [None]:
# dosage prompt
dose_val = lincs_adata.obs['dose_val_4f'].unique()
dosage_prompt = {}
for i in dose_val:
    dosage_prompt[i]='The dosage is '+i.astype(str)+' micromoles.'

In [None]:
dosage_prompt_emb = {}
with torch.no_grad():
    for k,v in dosage_prompt.items():
        inputs = tokenizer(v, return_tensors="pt")
        outputs = model(**inputs)
        dosage_prompt_emb[k]=outputs.last_hidden_state.detach().squeeze(0).cpu()
torch.save(dosage_prompt_emb, 'data/dosage_prompt_emb_lincs.pkl')