# Load a pretrained model for inferencing
* Dataset: Drugbank
* Model: trained on ClinTox in [NB](https://adb-830292400663869.9.azuredatabricks.net/editor/notebooks/2340616571585910?o=830292400663869). Model is loaded from mlflow

In [0]:
%pip install chemprop rdkit-pypi mlflow
dbutils.library.restartPython()

In [0]:
%pip freeze

In [0]:
import numpy as np
import pandas as pd
from pathlib import Path
import mlflow.pytorch
import torch
from lightning.pytorch import Trainer

from chemprop import data, featurizers, models
from chemprop import nn

In [0]:
# TODO: put into a config dict when available
catalog = "genesis_workbench"
schema = "dev_chem"
chemprop_dir = f"{Path.cwd()}"
# Load registered model for inferencing
# if legacy WS
registered_model = "clintox"
# if UC
#registered_model = f"{catalog}.{schema}.clintox"
model_version = 2 #or set to latest
# Load data for inferencing
data_table = f"{catalog}.{schema}.drugbank"
smiles_column = "smiles"
target_column = "ClinTox"

chemprop_dir

#### Load appropriate dataset
Predict clintox of compounds in DrugBank

In [0]:
df = spark.table(data_table)
display(df)

In [0]:
# Load from SMILES
smis = df.select(smiles_column).toPandas()[smiles_column].values
smis

In [0]:
# SMILES -> MoleculeDatapoint
test_data = [data.MoleculeDatapoint.from_smi(smi) for smi in smis]

# Featurization: MoleculeDatapoint -> graph descriptors (all in MoleculeDataset)
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
test_dset = data.MoleculeDataset(test_data, featurizer=featurizer)

# Convert into torch dataloader
test_loader = data.build_dataloader(test_dset, shuffle=False)

#### Load appropriate model from mlflow

In [0]:
# run_id = run.info.run_id
# run_id = "527482b86f7242ceac6dfea52c3a76ab"
# model_uri = f"runs:/{run_id}/model"
model_uri = f"models:/{registered_model}/{model_version}"
model = mlflow.pytorch.load_model(model_uri)

In [0]:
with torch.inference_mode():
    trainer = Trainer(
        logger=None,
        enable_progress_bar=True,
        accelerator="auto",
        devices=1
    )
    test_preds = trainer.predict(model, test_loader)

In [0]:
test_preds

In [0]:
df_test = df.select(['name', 'id', smiles_column, target_column]).toPandas()
df_test['pred'] = np.concatenate(test_preds, axis=0)
df_test['pred_round'] = df_test['pred'].round(0)
df_test['ClinTox_round'] = df_test['ClinTox'].round(0)
df_test = spark.createDataFrame(df_test)
display(df_test)

Databricks visualization. Run in Databricks to view.