# Patient Dropout Prediction Model

This notebook builds an XGBoost classifier to predict patient dropout from clinical trials using Python libraries.

**Features:**
- **Age**: Patient age (0-99)
- **Gender**: Patient gender (MALE/FEMALE)
- **Target**: PATIENT_DROPPED indicator (1 = dropped out, 0 = completed)

**Tech Stack:**
- Snowpark for Python (data access)
- XGBoost (gradient boosting classifier)
- scikit-learn (preprocessing & metrics)
- pandas (data manipulation)


## 1. Import Libraries and Setup

### Why This Tech Stack?

**Python with Snowpark:**
- Access Snowflake data directly without data movement
- Leverage Snowflake's compute power for data operations
- Convert seamlessly to pandas for ML operations

**XGBoost:**
- Industry-standard gradient boosting algorithm
- Handles imbalanced datasets (common in dropout scenarios)
- Provides feature importance for clinical interpretability
- No feature scaling required (works with raw age values)
- Proven performance in healthcare applications

**scikit-learn:**
- Standard library for preprocessing and evaluation
- Extensive metrics for model assessment
- Compatible with XGBoost workflow

**Visualization Libraries:**
- matplotlib/seaborn for clear, publication-ready charts
- Essential for communicating findings to clinical stakeholders


In [None]:
### Import Required Libraries

# Snowpark for Python
from snowflake.snowpark.context import get_active_session
import snowflake.snowpark.functions as F

# Data science libraries
import pandas as pd
import numpy as np
from xgboost import XGBClassifier

# Scikit-learn for preprocessing and metrics
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, 
    precision_score, 
    recall_score, 
    f1_score,
    confusion_matrix,
    classification_report,
    roc_auc_score,
    roc_curve
)

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Misc
import warnings
warnings.simplefilter('ignore')


In [None]:
# Establish Snowflake session
session = get_active_session()

# Add query tag for tracking
session.query_tag = {
    "origin": "patient_dropout_ml",
    "model": "xgboost_classifier",
    "version": {"major": 1, "minor": 0}
}

print("Session established successfully!")
session


## 2. Load Training Data

Load the patient dropout data from INFORMATICS_SANDBOX.ML_TEST.DOR_ANALYSIS_FF

### Why Load Data with Snowpark?

**Benefits:**
- **No data movement**: Query data directly in Snowflake without extracting to local files
- **Security**: Data stays within Snowflake's secure environment
- **Scalability**: Leverages Snowflake's warehouse compute for large datasets
- **Lazy evaluation**: Operations are pushed down to Snowflake for efficiency

**Process:**
1. Connect via `get_active_session()` - uses existing notebook session
2. Reference table directly - Snowpark creates a DataFrame pointer
3. Convert to pandas only when needed for ML operations


In [None]:
# Load the training data from Snowflake table
patient_data_df = session.table("INFORMATICS_SANDBOX.ML_TEST.DOR_ANALYSIS_FF")

# Display basic info
print(f"Total records: {patient_data_df.count()}")
patient_data_df.show()


In [None]:
# Check the schema
patient_data_df.describe()

# Show column names and types
for field in patient_data_df.schema.fields:
    print(f"{field.name}: {field.datatype}")


## 3. Exploratory Data Analysis

### Why EDA is Critical for Clinical Trials

**Understanding the data before modeling:**
- **Identify patterns**: Do certain age groups or genders have higher dropout rates?
- **Data quality**: Check for missing values, outliers, or data entry errors
- **Class imbalance**: Is dropout rare or common? Affects model choice and evaluation
- **Feature relationships**: How do age and gender relate to dropout?

**Clinical Value:**
- Provides actionable insights even before predictive modeling
- Helps trial coordinators understand risk factors
- Validates that the data matches clinical expectations
- Identifies populations that may need additional support

**For Stakeholders:**
- Visualizations are easier to understand than raw statistics
- Age group analysis reveals which demographics need intervention
- Gender analysis identifies potential bias or fairness concerns


