# Notebook 05: Drug Shortage Prediction

**WISE Workshop | Addis Ababa, Feb 2026**

In this notebook, we'll replicate the approach from Roe et al. (2025), who applied Random Forest to predict drug shortages in South Korea. We adapt their methods to our Ethiopian supply chain dataset.

**Reference:** Roe et al. (2025). "Drug shortage in South Korea: machine learning-based prediction models and analysis of duration and causal factors." *Frontiers in Pharmacology*. [DOI: 10.3389/fphar.2025.1608843](https://doi.org/10.3389/fphar.2025.1608843)

## Part 1: Background

### The Drug Shortage Problem

Drug shortages disrupt healthcare delivery worldwide. Key challenges:
- **Unpredictable timing**: When will shortages occur?
- **Unknown duration**: How long will they last?
- **Multiple causes**: Manufacturing, raw materials, regulatory, demand surges

### The Korean Study Approach

Roe et al. built two ML models:

| Model | Task | Target | Performance |
|-------|------|--------|-------------|
| **Model 1** | Duration prediction | Short/Medium/Long | 62% accuracy |
| **Model 2** | Cause classification | 5 cause categories | >70% F1-score |

**Top predictors identified:**
1. Shortage frequency (how often the drug has been short before)
2. Import status (domestic vs. imported)
3. Alternative drug availability

## Setup

In [None]:
# Import packages
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.metrics import (
    accuracy_score, classification_report, confusion_matrix,
    f1_score, mean_squared_error, r2_score
)

import warnings
warnings.filterwarnings('ignore')

print("Packages loaded!")

## Part 2: Load and Prepare Data

We'll use our Ethiopian supply chain dataset, which now includes shortage-related columns.

In [None]:
# Load data from GitHub
url = "https://raw.githubusercontent.com/sysylvia/ethiopia-ds-workshop-2026/main/data/supply-chain-sample.csv"
df = pd.read_csv(url)

print(f"Data shape: {df.shape}")
print(f"\nColumns: {df.columns.tolist()}")
df.head()

In [None]:
# Explore the shortage-related columns
print("Shortage Duration Categories:")
print(df['shortage_duration_category'].value_counts())
print("\nShortage Cause Categories:")
print(df['shortage_cause'].value_counts())

### Data Dictionary: New Shortage Columns

| Column | Description | Mapping to Korean Study |
|--------|-------------|------------------------|
| `shortage_occurred` | Binary: did a shortage happen? | Similar to their outcome |
| `shortage_duration_days` | Duration in days (if occurred) | Their Model 1 target |
| `shortage_duration_category` | Short/Medium/Long | Their Model 1 target (binned) |
| `shortage_cause` | Cause category | Their Model 2 target |
| `shortage_frequency` | Historical shortage count | Their top predictor |

## Part 3: Duration Prediction (Model 1)

Like Roe et al., we'll predict shortage duration categories using Random Forest.

**Their result:** 62% accuracy

In [None]:
# Filter to rows where shortage occurred
df_shortage = df[df['shortage_occurred'] == 1].copy()
print(f"Shortage events: {len(df_shortage)}")

# Check class distribution
print("\nDuration category distribution:")
print(df_shortage['shortage_duration_category'].value_counts(normalize=True).round(3))

In [None]:
# Prepare features for Model 1 (Duration Prediction)
# Map categorical variables
facility_type_map = {'Hospital': 0, 'Health Center': 1, 'Clinic': 2}
region_map = {'Addis Ababa': 0, 'Oromia': 1, 'Amhara': 2, 'SNNP': 3, 'Tigray': 4}
season_map = {'dry': 0, 'rainy': 1}

df_shortage['facility_type_enc'] = df_shortage['facility_type'].map(facility_type_map)
df_shortage['region_enc'] = df_shortage['region'].map(region_map)
df_shortage['season_enc'] = df_shortage['season'].map(season_map)

# Define features (similar to Korean study predictors)
feature_cols = [
    'shortage_frequency',      # Top predictor in Korean study
    'facility_type_enc',       # Maps to their "drug characteristics"
    'region_enc',              # Geographic factor
    'distance_to_warehouse',   # Supply chain factor (like import status)
    'season_enc',              # Temporal factor
    'previous_demand',         # Demand pressure
    'stockout_last_month',     # Recent history
    'avg_delivery_days'        # Logistics factor
]

X_duration = df_shortage[feature_cols]
y_duration = df_shortage['shortage_duration_category']

print(f"Features: {feature_cols}")
print(f"\nX shape: {X_duration.shape}")

In [None]:
# Encode target variable
le_duration = LabelEncoder()
y_duration_enc = le_duration.fit_transform(y_duration)

print("Duration categories:", le_duration.classes_)

In [None]:
# Train/test split
X_train_dur, X_test_dur, y_train_dur, y_test_dur = train_test_split(
    X_duration, y_duration_enc, test_size=0.2, random_state=42, stratify=y_duration_enc
)

print(f"Training set: {len(X_train_dur)} samples")
print(f"Test set: {len(X_test_dur)} samples")

In [None]:
# Train Random Forest classifier (like Korean study)
rf_duration = RandomForestClassifier(
    n_estimators=100,
    max_depth=10,
    min_samples_leaf=5,
    random_state=42,
    n_jobs=-1
)

rf_duration.fit(X_train_dur, y_train_dur)

# Predictions
y_pred_dur = rf_duration.predict(X_test_dur)

# Evaluate
accuracy_dur = accuracy_score(y_test_dur, y_pred_dur)
f1_dur = f1_score(y_test_dur, y_pred_dur, average='weighted')

print("=" * 50)
print("MODEL 1: Duration Prediction Results")
print("=" * 50)
print(f"Our Accuracy: {accuracy_dur:.1%}")
print(f"Korean Study: 62%")
print(f"\nOur F1-Score: {f1_dur:.1%}")
print("=" * 50)

In [None]:
# Cross-validation for more robust estimate
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
cv_scores = cross_val_score(rf_duration, X_duration, y_duration_enc, cv=cv, scoring='accuracy')

print(f"5-Fold CV Accuracy: {cv_scores.mean():.1%} (+/- {cv_scores.std():.1%})")

In [None]:
# Classification report
print("\nClassification Report:")
print(classification_report(y_test_dur, y_pred_dur, target_names=le_duration.classes_))

In [None]:
# Confusion matrix
fig, ax = plt.subplots(figsize=(8, 6))
cm = confusion_matrix(y_test_dur, y_pred_dur)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=le_duration.classes_,
            yticklabels=le_duration.classes_, ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('Actual')
ax.set_title('Duration Prediction: Confusion Matrix')
plt.tight_layout()
plt.show()

## Part 4: Cause Classification (Model 2)

Like Roe et al.'s second model, we'll predict the cause of shortages.

**Their result:** >70% F1-score

In [None]:
# Check cause distribution
print("Shortage Cause Distribution:")
print(df_shortage['shortage_cause'].value_counts(normalize=True).round(3))

In [None]:
# Prepare data for Model 2
y_cause = df_shortage['shortage_cause']

# Encode target
le_cause = LabelEncoder()
y_cause_enc = le_cause.fit_transform(y_cause)

print("Cause categories:", le_cause.classes_)

In [None]:
# Train/test split for cause prediction
X_train_cause, X_test_cause, y_train_cause, y_test_cause = train_test_split(
    X_duration, y_cause_enc, test_size=0.2, random_state=42, stratify=y_cause_enc
)

print(f"Training set: {len(X_train_cause)} samples")
print(f"Test set: {len(X_test_cause)} samples")

In [None]:
# Train Random Forest for cause classification
rf_cause = RandomForestClassifier(
    n_estimators=100,
    max_depth=10,
    min_samples_leaf=5,
    random_state=42,
    n_jobs=-1,
    class_weight='balanced'  # Handle imbalanced classes
)

rf_cause.fit(X_train_cause, y_train_cause)

# Predictions
y_pred_cause = rf_cause.predict(X_test_cause)

# Evaluate
accuracy_cause = accuracy_score(y_test_cause, y_pred_cause)
f1_cause = f1_score(y_test_cause, y_pred_cause, average='weighted')

print("=" * 50)
print("MODEL 2: Cause Classification Results")
print("=" * 50)
print(f"Our Accuracy: {accuracy_cause:.1%}")
print(f"Our F1-Score: {f1_cause:.1%}")
print(f"Korean Study: >70% F1")
print("=" * 50)

In [None]:
# Classification report for causes
print("\nClassification Report:")
print(classification_report(y_test_cause, y_pred_cause, target_names=le_cause.classes_))

In [None]:
# Confusion matrix for cause prediction
fig, ax = plt.subplots(figsize=(10, 8))
cm_cause = confusion_matrix(y_test_cause, y_pred_cause)
sns.heatmap(cm_cause, annot=True, fmt='d', cmap='Greens',
            xticklabels=le_cause.classes_,
            yticklabels=le_cause.classes_, ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('Actual')
ax.set_title('Cause Classification: Confusion Matrix')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

## Part 5: Feature Importance Comparison

The Korean study found these were the top predictors:
1. **Shortage frequency** (most important)
2. Import status
3. Alternative drug availability

Let's see what our models find!

In [None]:
# Feature importance from both models
importance_dur = pd.DataFrame({
    'feature': feature_cols,
    'importance': rf_duration.feature_importances_
}).sort_values('importance', ascending=False)

importance_cause = pd.DataFrame({
    'feature': feature_cols,
    'importance': rf_cause.feature_importances_
}).sort_values('importance', ascending=False)

# Plot side by side
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

sns.barplot(data=importance_dur, x='importance', y='feature', ax=axes[0], palette='Blues_d')
axes[0].set_title('Model 1: Duration Prediction\nFeature Importance')
axes[0].set_xlabel('Importance')

sns.barplot(data=importance_cause, x='importance', y='feature', ax=axes[1], palette='Greens_d')
axes[1].set_title('Model 2: Cause Classification\nFeature Importance')
axes[1].set_xlabel('Importance')

plt.tight_layout()
plt.show()

In [None]:
# Compare our top predictors to Korean study
print("Top 3 Predictors Comparison")
print("=" * 60)
print("\nKorean Study (Roe et al., 2025):")
print("  1. Shortage frequency")
print("  2. Import status")
print("  3. Alternative drug availability")
print("\nOur Duration Model:")
for i, row in importance_dur.head(3).iterrows():
    print(f"  {importance_dur.index.get_loc(i)+1}. {row['feature']} ({row['importance']:.3f})")
print("\nOur Cause Model:")
for i, row in importance_cause.head(3).iterrows():
    print(f"  {importance_cause.index.get_loc(i)+1}. {row['feature']} ({row['importance']:.3f})")

## Part 6: Discussion Questions

**Think about:**

1. **How does our accuracy compare to the Korean study?**
   - They achieved 62% for duration, >70% F1 for causes
   - What might explain differences?

2. **Are the important features similar?**
   - Korean study: shortage frequency was #1
   - What drives predictions in our context?

3. **How would you improve these models?**
   - More features? (e.g., manufacturer data, seasonality)
   - Different algorithms?
   - More data?

4. **How could these predictions be used?**
   - Early warning systems
   - Resource allocation
   - Policy planning

## Summary

In this notebook, we:

1. **Replicated** the Roe et al. (2025) approach for drug shortage prediction
2. **Built Model 1** for duration prediction (classification)
3. **Built Model 2** for cause classification
4. **Compared** our feature importance to their findings
5. **Discussed** implications for supply chain management

### Key Takeaways

- Random Forest can predict shortage characteristics with reasonable accuracy
- Historical shortage patterns are highly predictive
- These models can inform proactive supply chain management

### Connection to WISE Project

This same approach can be applied to predict:
- Which facilities are at risk of stockouts
- How long disruptions might last
- What interventions might be most effective

**Reference:** Roe, Y., Lee, S., Kim, C., & Lee, J. (2025). Drug shortage in South Korea: machine learning-based prediction models and analysis of duration and causal factors. *Frontiers in Pharmacology*, 16, 1608843. [DOI: 10.3389/fphar.2025.1608843](https://doi.org/10.3389/fphar.2025.1608843)