# Model Training  
### Author: Roberto Olayo-Alarcon
  
Here, we perform a random search over XGBoost parameters to train models to predict antimicrobial activity using different chemical representations

In [60]:
import os
import random
import numpy as np
import pandas as pd

                                                                                                                                                          
# Classfier
from xgboost import XGBClassifier

# 
from sklearn.metrics import roc_auc_score, f1_score, precision_recall_curve, auc
from sklearn.model_selection import ParameterGrid

from sklearn.preprocessing import OneHotEncoder

## Prepare directories

In [61]:
INPUT_DIR = "../data/01.prepare_training_data/"

OUTPUT_DIR = "../data/02.model_training"
os.makedirs(OUTPUT_DIR, exist_ok=True)

## Prepare data

In [62]:
representation_dict = {"MolE": pd.read_csv(os.path.join(INPUT_DIR, "maier_mole_representation.tsv.gz"), index_col=0, sep='\t'),
"ecfp4": pd.read_csv(os.path.join(INPUT_DIR, "maier_ecfp4_representation.tsv.gz"), index_col=0, sep='\t'),
"chemDesc": pd.read_csv(os.path.join(INPUT_DIR, "maier_chemdesc_representation.tsv.gz"), index_col=0, sep='\t')}

split_df = pd.read_csv(os.path.join(INPUT_DIR, "maier_scaffold_split.tsv.gz"), index_col="prestwick_ID", sep='\t')
screen_df = pd.read_csv(os.path.join(INPUT_DIR, "maier_screening_results.tsv.gz"), index_col="prestwick_ID", sep='\t')


## Model parameters for random search

In [63]:
XGB_PARAMS = {"nthread": 20,
                         "n_estimators":[30, 100, 300, 500, 1000],
                         "max_depth": [5, 10, 50, 100],
                         "eta":[0.3, 0.1, 0.05, 1],
                         "subsample": [0.3, 0.5, 0.8, 1.0],
                         "objective": "binary:logistic"}

## Helper Functions

In [81]:
# Select parameters randomly
def select_params(original_config):
    """
    Randomly select parameters from the provided configuration for modeling.

    This function traverses through the provided configuration dictionary and randomly selects
    values from lists or dictionaries.

    Parameters:
    - original_config (dict): The original configuration dictionary containing parameters.

    Returns:
    - model_config (dict): Configuration with randomly selected parameters.
    """

    model_config = original_config.copy()

    for key, value in model_config.items():

        if type(value) == list:
            model_config[key] = random.choice(value)
        elif type(value) == dict:
            model_config[key] = select_params(value)
        
    return model_config

# Prepare strain one-hot-encoding
def prep_ohe(categories):

    """
    Prepare one-hot encoding for strain variables.

    This function creates a one-hot encoding representation of the provided categorical variables.
    It fits a OneHotEncoder to the categories and transforms them into a pandas DataFrame.

    Parameters:
    - categories (array-like): Array-like object containing categorical variables.

    Returns:
    - cat_ohe (pandas.DataFrame): DataFrame representing the one-hot encoded categorical variables.
    """

    ohe = OneHotEncoder(sparse=False)
    ohe.fit(pd.DataFrame(categories))
    cat_ohe = pd.DataFrame(ohe.transform(pd.DataFrame(categories)),
             index=categories, columns=ohe.categories_)
    
    return cat_ohe

