In [1]:
import pickle
import pandas as pd
import numpy as np
import deepchem as dc

import mlflow
import sys
sys.path.insert(0, '..')
from pathlib import Path
import src

In [2]:
# artifacts such as saved model weights, pickle files, etc
model_path = '../gnn_model/'
metadata_path = str(Path('..','metadata.txt'))
artifacts = {'model_files': model_path,
            'metadata': metadata_path}


In [3]:
# Serve as an MLflow wrapper for model
class ModelWrapper(mlflow.pyfunc.PythonModel):
    
    # Load in model and all required artifacts
    # context object is provided by mlflow
    # it contains all artifacts
    def load_context(self, context):
        import deepchem as dc
        import pickle
        import src
        
        # load model
        model = dc.models.GraphConvModel(12, model_dir=context.artifacts['model_files'], )
        model.restore()
        
        self.model = model
    
    # function which takes in pandas df and returns predicted labels 
    def predict(self, context, model_input, encode=True):
        from src import dc_utils
        from src import labels as le
        ds = dc_utils.df_to_dataset(model_input)
        y_pred = self.model.predict(ds)
        if encode:
            y_pred = np.argmax(y_pred, axis=2)
            y_pred = le.inverse_transform(y_pred)
        return y_pred

In [4]:
# mlflow complains if directory already exists, so remove it before saving model to mlflow directory
!rm -rf ./docker/mlflow_root/model
src_path = Path('../src')
mlflowpath= Path('docker','mlflow_root','model')

mlflow.pyfunc.save_model(path=str(mlflowpath), python_model=ModelWrapper(), 
                         artifacts=artifacts, conda_env='docker/dc_env_docker.yml', 
                         code_path=['../src/'])

In [5]:
import numpy as np 
np.ones((3,3)).tolist()

[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]

In [6]:
# load model from mlflow directory
loaded_model = mlflow.pyfunc.load_model(str(mlflowpath))

In [7]:
tox21_tasks_2, tox21_datasets_2, transformers_2 = dc.molnet.load_tox21(featurizer='GraphConv')
train_dataset_2, valid_dataset_2, test_dataset_2 = tox21_datasets_2

In [8]:
test_dict =  src.dc_utils.dataset_to_dict(test_dataset_2)
test_dict = src.dc_utils.json_dict(test_dict)
with open('test_data_gnn.pickle', 'wb') as f:
    pickle.dump(test_dict, f)

In [9]:
## the incoming JSON data will look like this
dd = src.dc_utils.json_dict({'data': src.dc_utils.dataset_to_dict(test_dataset_2)})

In [10]:
d = src.dc_utils.json_dict_to_dict(dd['data'])

In [11]:
df = src.dc_utils.dict_to_dataframe(d['df'])

In [12]:
loaded_model.predict(df)

[['estrogen receptor alpha, LBD (ER, LBD): inactive',
  'estrogen receptor alpha, full (ER, full): inactive',
  'aromatase: inactive',
  'aryl hydrocarbon receptor (AhR): inactive',
  'androgen receptor, full (AR, full): inactive',
  'androgen receptor, LBD (AR, LBD): inactive',
  'peroxisome proliferator-activated receptor gamma (PPAR-gamma): inactive',
  'nuclear factor (erythroid-derived 2)-like 2/antioxidant responsive element (Nrf2/ARE): inactive',
  'heat shock factor response element (HSE): inactive',
  'ATAD5: inactive',
  'mitochondrial membrane potential (MMP): inactive',
  'p53: inactive'],
 ['estrogen receptor alpha, LBD (ER, LBD): active',
  'estrogen receptor alpha, full (ER, full): inactive',
  'aromatase: active',
  'aryl hydrocarbon receptor (AhR): active',
  'androgen receptor, full (AR, full): active',
  'androgen receptor, LBD (AR, LBD): active',
  'peroxisome proliferator-activated receptor gamma (PPAR-gamma): active',
  'nuclear factor (erythroid-derived 2)-like 2