In [1]:
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain.prompts import ChatPromptTemplate
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
from typing import Optional, List, Union
from pydantic import BaseModel, Field, field_validator
from typing import List, Optional
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import PydanticOutputParser
import json
import os
import re
import glob
import json
from json import JSONEncoder
import requests, urllib.parse
from langchain_core.tools import tool


In [2]:
# Pick via provider:model string → works across providers
llm = init_chat_model(model="gemini-2.5-flash",
                      model_provider="google_genai",
                      temperature=0.2)

In [5]:
#global_path  = "/Users/ilboukil/Library/CloudStorage/OneDrive-SIBSwissInstituteofBioinformatics/Trainings-cb-402/ML_summer_school_code/"
global_path  = "/Users/SJp/Documents/project_local/VIB-LLM-SS/ml-summerschool-2025/topic-1_data-integration-and-llms/project/results/"

patient_id = "MM082"
# Path to your JSON file (e.g., patient PKG or classification output)
json_path = f"{global_path}/{patient_id}.json"

with open(json_path, "r") as f:
    patient_json = json.load(f)


## Test without tools

In [6]:
prompt = ChatPromptTemplate.from_messages([
    ("system", """You are a biomedical-AI assistant that interprets predictions from a AI-powered predictive model for clinicians.
You are given a JSON with some information regarding the patient and 2 drugs. In the JSON, you are given:
- Patient ID
- Disease type
- For each drug, you will be given:
    - The drug name
    - The predicted class. This class can be one of three options: 
        - No effect if the drug is predicted as having no effect on treating the patient disease
        - Positive response if the drug is predicted as having a positive effect on treating the patient disease
        - Adverse effects if the drug is predicted as having a negative effect on the patient disease
    - Each predicted class has an associated probability
    - Each predicted class has associated features, that are responsible for the prediction. To reflect the importance of these features on the prediction we have the SHAP values. We have the top positive SHAP values and the top negative SHAP values.

Taking into account this JSON and the information explained above, I want you as a smart biomedical-AI assistant to pick the best of the two drugs.
Once you have  picked the best drug for my patient, I want you to write a report on the chosen drug, please include both positive an negative points about the drug. This should be targeted towards clinicians. 
After that, please write about the features involved in the decision making process, and look in the litterature for information about the relationship between these features and the disease the patient has. 
"""),
    ("human", "{JSON_input}")
])

In [7]:
chain = prompt | llm | StrOutputParser()
response = chain.invoke({"JSON_input": patient_json})

In [8]:
class GNN_prediction_report(BaseModel):
    patient_ID: Optional[Union[str, int]]  # "555-1234", 5551234, or None
    disease_type: str = Field(description="e.g., Melanoma")
    recomended_drug_name: str = Field(description = "e.g., Pembrolizumab")
    info_on_recommended_drug: str = Field(description = "e.g., Pembrolizumab is a PD-1 inhibitor that has demonstrated significant efficacy in advanced melanoma, improving both progression-free and overall survival. Clinically, it can induce durable responses in a subset of patients. However, its use is associated with immune-related adverse effects, including colitis, hepatitis, pneumonitis, endocrinopathies (such as hypothyroidism or hypophysitis), and less commonly severe dermatologic or neurologic toxicities. Careful monitoring and prompt management of these toxicities are essential during treatment")
    decision_making_process: str 

In [9]:
parser = PydanticOutputParser(pydantic_object=GNN_prediction_report)
format_instructions = parser.get_format_instructions()
prompt = ChatPromptTemplate.from_messages([
    ("system", "Extract per schema:\n{format_instructions}"),
    ("human", "{text}"),
]).partial(format_instructions=format_instructions)

parsing_llm = prompt | llm | parser

# if `drug_text` is an AIMessage, use .content; otherwise pass the raw string
result = parsing_llm.invoke({"text": response})


In [10]:
from json import JSONEncoder
class MyEncoder(JSONEncoder):
    def default(self, o):
        return o.__dict__


In [11]:
file_name = f"{global_path}/MM082_results_prompt_without_tools.json"
with open(file_name, "w") as f:
    json.dump(MyEncoder().encode(result), f, indent=2)

## Adding the tools