In [None]:
# Convert to pandas for analysis and visualization
patient_pd = patient_data_df.to_pandas()

# Display column names (important for debugging)
print("Column names in the data:")
print(patient_pd.columns.tolist())
print("\nData types:")
print(patient_pd.dtypes)
print("\nSample data:")
print(patient_pd.head(10))


In [None]:
# Validate that required columns exist
required_columns = ['AGE', 'GENDER', 'PATIENT_DROPPED']
missing_columns = [col for col in required_columns if col not in patient_pd.columns]

if missing_columns:
    print(f"⚠️ WARNING: Missing required columns: {missing_columns}")
    print(f"Available columns: {patient_pd.columns.tolist()}")
    print("\nPlease check your data source. Expected columns:")
    print("  - AGE (or age)")
    print("  - GENDER (or gender)")
    print("  - PATIENT_DROPPED (or patient_dropped)")
else:
    print("✓ All required columns found!")


In [None]:
# Normalize column names to uppercase for consistency
patient_pd.columns = patient_pd.columns.str.upper()

# Display the normalized column names
print("Normalized column names:")
print(patient_pd.columns.tolist())

# Check data distribution
print("\n=== Dataset Overview ===")
print(f"Total patients: {len(patient_pd)}")
print(f"Dropout count: {patient_pd['PATIENT_DROPPED'].sum()}")
print(f"Dropout percentage: {patient_pd['PATIENT_DROPPED'].mean() * 100:.2f}%")
print(f"\nAge statistics:")
print(f"  Mean age: {patient_pd['AGE'].mean():.2f}")
print(f"  Min age: {patient_pd['AGE'].min()}")
print(f"  Max age: {patient_pd['AGE'].max()}")
print(f"  Std age: {patient_pd['AGE'].std():.2f}")

# Check for missing values
print(f"\nMissing values:")
print(patient_pd.isnull().sum())


In [None]:
# Dropout rate by gender
print("=== Dropout Rate by Gender ===")
gender_stats = patient_pd.groupby('GENDER').agg({
    'PATIENT_DROPPED': ['count', 'sum', 'mean']
}).round(4)
gender_stats.columns = ['Total_Patients', 'Dropout_Count', 'Dropout_Rate']
gender_stats['Dropout_Percentage'] = gender_stats['Dropout_Rate'] * 100
print(gender_stats)


In [None]:
# Dropout rate by age group
print("\n=== Dropout Rate by Age Group ===")
patient_pd['AGE_GROUP'] = pd.cut(
    patient_pd['AGE'], 
    bins=[0, 30, 50, 70, 100],
    labels=['18-29', '30-49', '50-69', '70+']
)

age_stats = patient_pd.groupby('AGE_GROUP').agg({
    'PATIENT_DROPPED': ['count', 'sum', 'mean']
}).round(4)
age_stats.columns = ['Total_Patients', 'Dropout_Count', 'Dropout_Rate']
age_stats['Dropout_Percentage'] = age_stats['Dropout_Rate'] * 100
print(age_stats)

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Dropout rate by age group
age_stats['Dropout_Percentage'].plot(kind='bar', ax=axes[0], color='steelblue')
axes[0].set_title('Dropout Rate by Age Group')
axes[0].set_ylabel('Dropout Percentage (%)')
axes[0].set_xlabel('Age Group')
axes[0].tick_params(axis='x', rotation=45)

# Dropout rate by gender
gender_stats['Dropout_Percentage'].plot(kind='bar', ax=axes[1], color='coral')
axes[1].set_title('Dropout Rate by Gender')
axes[1].set_ylabel('Dropout Percentage (%)')
axes[1].set_xlabel('Gender')
axes[1].tick_params(axis='x', rotation=0)

plt.tight_layout()
plt.show()


## 4. Data Preprocessing

Prepare features for XGBoost model training by encoding categorical variables.

### Why Encode Gender as Binary?

**Machine Learning Requirement:**
- ML models require numerical input - cannot process text like "MALE"/"FEMALE"
- Binary encoding (0/1) is the simplest and most interpretable approach

