# Gut Microbiome - Infer Autism Spectrum Disorders from 16S Abundance

**Data and Paper Credits:**

Zhou Dan et al. published on April 21st of 2020 - [Altered gut microbial profile is associated with abnormal metabolism activity of Autism Spectrum Disorder](https://www.tandfonline.com/doi/full/10.1080/19490976.2020.1747329)

In [None]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import sklearn
from sklearn.metrics import (f1_score, roc_auc_score, accuracy_score,
                            confusion_matrix, precision_recall_curve, 
                            auc, roc_curve, recall_score, precision_score)
from xgboost import XGBClassifier
import xgboost as xgb
import shap
from google import genai
import matplotlib.pyplot as plt

from glob import glob

from dotenv import load_dotenv
load_dotenv()

shap.plots.initjs()

## Hyperparams / Constants

In [None]:
# some hyper parameters
SEED = 1970
test_train_split_SEED = 1970
# FOLDS = 10
show_fold_stats = True
PREDICT_THRESHOLD = 0.6

DATA_DIR = os.path.join(os.environ["HOME"], "OtonoCo",
               "Datasets","AMILI","ASD_Gut_Microbiome")

# Gemini Functions

In [None]:
gemini_client = genai.Client( api_key=os.environ["GOOGLE_API_KEY"])
GEMINI_MODEL = "gemini-2.0-flash-exp"

def get_llm_response_gemini(system_msg, user_msg):
    gemini_content = system_msg + "\n" + user_msg
    
    response = gemini_client.models.generate_content(
        model=GEMINI_MODEL, 
        contents=gemini_content
    )
    total_tokens = response.usage_metadata.total_token_count
    return response.text, total_tokens

In [None]:
def prompt_extract_bact_strain(input_taxo):
    system_msg = f"""
    You are an expert in identifying biological entities from text.
    You have been given the following text:
    {input_taxo}
    """
    user_msg = f"""
    Extract biological entities from the given text and give you answer in the following format:
    entity 1, entity 2, ...
    
    Output only the biological entities without anything else.
    """
    return system_msg, user_msg

def prompt_bio_entity_role(input_text, context):
    system_msg = f"""
    You are an expert in the field of Human Gut Microbiome.
    You have been given the following text containing biological entities:
    {input_text}
    """
    user_msg = f"""
    Describe the roles of the given entities in Human Gut Microbiome with regards to {context}.
    """
    return system_msg, user_msg

def extract_bact_strains(input_taxo):
    import time
    time.sleep(5)
    system_msg, user_msg = prompt_extract_bact_strain(input_taxo)
    results, tokens = get_llm_response_gemini(system_msg, user_msg)
    return results.replace("_"," "), tokens

def find_bio_entity_role(input_taxo, context = "diseases"):
    import time
    time.sleep(5)
    system_msg, user_msg = prompt_bio_entity_role(input_taxo, context)
    results, tokens = get_llm_response_gemini(system_msg, user_msg)
    return results.replace("_"," "), tokens
    
def single_otu_entity(in_otu):
    otu_entity = otu_taxo_dict.get(in_otu)
    print(in_otu + " - " + extract_bact_strains(otu_entity)[0])

## Data

In [None]:
csv_list = sorted(glob(os.path.join(DATA_DIR, "*.csv")))
csv_list

In [None]:
pd_meta_abundance = pd.read_csv(csv_list[0])
pd_16s = pd.read_csv(csv_list[1])

display(pd_meta_abundance.head())
display(pd_16s.head())
display(pd_16s.shape)

## OTU-TAXONOMY MAPPING

Keep track of `OTU-Taxonomy` pair for later use in using OTU to get TAXONOMY

In [None]:
otu_list_16s = pd_16s["OTU"]
taxo_list_16s = pd_16s["taxonomy"]

otu_taxo_dict = dict(zip(otu_list_16s, taxo_list_16s))


## Tranpose Dataset

In [None]:
taxa = pd_16s[['OTU', 'taxonomy']].set_index('OTU')
pd_16s_T = pd_16s.drop('taxonomy', axis=1).set_index('OTU').transpose()

display(pd_16s_T.head())

In [None]:
def define_target(input_otu):
    if (input_otu.startswith("A")):
        return 1
    else:
        return 0

## Restore `OTU` from `index`

In [None]:
pd_16s_df = pd_16s_T.copy()
pd_16s_df["OTU"] = pd_16s_df.index
display(pd_16s_df.head())

pd_16s_df.insert(loc = 0, column = "AUTISM", 
                 value = pd_16s_df["OTU"].apply(define_target))

display(pd_16s_df.head())

**Target Analysis**

In [None]:
pd_16s_df["AUTISM"].value_counts()

In [None]:
143/(143+111)

## Determine Total Species & Relative Abundance

In [None]:
rel_abund_df = pd_16s_df.iloc[:, 1:].set_index('OTU')
rel_abund_df.head()

In [None]:
abs_abundance = int(rel_abund_df.sum(axis = 1).mean())
abs_abundance

In [None]:
rel_abundance_df = (rel_abund_df / abs_abundance) * 100

display(rel_abundance_df.head())

## Restore Target and OTU

In [None]:
rel_abundance_df2 = rel_abundance_df.copy() #.reset_index()

target_list = pd_16s_df["AUTISM"].to_list()

rel_abundance_df2.insert(loc = 1, column = "AUTISM", value = target_list)

display(rel_abundance_df2.head())
display(rel_abundance_df2["AUTISM"].value_counts())

## Split Data
### Create `X` and `y`

In [None]:
X_df =  rel_abundance_df2.copy()
X = X_df[[x for x in list(rel_abundance_df2.columns) if x != "AUTISM"]]
X = X[[x for x in list(X.columns) if x != "OTU"]]
y = rel_abundance_df2[["AUTISM"]]

display(X.head())
display(y.head())

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=SEED)

