In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
'''
CLINICAL TRIAL ENROLLMENT SUCCESS – WORKSHOP VERSION
====================================================

End-to-end script for students:
- Problem framing
- Data collection from ClinicalTrials.gov
- Data cleaning & feature engineering
- Multiple classification models & comparison
- Basic interpretation
'''

import requests
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

try:
    from xgboost import XGBClassifier
    HAS_XGB = True
except ImportError:
    HAS_XGB = False

print("=" * 80)
print("CLINICAL TRIAL ENROLLMENT SUCCESS – HANDS-ON WORKSHOP")
print("=" * 80)

# =============================================================================
# STEP 1: PROBLEM INTRODUCTION
# =============================================================================
print("\n[STEP 1] THE BUSINESS PROBLEM")
print("-" * 80)
print('''
Context:
- Many clinical trials struggle to recruit enough patients on time.
- This creates delays, extra costs, and sometimes failed projects.

Our simplified goal for this workshop:
- Build ML models to classify trials into two groups:
  * Likely to be "successful" (higher enrollment)
  * Likely to "struggle" (lower enrollment)
- Compare different models and discuss what works better and why.
''')
input("Press Enter to continue...")

# =============================================================================
# STEP 2: DATA COLLECTION FROM CLINICALTRIALS.GOV
# =============================================================================
print("\n[STEP 2] COLLECTING DATA FROM CLINICALTRIALS.GOV API")
print("-" * 80)

def quick_collect_trials(condition="cancer", max_trials=300):
    """
    Quick data collection for demo.
    NOTE: This uses the v2 studies endpoint with a simple query.
    """
    print(f"Fetching ~{max_trials} '{condition}' trials from ClinicalTrials.gov...")

    url = "https://clinicaltrials.gov/api/v2/studies"
    params = {
        "query.cond": condition,
        "pageSize": max_trials,
        "format": "json"
    }

    try:
        resp = requests.get(url, params=params, timeout=30)
        resp.raise_for_status()
        data = resp.json()
        studies = data.get("studies", [])
        print(f"Retrieved {len(studies)} studies.")

        rows = []
        for s in studies:
            protocol = s.get("protocolSection", {})
            ident = protocol.get("identificationModule", {})
            design = protocol.get("designModule", {})
            cond_module = protocol.get("conditionsModule", {})
            elig = protocol.get("eligibilityModule", {})
            contacts = protocol.get("contactsLocationsModule", {})
            status_mod = protocol.get("statusModule", {})

            nct_id = ident.get("nctId")
            brief_title = ident.get("briefTitle")

            phase = design.get("phases")
            if isinstance(phase, list):
                phase = phase[0]  # take first if list

            enrollment_info = design.get("enrollmentInfo", {})
            enrollment_count = enrollment_info.get("count")

            conditions = cond_module.get("conditions")
            if isinstance(conditions, list):
                conditions = ", ".join(conditions)

            # locations
            locs = contacts.get("locations", [])
            num_locations = len(locs)
            countries = sorted(list({loc.get("country") for loc in locs if loc.get("country")}))
            num_countries = len(countries)
            countries_str = ", ".join(countries)

            # status & dates
            overall_status = status_mod.get("overallStatus")
            start_date = status_mod.get("startDateStruct", {}).get("date")
            primary_completion_date = status_mod.get("primaryCompletionDateStruct", {}).get("date")

            # simple "success proxy": if enrollment_count exists and > some threshold
            # (this is a placeholder; in real life we'd use actual outcome)
            rows.append({
                "nct_id": nct_id,
                "brief_title": brief_title,
                "phase": phase,
                "enrollment_count": enrollment_count,
                "conditions": conditions,
                "overall_status": overall_status,
                "num_locations": num_locations,
                "num_countries": num_countries,
                "countries": countries_str,
                "start_date": start_date,
                "primary_completion_date": primary_completion_date
            })

        df = pd.DataFrame(rows)
        return df

    except Exception as e:
        print("API error occurred. Using a small synthetic fallback dataset instead.")
        print("Error details:", e)

        # ---------------------------------------------------------------------
        # FALLBACK SYNTHETIC DATA (for offline / no-API environments)
        # ---------------------------------------------------------------------
        data = {
            "nct_id": [f"NCT0000{i}" for i in range(1, 31)],
            "brief_title": [f"Trial {i}" for i in range(1, 31)],
            "phase": np.random.choice(["Phase 1", "Phase 2", "Phase 3"], 30),
            "enrollment_count": np.random.randint(10, 2000, 30),
            "conditions": np.random.choice(["Cancer", "Diabetes", "Cardio"], 30),
            "overall_status": np.random.choice(["Completed", "Terminated", "Recruiting"], 30),
            "num_locations": np.random.randint(1, 50, 30),
            "num_countries": np.random.randint(1, 5, 30),
            "countries": np.random.choice(["US", "US, Canada", "US, UK, Germany"], 30),
            "start_date": ["2015-01-01"] * 30,
            "primary_completion_date": ["2018-01-01"] * 30
        }
        return pd.DataFrame(data)