**Why Binary Over Other Encoding Methods:**

1. **vs. One-Hot Encoding**: 
   - Binary uses 1 column instead of 2 (MALE, FEMALE)
   - More efficient for tree-based models like XGBoost
   - No dummy variable trap

2. **vs. Ordinal Encoding**: 
   - Binary doesn't imply order (MALE ≠ "greater than" FEMALE)
   - Prevents model from learning false relationships

3. **vs. Label Encoding**: 
   - Binary is clearer: MALE=1, FEMALE=0
   - Easy to interpret feature importance

**XGBoost Compatibility:**
- Tree-based models handle binary features efficiently
- Creates clear splits: "Is Male?" → Yes/No

### Why Normalize Column Names?

- **Consistency**: Prevents KeyError due to case sensitivity
- **Reliability**: Works regardless of Snowflake table schema
- **Best Practice**: Standard data engineering approach


In [None]:
# Define feature columns and target
FEATURE_COLUMNS = ['AGE', 'GENDER']
TARGET_COLUMN = 'PATIENT_DROPPED'

# Create a clean dataframe with only required columns
df_clean = patient_pd[FEATURE_COLUMNS + [TARGET_COLUMN]].copy()

# Handle case variations in gender
df_clean['GENDER'] = df_clean['GENDER'].str.upper()

# Encode gender: MALE = 1, FEMALE = 0
df_clean['GENDER_ENCODED'] = (df_clean['GENDER'] == 'MALE').astype(int)

# Drop the original gender column
df_clean = df_clean.drop('GENDER', axis=1)

# Remove AGE_GROUP if it exists (was created for EDA only)
if 'AGE_GROUP' in df_clean.columns:
    df_clean = df_clean.drop('AGE_GROUP', axis=1)

print("Preprocessed data shape:", df_clean.shape)
print("\nFirst few rows:")
print(df_clean.head())


In [None]:
# Verify no missing values and data types
print("Data Info:")
print(df_clean.info())
print("\nData Statistics:")
print(df_clean.describe())
print("\nClass distribution:")
print(df_clean[TARGET_COLUMN].value_counts())


## 5. Train/Test Split

Split the data into training and testing sets following best practices.

### Why 80/20 Split?

**Industry Standard:**
- **80% training**: Provides sufficient data for model to learn patterns
- **20% testing**: Large enough for reliable performance estimates
- Common practice in healthcare ML applications

**Alternative Split Ratios:**
- 70/30: Use if you have limited data
- 90/10: Use if you have very large datasets
- 60/20/20: Add validation set for hyperparameter tuning

**For Clinical Trials:**
- 80/20 balances learning with evaluation
- Test set represents ~20% of future patients
- Enough data to detect performance issues

### Why Stratification is Critical?

**What is Stratification:**
- Maintains the same dropout rate in both train and test sets
- If overall dropout is 25%, both sets will have ~25% dropout

**Why It Matters:**
- **Prevents bias**: Without it, test set might be easier/harder than reality
- **Reliable metrics**: Ensures test performance reflects true model capability
- **Clinical validity**: Test set represents the same patient mix as training

**Example Without Stratification:**
- Training set: 30% dropout (model learns from harder cases)
- Test set: 15% dropout (model appears better than it is)
- **Result**: Overly optimistic performance estimates

**With Stratification:**
- Both sets: ~25% dropout
- Fair evaluation of model's true capability

### Why random_state=42?

- **Reproducibility**: Same split every time code runs
- **Debugging**: Can investigate specific patients in test set
- **Comparison**: Others can replicate results
- **42**: Computer science tradition (Hitchhiker's Guide to the Galaxy)


In [None]:
# Prepare features (X) and target (y)
X = df_clean[['AGE', 'GENDER_ENCODED']]
y = df_clean[TARGET_COLUMN]

# Split into train and test sets (80/20 split)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, 
    test_size=0.2, 
    random_state=42,
    stratify=y  # Maintain class distribution in both sets
)