In [12]:
@tool
def get_openfda_label(ingredient):
    """
    Fetches drug purpose, indications, usages, adverse reactions, warning, and dosage and administration information from the FDA API.

    Args:
        drug_name (str): The name of the drug (e.g., "aspirin").

    Returns:
        dict: The  response drug purpose, indications, usages, adverse reactions, warning, and dosage and administration information
    """
    # Standardize ingredient name
    ingredient = urllib.parse.quote(ingredient)
    print(ingredient)
    base = "https://api.fda.gov/drug/label.json"
    q = f'openfda.substance_name:"{ingredient}"'
    r = requests.get(base, params={"search": q, "limit": 1})
    r.raise_for_status()
    res = r.json().get("results", [])
    if not res:
        q2 = f'openfda.brand_name:"{ingredient}"'
        r = requests.get(base, params={"search": q2, "limit": 1})
        r.raise_for_status()
        res = r.json().get("results", [])
        if not res:
            return None
    return res[0]

# item = get_openfda_label("ibuprofen")

# if item:
#     for key in ("purpose","indications_and_usage", "adverse_reactions", "warnings", "dosage_and_administration"):
#         if key in item:
#             print(f"\n=== {key} ===\n{item[key][0][:800]}")
# else:
#     print("Cannot find record on openFDA for that name.")

In [13]:
# bioassistant.py

import json
from typing import List, Dict
from collections import defaultdict

from gseapy import enrichr
from langchain.chat_models import init_chat_model
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool


# -------------------------------
# Tool: Enrichment
# -------------------------------
@tool
def enrichr_query(gene_list: List[str]):
    """Run enrichment analysis on a list of genes using gseapy (GO Biological Process)."""
    enr = enrichr(
        gene_list=gene_list,
        gene_sets='GO_Biological_Process_2021',
        organism='Human',
        outdir=None,
        cutoff=0.05
    )
    return enr.results  # DataFrame


# -------------------------------
# LLM setup
# -------------------------------
def get_llm_with_tools(model: str = "gemini-2.5-flash", provider: str = "google_genai"):
    """Initialize the chat model and bind the enrichment tool."""
    llm = init_chat_model(model=model, model_provider=provider, temperature=0.2)
    return llm.bind_tools([enrichr_query])


def get_prompt_chain(llm_with_tools):
    """Return a chain with system+human prompt bound to the LLM with tools."""
    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are a helpful bioinformatics assistant. Use tools when needed."),
        ("human", "{question}")
    ])
    return prompt | llm_with_tools


# -------------------------------
# SHAP → Gene Sets → Enrichment → Summarization
# -------------------------------
@tool
def analyze_patient(patient_json: Dict, patient_id: str, chain):
    """
    Collect SHAP features per predicted class, run enrichment, and ask LLM to summarize.
    
    Parameters
    ----------
    patient_json : dict
        JSON object with structure { patient_id: { "Drugs": {...}} }
    patient_id : str
        Patient ID key in patient_json
    chain : LangChain runnable (prompt | llm_with_tools)
    
    Returns
    -------
    results_by_class : dict
        { predicted_class: { "positive": enrichment_df, "negative": enrichment_df } }
    """
    drug_keys = list(patient_json[patient_id]["Drugs"].keys())
    class_features = defaultdict(lambda: {"positive": [], "negative": []})

    # Step 1: Collect SHAP features by predicted class
    for drug in drug_keys:
        predicted_class = patient_json[patient_id]["Drugs"][drug]["Predicted_Class"]
        pos_features = patient_json[patient_id]["Drugs"][drug]['SHAP']['Top_Positive']
        neg_features = patient_json[patient_id]["Drugs"][drug]['SHAP']['Top_Negative']

        class_features[predicted_class]["positive"].extend(
            [item["Feature"].split("_", 1)[1] if "_" in item["Feature"] else item["Feature"]
             for item in pos_features]
        )
        class_features[predicted_class]["negative"].extend(
            [item["Feature"].split("_", 1)[1] if "_" in item["Feature"] else item["Feature"]
             for item in neg_features]
        )

    # Step 2: Run enrichment
    results_by_class = {}
    for cls, feats in class_features.items():
        results_by_class[cls] = {}
        results_by_class[cls]["positive"] = (
            enrichr_query({"gene_list": list(set(feats["positive"]))}) if feats["positive"] else None
        )
        results_by_class[cls]["negative"] = (
            enrichr_query({"gene_list": list(set(feats["negative"]))}) if feats["negative"] else None
        )

    # Step 3: Summarize with LLM
    summaries = {}
    for cls, res in results_by_class.items():
        question = f"Predicted class: {cls}\nSummarize functional biology or pathways of SHAP features.\n"

        if res["positive"] is not None and not res["positive"].empty:
            question += f"\nPositive SHAP features (supporting {cls}):\n{res['positive'].head(10).to_string(index=False)}\n"
        if res["negative"] is not None and not res["negative"].empty:
            question += f"\nNegative SHAP features (against {cls}):\n{res['negative'].head(10).to_string(index=False)}\n"

        ai_msg = chain.invoke({"question": question})
        summaries[cls] = ai_msg.content

    return results_by_class, summaries