# You can change the condition if you like
df_raw = quick_collect_trials(condition="cancer", max_trials=300)

print("\nSample of raw data:")
print(df_raw.head())
input("\nPress Enter to continue to data cleaning & feature engineering...")

# =============================================================================
# STEP 3: DATA CLEANING & FEATURE ENGINEERING
# =============================================================================
print("\n[STEP 3] DATA CLEANING & FEATURE ENGINEERING")
print("-" * 80)

df = df_raw.copy()

# Basic cleaning: remove trials without enrollment_count
df = df[df["enrollment_count"].notnull()].copy()

# Convert enrollment_count to numeric
df["enrollment_count"] = pd.to_numeric(df["enrollment_count"], errors="coerce")
df = df[df["enrollment_count"].notnull()].copy()

# Simple success label:
# Let's say a trial is "successful" if enrollment_count >= median of dataset
threshold = df["enrollment_count"].median()
df["success"] = (df["enrollment_count"] >= threshold).astype(int)

print(f"Enrollment median threshold for success label: {threshold:.1f}")
print(df[["nct_id", "enrollment_count", "success"]].head())

# Simple categorical cleanup
df["phase"] = df["phase"].fillna("Unknown")
df["overall_status"] = df["overall_status"].fillna("Unknown")

# Numerical features
df["num_locations"] = df["num_locations"].fillna(0).astype(int)
df["num_countries"] = df["num_countries"].fillna(0).astype(int)

# Optional: derive a very rough "global" indicator
df["is_multinational"] = (df["num_countries"] > 1).astype(int)

# Keep a subset of columns for ML
features = [
    "phase",
    "overall_status",
    "num_locations",
    "num_countries",
    "is_multinational"
]
target = "success"

df_ml = df[features + [target]].copy()

print("\nPrepared ML dataset sample:")
print(df_ml.head())
print("\nClass distribution (0 = low enrollment, 1 = high enrollment):")
print(df_ml["success"].value_counts())
input("\nPress Enter to continue to train/test split...")

# =============================================================================
# STEP 4: TRAIN / TEST SPLIT
# =============================================================================
print("\n[STEP 4] TRAIN / TEST SPLIT")
print("-" * 80)

X = df_ml[features]
y = df_ml[target]

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, random_state=42, stratify=y
)

print("Train size:", X_train.shape, " Test size:", X_test.shape)
input("\nPress Enter to continue to baseline model...")

# =============================================================================
# STEP 5: BASELINE MODEL – LOGISTIC REGRESSION
# =============================================================================
print("\n[STEP 5] BASELINE MODEL – LOGISTIC REGRESSION")
print("-" * 80)

# Identify categorical vs numeric columns
cat_cols = ["phase", "overall_status"]
num_cols = ["num_locations", "num_countries", "is_multinational"]

preprocessor = ColumnTransformer(
    transformers=[
        ("cat", OneHotEncoder(handle_unknown="ignore"), cat_cols),
        ("num", "passthrough", num_cols)
    ]
)

log_reg = LogisticRegression(
    max_iter=1000,
    class_weight="balanced",
    solver="lbfgs"
)

pipe_log_reg = Pipeline(
    steps=[
        ("preprocess", preprocessor),
        ("model", log_reg)
    ]
)