print(f"Training set size: {len(X_train)} ({len(X_train)/len(X)*100:.1f}%)")
print(f"Test set size: {len(X_test)} ({len(X_test)/len(X)*100:.1f}%)")
print(f"\nTraining set dropout rate: {y_train.mean()*100:.2f}%")
print(f"Test set dropout rate: {y_test.mean()*100:.2f}%")


## 6. Train XGBoost Model

Train an XGBoost classifier following the pattern from MEDPACE_ML_HOL notebooks.

### Why XGBoost for Clinical Trial Dropout Prediction?

**Technical Advantages:**

1. **Gradient Boosting:**
   - Builds trees sequentially, each correcting previous errors
   - More accurate than single decision trees
   - Better than random forests for tabular data

2. **Handles Imbalanced Data:**
   - Clinical trial dropout is often imbalanced (more completions than dropouts)
   - XGBoost's regularization prevents overfitting to majority class
   - Can use `scale_pos_weight` if needed

3. **No Feature Scaling Required:**
   - Works directly with raw age values (18-99)
   - Unlike logistic regression or neural networks
   - Simplifies preprocessing pipeline

4. **Feature Importance:**
   - Tells us which features drive predictions
   - Critical for clinical stakeholders: "Why did the model predict this?"
   - Helps validate model makes sense medically

5. **Proven in Healthcare:**
   - Used in patient risk prediction, readmission forecasting
   - Published in medical journals
   - Trusted by healthcare data scientists

**Model Parameters Explained:**

- **n_estimators=100**: Build 100 sequential trees
  - More trees = better accuracy but slower training
  - 100 is a good starting point
  
- **max_depth=6**: Each tree can be 6 levels deep
  - Deeper trees capture more complex patterns
  - Too deep = overfitting risk
  - 6 balances complexity and generalization

- **learning_rate=0.1**: How much each tree contributes
  - Lower = more conservative, needs more trees
  - 0.1 is standard default
  - Can reduce to 0.01 with more trees for better performance

- **random_state=42**: Reproducible results

- **eval_metric='logloss'**: Measures prediction confidence
  - Lower is better
  - Standard for classification

### Why XGBoost vs. Other Algorithms?

| Algorithm | Pros | Cons | Use Case |
|-----------|------|------|----------|
| **XGBoost** ✅ | Accurate, handles imbalance, interpretable | Needs tuning | **Best for this task** |
| Logistic Regression | Simple, fast, interpretable | Assumes linearity, needs scaling | Good baseline only |
| Random Forest | Good accuracy, parallel training | Less accurate than XGBoost | Alternative choice |
| Neural Networks | Can learn complex patterns | Black box, needs lots of data | Overkill for 2 features |
| Decision Tree | Very interpretable | Poor accuracy, overfits | Too simple |

**For Clinical Trials:**
- XGBoost balances accuracy with interpretability
- Feature importance helps clinical teams understand risk factors
- Industry standard for healthcare ML


In [None]:
# Initialize XGBoost Classifier
xgb_model = XGBClassifier(
    n_estimators=100,        # Number of trees
    max_depth=6,             # Maximum tree depth
    learning_rate=0.1,       # Step size shrinkage
    random_state=42,
    eval_metric='logloss'    # Evaluation metric
)

# Train the model
print("Training XGBoost model...")
xgb_model.fit(X_train, y_train)
print("Training complete!")

# Display feature importance
feature_names = ['AGE', 'GENDER_ENCODED']
feature_importance = pd.DataFrame({
    'Feature': feature_names,
    'Importance': xgb_model.feature_importances_
}).sort_values('Importance', ascending=False)

print("\nFeature Importance:")
print(feature_importance)


In [None]:
# Visualize feature importance
plt.figure(figsize=(10, 5))
plt.barh(feature_importance['Feature'], feature_importance['Importance'], color='skyblue')
plt.xlabel('Importance Score')
plt.title('XGBoost Feature Importance')
plt.tight_layout()
plt.show()


## 7. Make Predictions

