# Scalable and universal prediction of cellular phenotypes

This notebook demonstrates how to make predictions with Prophet with any of the checkpoints we have made available.

In [46]:
import yaml
from prophet import Prophet, set_config

We load in a config file to automatically get the file paths for the embeddings which were used, but any of the files in `embeddings` can be passed.

In [49]:
with open('config_file_finetuning.yaml', 'r') as f:
    config = set_config(yaml.safe_load(f))

In [50]:
# replace this with the path to the checkpoint you want to use
path = '/ictstr01/home/icb/yuge.ji/projects/super_rad_project/pretrained_prophet/GDSC/gene_0_TrainedOn179_genes_out_300cl_1219iv_512model_8layers_Falsesimpler_Truemask_0.0001lr_Falseexplicitphenotype_5000warmup\
_40000max_iters_Falseunbalanced_0.01wd_256bs_Trueft/gene_0_TrainedOn179_seed_110/epoch=23-step=18528.ckpt'

In [51]:
model = Prophet(
    iv_emb_path=config.genes_prior,
    cl_emb_path=config.cell_lines_prior,
    ph_emb_path=config.phenotype_prior,
    model_pth=path,
)

returning trained model!
Gene net:  Sequential(
  (0): Linear(in_features=1219, out_features=512, bias=True)
  (1): GELU(approximate='none')
  (2): Dropout(p=0.1, inplace=False)
  (3): Linear(in_features=512, out_features=512, bias=True)
)
Cell line net:  Sequential(
  (0): Linear(in_features=300, out_features=512, bias=True)
  (1): GELU(approximate='none')
  (2): Dropout(p=0.1, inplace=False)
  (3): Linear(in_features=512, out_features=512, bias=True)
)
Regressor:  Sequential(
  (0): Linear(in_features=512, out_features=512, bias=True)
  (1): GELU(approximate='none')
  (2): Dropout(p=0.2, inplace=False)
  (3): Linear(in_features=512, out_features=512, bias=True)
  (4): GELU(approximate='none')
  (5): Linear(in_features=512, out_features=1, bias=True)
)


Suppose we have some small molecules, some cell lines we would like to test them in, and we're interested in measuring their relative IC50. We can pass in lists of these inputs, and Prophet will return predictions for all combinations:

In [63]:
iv_list = [
    'oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc4=c(c=c(c=c4)i)f)=o',
    'cc(nc1=cc=cc(n(c2=o)c(c(c(n2c3cc3)=o)=c(n4c)nc5=cc=c(c=c5f)i)=c(c4=o)c)=c1)=o',
    'fc1=cc=c(c(f)=c1c(c2=cnc3=nc=c(c=c32)c4=cc=c(c=c4)cl)=o)ns(ccc)(=o)=o',
    'cs(=o)c'  # DMSO
]
cl_list = ['A375','UACC62','WM983B','MALME3M','A2058','WM793','HT144','RPMI7951','WM1799','LOXIMVI','WM2664','WM88','G361','SKMEL24','WM115', 'SKMEL2', 'SKMEL1', 'HMCB', 'MDAMB435S', 'UACC257']
ph_list = ['GDSC']

In [64]:
# predict with lists of treatments and cell lines
df = model.predict(
    target_ivs=iv_list,
    target_cls=cl_list,
    target_phs=ph_list,
    iv_col=['iv1', 'iv2'],  # pass to turn on combinatorial predictions
    num_iterations=1, save=False,
)
df

There are 1 iterations


  0%|                                                                                                                                                                                     | 0/1 [00:00<?, ?it/s]

Removing 0 such as [] from ['iv1', 'iv2']. 200 rows remaining.
Removing 0 such as [] from ['cell_line']. 200 rows remaining.


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Predicting DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.03it/s]


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:47<00:00, 47.36s/it]


Unnamed: 0,iv1,iv2,cell_line,phenotype,iv1+iv2,value,pred
0,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,A375,GDSC,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,_,0.182267
1,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,UACC62,GDSC,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,_,0.175306
2,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,WM983B,GDSC,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,_,0.326975
3,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,MALME3M,GDSC,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,_,0.292684
4,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,A2058,GDSC,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,_,0.377774
...,...,...,...,...,...,...,...
252,cs(=o)c,cs(=o)c,SKMEL2,GDSC,cs(=o)c+cs(=o)c,_,0.621489
253,cs(=o)c,cs(=o)c,SKMEL1,GDSCcomb,cs(=o)c+cs(=o)c,_,0.604686
254,cs(=o)c,cs(=o)c,HMCB,PRISM,cs(=o)c+cs(=o)c,_,0.607143
255,cs(=o)c,cs(=o)c,MDAMB435S,inhouse,cs(=o)c+cs(=o)c,_,0.582945


If we're interested in only a subset of the experimental matrix, we can also pass in a custom dataframe. (This is the recommended usage, as users understand exactly the list being predicted.)

In [66]:
# predict exact experiments with a dataframe
crossmerge = pd.MultiIndex.from_product([
    iv_list,
    cl_list,
], names=['iv1', 'cell_line'])

input_df = crossmerge.to_frame(index=False).reset_index(drop=True)
input_df['iv2'] = 'cs(=o)c'
input_df['phenotype'] = 'GDSC'
df = model.predict(input_df, num_iterations=1, save=False)
df

There are 1 iterations


  0%|                                                                                                                                                                                     | 0/1 [00:00<?, ?it/s]GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Predicting DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.89it/s]


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:21<00:00, 21.48s/it]


Unnamed: 0,iv1,cell_line,iv2,phenotype,pred
0,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,A375,cs(=o)c,GDSC,0.433839
1,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,UACC62,cs(=o)c,GDSC,0.466371
2,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,WM983B,cs(=o)c,GDSC,0.530880
3,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,MALME3M,cs(=o)c,GDSC,0.456714
4,oc1([c@h]2ncccc2)cn(c1)c(c3=c(c(f)=c(c=c3)f)nc...,A2058,cs(=o)c,GDSC,0.511629
...,...,...,...,...,...
132,cs(=o)c,SKMEL2,cs(=o)c,GDSC,0.621489
133,cs(=o)c,SKMEL1,cs(=o)c,GDSCcomb,0.604686
134,cs(=o)c,HMCB,cs(=o)c,PRISM,0.607143
135,cs(=o)c,MDAMB435S,cs(=o)c,inhouse,0.582945