print("Training Logistic Regression...")
pipe_log_reg.fit(X_train, y_train)

y_pred_lr = pipe_log_reg.predict(X_test)
y_prob_lr = pipe_log_reg.predict_proba(X_test)[:, 1]

print("\nLogistic Regression – Classification Report:")
print(classification_report(y_test, y_pred_lr, digits=3))

try:
    auc_lr = roc_auc_score(y_test, y_prob_lr)
    print(f"Logistic Regression – ROC AUC: {auc_lr:.3f}")
except Exception as e:
    print("Could not compute AUC:", e)

input("\nPress Enter to continue to Random Forest model...")

# =============================================================================
# STEP 6: RANDOM FOREST MODEL
# =============================================================================
print("\n[STEP 6] RANDOM FOREST MODEL")
print("-" * 80)

rf = RandomForestClassifier(
    n_estimators=200,
    max_depth=None,
    random_state=42,
    class_weight="balanced_subsample",
    n_jobs=-1
)

pipe_rf = Pipeline(
    steps=[
        ("preprocess", preprocessor),
        ("model", rf)
    ]
)

print("Training Random Forest...")
pipe_rf.fit(X_train, y_train)

y_pred_rf = pipe_rf.predict(X_test)
y_prob_rf = pipe_rf.predict_proba(X_test)[:, 1]

print("\nRandom Forest – Classification Report:")
print(classification_report(y_test, y_pred_rf, digits=3))

try:
    auc_rf = roc_auc_score(y_test, y_prob_rf)
    print(f"Random Forest – ROC AUC: {auc_rf:.3f}")
except Exception as e:
    print("Could not compute AUC:", e)

print("\nRandom Forest – Confusion Matrix:")
print(confusion_matrix(y_test, y_pred_rf))

input("\nPress Enter to continue to XGBoost model (if available)...")

# =============================================================================
# STEP 7: XGBOOST MODEL (OPTIONAL)
# =============================================================================
if HAS_XGB:
    print("\n[STEP 7] XGBOOST MODEL")
    print("-" * 80)

    xgb = XGBClassifier(
        n_estimators=300,
        learning_rate=0.05,
        max_depth=4,
        subsample=0.8,
        colsample_bytree=0.8,
        objective="binary:logistic",
        eval_metric="logloss",
        random_state=42,
        n_jobs=-1
    )

    pipe_xgb = Pipeline(
        steps=[
            ("preprocess", preprocessor),
            ("model", xgb)
        ]
    )

    print("Training XGBoost...")
    pipe_xgb.fit(X_train, y_train)

    y_pred_xgb = pipe_xgb.predict(X_test)
    y_prob_xgb = pipe_xgb.predict_proba(X_test)[:, 1]

    print("\nXGBoost – Classification Report:")
    print(classification_report(y_test, y_pred_xgb, digits=3))

    try:
        auc_xgb = roc_auc_score(y_test, y_prob_xgb)
        print(f"XGBoost – ROC AUC: {auc_xgb:.3f}")
    except Exception as e:
        print("Could not compute AUC:", e)
else:
    print("\n[STEP 7] XGBOOST MODEL")
    print("-" * 80)
    print("xgboost is not installed in this environment.")
    print("Students can install it with: pip install xgboost")
    print("Then add and train an XGBClassifier model similarly to Random Forest.")

input("\nPress Enter to see simple model comparison...")

# =============================================================================
# STEP 8: SIMPLE MODEL COMPARISON
# =============================================================================
print("\n[STEP 8] MODEL COMPARISON (QUICK VIEW)")
print("-" * 80)

results = []

# Logistic Regression
try:
    auc_lr = roc_auc_score(y_test, y_prob_lr)
except Exception:
    auc_lr = np.nan

results.append({
    "model": "Logistic Regression",
    "roc_auc": auc_lr
})

# Random Forest
try:
    auc_rf = roc_auc_score(y_test, y_prob_rf)
except Exception:
    auc_rf = np.nan

results.append({
    "model": "Random Forest",
    "roc_auc": auc_rf
})