Generate predictions on both training and test sets.


In [None]:
# Generate predictions on test set
y_test_pred = xgb_model.predict(X_test)
y_test_pred_proba = xgb_model.predict_proba(X_test)[:, 1]  # Probability of dropout

# Generate predictions on training set (to check for overfitting)
y_train_pred = xgb_model.predict(X_train)
y_train_pred_proba = xgb_model.predict_proba(X_train)[:, 1]

print("Predictions generated successfully!")
print(f"Test predictions shape: {y_test_pred.shape}")
print(f"Training predictions shape: {y_train_pred.shape}")


In [None]:
# Create a predictions dataframe for test set
test_predictions_df = pd.DataFrame({
    'AGE': X_test['AGE'].values,
    'GENDER_ENCODED': X_test['GENDER_ENCODED'].values,
    'Actual_Dropout': y_test.values,
    'Predicted_Dropout': y_test_pred,
    'Dropout_Probability': y_test_pred_proba
})

print("Sample predictions:")
print(test_predictions_df.head(10))


In [None]:
## 8. Model Evaluation - Test Set Performance

Calculate comprehensive evaluation metrics on the test set.

### Why These Metrics Matter for Clinical Trials

**Accuracy:**
- **What**: Percentage of correct predictions (dropouts + completions)
- **Clinical Meaning**: Overall reliability of predictions
- **Limitation**: Can be misleading if data is imbalanced
- **Example**: 85% accuracy means 85 out of 100 predictions are correct

**Precision (Positive Predictive Value):**
- **What**: Of patients predicted to dropout, how many actually drop out?
- **Formula**: True Positives / (True Positives + False Positives)
- **Clinical Meaning**: How confident can we be when model says "will dropout"?
- **Use Case**: Resource allocation - don't waste intervention on false alarms
- **Example**: 70% precision = 70% of "predicted dropouts" actually dropout

**Recall (Sensitivity):**
- **What**: Of patients who actually dropout, how many did we catch?
- **Formula**: True Positives / (True Positives + False Negatives)
- **Clinical Meaning**: Are we missing patients who need intervention?
- **Use Case**: Patient safety - catching at-risk patients is critical
- **Example**: 80% recall = we catch 8 out of 10 actual dropouts

**F1 Score:**
- **What**: Harmonic mean of precision and recall
- **Formula**: 2 × (Precision × Recall) / (Precision + Recall)
- **Clinical Meaning**: Balance between catching dropouts and avoiding false alarms
- **Best When**: Need to balance both precision and recall
- **Range**: 0 (worst) to 1 (perfect)

**ROC AUC:**
- **What**: Area Under Receiver Operating Characteristic Curve
- **Range**: 0.5 (random guessing) to 1.0 (perfect)
- **Clinical Meaning**: Overall model discrimination ability
- **Interpretation**:
  - 0.9-1.0: Excellent
  - 0.8-0.9: Good
  - 0.7-0.8: Fair
  - 0.6-0.7: Poor
  - 0.5-0.6: Fail (barely better than random)

### Why Check Training Performance Too?

**Overfitting Detection:**
- **Training accuracy >> Test accuracy**: Model memorized training data
- **Example**: Train 95%, Test 70% = Overfitting
- **Solution**: Reduce model complexity, add regularization, get more data

**Underfitting Detection:**
- **Both low**: Model too simple to learn patterns
- **Example**: Train 60%, Test 58% = Underfitting
- **Solution**: Increase model complexity, add features

**Healthy Model:**
- **Similar performance**: Train 85%, Test 82%
- **Generalizes well**: Will work on new patients


# Calculate test set metrics
test_accuracy = accuracy_score(y_test, y_test_pred)
test_precision = precision_score(y_test, y_test_pred)
test_recall = recall_score(y_test, y_test_pred)
test_f1 = f1_score(y_test, y_test_pred)
test_auc = roc_auc_score(y_test, y_test_pred_proba)

