In [84]:
# imports

import os
import json
from dotenv import load_dotenv
from openai import OpenAI
import gradio as gr
import requests
import urllib.parse  
from datetime import datetime

In [85]:
# Initialization

load_dotenv(override=True)

openai_api_key = os.getenv('OPENAI_API_KEY')
if openai_api_key:
    print(f"OpenAI API Key exists and begins {openai_api_key[:8]}")
else:
    print("OpenAI API Key not set")
    
MODEL = "gpt-4o-mini"
openai = OpenAI()

OpenAI API Key exists and begins sk-proj-


In [86]:
system_message = (
    "You are a clinical trials assistant that uses the ClinicalTrials.gov API to answer questions. "
    "If the user asks for study data, you must always call the `search_studies` tool. "
    "Only use your own knowledge for greetings or general questions."
    "Always return detailed and structured responses."
)

In [87]:
def search_studies(query): 
    import urllib.parse
    import requests

    print(f"[DEBUG] Received query: {query} (type: {type(query)})")
    
    if isinstance(query, dict):
        filters = []
        query_parts = []

        # ✅ Condition query
        if "query" in query:
            query_string = urllib.parse.quote_plus(query["query"])
            query_parts.append(f"query.cond={query_string}")

        # ✅ Structured filters (excluding date filters)
        if "phase" in query:
            filters.append(f"AREA[Phase]{query['phase']}")
        if "country" in query:
            filters.append(f"AREA[LocationCountry]{query['country']}")
        if "study_type" in query:
            filters.append(f"AREA[StudyType]{query['study_type']}")
        if "sex" in query:
            filters.append(f"AREA[Sex]{query['sex']}")
        if "age_group" in query:
            filters.append(f"AREA[StdAge]{query['age_group']}")
        if "status" in query:
            filters.append(f"AREA[OverallStatus]{query['status']}")
        if "sampling_method" in query:
            filters.append(f"AREA[SamplingMethod]{query['sampling_method']}")
        if "ipd_sharing" in query:
            filters.append(f"AREA[IPDSharing]{query['ipd_sharing']}")
        if "expanded_access" in query:
            filters.append(f"AREA[ExpandedAccess]{query['expanded_access']}")

        page_size = query.get("max_results", 3)

        # ❗ Remove date filters from API query
        filter_advanced = " AND ".join(filters)
        if filter_advanced:
            filter_advanced = f"({filter_advanced})"
        encoded_filter = urllib.parse.quote(filter_advanced, safe="[]*")

        url = (
            f"https://clinicaltrials.gov/api/v2/studies?"
            f"{'&'.join(query_parts)}"
            f"{f'&filter.advanced={encoded_filter}' if filter_advanced else ''}"
            f"&pageSize={page_size}"
        )
        
    else:
        encoded_query = urllib.parse.quote_plus(query)
        url = f"https://clinicaltrials.gov/api/v2/studies?query.cond={encoded_query}&pageSize=3"

    print("Requesting:", url)

    try:
        response = requests.get(url)
        print("Status Code:", response.status_code)

        # 🔁 Fallback if query.cond fails
        if response.status_code == 400 and any("query.cond=" in part for part in query_parts):
            print("[⚠️ Fallback] Retrying with query.text instead of query.cond...")
            query_parts = [p.replace("query.cond=", "query.text=") for p in query_parts]
            url = (
                f"https://clinicaltrials.gov/api/v2/studies?"
                f"{'&'.join(query_parts)}"
                f"{f'&filter.advanced={encoded_filter}' if filter_advanced else ''}"
                f"&pageSize={page_size}"
            )
            print("Retrying:", url)
            response = requests.get(url)
            print("Retry Status Code:", response.status_code)

        if response.status_code == 200:
            data = response.json()
            trials = data.get("studies", [])
            if not trials:
                return "No studies found."

            # ✅ Post-filter
            def matches(study, key, value):
                section = study.get("protocolSection", {})
                if key == "sponsor":
                    return value.lower() in section.get("sponsorCollaboratorsModule", {}).get("leadSponsor", {}).get("name", "").lower()
                elif key == "intervention":
                    interventions = section.get("armsInterventionsModule", {}).get("interventions", [])
                    return any(value.lower() in i.get("name", "").lower() for i in interventions)
                return True

            if "sponsor" in query:
                trials = [s for s in trials if matches(s, "sponsor", query["sponsor"])]
            if "intervention" in query:
                trials = [s for s in trials if matches(s, "intervention", query["intervention"])]

            if not trials:
                return "No studies found after applying filters."

            result = []
            for study in trials:
                ps = study.get("protocolSection", {})
                id_module = ps.get("identificationModule", {})
                design_module = ps.get("designModule", {})
                status_module = ps.get("statusModule", {})
                elig_module = ps.get("eligibilityModule", {})
                ipd_module = ps.get("ipdSharingStatementModule", {})
                desc_module = ps.get("descriptionModule", {})
                contact_module = ps.get("contactsLocationsModule", {})
                sponsor_module = ps.get("sponsorCollaboratorsModule", {})
                outcomes_module = ps.get("outcomesModule", {})
                arms_module = ps.get("armsInterventionsModule", {})

                nct_id = id_module.get("nctId", "N/A")
                title = id_module.get("briefTitle", "No Title")
                official_title = id_module.get("officialTitle", "N/A")
                phases = design_module.get("phases", [])
                study_type = design_module.get("studyType", "N/A")
                status = status_module.get("overallStatus", "N/A")
                ipd_sharing = ipd_module.get("ipdSharing", "N/A")
                expanded_access = status_module.get("expandedAccessInfo", {}).get("hasExpandedAccess", "N/A")
                sex = elig_module.get("sex", "N/A")
                std_ages = elig_module.get("stdAges", [])
                age_range = ", ".join(std_ages) if std_ages else "N/A"
                sampling_method = elig_module.get("samplingMethod", "N/A")
                criteria = elig_module.get("eligibilityCriteria", "N/A")
                start_date = status_module.get("startDateStruct", {}).get("date", "N/A")
                completion_date = status_module.get("completionDateStruct", {}).get("date", "N/A")
                locations = contact_module.get("locations", [])
                countries = sorted({loc.get("country") for loc in locations if loc.get("country")})

                location_lines = []
                for loc in locations:
                    parts = [loc.get("facility"), loc.get("city"), loc.get("state"), loc.get("country")]
                    clean = [p for p in parts if p]
                    if clean:
                        location_lines.append(", ".join(clean))
                locations_text = "\n".join(f"- {line}" for line in location_lines) if location_lines else "N/A"

                description = desc_module.get("detailedDescription", "N/A")
                interventions = arms_module.get("interventions", [])
                intervention_names = [iv.get("name", "") for iv in interventions if iv.get("name")]
                intervention_text = ", ".join(intervention_names) if intervention_names else "N/A"
                sponsor = sponsor_module.get("leadSponsor", {}).get("name", "N/A")
                collaborators = sponsor_module.get("collaborators", [])
                collaborator_names = [c.get("name", "") for c in collaborators]

                def format_outcomes(lst, label):
                    out = []
                    for o in lst:
                        out.append(f"- **{o.get('measure')}** ({o.get('timeFrame', 'N/A')}): {o.get('description', '')}")
                    return f"\n**{label}:**\n" + "\n".join(out) if out else ""

                outcomes = (
                    format_outcomes(outcomes_module.get("primaryOutcomes", []), "Primary Outcomes") +
                    format_outcomes(outcomes_module.get("secondaryOutcomes", []), "Secondary Outcomes") +
                    format_outcomes(outcomes_module.get("otherOutcomes", []), "Other Outcomes")
                )

                arms = []
                for group in arms_module.get("armGroups", []):
                    label = group.get("label", "N/A")
                    gtype = group.get("type", "N/A")
                    desc = group.get("description")
                    arms.append(f"- **{label}** ({gtype}): {desc.strip()}" if desc else f"- **{label}** ({gtype})")

                arms_text = "**Arms & Groups:**\n" + "\n".join(arms) if arms else ""
                phase_text = ', '.join(phases) if study_type.upper() == "INTERVENTIONAL" else "Not applicable (Observational study)"
                ctgov_link = f"https://clinicaltrials.gov/study/{nct_id}"
                description_block = f"**Detailed Description:**\n{description.strip()}\n\n" if description and description != "N/A" else ""

                result.append(
                    f"### 🧪 {title}\n\n"
                    f"**NCT ID:** `{nct_id}`\n"
                    f"🔗 [View on ClinicalTrials.gov]({ctgov_link})\n\n"
                    f"**Start Date (for filtering):** {start_date}\n"
                    f"**Completion Date (for filtering):** {completion_date}\n\n"
                    f"**Official Title:** {official_title}\n"
                    f"**Type:** {study_type.title()}\n"
                    f"**Phase:** {phase_text}\n"
                    f"**Status:** {status}\n"
                    f"**Country:** {', '.join(countries) if countries else 'N/A'}\n"
                    f"**Interventions:** {intervention_text}\n"
                    f"**Sponsor:** {sponsor}\n"
                    f"**Collaborators:** {', '.join(collaborator_names) if collaborator_names else 'None'}\n\n"
                )

            return "\n\n---\n\n".join(result).strip()

        return f"API returned error: {response.status_code}"

    except Exception as e:
        print("Exception occurred:", e)
        return "Error fetching study data."