# XGBoost
if HAS_XGB:
    try:
        auc_xgb = roc_auc_score(y_test, y_prob_xgb)
    except Exception:
        auc_xgb = np.nan

    results.append({
        "model": "XGBoost",
        "roc_auc": auc_xgb
    })

df_results = pd.DataFrame(results)
print(df_results)

print('''
Discussion prompts for students:
- Which model performs best on ROC AUC?
- Does a more complex model always win?
- How might business stakeholders think about:
  * Precision vs recall
  * Interpretability vs accuracy
''')

input("\nPress Enter to continue to simple interpretation...")

# =============================================================================
# STEP 9: SIMPLE INTERPRETATION EXAMPLE (RANDOM FOREST)
# =============================================================================
print("\n[STEP 9] SIMPLE FEATURE IMPORTANCE (RANDOM FOREST)")
print("-" * 80)

# NOTE: Because we used a Pipeline with one-hot encoding,
# getting feature importances is slightly more advanced.
# Here, to keep it simple, we recompute RF on a fully encoded dataset.

# ----- BEGIN SECTION YOU CAN TURN INTO A STUDENT EXERCISE -----
# (Later you can delete this and ask students to implement encoded RF + importances.)

# Fit encoder separately to see transformed feature names
encoder = OneHotEncoder(handle_unknown="ignore")
encoder.fit(X_train[cat_cols])

encoded_cat_cols = encoder.get_feature_names_out(cat_cols)
all_feature_names = list(encoded_cat_cols) + num_cols

# Transform data
X_train_enc = np.hstack([
    encoder.transform(X_train[cat_cols]).toarray(),
    X_train[num_cols].values
])
X_test_enc = np.hstack([
    encoder.transform(X_test[cat_cols]).toarray(),
    X_test[num_cols].values
])

rf_plain = RandomForestClassifier(
    n_estimators=200,
    random_state=42,
    class_weight="balanced_subsample",
    n_jobs=-1
)
rf_plain.fit(X_train_enc, y_train)

importances = rf_plain.feature_importances_
fi = pd.DataFrame({
    "feature": all_feature_names,
    "importance": importances
}).sort_values("importance", ascending=False)

print("\nTop 10 most important features (Random Forest):")
print(fi.head(10))
# ----- END SECTION YOU CAN TURN INTO A STUDENT EXERCISE -----

print('''
Discussion prompts:
- Do the important features match your intuition?
- How would you explain these patterns to a non-technical stakeholder?
''')

input("\nPress Enter for closing remarks...")

# =============================================================================
# STEP 10: WRAP-UP
# =============================================================================
print("\n[STEP 10] WRAP-UP")
print("-" * 80)
print('''
In this hands-on:
- You saw an end-to-end workflow:
  * Data collection from an external API
  * Cleaning and feature engineering
  * Multiple models and basic comparison
  * Simple interpretation

Next step in the classroom:
- Split into groups.
- Each group owns one sub-problem (data quality, model tuning, evaluation, or communication).
- Present your decisions as if you were an industry data science team.
''')

print("\nWorkshop script finished.")
print("=" * 80)


CLINICAL TRIAL ENROLLMENT SUCCESS – HANDS-ON WORKSHOP

[STEP 1] THE BUSINESS PROBLEM
--------------------------------------------------------------------------------

Context:
- Many clinical trials struggle to recruit enough patients on time.
- This creates delays, extra costs, and sometimes failed projects.

Our simplified goal for this workshop:
- Build ML models to classify trials into two groups:
  * Likely to be "successful" (higher enrollment)
  * Likely to "struggle" (lower enrollment)
- Compare different models and discuss what works better and why.

Press Enter to continue...

[STEP 2] COLLECTING DATA FROM CLINICALTRIALS.GOV API
--------------------------------------------------------------------------------
Fetching ~300 'cancer' trials from ClinicalTrials.gov...
Retrieved 300 studies.

Sample of raw data:
        nct_id                                        brief_title   phase  \
0  NCT03706625  Integrated Discovery of New Immuno-Molecular A...    None   
1  NCT01665625  R

# Pull data for offline use

In [None]:
# import requests
# import pandas as pd

