# Multitask training across 10 continuous endpoints in TDC
[Ref: ADMET-AI](https://academic.oup.com/bioinformatics/article/40/7/btae416/7698030)

In [0]:
%pip install chemprop rdkit-pypi
dbutils.library.restartPython()

In [0]:
import numpy as np
import pandas as pd
import torch
from lightning.pytorch import Trainer
from pathlib import Path
import mlflow.pytorch

from chemprop import data, models, nn

In [0]:
%pip freeze

In [0]:
# TODO: put into a config dict when available
chemprop_dir = f"{Path.cwd()}"
# data for multi-task training
data_table = "genesis_workbench.dev_chem.admetai_continuous_endpoints"
registered_model = "admetai_multitask_reg"
smiles_column = 'smiles'
chemprop_dir

In [0]:
# TODO: register model to Workspace instead of UC until graph signatures are allowed 
#mlflow.set_registry_uri("databricks-uc")
#registered_model = "genesis_workbench.dev_chem.clintox"
mlflow.set_registry_uri("databricks")

#### Load appropriate dataset

In [0]:
df = spark.table(data_table)
display(df)

In [0]:
# Load SMILES and endpoints
target_columns = [i for i in df.columns if i not in smiles_column]
smis = df.select(smiles_column).toPandas()[smiles_column].values
ys = df.select(target_columns).toPandas().values
smis, ys

In [0]:
target_columns

In [0]:
# SMILES -> MoleculeDatapoint
datapoints = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

#### Split into training, test and valiation sets

In [0]:
split_indices = data.make_split_indices(datapoints)
train_data, val_data, test_data = data.split_data_by_indices(datapoints, *split_indices)

train_dset = data.MoleculeDataset(train_data[0])
val_dset = data.MoleculeDataset(val_data[0])
test_dset = data.MoleculeDataset(test_data[0])
len(train_data[0]), len(val_data[0]), len(test_data[0])

#### Rescale and normalize

In [0]:
output_scaler = train_dset.normalize_targets()
val_dset.normalize_targets(output_scaler)

In [0]:
# Convert into torch dataloader
train_loader = data.build_dataloader(train_dset)
val_loader = data.build_dataloader(val_dset)
test_loader = data.build_dataloader(test_dset)

## Define model architecture

In [0]:
# from chemprop import data, models, nn
output_transform = nn.transforms.UnscaleTransform.from_standard_scaler(output_scaler)
ffn = nn.RegressionFFN(n_tasks = len(target_columns), output_transform=output_transform)

metric_list = [nn.metrics.R2Score(),
               nn.metrics.RMSE()]

mpnn = models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), ffn, 
                   batch_norm=True, metrics=metric_list)

In [0]:
trainer = Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    max_epochs=20
    )

## Log model to mlflow
Model is registered to legacy Workspace as mlflow currently doesn't support graph input signature To register model to UC, an input signature is required and it must be an array/df/dict/json. However, the current input is a BatchMolGraph

In [0]:
mlflow.pytorch.autolog(registered_model_name=registered_model)
with mlflow.start_run() as run:
    trainer.fit(mpnn, train_loader, val_loader)

In [0]:
# Get test statistics
test_stats = Trainer(logger=False).test(mpnn, test_loader)
test_stats

## Inferencing with model in memory

In [0]:
test_preds = Trainer(logger=False).predict(mpnn, test_loader)
test_preds

## Load model from mlflow for inference

In [0]:
run_id = run.info.run_id
#run_id = "25239d9b9a694612a6286265f7e3444a"
model_uri = f"runs:/{run_id}/model"

model = mlflow.pytorch.load_model(model_uri)
run_id

In [0]:
# model(next(iter(val_loader)).bmg)
test_preds_reloaded = Trainer(logger=False).predict(model, test_loader)
test_preds_reloaded

In [0]:
test_indices = split_indices[-1]
df_test = df.toPandas().iloc[test_indices[0]]
y_test = pd.DataFrame(torch.concat(test_preds_reloaded, axis=0),
                      columns=[f"{i}_pred" for i in target_columns])
y_test.index = df_test.index
df_test = pd.concat([df_test, y_test], axis=1)
df_test.head()

In [0]:
df_test.isnull().sum()

In [0]:
sdf_test = spark.createDataFrame(df_test)
display(sdf_test)

Databricks visualization. Run in Databricks to view.