In [88]:
# There's a particular dictionary structure that's required to describe our function:

search_function = {
    "name": "search_studies",
    "description": "Search for clinical trials with strict filtering on all key metadata fields such as condition, country, phase, study type, sex, age group, sampling method, sponsor, collaborators, intervention, start dates, completion dates, etc.",
    "parameters": {
        "type": "object",
        "properties": {
            "query": {
                "type": "string",
                "description": "Condition or keyword to search for. (e.g., 'lung cancer', 'IBD')"
            },
            "phase": {
                "type": "string",
                "description": "Clinical trial phase. (e.g., 'Phase 1', 'Phase 2', 'Phase 3')"
            },
            "status": {
                "type": "string",
                "description": "Recruitment status. (e.g., 'RECRUITING', 'COMPLETED')"
            },
            "country": {
                "type": "string",
                "description": "Country where the trial is conducted. (e.g., 'Italy')"
            },
            "study_type": {
                "type": "string",
                "description": "Type of study. (e.g., 'INTERVENTIONAL', 'OBSERVATIONAL')"
            },
            "sex": {
                "type": "string",
                "description": "Sex eligibility. (e.g., 'Male', 'Female', 'All')"
            },
            "age_group": {
                "type": "string",
                "description": "Standard age group. (e.g., 'CHILD', 'ADULT', 'OLDER_ADULT')"
            },
            "sampling_method": {
                "type": "string",
                "description": "Participant sampling method. (e.g., 'PROBABILITY_SAMPLE', 'NON_PROBABILITY_SAMPLE')"
            },
            "intervention": {
                "type": "string",
                "description": "Intervention or treatment keyword. (e.g., 'aspirin', 'TAE')"
            },
            "sponsor": {
                "type": "string",
                "description": "Name of the lead sponsor or organization. (e.g., 'Pfizer', 'NIH')"
            },
            "ipd_sharing": {
                "type": "string",
                "description": "Will individual participant data (IPD) be shared? (e.g., 'YES', 'NO', 'UND')"
            },
            "expanded_access": {
                "type": "string",
                "description": "Whether expanded access is available. (e.g., 'YES', 'NO', 'UNKNOWN')"
            },
            "start_date_from": {
                "type": "string",
                "description": "Earliest start date allowed (format: YYYY-MM or YYYY-MM-DD)"
            },
            "start_date_to": {
                "type": "string",
                "description": "Latest start date allowed"
            },
            "completion_date_from": {
                "type": "string",
                "description": "Earliest completion date allowed"
            },
            "completion_date_to": {
                "type": "string",
                "description": "Latest completion date allowed"
            },
            "max_results": {
                "type": "integer",
                "description": "Maximum number of studies to return"
            }
        },
        "required": ["query"],
        "additionalProperties": False
    }
}

