<a href="https://colab.research.google.com/github/thbuerg/MetabolomicsCommonDiseases/blob/main/analysis/examples/MetabolomicsInference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Inference on the Metabolomic State Model
This notebook allows you to run inference on your own data using the Metabolomic State Model (MSM). Please note: The data must be NMR Metabolomics data from [Nightingale Health](https://nightingalehealth.com). Data must be provided as `.csv` file with one sample per row and one metabolite measure per column. The order of the metabolite columns should follow the [example data](https://https://github.com/thbuerg/MetabolomicsCommonDiseases/metabolomicstatemodel/inference/example_data.csv).

To run inference on your data, execute each cell of this notebook, and upload your data as `.csv` file where prompted.

Happy inference!

In [3]:
#@title Install dependencies and download model weights from zenodo

#@markdown Please execute this cell by pressing the _Play_ button 
#@markdown on the left to install dependencies and download model weights
#@markdown in this Colab notebook. 

import os
import tqdm.notebook
from IPython.utils import io

model_files = [f"model_{i}.onnx" for i in range(22)]
download_dir = 'downloads/'
os.makedirs(download_dir, exist_ok=True)

TQDM_BAR_FORMAT =\
'{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'


with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:
    with io.capture_output() as captured:
      
      %shell pip install \
      onnx onnxruntime pickle5 numpy matplotlib wget scikit-learn==0.24.2

      pbar.update(25)

      import wget
      
      for f in ["metabolites_metadata.csv", "scaler_dict.p"]:
        if not os.path.isfile(f):
          wget.download(
              f"https://sandbox.zenodo.org/record/990127/files/{f}?download=1", 
              out=os.path.join(download_dir, f))
    
      pbar.update(25)
      
      for model_file in model_files:
        if not os.path.isfile(model_file):
          wget.download(
              "https://sandbox.zenodo.org/record/990127/files/"\
              +model_file+"?download=1",
               out=os.path.join(download_dir, model_file))
        pbar.update(50//len(model_files))

  0%|          | 0/100 [elapsed: 00:00 remaining: ?]

In [6]:
#@title Define necessary functions

#@markdown Please execute this cell by pressing the _Play_ button 
#@markdown on the left to install dependencies and download model weights
#@markdown in this Colab notebook. 

import pickle5 as pickle
import onnxruntime
import pandas as pd

from sklearn.preprocessing import StandardScaler


endpoints = ['M_MACE',
 'M_all_cause_dementia',
 'M_type_2_diabetes',
 'M_liver_disease',
 'M_renal_disease',
 'M_atrial_fibrillation',
 'M_heart_failure',
 'M_coronary_heart_disease',
 'M_venous_thrombosis',
 'M_cerebral_stroke',
 'M_abdominal_aortic_aneurysm',
 'M_peripheral_arterial_disease',
 'M_asthma',
 'M_chronic_obstructuve_pulmonary_disease',
 'M_lung_cancer',
 'M_non_melanoma_skin_cancer',
 'M_colon_cancer',
 'M_rectal_cancer',
 'M_prostate_cancer',
 'M_breast_cancer',
 'M_parkinsons_disease',
 'M_fractures',
 'M_cataracts',
 'M_glaucoma']

metabolite_labels = [
                     'NMR_3hydroxybutyrate',
 'NMR_acetate',
 'NMR_acetoacetate',
 'NMR_acetone',
 'NMR_alanine',
 'NMR_albumin',
 'NMR_apolipoprotein_a1',
 'NMR_apolipoprotein_b',
 'NMR_average_diameter_for_hdl_particles',
 'NMR_average_diameter_for_ldl_particles',
 'NMR_average_diameter_for_vldl_particles',
 'NMR_cholesterol_in_chylomicrons_and_extremely_large_vldl',
 'NMR_cholesterol_in_idl',
 'NMR_cholesterol_in_large_hdl',
 'NMR_cholesterol_in_large_ldl',
 'NMR_cholesterol_in_large_vldl',
 'NMR_cholesterol_in_medium_hdl',
 'NMR_cholesterol_in_medium_ldl',
 'NMR_cholesterol_in_medium_vldl',
 'NMR_cholesterol_in_small_hdl',
 'NMR_cholesterol_in_small_ldl',
 'NMR_cholesterol_in_small_vldl',
 'NMR_cholesterol_in_very_large_hdl',
 'NMR_cholesterol_in_very_large_vldl',
 'NMR_cholesterol_in_very_small_vldl',
 'NMR_cholesteryl_esters_in_chylomicrons_and_extremely_large_vldl',
 'NMR_cholesteryl_esters_in_hdl',
 'NMR_cholesteryl_esters_in_idl',
 'NMR_cholesteryl_esters_in_ldl',
 'NMR_cholesteryl_esters_in_large_hdl',
 'NMR_cholesteryl_esters_in_large_ldl',
 'NMR_cholesteryl_esters_in_large_vldl',
 'NMR_cholesteryl_esters_in_medium_hdl',
 'NMR_cholesteryl_esters_in_medium_ldl',
 'NMR_cholesteryl_esters_in_medium_vldl',
 'NMR_cholesteryl_esters_in_small_hdl',
 'NMR_cholesteryl_esters_in_small_ldl',
 'NMR_cholesteryl_esters_in_small_vldl',
 'NMR_cholesteryl_esters_in_vldl',
 'NMR_cholesteryl_esters_in_very_large_hdl',
 'NMR_cholesteryl_esters_in_very_large_vldl',
 'NMR_cholesteryl_esters_in_very_small_vldl',
 'NMR_citrate',
 'NMR_clinical_ldl_cholesterol',
 'NMR_concentration_of_chylomicrons_and_extremely_large_vldl_particles',
 'NMR_concentration_of_hdl_particles',
 'NMR_concentration_of_idl_particles',
 'NMR_concentration_of_ldl_particles',
 'NMR_concentration_of_large_hdl_particles',
 'NMR_concentration_of_large_ldl_particles',
 'NMR_concentration_of_large_vldl_particles',
 'NMR_concentration_of_medium_hdl_particles',
 'NMR_concentration_of_medium_ldl_particles',
 'NMR_concentration_of_medium_vldl_particles',
 'NMR_concentration_of_small_hdl_particles',
 'NMR_concentration_of_small_ldl_particles',
 'NMR_concentration_of_small_vldl_particles',
 'NMR_concentration_of_vldl_particles',
 'NMR_concentration_of_very_large_hdl_particles',
 'NMR_concentration_of_very_large_vldl_particles',
 'NMR_concentration_of_very_small_vldl_particles',
 'NMR_creatinine',
 'NMR_degree_of_unsaturation',
 'NMR_docosahexaenoic_acid',
 'NMR_free_cholesterol_in_chylomicrons_and_extremely_large_vldl',
 'NMR_free_cholesterol_in_hdl',
 'NMR_free_cholesterol_in_idl',
 'NMR_free_cholesterol_in_ldl',
 'NMR_free_cholesterol_in_large_hdl',
 'NMR_free_cholesterol_in_large_ldl',
 'NMR_free_cholesterol_in_large_vldl',
 'NMR_free_cholesterol_in_medium_hdl',
 'NMR_free_cholesterol_in_medium_ldl',
 'NMR_free_cholesterol_in_medium_vldl',
 'NMR_free_cholesterol_in_small_hdl',
 'NMR_free_cholesterol_in_small_ldl',
 'NMR_free_cholesterol_in_small_vldl',
 'NMR_free_cholesterol_in_vldl',
 'NMR_free_cholesterol_in_very_large_hdl',
 'NMR_free_cholesterol_in_very_large_vldl',
 'NMR_free_cholesterol_in_very_small_vldl',
 'NMR_glucose',
 'NMR_glutamine',
 'NMR_glycine',
 'NMR_glycoprotein_acetyls',
 'NMR_hdl_cholesterol',
 'NMR_histidine',
 'NMR_isoleucine',
 'NMR_ldl_cholesterol',
 'NMR_lactate',
 'NMR_leucine',
 'NMR_linoleic_acid',
 'NMR_monounsaturated_fatty_acids',
 'NMR_omega3_fatty_acids',
 'NMR_omega6_fatty_acids',
 'NMR_phenylalanine',
 'NMR_phosphatidylcholines',
 'NMR_phosphoglycerides',
 'NMR_phospholipids_in_chylomicrons_and_extremely_large_vldl',
 'NMR_phospholipids_in_hdl',
 'NMR_phospholipids_in_idl',
 'NMR_phospholipids_in_ldl',
 'NMR_phospholipids_in_large_hdl',
 'NMR_phospholipids_in_large_ldl',
 'NMR_phospholipids_in_large_vldl',
 'NMR_phospholipids_in_medium_hdl',
 'NMR_phospholipids_in_medium_ldl',
 'NMR_phospholipids_in_medium_vldl',
 'NMR_phospholipids_in_small_hdl',
 'NMR_phospholipids_in_small_ldl',
 'NMR_phospholipids_in_small_vldl',
 'NMR_phospholipids_in_vldl',
 'NMR_phospholipids_in_very_large_hdl',
 'NMR_phospholipids_in_very_large_vldl',
 'NMR_phospholipids_in_very_small_vldl',
 'NMR_polyunsaturated_fatty_acids',
 'NMR_pyruvate',
 'NMR_remnant_cholesterol_nonhdl_nonldl_cholesterol',
 'NMR_saturated_fatty_acids',
 'NMR_sphingomyelins',
 'NMR_total_cholesterol',
 'NMR_total_cholesterol_minus_hdlc',
 'NMR_total_cholines',
 'NMR_total_concentration_of_branchedchain_amino_acids_leucine_isoleucine_valine',
 'NMR_total_concentration_of_lipoprotein_particles',
 'NMR_total_esterified_cholesterol',
 'NMR_total_fatty_acids',
 'NMR_total_free_cholesterol',
 'NMR_total_lipids_in_chylomicrons_and_extremely_large_vldl',
 'NMR_total_lipids_in_hdl',
 'NMR_total_lipids_in_idl',
 'NMR_total_lipids_in_ldl',
 'NMR_total_lipids_in_large_hdl',
 'NMR_total_lipids_in_large_ldl',
 'NMR_total_lipids_in_large_vldl',
 'NMR_total_lipids_in_lipoprotein_particles',
 'NMR_total_lipids_in_medium_hdl',
 'NMR_total_lipids_in_medium_ldl',
 'NMR_total_lipids_in_medium_vldl',
 'NMR_total_lipids_in_small_hdl',
 'NMR_total_lipids_in_small_ldl',
 'NMR_total_lipids_in_small_vldl',
 'NMR_total_lipids_in_vldl',
 'NMR_total_lipids_in_very_large_hdl',
 'NMR_total_lipids_in_very_large_vldl',
 'NMR_total_lipids_in_very_small_vldl',
 'NMR_total_phospholipids_in_lipoprotein_particles',
 'NMR_total_triglycerides',
 'NMR_triglycerides_in_chylomicrons_and_extremely_large_vldl',
 'NMR_triglycerides_in_hdl',
 'NMR_triglycerides_in_idl',
 'NMR_triglycerides_in_ldl',
 'NMR_triglycerides_in_large_hdl',
 'NMR_triglycerides_in_large_ldl',
 'NMR_triglycerides_in_large_vldl',
 'NMR_triglycerides_in_medium_hdl',
 'NMR_triglycerides_in_medium_ldl',
 'NMR_triglycerides_in_medium_vldl',
 'NMR_triglycerides_in_small_hdl',
 'NMR_triglycerides_in_small_ldl',
 'NMR_triglycerides_in_small_vldl',
 'NMR_triglycerides_in_vldl',
 'NMR_triglycerides_in_very_large_hdl',
 'NMR_triglycerides_in_very_large_vldl',
 'NMR_triglycerides_in_very_small_vldl',
 'NMR_tyrosine',
 'NMR_vldl_cholesterol',
 'NMR_valine']

# load metabolite metadata and get machine readable format
metadata = pd.read_csv(
    os.path.join(download_dir, "metabolites_metadata.csv")).assign(
    metabolite_label=lambda x: "NMR_"+x.metabolite).sort_values(
      "metabolite_label")
    
metabolite_labels_mr = metadata.set_index(
    "metabolite_label").loc[metabolite_labels].machine_readable_name.to_list()


results_dir = os.makedirs('results/', exist_ok=True)

# load scalers
scaler_dict = pickle.load(
    open(os.path.join(download_dir, "scaler_dict.p"), "rb"))

# define forward pass over ensemble
def inference(x):
  predictions = []
  x = np.log1p(x)
  for i, m in enumerate(model_files):
    m = os.path.join(download_dir, m)

    # transform inputs:
    scaled = scaler_dict[i].transform(x)

    # run onnx inference
    ort_session = onnxruntime.InferenceSession(m)
    ort_inputs = {ort_session.get_inputs()[0].name: scaled}
    ort_outs = ort_session.run(None, ort_inputs)
    predictions.append(pd.DataFrame(np.concatenate(ort_outs[0::3], axis=1),
                                    columns=endpoints)\
                      .apply(np.exp)\
                      .assign(model_id=i)
                      )
  predictions = pd.concat(predictions, axis=0)
  return predictions

In [None]:
#@title Upload data for inference

#@markdown Please execute this cell by pressing the _Play_ button 
#@markdown on the left to install dependencies and download model weights
#@markdown in this Colab notebook.

#@markdown **Note**: The data must be in [the supported format](https://google.com).
from google.colab import files

input_files = [*files.upload().keys()]
input_raw_list = [pd.read_csv(f) for f in input_files]

for df in input_raw_list:
  assert df.shape[1] == len(metabolite_labels_mr)

input_raw = pd.concat(input_raw_list, axis=0)

In [7]:
#@title Run the Metabolomic State Model

#@markdown Please execute this cell by pressing the _Play_ button 
#@markdown on the left to install dependencies and download model weights
#@markdown in this Colab notebook.

#@markdown Once this cell has been executed a download with the predictions
#@markdown will start automatically.

import numpy as np

def split_given_size(arr, size):
  splits = np.split(arr, np.arange(size,len(arr),size))
  n_pad = size-splits[-1].shape[0]
  if splits[-1].shape[0] < size:
    splits[-1] = np.pad(splits[-1], pad_width=((0, n_pad),
                                        (0,0)))
  return splits, n_pad

with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:
  with io.capture_output() as captured:
    
    batches, n_padded = split_given_size(input_raw, 1024)
    pbar.update(5)
    predictions = []

    for i, b in enumerate(batches):
      preds = inference(b)
      if i+1 == len(batches):
        preds = preds.iloc[:-n_padded]
      preds.index = preds.index+i*1024
      predictions.append(preds)
      pbar.update(80//len(batches))

    predictions = pd.concat(predictions, axis=0).rename_axis('id').reset_index().melt(id_vars=["id", "model_id"], var_name="endpoint", value_name="value").set_index("id")
    means = predictions.groupby(['id', 'endpoint']).mean().assign(metabolomic_state="mean").drop('model_id', axis=1)
    stds = predictions.groupby(['id', 'endpoint']).std().assign(metabolomic_state="std").drop('model_id', axis=1)
    predictions_aggr = pd.concat([means, stds], axis=0).sort_index().reset_index()
    pbar.update(10)
    
    predictions_aggr.to_csv(os.path.join(results_dir, 'predictions.csv'))
    pbar.update(5)

print('DONE')

!zip -q -r {results_dir}.zip {results_dir}
files.download(f'{results_dir}.zip')

  0%|          | 0/100 [elapsed: 00:00 remaining: ?]

NameError: ignored

In [None]:
predictions_aggr.head()