# Train a Chemprop (MPNN) classifier architecture on ClinTox (single task)
[Ref: ADMET-AI](https://academic.oup.com/bioinformatics/article/40/7/btae416/7698030)

In [0]:
%pip install chemprop rdkit-pypi
dbutils.library.restartPython()

In [0]:
%pip freeze

In [0]:
import pandas as pd
import numpy as np
import os
from pathlib import Path
import torch
from lightning.pytorch import Trainer
import mlflow.pytorch

from chemprop import data, featurizers, models, nn

In [0]:
# TODO: put into a config dict when available
catalog = "genesis_workbench"
schema = "dev_chem"
chemprop_dir = f"{Path.cwd()}"
# data for training
data_table = f"{catalog}.{schema}.clintox"
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_column = 'ClinTox' # list of names of the columns containing targets
# model to be registered after training
registered_model = "clintox"

In [0]:
# Infer username, notebook_name for setting experiment name
username = dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get()
notebook_name = os.path.basename(dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get())
username, notebook_name

In [0]:
# TODO: register model to Workspace until graph signatures are allowed 
#mlflow.set_registry_uri("databricks-uc")
#registered_model = f"{catalog}.{schema}.clintox"
mlflow.set_registry_uri("databricks")

#### Load appropriate dataset
Predict solubility of compound in solvent

In [0]:
df = spark.table(data_table)
display(df)

In [0]:
# Load from SMILES
smis = df.select(smiles_column).toPandas()[smiles_column].values
ys = df.select(target_column).toPandas().values
smis, ys

In [0]:
# SMILES -> MoleculeDatapoint
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
all_data

#### Split data into train, test, validation sets

In [0]:
list(data.SplitType.keys())

In [0]:
mols = [d.mol for d in all_data]  # RDkit Mol objects are used for structure based splits
mols

In [0]:
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)
len(train_data[0]), len(test_data[0]), len(val_data[0])

In [0]:
# Featurization: MoleculeDatapoint -> graph descriptors (all in MoleculeDataset)
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.MoleculeDataset(train_data[0], featurizer)
val_dset = data.MoleculeDataset(val_data[0], featurizer)
test_dset = data.MoleculeDataset(test_data[0], featurizer)

In [0]:
# Convert into torch dataloader
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

## Model architecture

In [0]:
mp = nn.BondMessagePassing()
agg = nn.MeanAggregation() #see options in nn.agg.AggregationRegistry
ffn = nn.BinaryClassificationFFN() #see options in nn.PredictorRegistry

In [0]:
metric_list = [nn.metrics.BinaryAUROC(),
               nn.metrics.BinaryAUPRC(),
               nn.metrics.BinaryAccuracy(),
               nn.metrics.BinaryF1Score()]

In [0]:
mpnn = models.MPNN(mp, agg, ffn, batch_norm=True, metrics=metric_list)
mpnn

In [0]:
# Check that Y is a batch_size x 1 tensor
next(iter(val_loader)).Y.shape

In [0]:
trainer = Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20
)

## Log model to mlflow
Model is registered to legacy Workspace as mlflow currently doesn't support graph input signature
To register model to UC, an input signature is required and it must be an array/df/dict/json. However, the current input is a [BatchMolGraph](https://github.com/chemprop/chemprop/blob/f8774bd92174f97030e5ba25eb971e33f45cb96b/chemprop/data/collate.py#L13)

In [0]:
mlflow.pytorch.autolog(registered_model_name=registered_model)
with mlflow.start_run() as run:
    trainer.fit(mpnn, train_loader, val_loader)

In [0]:
mlflow.end_run()

In [0]:
# Optional: save model to UC Volume
mlflow.artifacts.download_artifacts(run_id=run.info.run_id,
                                    artifact_path="model/data/model.pth", 
                                    dst_path="/Volumes/genesis_workbench/dev_chem/models/clintox")

In [0]:
# If convert to pyfunc with dict input signature
# Does not work
# def make_dict(bmg):
#     body_lines = ','.join(f"'{f}':" + (f'str(self.{f})' if f == 'message_id'
#                                        else f'self.{f}') for f in bmg.__slots__)
#     # Compute the text of the entire function.
#     txt = f'def dict(self):\n return {{{body_lines}}}'
#     ns = {}
#     exec(txt, locals(), ns)
#     _dict_fn = bmg.__class__.dict = ns['dict']
#     return _dict_fn(bmg)

# bmg = next(iter(val_loader)).bmg
# bmg_dict = make_dict(bmg)


# class ChemPropModel(mlflow.pyfunc.PythonModel):
#     def load_context(self, context):
#          self.model = torch.load(context.artifacts["model_path"])

#     def format_input(self, bmg_dict):
#         #dict to bmg
#         from dataclasses import InitVar
#         from typing import Sequence
#         from chemprop.data.molgraph import MolGraph
#         from chemprop.data.collate import BatchMolGraph
#         molgraph = MolGraph(**{k:v for k,v in bmg_dict.items() if k in ['V', 'E', 'edge_index', 'rev_edge_index']})
#         bmg = BatchMolGraph(mgs=molgraph)
#         return bmg
        
#     def predict(self, context, dict_input):
#         # input must be a dict with a key 'input'
#         # Option 1 if using DataLoader
#         # results = Trainer(logger=False).predict(self.model, data_loader)
#         # Option 2 if using a single batch
#         bmg = next(iter(data_loader)).bmg
#         results = self.model(bmg)
#         return results
    
# with mlflow.start_run() as run:
#     mlflow.pyfunc.log_model(
#         python_model=ChemPropModel(),
#         input_example=bmg_dict,
#         artifact_path="model",
#         artifacts={"model_path": 'data/model.pth'},
#     )

In [0]:
# Get test statistics
test_stats = Trainer(logger=False).test(mpnn, test_loader)
test_stats

## Inferencing with model in memory

In [0]:
test_preds = Trainer(logger=False).predict(mpnn, test_loader)

## Load model from mlflow for inference

In [0]:
run_id = run.info.run_id
#run_id = "527482b86f7242ceac6dfea52c3a76ab"
model_uri = f"runs:/{run_id}/model"

model = mlflow.pytorch.load_model(model_uri)

In [0]:
# model(next(iter(val_loader)).bmg)
test_preds_reloaded = Trainer(logger=False).predict(model, test_loader)
test_preds_reloaded

In [0]:
# Check if equivalent
[i for i,j in zip(test_preds, test_preds_reloaded) if max(i-j)>1**-10]

In [0]:
df_test = df.toPandas().iloc[test_indices[0]]
df_test['pred'] = np.concatenate(test_preds_reloaded, axis=0)
df_test['pred_rounded'] = np.round(df_test['pred'])
df_test

In [0]:
sdf_test = spark.createDataFrame(df_test)
display(sdf_test)

Databricks visualization. Run in Databricks to view.