np.shape(X_train)

## Run XGBoost Training on 16S Data

In [None]:
Xy = xgb.DMatrix(X, y, enable_categorical=True)
booster = xgb.train({"tree_method": "hist", "max_cat_to_onehot": 5}, Xy)
# Must use JSON for serialization, otherwise the information is lost
booster.save_model("abund-16S-model.json")

**https://xgboost.readthedocs.io/en/stable/parameter.html**

In [None]:
model = xgb.XGBClassifier(objective="binary:logistic") 
model.fit(X_train, y_train)

## Load Model

In [None]:
model = xgb.Booster()
model.load_model("abund-16S-model.json") 

# SHAP Explainer
## Explain `X`

In [None]:
explainer = shap.Explainer(model)
shap_value_X = explainer(X)

## Explain `X_test`

In [None]:
explainer = shap.Explainer(model)
shap_value_X_test = explainer(X_test)

## Waterfall Plot for `X`

In [None]:
shap.plots.waterfall(shap_value_X[0])

In [None]:
otu_taxo_dict.get("OTU625")

### Waterfall Plot for `X_test`

In [None]:
shap.plots.waterfall(shap_value_X_test[0])

In [None]:
single_otu_entity("OTU1301")

single_otu_entity("OTU390")

single_otu_entity("OTU1278")

single_otu_entity("OTU976")

single_otu_entity("OTU910")

### Force Plot for `X`

In [None]:
shap.plots.force(shap_value_X[0])

### Force Plot for `X_test`

In [None]:
shap.plots.force(shap_value_X_test[0])

### Beeswarm for `X`

In [None]:
shap.plots.beeswarm(shap_value_X)

### OTU Entity Abundance Analysis

In [None]:
single_otu_entity("OTU625")

single_otu_entity("OTU976")

single_otu_entity("OTU1301")

single_otu_entity("OTU390")

single_otu_entity("OTU813")

### Beeswarm for `X_test`

In [None]:
shap.plots.beeswarm(shap_value_X_test)

# Overall Shapley Importance

In [None]:
SHAP = booster.predict(Xy, pred_interactions=True)

# categorical features are listed as "c"
# print(booster.feature_types)

In [None]:
score_dict = booster.get_score()
score_table = pd.DataFrame({"column": score_dict.keys(),
              "shap_score": sorted([score_dict[x] for x in score_dict.keys()], reverse=True)})
display(score_table.head(10))