In [89]:
# And this is included in a list of tools:

tools = [{"type": "function", "function": search_function}]

In [90]:
def chat(message, history):
    messages = [{"role": "system", "content": system_message}] + history + [{"role": "user", "content": message}]

    # 🔄 First attempt: try to stream the LLM output
    response_stream = openai.chat.completions.create(
        model=MODEL,
        messages=messages,
        tools=tools,
        tool_choice="auto",
        stream=True
    )

    full_response = ""
    tool_call_detected = False

    for chunk in response_stream:
        choice = chunk.choices[0]
        delta = choice.delta

        # 🧠 Detect tool call request during stream
        if hasattr(delta, "tool_calls") and delta.tool_calls:
            tool_call_detected = True
            break  # Exit streaming — can't continue past tool call

        if delta.content:
            full_response += delta.content
            yield full_response  # Live stream to user

    # 🧰 Tool call fallback (non-streamed)
    if tool_call_detected:
        fallback = openai.chat.completions.create(
            model=MODEL,
            messages=messages,
            tools=tools,
            tool_choice="auto"  # No stream here, required to get tool_calls
        )

        message = fallback.choices[0].message
        print("Finish reason:", fallback.choices[0].finish_reason)
        print("Tool calls:", message.tool_calls if hasattr(message, 'tool_calls') else None)

        # 🔧 Call the tool(s)
        tool_responses = handle_tool_call(message)

        # Add the assistant tool call message and all corresponding tool responses
        messages.append(message)
        messages.extend(tool_responses)

        # 🧠 Now ask GPT to summarize the tool result(s)
        final_response_stream = openai.chat.completions.create(
            model=MODEL,
            messages=messages,
            stream=True
        )

        final_output = ""
        for chunk in final_response_stream:
            delta = chunk.choices[0].delta
            if delta.content:
                final_output += delta.content
                yield final_output  # Stream final GPT summary

    # 🧯 Final fallback if nothing streamed
    elif not full_response:
        fallback = openai.chat.completions.create(
            model=MODEL,
            messages=messages,
            tools=tools,
            tool_choice="auto"
        )
        yield fallback.choices[0].message.content