# # -------------------------------
# # Configuration
# # -------------------------------
# CONDITION = "cancer"      # change if you want a different topic
# PAGE_SIZE = 100           # max 1000; use 100 to be gentle on API
# TARGET_N = 1000           # how many studies you want in total (approx)
# OUTFILE = "clinicaltrials_1000_cancer.csv"

# base_url = "https://clinicaltrials.gov/api/v2/studies"

# params = {
#     "query.cond": CONDITION,
#     "pageSize": PAGE_SIZE,
#     "format": "json"
# }

# all_rows = []
# next_token = None
# total_fetched = 0

# while True:
#     if next_token:
#         params["pageToken"] = next_token
#     elif "pageToken" in params:
#         params.pop("pageToken")

#     print(f"Requesting page with params: {params}")
#     resp = requests.get(base_url, params=params, timeout=60)

#     if resp.status_code != 200:
#         print("Request failed with status:", resp.status_code)
#         break

#     data = resp.json()
#     studies = data.get("studies", [])
#     print(f"  Retrieved {len(studies)} studies")

#     # Parse each study into a flat dict
#     for s in studies:
#         protocol = s.get("protocolSection", {})

#         ident = protocol.get("identificationModule", {})
#         design = protocol.get("designModule", {})
#         cond_module = protocol.get("conditionsModule", {})
#         status_mod = protocol.get("statusModule", {})
#         contacts = protocol.get("contactsLocationsModule", {})

#         enrollment_info = design.get("enrollmentInfo", {})
#         enrollment = enrollment_info.get("count")

#         phases = design.get("phases")
#         if isinstance(phases, list):
#             phase = phases[0]
#         else:
#             phase = phases

#         conditions = cond_module.get("conditions")
#         if isinstance(conditions, list):
#             conditions = ", ".join(conditions)

#         locations = contacts.get("locations", [])
#         num_locations = len(locations)
#         countries = sorted(
#             list({loc.get("country") for loc in locations if loc.get("country")})
#         )
#         num_countries = len(countries)
#         countries_str = ", ".join(countries)

#         start_date = status_mod.get("startDateStruct", {}).get("date")
#         primary_completion_date = status_mod.get("primaryCompletionDateStruct", {}).get("date")
#         overall_status = status_mod.get("overallStatus")

#         row = {
#             "nct_id": ident.get("nctId"),
#             "brief_title": ident.get("briefTitle"),
#             "phase": phase,
#             "conditions": conditions,
#             "overall_status": overall_status,
#             "enrollment_count": enrollment,
#             "num_locations": num_locations,
#             "num_countries": num_countries,
#             "countries": countries_str,
#             "start_date": start_date,
#             "primary_completion_date": primary_completion_date,
#         }
#         all_rows.append(row)
#         total_fetched += 1

#         if total_fetched >= TARGET_N:
#             break

#     if total_fetched >= TARGET_N:
#         print(f"Reached target of {TARGET_N} studies.")
#         break

#     # pagination token
#     next_token = data.get("nextPageToken")
#     if not next_token:
#         print("No nextPageToken found; stopping.")
#         break

# print(f"\nTotal studies collected: {len(all_rows)}")

# # Convert to DataFrame and save
# df = pd.DataFrame(all_rows)
# df.to_csv(OUTFILE, index=False)
# print(f"Saved CSV to: {OUTFILE}")
# df.head()


Requesting page with params: {'query.cond': 'cancer', 'pageSize': 100, 'format': 'json'}
  Retrieved 100 studies
Requesting page with params: {'query.cond': 'cancer', 'pageSize': 100, 'format': 'json', 'pageToken': 'ZVNj7o2Elu8o3lpvVdH4qbrumpOQJJxsZvCl0A'}
  Retrieved 100 studies
Requesting page with params: {'query.cond': 'cancer', 'pageSize': 100, 'format': 'json', 'pageToken': 'ZVNj7o2Elu8o3lpvVdH4qbrumpOQJJxoZ_Ol0A'}
  Retrieved 100 studies