print("=== TEST SET PERFORMANCE ===")
print(f"Accuracy:  {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
print(f"Precision: {test_precision:.4f}")
print(f"Recall:    {test_recall:.4f}")
print(f"F1 Score:  {test_f1:.4f}")
print(f"ROC AUC:   {test_auc:.4f}")


In [None]:
# Calculate training set metrics (to check for overfitting)
train_accuracy = accuracy_score(y_train, y_train_pred)
train_precision = precision_score(y_train, y_train_pred)
train_recall = recall_score(y_train, y_train_pred)
train_f1 = f1_score(y_train, y_train_pred)
train_auc = roc_auc_score(y_train, y_train_pred_proba)

print("\n=== TRAINING SET PERFORMANCE ===")
print(f"Accuracy:  {train_accuracy:.4f} ({train_accuracy*100:.2f}%)")
print(f"Precision: {train_precision:.4f}")
print(f"Recall:    {train_recall:.4f}")
print(f"F1 Score:  {train_f1:.4f}")
print(f"ROC AUC:   {train_auc:.4f}")

# Check for overfitting
print("\n=== OVERFITTING CHECK ===")
print(f"Accuracy difference: {abs(train_accuracy - test_accuracy):.4f}")
if abs(train_accuracy - test_accuracy) < 0.05:
    print("✓ Model generalizes well (difference < 5%)")
else:
    print("⚠ Possible overfitting detected")


In [None]:
# Confusion Matrix
cm = confusion_matrix(y_test, y_test_pred)

print("\n=== CONFUSION MATRIX ===")
print(f"True Negatives:  {cm[0,0]}")
print(f"False Positives: {cm[0,1]}")
print(f"False Negatives: {cm[1,0]}")
print(f"True Positives:  {cm[1,1]}")

# Visualize confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['No Dropout', 'Dropout'],
            yticklabels=['No Dropout', 'Dropout'])
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.title('Confusion Matrix - Test Set')
plt.show()


In [None]:
# Classification Report
print("\n=== CLASSIFICATION REPORT ===")
print(classification_report(y_test, y_test_pred, 
                          target_names=['No Dropout', 'Dropout']))


## 9. ROC Curve and AUC

Visualize model performance across different classification thresholds.

### What is ROC Curve?

**ROC = Receiver Operating Characteristic**

**What It Shows:**
- X-axis: False Positive Rate (incorrectly predicted dropouts)
- Y-axis: True Positive Rate (correctly predicted dropouts)
- Plots performance at every possible threshold (0.0 to 1.0)

**How to Read It:**
- **Top-left corner**: Perfect model (100% TPR, 0% FPR)
- **Diagonal line**: Random guessing (no better than coin flip)
- **Area under curve (AUC)**: Overall performance metric

### Why ROC/AUC Matters for Clinical Decisions

**Threshold Trade-offs:**

By default, models predict "dropout" if probability > 0.5. But we can adjust:

1. **Lower threshold (e.g., 0.3):**
   - Catch more dropouts (higher recall)
   - More false alarms (lower precision)
   - **Use when**: Missing a dropout is costly
   - **Example**: Critical trial, need all-hands intervention

2. **Higher threshold (e.g., 0.7):**
   - Fewer false alarms (higher precision)
   - Miss some dropouts (lower recall)
   - **Use when**: Resources are limited
   - **Example**: Focus only on highest-risk patients

3. **Default threshold (0.5):**
   - Balanced approach
   - **Use when**: No strong preference either way

**Clinical Application:**

The ROC curve helps trial coordinators decide:
- "Should we intervene at 30% risk or 70% risk?"
- "What's the cost of false alarms vs. missed patients?"
- "How do we balance resources with patient safety?"

**AUC Interpretation:**
- **0.9-1.0**: Excellent - Trust the model's risk scores
- **0.8-0.9**: Good - Useful for triaging patients
- **0.7-0.8**: Fair - Use with caution, needs improvement
- **< 0.7**: Poor - Not reliable for clinical decisions

### Why Better Than Accuracy Alone?