# bioassistant.py (add at the bottom)

import pandas as pd
def save_patient_summary_html(patient_id: str,
                              results_by_class: dict,
                              summaries: dict,
                              out_path: str = None):
    """
    Save the enrichment results + LLM summaries into an HTML report.
    
    Parameters
    ----------
    patient_id : str
        Patient identifier
    results_by_class : dict
        Output from analyze_patient (enrichment results)
    summaries : dict
        Output from analyze_patient (LLM summaries)
    out_path : str
        File path for the HTML file (default = f"{patient_id}_summary.html")
    """
    if out_path is None:
        out_path = f"{patient_id}_summary.html"

    html_parts = [f"<h1>Patient {patient_id} – Pathway Analysis Report</h1>"]

    for cls, summary in summaries.items():
        html_parts.append(f"<h2>Predicted Class: {cls}</h2>")
        html_parts.append(f"<p><strong>LLM Summary:</strong><br>{summary}</p>")

        # Insert enrichment tables
        for direction in ["positive", "negative"]:
            df = results_by_class[cls].get(direction)
            if df is not None and not df.empty:
                html_parts.append(f"<h3>{direction.title()} SHAP Features Enrichment</h3>")
                html_parts.append(df.head(15).to_html(index=False, escape=False))
    
    html = "\n".join(html_parts)
    with open(out_path, "w") as f:
        f.write(html)
    print(f"✅ HTML report saved to {out_path}")


In [14]:
prompt = ChatPromptTemplate.from_messages([
    ("system", """You are a biomedical-AI assistant that interprets predictions from a AI-powered predictive model for clinicians.
You are given a JSON with some information regarding the patient and 2 drugs. In the JSON, you are given:
- Patient ID
- Disease type
- For each drug, you will be given:
    - The drug name
    - The predicted class. This class can be one of three options: 
        - No effect if the drug is predicted as having no effect on treating the patient disease
        - Positive response if the drug is predicted as having a positive effect on treating the patient disease
        - Adverse effects if the drug is predicted as having a negative effect on the patient disease
    - Each predicted class has an associated probability
    - Each predicted class has associated features, that are responsible for the prediction. To reflect the importance of these features on the prediction we have the SHAP values. We have the top positive SHAP values and the top negative SHAP values.

Taking into account this JSON and the information explained above, I want you as a smart biomedical-AI assistant to pick the best of the two drugs.
Once you have  picked the best drug for my patient, I want you to write a report on the chosen drug, please include both positive an negative points about the drug. This should be targeted towards clinicians. You should use the tools at your disposal.
After that, please write about the features involved in the decision making process, and use the tools at your disposal for information about the relationship between these features and the disease the patient has. You should use the tools at your disposal.
For all tasks, tell me as much as possible.
"""),
    ("human", "{JSON_input}")
])

In [None]:
llm_with_tools = llm.bind_tools([get_openfda_label, enrichr_query])
chain = prompt | llm_with_tools | StrOutputParser()
response = chain.invoke({"JSON_input": patient_json})


In [16]:
parser = PydanticOutputParser(pydantic_object=GNN_prediction_report)
format_instructions = parser.get_format_instructions()
prompt_formats = ChatPromptTemplate.from_messages([
    ("system", "Extract per schema:\n{format_instructions}"),
    ("human", "{text}"),
]).partial(format_instructions=format_instructions)

parsing_llm = prompt_formats | llm | parser

# if `drug_text` is an AIMessage, use .content; otherwise pass the raw string
result = parsing_llm.invoke({"text": response})


# save_patient_summary_html(patient_id,out_path="results.html")



In [17]:
class MyEncoder(JSONEncoder):
    def default(self, o):
        return o.__dict__

file_name = f"{global_path}/MM082_results_prompt1_rep4.json"
with open(file_name, "w") as f:
    json.dump(MyEncoder().encode(response), f, indent=2)