In [91]:
def handle_tool_call(message):
    import json

    tool_responses = []

    for tool_call in message.tool_calls:
        arguments = json.loads(tool_call.function.arguments)
        result = search_studies(arguments)

        tool_responses.append({
            "role": "tool",
            "tool_call_id": tool_call.id,
            "content": result if isinstance(result, str) else json.dumps(result)
        })

    return tool_responses

In [92]:
example_prompts = [
    "Show me trials in Taiwan studying Vedolizumab",
    "List studies for Crohn's disease that started after 2015",
    "Give me 5 completed trials on lung cancer in Japan",
    "Find interventional Phase 3 studies for breast cancer in France",
    "List observational studies in Asia with female participants over 65.",
    "Get details for NCT06100289",
    "Show studies that started after 2022 for Asthma and are still ongoing."
]

gr.ChatInterface(
    fn=chat,
    type="messages",
    title="ClinicalTrials.gov Agent",
    description=(
        "Ask about medical conditions, NCT IDs, trial phases, study types, recruitment status, interventions, sponsors, age groups, sex, sampling methods, IPD sharing, expanded access, countries, locations, and date ranges (start/completion).\n\n"
        "💡 You can also try one of the examples below to get started."
    ),
    chatbot=gr.Chatbot(label="CTGagent", type="messages"),
    examples=example_prompts
).launch(app_kwargs={"title": "CTGagent"})


* Running on local URL:  http://127.0.0.1:7867

To create a public link, set `share=True` in `launch()`.




Finish reason: tool_calls
Tool calls: [ChatCompletionMessageToolCall(id='call_OKLKZYud2q7uxJVw5NXn1TpH', function=Function(arguments='{"query":"Asthma","status":"RECRUITING","start_date_from":"2022-01-01","max_results":10}', name='search_studies'), type='function')]
[DEBUG] Received query: {'query': 'Asthma', 'status': 'RECRUITING', 'start_date_from': '2022-01-01', 'max_results': 10} (type: <class 'dict'>)
Requesting: https://clinicaltrials.gov/api/v2/studies?query.cond=Asthma&filter.advanced=%28AREA[OverallStatus]RECRUITING%29&pageSize=10
Status Code: 200
Finish reason: tool_calls
Tool calls: [ChatCompletionMessageToolCall(id='call_5sqJJdsQTLkqh9FfziUkr7Up', function=Function(arguments='{"query":"NCT06100289"}', name='search_studies'), type='function')]
[DEBUG] Received query: {'query': 'NCT06100289'} (type: <class 'dict'>)
Requesting: https://clinicaltrials.gov/api/v2/studies?query.cond=NCT06100289&pageSize=3
Status Code: 200