def get_split(data_df, y_df, splitter_df, split_strat = "split"):

    """
    Prepare data splits for training, validation, and testing.

    This function prepares the data splits for training, validation, and testing based on the given split strategy.
    It joins molecular features with taxonomic One-Hot Encoded (OHE) labels and separates the data into respective splits.

    Parameters:
    - data_df (pandas.DataFrame): DataFrame containing molecular features.
    - y_df (pandas.DataFrame): DataFrame containing labels.
    - splitter_df (pandas.DataFrame): DataFrame containing chemical IDs and split information.
    - split_strat (str, optional): Split strategy to use ("split" by default).

    Returns:
    - X_train (pandas.DataFrame): DataFrame containing features for training.
    - X_valid (pandas.DataFrame): DataFrame containing features for validation.
    - X_test (pandas.DataFrame): DataFrame containing features for testing.
    - y_train (numpy.ndarray): Array containing labels for training.
    - y_valid (numpy.ndarray): Array containing labels for validation.
    - y_test (numpy.ndarray): Array containing labels for testing.
    """

    # Get the chemicals in each split of data
    train_chems = splitter_df.loc[splitter_df[split_strat] == "train"].index
    validation_chems = splitter_df.loc[splitter_df[split_strat] == "valid"].index
    test_chems = splitter_df.loc[splitter_df[split_strat] == "test"].index

    # Prepare taxonomic OHE
    taxa_ohe = prep_ohe(y_df.columns) 

    # Pivot longer screen results
    screen_melt = y_df.unstack().reset_index().rename(columns={0: "label",
                                                               "level_0": "taxa_name"})
    
    # Join molecular features and then join taxa OHE
    data_df.columns = [str(c) for c in data_df.columns]
    data_df = data_df.fillna(0)

    screen_feat = screen_melt.join(data_df, on="prestwick_ID")
    screen_feat = screen_feat.join(taxa_ohe, on="taxa_name")

    assert screen_feat.shape[0] == screen_melt.shape[0]


    # Gather train
    X_train = screen_feat.loc[screen_feat["prestwick_ID"].isin(train_chems)].drop(columns=["prestwick_ID", 
                                                                                          "label", 
                                                                                          "taxa_name"])
    y_train = screen_feat.loc[screen_feat["prestwick_ID"].isin(train_chems), ["label"]].values

    # Gather valid
    X_valid = screen_feat.loc[screen_feat["prestwick_ID"].isin(validation_chems)].drop(columns=["prestwick_ID", 
                                                                                          "label", 
                                                                                          "taxa_name"])
    y_valid = screen_feat.loc[screen_feat["prestwick_ID"].isin(validation_chems), ["label"]].values

    # Gather test
    X_test = screen_feat.loc[screen_feat["prestwick_ID"].isin(test_chems)].drop(columns=["prestwick_ID", 
                                                                                          "label", 
                                                                                          "taxa_name"])
    y_test = screen_feat.loc[screen_feat["prestwick_ID"].isin(test_chems), ["label"]].values

    
    return X_train, X_valid, X_test, y_train, y_valid, y_test
 
def get_performance_metrics(y_true, y_pred, y_score, split_name):

    """
    Compute performance metrics for a given data split.

    This function calculates various performance metrics including AUROC, AUPRC, and F1 score
    based on the true labels, predicted labels, and predicted scores.

    Parameters:
    - y_true (array-like): True labels.
    - y_pred (array-like): Predicted labels.
    - y_score (array-like): Predicted scores.
    - split_name (str): Name of the data split.

    Returns:
    - out_dict (dict): Dictionary containing computed performance metrics.
        Keys:
        - '{split_name}_auroc': Area Under the Receiver Operating Characteristic curve (AUROC) score.
        - '{split_name}_prauc': Area Under the Precision-Recall curve (AUPRC) score.
        - '{split_name}_f1': F1 score.
    """

    pr, rec, _ = precision_recall_curve(y_true, y_score[:, 1])

    out_dict = {f"{split_name}_auroc": roc_auc_score(y_true=y_true, y_score=y_score[:, 1]),
                f"{split_name}_prauc": auc(rec, pr),
                f"{split_name}_f1": f1_score(y_true=y_true, y_pred=y_pred)}

    return out_dict

## Main training function