In [None]:
EVAL = booster.predict(Xy, pred_interactions=False)
predict_df = pd.DataFrame(EVAL)
predict_df.columns = ["PREDICT_AUTISM"]
predict_df.head(10)

## Make Prediction on `X_test`

In [None]:
predicted_test = model.predict(X_test)

print(predicted_test)

In [None]:
Xy_test = xgb.DMatrix(X_test, y_test, enable_categorical=True)
EVAL_test = booster.predict(Xy_test, pred_interactions=False)

## Create Data Frame for the Predicted Probability

In [None]:
predict_df = pd.DataFrame(EVAL_test)
predict_df.columns = ["PREDICTED_PROB_AUTISM"]
predict_df["PREDICTED_AUTISM"] = predicted_test
predict_df["ACTUAL_AUTISM"] = y_test["AUTISM"].to_list()
predict_df["SUBJECT"] = X_test.index
predict_df.head(10)

## Confusion Matrix

In [None]:
sklearn.metrics.confusion_matrix(predict_df["ACTUAL_AUTISM"], predict_df["PREDICTED_AUTISM"])

## Evaluate Model Performance

In [None]:
ACTUAL = predict_df["ACTUAL_AUTISM"].to_list()
PREDICTED = predict_df["PREDICTED_AUTISM"].to_list()


pred_acc = accuracy_score(ACTUAL, PREDICTED)
pred_f1 = f1_score(ACTUAL, PREDICTED)
pred_roc = roc_auc_score(ACTUAL, PREDICTED)
pred_recall = recall_score(ACTUAL, PREDICTED)
pred_precision = precision_score(ACTUAL, PREDICTED)

print(f"Accuracy Score: {round(pred_acc, 3)}")
print(f"F1 Score: {round(pred_f1, 3)}")
print(f"ROC Score: {round(pred_f1, 3)}")
print(f"RECALL Score: {round(pred_recall, 3)}")
print(f"PRECISION Score: {round(pred_precision, 3)}")

## Confusion Matrix

In [None]:
confusion_matrix(ACTUAL, PREDICTED)

## Variable Importance

In [None]:
var_imp_df = pd.DataFrame(sorted([(v,k) for k,v in score_dict.items()], reverse=True))
var_imp_df.columns = ["SCORE","OTU"]
var_imp_df["taxonomy"] = var_imp_df["OTU"].apply(lambda x : otu_taxo_dict.get(x))
var_imp_df.head(10)

## Build Table with Bacteria Strains

In [None]:
var_imp_df2 = var_imp_df.copy().head(20)
var_imp_df2["BIOLOGICAL"] = var_imp_df2["taxonomy"].apply(lambda x: extract_bact_strains(x)[0])
var_imp_df2

In [None]:
var_imp_df2.to_csv("output/biologicals.csv", index = False)

## Retrieve Relevant Contents

In [None]:
otu_biological = dict(zip(var_imp_df2["OTU"].to_list(), var_imp_df2["BIOLOGICAL"].to_list()))
otu_biological

In [None]:
import time
for k in otu_biological.keys():
    output_file = open(f"output/bact_role/{k}.md", "w")
    biologicals = otu_biological.get(k)
    results, tokens = find_bio_entity_role(biologicals, "diseases")
    output_file.write(results)
    output_file.close()
    time.sleep(3)

In [None]:
import time
for k in otu_biological.keys():
    output_file = open(f"output/probiotic_formulation/{k}.md", "w")
    biologicals = otu_biological.get(k)
    results, tokens = find_bio_entity_role(biologicals, "probiotic formulation")
    output_file.write(results)
    output_file.close()
    time.sleep(3)

## ROC Curve

In [None]:
def plot_roc(input_y_test, input_predicted_prob):
    # Calculate the false positive rate, true positive rate, and thresholds
    fpr, tpr, thresholds = roc_curve(input_y_test, input_predicted_prob)
    
    # Calculate the area under the ROC curve (AUC)
    roc_auc = auc(fpr, tpr)
    
    # Plot the ROC curve
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='blue', label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='maroon', linestyle='--', label='Random guess')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.grid()
    plt.show()

In [None]:
plot_roc(y_test, predict_df["PREDICTED_PROB_AUTISM"].to_list())

## OTU QUERY

In [None]:
single_otu_entity("OTU115")