Requesting page with params: {'query.cond': 'cancer', 'pageSize': 100, 'format': 'json', 'pageToken': 'ZVNj7o2Elu8o3lpvVdH4qbrumpOQJJxuZfSh3_g'}
  Retrieved 100 studies
Requesting page with params: {'query.cond': 'cancer', 'pageSize': 100, 'format': 'json', 'pageToken': 'ZVNj7o2Elu8o3lpvVdH4qbrumpOQJJxuZvOg2fg'}
  Retrieved 100 studies
Requesting page with params: {'query.cond': 'cancer', 'pageSize': 100, 'format': 'json', 'pageToken': 'ZVNj7o2Elu8o3lpvVdH4qbrumpOQJJxuYPai2Pg'}
  Retrieved 100 studies
Requesting page with params: {'query.cond': 

Unnamed: 0,nct_id,brief_title,phase,conditions,overall_status,enrollment_count,num_locations,num_countries,countries,start_date,primary_completion_date
0,NCT03315195,Preoperative Oral Nutritional Supplement vs Co...,,"Surgery--Complications, Nutrition Aspect of Ca...",UNKNOWN,268.0,1,1,Thailand,2017-11-25,2021-11
1,NCT02606539,Surgery Plus Single Agent Chemotherapy Versus ...,PHASE2,Gestational Trophoblastic Neoplasms,UNKNOWN,20.0,2,1,Egypt,2015-09,2017-06
2,NCT07021677,"Use of a New Medicine ""Daratumumab"" to Treat L...",PHASE2,T Acute Lymphoblastic Leukemia,RECRUITING,18.0,1,1,India,2023-08-28,2026-08-28
3,NCT01308775,Comparing (SIS.NET) to Standard Care in Patien...,,Breast Cancer,COMPLETED,102.0,1,1,United States,2011-01,2014-03
4,NCT05038657,"Atezolizumab Immunotherapy, in Immunotherapy N...",PHASE2,"Squamous Cell Carcinoma, Urinary Tract Cancer",UNKNOWN,36.0,1,1,United Kingdom,2022-05-30,2024-01


In [None]:
df

Unnamed: 0,nct_id,brief_title,phase,conditions,overall_status,enrollment_count,num_locations,num_countries,countries,start_date,primary_completion_date
0,NCT03315195,Preoperative Oral Nutritional Supplement vs Co...,,"Surgery--Complications, Nutrition Aspect of Ca...",UNKNOWN,268.0,1,1,Thailand,2017-11-25,2021-11
1,NCT02606539,Surgery Plus Single Agent Chemotherapy Versus ...,PHASE2,Gestational Trophoblastic Neoplasms,UNKNOWN,20.0,2,1,Egypt,2015-09,2017-06
2,NCT07021677,"Use of a New Medicine ""Daratumumab"" to Treat L...",PHASE2,T Acute Lymphoblastic Leukemia,RECRUITING,18.0,1,1,India,2023-08-28,2026-08-28
3,NCT01308775,Comparing (SIS.NET) to Standard Care in Patien...,,Breast Cancer,COMPLETED,102.0,1,1,United States,2011-01,2014-03
4,NCT05038657,"Atezolizumab Immunotherapy, in Immunotherapy N...",PHASE2,"Squamous Cell Carcinoma, Urinary Tract Cancer",UNKNOWN,36.0,1,1,United Kingdom,2022-05-30,2024-01
...,...,...,...,...,...,...,...,...,...,...,...
995,NCT00553371,Follow-up Evaluation Using CT Scans in Patient...,,Testicular Germ Cell Tumor,UNKNOWN,300.0,1,1,United Kingdom,2006-04,
996,NCT00898144,Study of Pap Smears From Patients Enrolled on ...,,"Cervical Cancer, Precancerous Condition",UNKNOWN,55.0,0,0,,2008-02,2009-06
997,NCT01652144,"A Phase II Study of AT7519M, a CDK Inhibitor, ...",PHASE2,Mantle Cell Lymphoma,COMPLETED,12.0,6,1,Canada,2012-09-14,2014-12-02
998,NCT02558244,Impact of Image-defined Risk Factors on the Ou...,,Neuroblastoma,UNKNOWN,100.0,1,1,Egypt,2016-01,2023-09
