# Load a pretrained model for multi-component inferencing
* Dataset: 100 test compounds from chemprop
* Model: general multicomponent regressor. Model is loaded from checkpoint files provided by Chemprop. These checkpoint files can be logged to mlflow and registered and reloaded accordingly too.

In [0]:
%pip install chemprop rdkit-pypi

In [0]:
import numpy as np
import pandas as pd
import torch
from lightning import pytorch as pl
from pathlib import Path

In [0]:
from chemprop import data, featurizers
from chemprop.models import multi

In [0]:
# TODO: put into config dict when availab;e
chemprop_dir = f"{Path.cwd()}"
# load model (from chemprop as .ckpt) for inferencing
model_name = "example_model_v2_regression_mol+mol.ckpt"
# data (csv) for inferencing
data_path = f"{chemprop_dir}/data/mol+mol.csv"

# Multicomponent suggests multiple columns with SMILES
smiles_columns = ['smiles', 'solvent']
chemprop_dir

#### Load appropriate model

In [0]:
checkpoint_path = f"{chemprop_dir}/models/{model_name}"
mcmpnn = multi.MulticomponentMPNN.load_from_checkpoint(checkpoint_path)
mcmpnn

#### Load appropriate dataset
Predict solubility of compound in solvent

In [0]:
df_test = pd.read_csv(data_path)
df_test

In [0]:
# Load SMILES
smiss = df_test[smiles_columns].values
smiss[:5]

In [0]:
# SMILES -> MoleculeDatapoint
n_componenets = len(smiles_columns)
test_datapointss = [[data.MoleculeDatapoint.from_smi(smi) for smi in smiss[:, i]] for i in range(n_componenets)]

# Featurization: MoleculeDatapoint -> graph descriptors (all in MoleculeDataset)
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
test_dsets = [data.MoleculeDataset(test_datapoints, featurizer) for test_datapoints in test_datapointss]

# test_dsets contains 2 datasets (compound and solvent)
# so they need to be converted into a MulticomponentDataset
# and then converted to a dataloader for torch
test_mcdset = data.MulticomponentDataset(test_dsets)
test_loader = data.build_dataloader(test_mcdset, shuffle=False)

In [0]:
with torch.inference_mode():
    trainer = pl.Trainer(
        logger=False,
        enable_progress_bar=True,
        accelerator="auto",
        devices=1
    )
    test_preds = trainer.predict(mcmpnn, test_loader)

In [0]:
test_preds

In [0]:
df_test['pred'] = np.concatenate(test_preds, axis=0)
df = spark.createDataFrame(df_test)
display(df)

Databricks visualization. Run in Databricks to view.

Calibration may be needed to align predicted values with actual values. For the common use case of virtual screening to identify the top k molecules, this calibration step may not be necessary.