**Example Scenario:**
- 90% of patients complete the trial (10% dropout)
- A "dummy model" that predicts "complete" for everyone = 90% accuracy!
- But it's useless - catches 0% of dropouts
- **AUC for dummy model = 0.5** (random guessing)
- **AUC for good model = 0.85** (actually useful)

**ROC/AUC captures the full picture:**
- Works even with imbalanced data
- Shows trade-offs at all thresholds
- Single number (AUC) summarizes overall performance


In [None]:
# Calculate ROC curve
fpr, tpr, thresholds = roc_curve(y_test, y_test_pred_proba)

# Plot ROC curve
plt.figure(figsize=(10, 7))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {test_auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Classifier')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid(alpha=0.3)
plt.show()


In [None]:
## 10. Predict on New Patients

Apply the trained model to score new patients for dropout risk.

### Why Production Scoring Matters

**Moving from Model to Action:**
- Training and evaluation are academic exercises
- **Production scoring** is where the model creates value
- This section shows how the model would be used in real clinical operations

**Real-World Workflow:**
1. **New patients enroll** in clinical trial
2. **Capture demographics** (age, gender)
3. **Run model** to get dropout probability
4. **Categorize risk** (High/Medium/Low)
5. **Trigger interventions** for high-risk patients

### Risk Categorization Strategy

**Why Risk Categories?**
- Clinical teams need actionable categories, not just probabilities
- "70% dropout risk" → "High Risk" is clearer for non-technical staff
- Enables standardized intervention protocols

**Our Categories:**
- **High Risk (≥70%)**: Immediate intervention required
  - Assign dedicated coordinator
  - Weekly check-ins
  - Address barriers proactively
  
- **Medium Risk (40-69%)**: Enhanced monitoring
  - Bi-weekly check-ins
  - Watch for warning signs
  - Provide extra support if needed
  
- **Low Risk (<40%)**: Standard protocol
  - Regular scheduled visits
  - Standard communication
  - Routine follow-up

**Adjusting Thresholds:**
These thresholds (70%, 40%) can be adjusted based on:
- Available resources
- Trial criticality
- Historical dropout costs
- Intervention effectiveness

### Production Considerations

**Data Pipeline:**
- New patients → Database → Model → Risk Score → Intervention System
- Automated daily/weekly batch scoring
- Or real-time scoring at enrollment

**Model Monitoring:**
- Track prediction distribution over time
- Alert if patterns change (model drift)
- Regular retraining with new data

**Clinical Integration:**
- Export predictions to trial management system
- Dashboard showing high-risk patients
- Automatic alerts to coordinators


# Create sample new patients to score
new_patients_data = pd.DataFrame({
    'AGE': [25, 45, 65, 30, 75, 22, 55, 40, 80, 28],
    'GENDER': ['FEMALE', 'MALE', 'FEMALE', 'MALE', 'FEMALE', 
               'MALE', 'FEMALE', 'MALE', 'MALE', 'FEMALE']
})

print("New patients to score:")
print(new_patients_data)


In [None]:
# Preprocess new patients (same as training data)
new_patients_data['GENDER_ENCODED'] = (new_patients_data['GENDER'].str.upper() == 'MALE').astype(int)

# Prepare features for prediction
X_new = new_patients_data[['AGE', 'GENDER_ENCODED']]

print("\nPreprocessed features:")
print(X_new)


In [None]:
# Make predictions on new patients
new_predictions = xgb_model.predict(X_new)
new_predictions_proba = xgb_model.predict_proba(X_new)[:, 1]

# Create results dataframe
new_patients_results = new_patients_data.copy()
new_patients_results['Predicted_Dropout'] = new_predictions
new_patients_results['Dropout_Probability'] = new_predictions_proba

# Add risk category
def categorize_risk(prob):
    if prob >= 0.7:
        return 'High Risk'
    elif prob >= 0.4:
        return 'Medium Risk'
    else:
        return 'Low Risk'

new_patients_results['Risk_Category'] = new_patients_results['Dropout_Probability'].apply(categorize_risk)

print("\n=== NEW PATIENT PREDICTIONS ===")
print(new_patients_results[['AGE', 'GENDER', 'Dropout_Probability', 'Risk_Category']].sort_values('Dropout_Probability', ascending=False))


In [None]:
# Visualize risk distribution for new patients
risk_counts = new_patients_results['Risk_Category'].value_counts()

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Risk category distribution
risk_counts.plot(kind='bar', ax=axes[0], color=['red', 'orange', 'green'])
axes[0].set_title('Risk Category Distribution - New Patients')
axes[0].set_ylabel('Count')
axes[0].set_xlabel('Risk Category')
axes[0].tick_params(axis='x', rotation=45)

# Dropout probability distribution
axes[1].hist(new_patients_results['Dropout_Probability'], bins=10, color='steelblue', edgecolor='black')
axes[1].set_title('Dropout Probability Distribution')
axes[1].set_xlabel('Dropout Probability')
axes[1].set_ylabel('Frequency')
axes[1].axvline(0.4, color='orange', linestyle='--', label='Medium Risk Threshold')
axes[1].axvline(0.7, color='red', linestyle='--', label='High Risk Threshold')
axes[1].legend()

plt.tight_layout()
plt.show()


In [None]:
## 11. Summary and Next Steps


### Model Summary

**Data Source:**
- Table: INFORMATICS_SANDBOX.ML_TEST.DOR_ANALYSIS_FF
- Features: Age (numeric), Gender (categorical)
- Target: Patient_Dropped (binary: 1=dropped out, 0=completed)

**Model Details:**
- Algorithm: XGBoost Classifier
- Implementation: Python (xgboost library)
- Train/Test Split: 80/20 with stratification

**Model Performance:**
- Test Accuracy: {printed above}
- Test AUC: {printed above}  
- Precision, Recall, F1: {printed above}

**Workflow:**
1. ✅ Load data from Snowflake using Snowpark
2. ✅ Exploratory data analysis with visualizations
3. ✅ Feature encoding (Gender → binary)
4. ✅ Train/test split with stratification
5. ✅ XGBoost model training
6. ✅ Comprehensive evaluation (accuracy, precision, recall, F1, AUC, confusion matrix, ROC curve)
7. ✅ Production scoring on new patients with risk categorization

### Potential Improvements

1. **Feature Engineering:**
   - Add medical history features
   - Include trial duration and phase
   - Add previous trial participation data
   - Create age bins or polynomial features

2. **Model Enhancements:**
   - Hyperparameter tuning with GridSearchCV or RandomizedSearchCV
   - Try ensemble methods (Random Forest, LightGBM)
   - Implement SMOTE or class weighting if imbalanced
   - Feature selection techniques

3. **MLOps:**
   - Integrate with Snowflake Model Registry
   - Set up model monitoring and drift detection
   - Create automated retraining pipeline
   - Deploy as Snowflake UDF for real-time scoring

4. **Validation:**
   - Implement k-fold cross-validation
   - Test on multiple clinical trial datasets
   - Perform temporal validation (train on old data, test on recent)

5. **Interpretability:**
   - Add SHAP values for model explainability
   - Create feature importance visualizations
   - Analyze misclassified cases
5. Address class imbalance if present (SMOTE, class weights)
6. Create a production deployment pipeline
7. Integrate with Snowflake Feature Store and Model Registry


## 12. Model Persistence (Optional)

Save the trained XGBoost model for future use.


In [None]:
# Optional: Save model to file for later use
# Uncomment to save the model

# import joblib
# import os

# # Create models directory if it doesn't exist
# os.makedirs('/tmp/models', exist_ok=True)

# # Save the model
# model_path = '/tmp/models/patient_dropout_xgboost.joblib'
# joblib.dump(xgb_model, model_path)
# print(f"Model saved to: {model_path}")

# # To load the model later:
# # loaded_model = joblib.load(model_path)
# # predictions = loaded_model.predict(X_new)

print("Model training and evaluation complete!")