In [86]:
def eval_models(dataset_representation, 
            n_train = 1, 
            n_models = 1,
            feature_options = representation_dict, 
            XGB_params_dict = XGB_PARAMS,
            split_df = split_df,
            screen_df = screen_df):
    
    """
    Evaluate multiple XGBoost models with different configurations on a given dataset representation.

    This function trains and evaluates multiple XGBoost models with different configurations
    on a specified dataset representation. It computes performance metrics for each model
    on validation and test sets and returns the results as a DataFrame.

    Parameters:
    - dataset_representation (str): Name of the dataset representation to use.
    - n_train (int, optional): Number of training iterations for each model (default is 1).
    - n_models (int, optional): Number of models to evaluate (default is 1).
    - feature_options (dict, optional): Dictionary containing different dataset representations (default is representation_dict).
    - XGB_params_dict (dict, optional): Dictionary containing XGBoost parameters for model configuration (default is XGB_PARAMS).
    - split_df (pandas.DataFrame, optional): DataFrame containing split information (default is split_df).
    - screen_df (pandas.DataFrame, optional): DataFrame containing screening data (default is screen_df).

    Returns:
    - results_df (pandas.DataFrame): DataFrame containing performance metrics for all evaluated models.
    """

    # This should be a dictionary containing all possible values for the classifier in question params
    classifier_params_copy = XGB_params_dict.copy()

    # Get the corresponding features and screen
    features_df = feature_options[dataset_representation].copy()

    # Since the splits are already made, we just have to separate the data
    X_train, X_valid, X_test, y_train, y_valid, y_test = get_split(features_df, screen_df, split_df)


    # Iterate over models
    results_list=[]
    for m in range(n_models):
        
        # Gather model configuration
        model_config = select_params(classifier_params_copy)
        model_config_str = str(model_config)

        # Iterate over training
        for t in range(n_train):

            # Create base estimator
            model_config["seed"] = np.random.randint(1_000_000, size=1)[0]
            base_estimator = XGBClassifier(**model_config)

            # Train model
            base_estimator.fit(X=X_train, y=y_train)

            # Validation
            print("At Validation")
            validation_proba = base_estimator.predict_proba(X=X_valid)
            validation_preds = base_estimator.predict(X=X_valid) 

            # Testing
            print("At Testing")
            test_proba = base_estimator.predict_proba(X=X_test)
            test_preds = base_estimator.predict(X=X_test)

            # Performance Metrics
            print("Gathering Results")
            validation_performance = get_performance_metrics(y_true=y_valid, y_pred=validation_preds, y_score=validation_proba, split_name="validation")
            test_performance = get_performance_metrics(y_true=y_test, y_pred=test_preds, y_score=test_proba, split_name="test")

            performance_dict = {**validation_performance, **test_performance}

            # Add information to the metrics
            performance_dict["model"] = f"model_{m}"
            performance_dict["train"] = f"train_{t}"
            performance_dict["model_type"] = "XGB"
            performance_dict["model_params"] = model_config_str
            performance_dict["representation"] = dataset_representation

            train_df = pd.DataFrame(performance_dict, index=[0])
            results_list.append(train_df)
    
    return pd.concat(results_list)

## Random Search

In [89]:
i = 0 
# Output file name
filename = "strain_performance.tsv.gz"

# Iterate over the representations
for representation in representation_dict.keys():
    print(f"Starting {representation} representation")

    # Random search
    results = eval_models(dataset_representation=representation)

    # Append results
    if i == 0:
        results.to_csv(os.path.join(OUTPUT_DIR, filename), sep='\t', index=False)
    else:
        results.to_csv(os.path.join(OUTPUT_DIR, filename), sep='\t', index=False, header=False, mode="a")
        
    i += 1


Starting MolE representation
At Validation
[[9.9796784e-01 2.0321773e-03]
 [9.9822879e-01 1.7712050e-03]
 [9.9245560e-01 7.5444025e-03]
 ...
 [9.9983788e-01 1.6215106e-04]
 [9.9981576e-01 1.8425252e-04]
 [9.9978560e-01 2.1438736e-04]]
At Testing
Gathering Results
Starting ECFP4 representation
At Validation
[[0.65721035 0.34278968]
 [0.85142165 0.14857836]
 [0.47510612 0.5248939 ]
 ...
 [0.86241525 0.13758476]
 [0.851101   0.148899  ]
 [0.86241525 0.13758476]]
At Testing
Gathering Results
Starting ChemDesc representation
At Validation
[[9.99996781e-01 3.24534813e-06]
 [9.99079943e-01 9.20061138e-04]
 [9.99998808e-01 1.21489290e-06]
 ...
 [9.99993920e-01 6.05880268e-06]
 [9.99893427e-01 1.06556006e-04]
 [9.99999881e-01 1.19222065e-07]]
At Testing
Gathering Results
