# Day 4: Agent-Based Model for Antibiotic Supply Chains

**WISE Workshop | Addis Ababa, Feb 2026**

In this notebook, you'll build a comprehensive agent-based model (ABM) simulating Ethiopia's antibiotic supply chain. **Critically, the ML demand predictions from Day 2 ARE the actual demand** that the supply chain must serve. This model tests supply chain resilience by asking: "Given this predicted demand, how well can the supply chain respond under various scenarios?"

The supply chain structure includes:
- **2 Manufacturers** producing antibiotics
- **1 Central Medical Store** (EPSA) distributing nationally
- **3 Regional Hospitals** serving as intermediate distribution points
- **100 Community Health Centers** delivering care based on predicted demand
- **Health Workers** with realistic attendance patterns

**Key concept:** The ML forecasts from Day 2 directly drive demand in this model. Scenarios modify either:
- **Demand** (e.g., outbreaks multiply predicted demand by 3x)
- **Supply** (e.g., manufacturer failures, transit delays)

We'll explore 8 scenarios to understand supply chain resilience and policy implications.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sysylvia/ethiopia-ds-workshop-2026/blob/main/notebooks/06-antibiotic-supply-chain-abm.ipynb)

## Setup

In [None]:
# Install Mesa if not available (for Google Colab)
try:
    import mesa
    print(f"Mesa version: {mesa.__version__}")
except ImportError:
    print("Installing Mesa...")
    !pip install mesa>=2.0.0
    import mesa
    print(f"Mesa installed! Version: {mesa.__version__}")

In [None]:
# Import packages
import mesa
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from enum import Enum
import warnings

warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)

print("All packages loaded!")

In [None]:
# Try to import ipywidgets for interactive features
try:
    import ipywidgets as widgets
    from IPython.display import display, clear_output
    WIDGETS_AVAILABLE = True
    print("Interactive widgets available!")
except ImportError:
    WIDGETS_AVAILABLE = False
    print("ipywidgets not available - interactive features disabled")

In [None]:
# Try to import networkx for supply chain visualization
try:
    import networkx as nx
    NETWORKX_AVAILABLE = True
    print("NetworkX available for supply chain visualization!")
except ImportError:
    NETWORKX_AVAILABLE = False
    print("NetworkX not available - network visualization disabled")

## Part 1: Generate Demand Predictions

**Important conceptual point:** The demand predictions we generate here ARE the actual demand that each CHC will experience. In Day 2, we built ML models to predict antibiotic demand based on population, seasonality, and trends. Here, we use the same formula to generate that predicted demand, which directly drives the ABM.

This is different from traditional forecasting models where forecasts might differ from actual outcomes. In this ABM:
- **ML Predictions = Actual Demand** (the ground truth for the simulation)
- **Supply Chain Response** = What we're testing (can it meet this demand?)
- **Scenarios** modify either demand (outbreaks) or supply (failures, delays)

In practice, you would load trained ML models from Day 2. For self-contained execution, we recreate the demand generation formula here.

In [None]:
# Define parameters matching Day 2 notebook
n_facilities = 100
n_months = 60  # 5 years of data
antibiotic_classes = ['Penicillins', 'Macrolides', 'Fluoroquinolones']

# Create facility IDs
facility_ids = [f'CHC_{i:03d}' for i in range(1, n_facilities + 1)]

# Assign population served to each facility (varies from 5,000 to 50,000)
facility_populations = {fid: np.random.randint(5000, 50001) for fid in facility_ids}

# Define antibiotic class characteristics
class_params = {
    'Penicillins': {
        'base_demand_per_1000': 15,  # High volume
        'seasonal_amplitude': 0.3,   # Strong seasonality
        'peak_month': 1,             # Peak in January (cold/flu season)
        'trend': 0.02                # 2% annual growth
    },
    'Macrolides': {
        'base_demand_per_1000': 8,   # Medium volume
        'seasonal_amplitude': 0.4,   # Very seasonal (respiratory)
        'peak_month': 12,            # Peak in December
        'trend': 0.03                # 3% annual growth
    },
    'Fluoroquinolones': {
        'base_demand_per_1000': 4,   # Lower volume
        'seasonal_amplitude': 0.1,   # More stable
        'peak_month': 7,             # Slight peak in rainy season
        'trend': 0.01                # 1% annual growth
    }
}

print(f"Facilities: {n_facilities}")
print(f"Months: {n_months}")
print(f"Antibiotic classes: {antibiotic_classes}")

In [None]:
def generate_monthly_forecast(facility_id: str, month: int, year: int, 
                              abx_class: str, outbreak_multiplier: float = 1.0) -> dict:
    """
    Generate a demand forecast for a specific facility/month/antibiotic.
    Returns point estimate and uncertainty bounds.
    """
    population = facility_populations[facility_id]
    params = class_params[abx_class]
    
    # Base demand scaled by population
    base = params['base_demand_per_1000'] * (population / 1000)
    
    # Seasonal effect
    seasonal = params['seasonal_amplitude'] * np.sin(
        2 * np.pi * (month - params['peak_month']) / 12
    )
    
    # Year-over-year trend (from start of simulation)
    years_from_start = (year - 1) + (month - 1) / 12
    trend = 1 + params['trend'] * years_from_start
    
    # Calculate demand
    demand = base * (1 + seasonal) * trend * outbreak_multiplier
    
    # Add uncertainty (15% standard deviation)
    uncertainty = demand * 0.15
    
    return {
        'expected_demand': int(max(1, demand)),
        'demand_lower': int(max(1, demand - 1.96 * uncertainty)),
        'demand_upper': int(max(1, demand + 1.96 * uncertainty)),
        'uncertainty': uncertainty
    }

# Test the forecast function
test_forecast = generate_monthly_forecast('CHC_001', 6, 1, 'Penicillins')
print(f"Test forecast for CHC_001, June Year 1, Penicillins:")
print(f"  Expected: {test_forecast['expected_demand']}")
print(f"  Range: [{test_forecast['demand_lower']}, {test_forecast['demand_upper']}]")

## Part 2: Model Configuration

Define all simulation parameters in a central configuration dictionary.

In [None]:
# Central configuration for the ABM
CONFIG = {
    # Supply chain structure
    'n_manufacturers': 2,
    'n_central_stores': 1,
    'n_hospitals': 3,
    'n_chcs': 100,
    
    # Simulation duration
    'n_months': 60,
    
    # Logistics parameters
    'transit_time': 2,          # months from manufacturer to CHC
    'order_lead_time': 2,       # months ahead facilities order
    'medicine_shelf_life': 12,  # months before expiry
    
    # Health worker parameters
    'health_worker_absenteeism': 0.10,  # 10% daily absence rate
    
    # Disease incidence per 1000 population per month (requiring antibiotics)
    # Children have highest rates, then elderly, then adults
    'incidence': {
        'child': 8.0,    # Higher respiratory infections
        'elderly': 5.0,  # Vulnerable population
        'adult': 2.0     # Lowest rate
    },
    
    # Death rates if untreated (per untreated case)
    # Elderly most vulnerable, then children, then adults
    'death_rates': {
        'elderly': 0.08,  # 8% mortality if untreated
        'child': 0.05,    # 5% mortality if untreated
        'adult': 0.02     # 2% mortality if untreated
    },
    
    # Population demographics (% of CHC catchment)
    'demographics': {
        'child': 0.35,    # 35% children (<15)
        'adult': 0.55,    # 55% adults (15-64)
        'elderly': 0.10   # 10% elderly (65+)
    },
    
    # Capacity parameters (monthly units)
    'manufacturer_capacity': 50000,  # Per manufacturer per month
    'central_store_capacity': 100000,
    'hospital_capacity': 20000,
    'chc_capacity': 2000,
    
    # Initial stock levels (% of capacity)
    'initial_stock_pct': 0.5,
    
    # Antibiotic classes
    'antibiotic_classes': ['Penicillins', 'Macrolides', 'Fluoroquinolones']
}

print("Configuration defined:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

In [None]:
# Assign CHCs to hospitals (regional distribution)
def assign_chcs_to_hospitals(n_chcs: int, n_hospitals: int) -> Dict[int, List[str]]:
    """
    Distribute CHCs across hospitals roughly equally.
    Returns dict mapping hospital_id to list of CHC IDs.
    """
    assignments = {i: [] for i in range(n_hospitals)}
    chc_ids = [f'CHC_{i:03d}' for i in range(1, n_chcs + 1)]
    
    for idx, chc_id in enumerate(chc_ids):
        hospital_idx = idx % n_hospitals
        assignments[hospital_idx].append(chc_id)
    
    return assignments

CHC_HOSPITAL_ASSIGNMENTS = assign_chcs_to_hospitals(CONFIG['n_chcs'], CONFIG['n_hospitals'])

print("CHC to Hospital assignments:")
for hosp_id, chcs in CHC_HOSPITAL_ASSIGNMENTS.items():
    print(f"  Hospital {hosp_id}: {len(chcs)} CHCs ({chcs[:3]}...)")

## Part 3: Agent Classes

We define agents for each entity in the supply chain. Mesa agents need a `unique_id` and reference to the `model`.

In [None]:
# Enumerations for clarity
class AgeGroup(Enum):
    CHILD = 'child'
    ADULT = 'adult'
    ELDERLY = 'elderly'

class MedicineStatus(Enum):
    IN_TRANSIT = 'in_transit'
    IN_STOCK = 'in_stock'
    DISPENSED = 'dispensed'
    EXPIRED = 'expired'

print("Enumerations defined")

In [None]:
class MedicineBatch:
    """
    Represents a batch of medicines with tracking for expiry and transit.
    Not a Mesa agent - just a data class for tracking inventory.
    """
    
    def __init__(self, medicine_type: str, quantity: int, manufacture_month: int,
                 shelf_life: int = 12):
        self.medicine_type = medicine_type
        self.quantity = quantity
        self.manufacture_month = manufacture_month
        self.expiry_month = manufacture_month + shelf_life
        self.status = MedicineStatus.IN_STOCK
        self.transit_months_remaining = 0
    
    def is_expired(self, current_month: int) -> bool:
        """Check if batch has expired."""
        return current_month >= self.expiry_month
    
    def months_until_expiry(self, current_month: int) -> int:
        """Get months remaining before expiry."""
        return max(0, self.expiry_month - current_month)
    
    def __repr__(self):
        return f"MedicineBatch({self.medicine_type}, qty={self.quantity}, exp_month={self.expiry_month})"

# Test
test_batch = MedicineBatch('Penicillins', 100, manufacture_month=1)
print(f"Test batch: {test_batch}")
print(f"  Expired at month 6? {test_batch.is_expired(6)}")
print(f"  Expired at month 15? {test_batch.is_expired(15)}")

In [None]:
class LocationAgent(mesa.Agent):
    """
    Base class for all location agents in the supply chain.
    Handles inventory management, shipments, and orders.
    """
    
    def __init__(self, unique_id, model, capacity: int, location_type: str):
        super().__init__(unique_id, model)  # Mesa 2.x requires (unique_id, model)
        self.capacity = capacity
        self.location_type = location_type
        
        # Inventory: dict of medicine_type -> list of MedicineBatch
        self.inventory: Dict[str, List[MedicineBatch]] = defaultdict(list)
        
        # Pending orders: list of (order_month, medicine_type, quantity)
        self.pending_orders: List[Tuple[int, str, int]] = []
        
        # Incoming shipments: list of (arrival_month, MedicineBatch)
        self.incoming_shipments: List[Tuple[int, MedicineBatch]] = []
        
        # Metrics
        self.total_received = defaultdict(int)
        self.total_shipped = defaultdict(int)
        self.total_expired = defaultdict(int)
    
    def get_stock_level(self, medicine_type: str) -> int:
        """Get current stock of a medicine type."""
        return sum(batch.quantity for batch in self.inventory[medicine_type]
                   if batch.status == MedicineStatus.IN_STOCK)
    
    def get_total_stock(self) -> int:
        """Get total stock across all medicine types."""
        return sum(self.get_stock_level(med) for med in self.model.config['antibiotic_classes'])
    
    def receive_shipment(self, batch: MedicineBatch):
        """Receive a medicine batch into inventory."""
        batch.status = MedicineStatus.IN_STOCK
        self.inventory[batch.medicine_type].append(batch)
        self.total_received[batch.medicine_type] += batch.quantity
    
    def process_incoming_shipments(self, current_month: int):
        """Process shipments that have arrived."""
        arrived = []
        still_in_transit = []
        
        for arrival_month, batch in self.incoming_shipments:
            if current_month >= arrival_month:
                self.receive_shipment(batch)
                arrived.append(batch)
            else:
                still_in_transit.append((arrival_month, batch))
        
        self.incoming_shipments = still_in_transit
        return arrived
    
    def ship_to(self, destination: 'LocationAgent', medicine_type: str, 
                quantity: int, transit_time: int, current_month: int) -> int:
        """
        Ship medicines to another location.
        Uses FEFO (First Expiry, First Out) to minimize waste.
        Returns actual quantity shipped.
        """
        # Sort batches by expiry (earliest first - FEFO)
        available_batches = sorted(
            [b for b in self.inventory[medicine_type] 
             if b.status == MedicineStatus.IN_STOCK and b.quantity > 0],
            key=lambda b: b.expiry_month
        )
        
        shipped = 0
        for batch in available_batches:
            if shipped >= quantity:
                break
            
            take = min(batch.quantity, quantity - shipped)
            batch.quantity -= take
            shipped += take
            
            # Create new batch for shipment
            shipped_batch = MedicineBatch(
                medicine_type=medicine_type,
                quantity=take,
                manufacture_month=batch.manufacture_month,
                shelf_life=batch.expiry_month - batch.manufacture_month
            )
            shipped_batch.status = MedicineStatus.IN_TRANSIT
            
            # Add to destination's incoming shipments
            arrival_month = current_month + transit_time
            destination.incoming_shipments.append((arrival_month, shipped_batch))
        
        # Clean up empty batches
        self.inventory[medicine_type] = [b for b in self.inventory[medicine_type] 
                                         if b.quantity > 0]
        
        self.total_shipped[medicine_type] += shipped
        return shipped
    
    def process_expiry(self, current_month: int) -> Dict[str, int]:
        """Remove expired medicines. Returns dict of expired quantities."""
        expired = defaultdict(int)
        
        for med_type in list(self.inventory.keys()):
            valid_batches = []
            for batch in self.inventory[med_type]:
                if batch.is_expired(current_month):
                    expired[med_type] += batch.quantity
                    batch.status = MedicineStatus.EXPIRED
                    self.total_expired[med_type] += batch.quantity
                else:
                    valid_batches.append(batch)
            self.inventory[med_type] = valid_batches
        
        return dict(expired)

print("LocationAgent base class defined")

In [None]:
class ManufacturerAgent(LocationAgent):
    """
    Manufacturer agent that produces antibiotics.
    Can be affected by disruptions (failure scenarios).
    """
    
    def __init__(self, unique_id, model, capacity: int):
        super().__init__(unique_id, model, capacity, 'manufacturer')
        self.monthly_production_capacity = capacity
        self.operational = True
        self.recovery_month = None  # Month when operations resume after failure
    
    def produce(self, current_month: int) -> Dict[str, int]:
        """
        Produce medicines for the month.
        Production is split across antibiotic classes based on historical demand ratios.
        """
        if not self.operational:
            if self.recovery_month and current_month >= self.recovery_month:
                self.operational = True
                print(f"  Manufacturer {self.unique_id} resumed operations at month {current_month}")
            else:
                return {}
        
        # Split production: 60% Penicillins, 30% Macrolides, 10% Fluoroquinolones
        production_split = {
            'Penicillins': 0.60,
            'Macrolides': 0.30,
            'Fluoroquinolones': 0.10
        }
        
        produced = {}
        for med_type, pct in production_split.items():
            qty = int(self.monthly_production_capacity * pct)
            batch = MedicineBatch(
                medicine_type=med_type,
                quantity=qty,
                manufacture_month=current_month,
                shelf_life=self.model.config['medicine_shelf_life']
            )
            self.inventory[med_type].append(batch)
            produced[med_type] = qty
        
        return produced
    
    def fail(self, current_month: int, recovery_months: int):
        """Simulate manufacturer failure."""
        self.operational = False
        self.recovery_month = current_month + recovery_months
        print(f"  Manufacturer {self.unique_id} FAILED at month {current_month}, recovery at month {self.recovery_month}")

print("ManufacturerAgent defined")

In [None]:
class CentralMedicalStoreAgent(LocationAgent):
    """
    Central Medical Store (EPSA) - national distribution hub.
    Receives from manufacturers, distributes to hospitals.
    """
    
    def __init__(self, unique_id, model, capacity: int):
        super().__init__(unique_id, model, capacity, 'central_store')
    
    def distribute_to_hospitals(self, hospitals: List['HospitalAgent'], 
                                 current_month: int, transit_time: int):
        """
        Distribute medicines to hospitals based on their needs.
        Uses proportional allocation based on downstream CHC count.
        """
        for med_type in self.model.config['antibiotic_classes']:
            available = self.get_stock_level(med_type)
            if available == 0:
                continue
            
            # Calculate shares based on number of CHCs served
            total_chcs = sum(len(h.served_chcs) for h in hospitals)
            
            for hospital in hospitals:
                share = len(hospital.served_chcs) / total_chcs if total_chcs > 0 else 0
                allocation = int(available * share * 0.8)  # Ship 80% of share
                
                if allocation > 0:
                    self.ship_to(hospital, med_type, allocation, transit_time, current_month)

print("CentralMedicalStoreAgent defined")

In [None]:
class HospitalAgent(LocationAgent):
    """
    Regional Hospital - intermediate distribution point.
    Receives from central store, distributes to CHCs.
    """
    
    def __init__(self, unique_id, model, capacity: int, served_chc_ids: List[str]):
        super().__init__(unique_id, model, capacity, 'hospital')
        self.served_chc_ids = served_chc_ids
        self.served_chcs: List['CommunityHealthCenterAgent'] = []  # Set during model init
    
    def distribute_to_chcs(self, current_month: int, transit_time: int):
        """
        Distribute medicines to CHCs based on their forecasted demand.
        """
        for med_type in self.model.config['antibiotic_classes']:
            available = self.get_stock_level(med_type)
            if available == 0:
                continue
            
            # Calculate total forecasted demand
            total_demand = sum(
                chc.get_forecast(current_month, med_type)['expected_demand']
                for chc in self.served_chcs
            )
            
            if total_demand == 0:
                continue
            
            for chc in self.served_chcs:
                forecast = chc.get_forecast(current_month, med_type)
                # Proportional allocation based on forecast
                share = forecast['expected_demand'] / total_demand
                allocation = int(available * share * 0.9)  # Ship 90% of share
                
                if allocation > 0:
                    self.ship_to(chc, med_type, allocation, transit_time, current_month)

print("HospitalAgent defined")

In [None]:
class CommunityHealthCenterAgent(LocationAgent):
    """
    Community Health Center - point of care delivery.
    
    Demand is derived directly from ML predictions (via get_forecast()),
    split by age group demographics. This ensures the actual demand
    matches what the ML models predict, testing supply chain resilience.
    """
    
    def __init__(self, unique_id, model, capacity: int, population_served: int):
        super().__init__(unique_id, model, capacity, 'chc')
        self.population_served = population_served
        self.chc_id = unique_id  # For compatibility with forecast function
        
        # Health worker assigned to this CHC
        self.health_worker: Optional['HealthWorkerAgent'] = None
        
        # Metrics
        self.patients_treated = defaultdict(int)  # by age group
        self.patients_untreated = defaultdict(int)  # by age group (shortage)
        self.patients_missed = defaultdict(int)  # by age group (worker absent)
        self.deaths = defaultdict(int)  # by age group
        self.shortages = defaultdict(int)  # by medicine type
    
    def get_forecast(self, month: int, medicine_type: str) -> dict:
        """
        Get demand forecast for this CHC.
        Uses the ML forecast generation function.
        """
        year = (month - 1) // 12 + 1
        month_of_year = ((month - 1) % 12) + 1
        
        # Check for outbreak multiplier in scenario
        outbreak_mult = 1.0
        if hasattr(self.model, 'outbreak_chcs') and self.unique_id in self.model.outbreak_chcs:
            outbreak_mult = self.model.outbreak_multiplier
        
        return generate_monthly_forecast(
            self.unique_id, month_of_year, year, medicine_type, outbreak_mult
        )
    
    def get_actual_demand(self, current_month: int) -> Dict[str, Dict[str, int]]:
        """
        Get actual demand for the month, derived from ML predictions.
        
        The ML predictions represent the total expected antibiotic demand.
        We split this by age group using demographic proportions to maintain
        age-specific outcome tracking (deaths affect age groups differently).
        
        Returns:
            Dict of {antibiotic_class: {age_group: count}}
        """
        demands = {}
        config = self.model.config
        
        for abx_class in config['antibiotic_classes']:
            forecast = self.get_forecast(current_month, abx_class)
            total_demand = forecast['expected_demand']
            
            # Split by age group using demographics
            # This preserves age-specific outcomes while using ML-predicted total
            demands[abx_class] = {}
            for age_group, pct in config['demographics'].items():
                demands[abx_class][age_group] = int(total_demand * pct)
        
        return demands
    
    def process_demand(self, current_month: int) -> Dict:
        """
        Process actual demand for the month.
        
        Demand comes directly from ML predictions (via get_actual_demand).
        Stock is consumed to meet demand, with shortages causing untreated
        cases and potential deaths based on age-specific mortality rates.
        
        Returns:
            Dict with treated/untreated/deaths/missed by age group
        """
        results = {
            'treated': defaultdict(int),
            'untreated': defaultdict(int),
            'deaths': defaultdict(int),
            'missed': defaultdict(int)  # health worker absent
        }
        
        config = self.model.config
        
        # Check health worker attendance
        worker_present = self.health_worker is None or self.health_worker.is_present()
        
        # Get actual demand from ML predictions
        demands = self.get_actual_demand(current_month)
        
        for abx_class, age_demands in demands.items():
            available = self.get_stock_level(abx_class)
            
            for age_group, demand in age_demands.items():
                if not worker_present:
                    # Health worker absent - all demand missed
                    results['missed'][age_group] += demand
                    self.patients_missed[age_group] += demand
                    continue
                
                # Apply private sector diversion if applicable
                if hasattr(self.model, 'private_sector_diversion') and self.model.private_sector_diversion > 0:
                    diverted = int(demand * self.model.private_sector_diversion)
                    demand = demand - diverted
                
                # Check for AMR scenario (resistance requiring antibiotic switch)
                if abx_class == 'Penicillins' and hasattr(self.model, 'amr_active') and self.model.amr_active:
                    # Some patients need Macrolides instead due to resistance
                    resistant_cases = int(demand * self.model.amr_resistance_rate)
                    if resistant_cases > 0:
                        # Try to serve resistant cases with Macrolides
                        macro_available = self.get_stock_level('Macrolides')
                        macro_served = min(resistant_cases, macro_available)
                        macro_unserved = resistant_cases - macro_served
                        
                        # Update Macrolides inventory
                        self._consume_stock('Macrolides', macro_served)
                        results['treated'][age_group] += macro_served
                        self.patients_treated[age_group] += macro_served
                        
                        # Unserved resistant cases
                        if macro_unserved > 0:
                            results['untreated'][age_group] += macro_unserved
                            self.patients_untreated[age_group] += macro_unserved
                            self.shortages['Macrolides'] += macro_unserved
                            
                            # Deaths from untreated
                            death_rate = config['death_rates'][age_group]
                            deaths = int(macro_unserved * death_rate)
                            results['deaths'][age_group] += deaths
                            self.deaths[age_group] += deaths
                        
                        # Reduce Penicillin demand by resistant cases
                        demand = demand - resistant_cases
                
                # Serve remaining demand from primary antibiotic class
                served = min(demand, available)
                unmet = demand - served
                available -= served
                
                results['treated'][age_group] += served
                self.patients_treated[age_group] += served
                
                if unmet > 0:
                    results['untreated'][age_group] += unmet
                    self.patients_untreated[age_group] += unmet
                    self.shortages[abx_class] += unmet
                    
                    # Deaths from untreated cases
                    death_rate = config['death_rates'][age_group]
                    deaths = int(unmet * death_rate)
                    results['deaths'][age_group] += deaths
                    self.deaths[age_group] += deaths
                
                # Consume stock
                self._consume_stock(abx_class, served)
        
        return results
    
    def _consume_stock(self, medicine_type: str, quantity: int):
        """
        Consume stock using FEFO (First Expiry, First Out).
        """
        remaining = quantity
        for batch in sorted(self.inventory[medicine_type], key=lambda b: b.expiry_month):
            if remaining <= 0:
                break
            if batch.quantity > 0:
                take = min(batch.quantity, remaining)
                batch.quantity -= take
                remaining -= take
        
        # Clean up empty batches
        self.inventory[medicine_type] = [b for b in self.inventory[medicine_type] if b.quantity > 0]

print("CommunityHealthCenterAgent defined")

In [None]:
class HealthWorkerAgent(mesa.Agent):
    """
    Health worker assigned to a CHC.
    Models absenteeism affecting service delivery.
    """
    
    def __init__(self, unique_id, model, assigned_chc: CommunityHealthCenterAgent,
                 absenteeism_rate: float = 0.10):
        super().__init__(unique_id, model)  # Mesa 2.x requires (unique_id, model)
        self.assigned_chc = assigned_chc
        self.absenteeism_rate = absenteeism_rate
        self._present_today = True
        
        # Link back to CHC
        assigned_chc.health_worker = self
    
    def determine_attendance(self):
        """Randomly determine if worker is present this month."""
        self._present_today = np.random.random() > self.absenteeism_rate
    
    def is_present(self) -> bool:
        return self._present_today

print("HealthWorkerAgent defined")

In [None]:
# NOTE: PatientAgent is no longer used in the current implementation.
# The model now processes demand in aggregate (via CommunityHealthCenterAgent.process_demand)
# rather than creating individual patient objects. This is more efficient and conceptually
# cleaner since ML predictions represent aggregate demand, not individual arrivals.
#
# This class is retained for reference and potential future extensions that might
# need individual patient tracking (e.g., disease progression, referral chains).

class PatientAgent:
    """
    [LEGACY - Not actively used]
    
    Patient seeking antibiotic treatment.
    Not a full Mesa agent - created and processed within a step.
    
    The current model uses aggregate demand processing instead of individual
    patient objects. See CommunityHealthCenterAgent.process_demand().
    """
    
    def __init__(self, model, age_group: AgeGroup, assigned_chc: CommunityHealthCenterAgent):
        self.model = model
        self.age_group = age_group
        self.assigned_chc = assigned_chc
        self.treated = False
        self.died = False
        self.missed_opportunity = False  # Health worker absent

print("PatientAgent defined (legacy - not actively used)")

## Part 4: The Supply Chain Model

The main Mesa model class coordinates all agents and runs the simulation.

In [None]:
class EthiopiaSupplyChainModel(mesa.Model):
    """
    Agent-based model of Ethiopia's antibiotic supply chain.
    
    Demand is driven by ML predictions from Day 2 - the forecasts ARE the actual
    demand that the supply chain must serve. The model tests supply chain 
    resilience against this predicted demand under various scenarios.
    """
    
    def __init__(self, config: dict = None, scenario: str = 'base', verbose: bool = True):
        super().__init__()
        self.config = config or CONFIG.copy()
        self.scenario = scenario
        self.verbose = verbose
        self.current_month = 0
        
        # Apply scenario-specific modifications
        self._apply_scenario(scenario)
        
        # Create agents
        self._create_agents()
        
        # Initialize data collection
        self._setup_data_collection()
        
        if self.verbose:
            print(f"Model initialized: scenario='{scenario}'")
            print(f"  Manufacturers: {len(self.manufacturers)}")
            print(f"  Central Stores: {len(self.central_stores)}")
            print(f"  Hospitals: {len(self.hospitals)}")
            print(f"  CHCs: {len(self.chcs)}")
            print(f"  Health Workers: {len(self.health_workers)}")
    
    def _apply_scenario(self, scenario: str):
        """Apply scenario-specific configuration changes."""
        self.outbreak_chcs = set()
        self.outbreak_multiplier = 1.0
        self.amr_active = False
        self.amr_resistance_rate = 0.0
        self.private_sector_diversion = 0.0
        self.manufacturer_failure_month = None
        self.manufacturer_failure_id = None
        self.manufacturer_recovery_months = 12
        
        if scenario == 'base':
            pass  # Default configuration
        
        elif scenario == 'weather_delays':
            self.config['transit_time'] = 4  # Double transit time
        
        elif scenario == 'disease_outbreak':
            # 25 random CHCs with 3x demand
            all_chcs = [f'CHC_{i:03d}' for i in range(1, self.config['n_chcs'] + 1)]
            self.outbreak_chcs = set(np.random.choice(all_chcs, size=25, replace=False))
            self.outbreak_multiplier = 3.0
        
        elif scenario == 'advance_ordering':
            self.config['order_lead_time'] = 4  # Order 4 months ahead
        
        elif scenario == 'manufacturer_failure':
            self.manufacturer_failure_month = 12
            self.manufacturer_failure_id = 0  # First manufacturer
            self.manufacturer_recovery_months = 12
        
        elif scenario == 'optimization_challenge':
            # Students implement their own optimizations
            self.config['order_lead_time'] = 3
            self.config['transit_time'] = 2
        
        elif scenario == 'amr_substitution':
            self.amr_active = False  # Starts inactive
            self.amr_resistance_rate = 0.30  # 30% Penicillin resistance
            self.amr_start_month = 24  # Resistance emerges month 24
        
        elif scenario == 'private_sector':
            self.private_sector_diversion = 0.25  # 25% go to private sector
    
    def _create_agents(self):
        """Create all agents in the supply chain."""
        # Manufacturers
        self.manufacturers = []
        for i in range(self.config['n_manufacturers']):
            mfr = ManufacturerAgent(
                f'MFR_{i}', self, self.config['manufacturer_capacity']
            )
            self.manufacturers.append(mfr)
        
        # Central Medical Stores
        self.central_stores = []
        for i in range(self.config['n_central_stores']):
            cms = CentralMedicalStoreAgent(
                f'CMS_{i}', self, self.config['central_store_capacity']
            )
            self.central_stores.append(cms)
        
        # Hospitals
        self.hospitals = []
        for i in range(self.config['n_hospitals']):
            served_chcs = CHC_HOSPITAL_ASSIGNMENTS[i]
            hosp = HospitalAgent(
                f'HOSP_{i}', self, self.config['hospital_capacity'], served_chcs
            )
            self.hospitals.append(hosp)
        
        # Community Health Centers
        self.chcs = []
        self.chc_lookup = {}  # For quick lookup by ID
        for i in range(1, self.config['n_chcs'] + 1):
            chc_id = f'CHC_{i:03d}'
            chc = CommunityHealthCenterAgent(
                chc_id, self, self.config['chc_capacity'], 
                facility_populations[chc_id]
            )
            self.chcs.append(chc)
            self.chc_lookup[chc_id] = chc
        
        # Link hospitals to their CHCs
        for hosp in self.hospitals:
            hosp.served_chcs = [self.chc_lookup[chc_id] for chc_id in hosp.served_chc_ids]
        
        # Health Workers (one per CHC)
        self.health_workers = []
        for i, chc in enumerate(self.chcs):
            hw = HealthWorkerAgent(
                f'HW_{i}', self, chc, self.config['health_worker_absenteeism']
            )
            self.health_workers.append(hw)
        
        # Initialize inventory with initial stock
        self._initialize_inventory()
    
    def _initialize_inventory(self):
        """Set initial stock levels across the supply chain."""
        initial_pct = self.config['initial_stock_pct']
        
        for agent_list in [self.manufacturers, self.central_stores, self.hospitals, self.chcs]:
            for agent in agent_list:
                for med_type in self.config['antibiotic_classes']:
                    initial_qty = int(agent.capacity * initial_pct / 3)  # Split across 3 classes
                    batch = MedicineBatch(
                        medicine_type=med_type,
                        quantity=initial_qty,
                        manufacture_month=0,
                        shelf_life=self.config['medicine_shelf_life']
                    )
                    agent.inventory[med_type].append(batch)
    
    def _setup_data_collection(self):
        """Initialize data collectors for metrics."""
        self.metrics_history = []
    
    def step(self):
        """Execute one simulation step (one month)."""
        self.current_month += 1
        
        if self.verbose and self.current_month % 12 == 1:
            print(f"Year {(self.current_month - 1) // 12 + 1} starting...")
        
        # Check for manufacturer failure event
        if (self.manufacturer_failure_month and 
            self.current_month == self.manufacturer_failure_month):
            self.manufacturers[self.manufacturer_failure_id].fail(
                self.current_month, self.manufacturer_recovery_months
            )
        
        # Check for AMR emergence
        if hasattr(self, 'amr_start_month') and self.current_month >= self.amr_start_month:
            self.amr_active = True
        
        # Monthly metrics
        month_metrics = {
            'month': self.current_month,
            'shortages': defaultdict(int),
            'wastage': defaultdict(int),
            'deaths': defaultdict(int),
            'missed_opportunities': defaultdict(int),
            'patients_treated': defaultdict(int),
            'patients_total': defaultdict(int)
        }
        
        # 1. Manufacturer production
        for mfr in self.manufacturers:
            mfr.produce(self.current_month)
        
        # 2. Process shipments arriving this month
        for agent_list in [self.central_stores, self.hospitals, self.chcs]:
            for agent in agent_list:
                agent.process_incoming_shipments(self.current_month)
        
        # 3. Manufacturers ship to central stores
        for mfr in self.manufacturers:
            for cms in self.central_stores:
                for med_type in self.config['antibiotic_classes']:
                    available = mfr.get_stock_level(med_type)
                    if available > 0:
                        mfr.ship_to(cms, med_type, available, 1, self.current_month)  # 1 month transit
        
        # 4. Central stores distribute to hospitals
        transit_time = self.config['transit_time'] // 2  # Half the total transit time
        for cms in self.central_stores:
            cms.distribute_to_hospitals(self.hospitals, self.current_month, transit_time)
        
        # 5. Hospitals distribute to CHCs
        for hosp in self.hospitals:
            hosp.distribute_to_chcs(self.current_month, transit_time)
        
        # 6. Determine health worker attendance
        for hw in self.health_workers:
            hw.determine_attendance()
        
        # 7. Process demand at each CHC
        # Demand = ML predictions. The model tests supply chain response to this demand.
        for chc in self.chcs:
            results = chc.process_demand(self.current_month)
            
            # Aggregate metrics from CHC results
            for age in ['child', 'adult', 'elderly']:
                month_metrics['patients_treated'][age] += results['treated'][age]
                month_metrics['patients_total'][age] += (
                    results['treated'][age] + 
                    results['untreated'][age] + 
                    results['missed'][age]
                )
                month_metrics['deaths'][age] += results['deaths'][age]
                month_metrics['missed_opportunities'][age] += results['missed'][age]
            
            # Collect shortage data by medicine type
            for med_type, count in chc.shortages.items():
                month_metrics['shortages'][med_type] += count
            
            # Reset CHC shortages for next month
            chc.shortages = defaultdict(int)
        
        # 8. Process medicine expiry across all locations
        for agent_list in [self.manufacturers, self.central_stores, self.hospitals, self.chcs]:
            for agent in agent_list:
                expired = agent.process_expiry(self.current_month)
                for med_type, qty in expired.items():
                    month_metrics['wastage'][med_type] += qty
        
        # Convert defaultdicts to regular dicts and store
        month_metrics['shortages'] = dict(month_metrics['shortages'])
        month_metrics['wastage'] = dict(month_metrics['wastage'])
        month_metrics['deaths'] = dict(month_metrics['deaths'])
        month_metrics['missed_opportunities'] = dict(month_metrics['missed_opportunities'])
        month_metrics['patients_treated'] = dict(month_metrics['patients_treated'])
        month_metrics['patients_total'] = dict(month_metrics['patients_total'])
        
        self.metrics_history.append(month_metrics)
    
    def run(self, n_months: int = None):
        """Run the simulation for specified months."""
        if n_months is None:
            n_months = self.config['n_months']
        
        if self.verbose:
            print(f"\nRunning simulation for {n_months} months...")
        
        for _ in range(n_months):
            self.step()
        
        if self.verbose:
            print(f"Simulation complete! Final month: {self.current_month}")
    
    def get_results_df(self) -> pd.DataFrame:
        """Convert metrics history to a DataFrame."""
        records = []
        for m in self.metrics_history:
            record = {'month': m['month']}
            
            # Flatten nested dicts
            for med_type in self.config['antibiotic_classes']:
                record[f'shortage_{med_type}'] = m['shortages'].get(med_type, 0)
                record[f'wastage_{med_type}'] = m['wastage'].get(med_type, 0)
            
            for age_group in ['child', 'adult', 'elderly']:
                record[f'deaths_{age_group}'] = m['deaths'].get(age_group, 0)
                record[f'missed_{age_group}'] = m['missed_opportunities'].get(age_group, 0)
                record[f'treated_{age_group}'] = m['patients_treated'].get(age_group, 0)
                record[f'total_{age_group}'] = m['patients_total'].get(age_group, 0)
            
            records.append(record)
        
        return pd.DataFrame(records)

print("EthiopiaSupplyChainModel defined")

## Part 5: Unit Tests

Before running full simulations, let's verify our model components work correctly.

In [None]:
def test_model_initialization():
    """Test that model initializes with correct agent counts."""
    model = EthiopiaSupplyChainModel(verbose=False)
    
    assert len(model.manufacturers) == CONFIG['n_manufacturers'], "Wrong manufacturer count"
    assert len(model.central_stores) == CONFIG['n_central_stores'], "Wrong central store count"
    assert len(model.hospitals) == CONFIG['n_hospitals'], "Wrong hospital count"
    assert len(model.chcs) == CONFIG['n_chcs'], "Wrong CHC count"
    assert len(model.health_workers) == CONFIG['n_chcs'], "Wrong health worker count"
    
    print("✓ Model initialization test passed")

test_model_initialization()

In [None]:
def test_medicine_expiry():
    """Test that medicines expire after shelf life."""
    batch = MedicineBatch('Penicillins', 100, manufacture_month=1, shelf_life=12)
    
    assert not batch.is_expired(6), "Should not be expired at month 6"
    assert not batch.is_expired(12), "Should not be expired at month 12"
    assert batch.is_expired(13), "Should be expired at month 13"
    assert batch.is_expired(20), "Should be expired at month 20"
    
    print("✓ Medicine expiry test passed")

test_medicine_expiry()

In [None]:
def test_transit_timing():
    """Test that shipments arrive after transit time."""
    model = EthiopiaSupplyChainModel(verbose=False)
    
    # Get a CHC and check initial state
    chc = model.chcs[0]
    initial_shipments = len(chc.incoming_shipments)
    
    # Run a few steps
    for _ in range(3):
        model.step()
    
    # Shipments should be arriving (processed into inventory)
    # After 3 months, some shipments should have been received
    assert model.current_month == 3, "Model should be at month 3"
    
    print("✓ Transit timing test passed")

test_transit_timing()

In [None]:
def test_demand_generation():
    """Test that demand generation uses ML predictions correctly."""
    model = EthiopiaSupplyChainModel(verbose=False)
    chc = model.chcs[0]
    
    # Get demand for month 6
    demand = chc.get_actual_demand(6)
    
    # Should have all 3 antibiotic classes
    assert len(demand) == 3, f"Should have 3 antibiotic classes, got {len(demand)}"
    
    # Each class should have 3 age groups
    for abx_class in demand:
        assert len(demand[abx_class]) == 3, f"Should have 3 age groups for {abx_class}"
        
        # Check age groups are present
        for age in ['child', 'adult', 'elderly']:
            assert age in demand[abx_class], f"Missing {age} in {abx_class}"
    
    # Demand should be positive and reasonable
    total_demand = sum(sum(ages.values()) for ages in demand.values())
    assert total_demand > 0, "Total demand should be positive"
    
    # Demand should reflect demographics (children = 35% of demand)
    penicillin_demand = demand['Penicillins']
    total_pen = sum(penicillin_demand.values())
    if total_pen > 0:
        child_pct = penicillin_demand['child'] / total_pen
        # Allow for rounding: should be close to 0.35
        assert 0.30 <= child_pct <= 0.40, f"Child pct should be ~0.35, got {child_pct:.2f}"
    
    print(f"  CHC {chc.unique_id} (pop={chc.population_served:,}), month 6:")
    for abx_class, ages in demand.items():
        print(f"    {abx_class}: {sum(ages.values())} total")
    print(f"  Total demand across all classes: {total_demand}")
    print("✓ Demand generation test passed")

test_demand_generation()

In [None]:
print("\n" + "="*50)
print("All unit tests passed!")
print("="*50)

## Part 6: Base Scenario Run

Now let's run the full 60-month simulation with default parameters.

In [None]:
# Run base scenario
np.random.seed(42)  # Reset seed for reproducibility

print("Initializing base scenario model...")
base_model = EthiopiaSupplyChainModel(scenario='base', verbose=True)

print("\nRunning simulation...")
base_model.run(60)

# Get results
base_results = base_model.get_results_df()
print(f"\nResults shape: {base_results.shape}")
base_results.head()

In [None]:
# Summary statistics for base scenario
print("="*60)
print("BASE SCENARIO SUMMARY (60 months)")
print("="*60)

# Shortages by medicine type
print("\nTotal Shortages by Medicine Type:")
for med in CONFIG['antibiotic_classes']:
    total = base_results[f'shortage_{med}'].sum()
    print(f"  {med}: {total:,}")

# Wastage by medicine type
print("\nTotal Wastage by Medicine Type:")
for med in CONFIG['antibiotic_classes']:
    total = base_results[f'wastage_{med}'].sum()
    print(f"  {med}: {total:,}")

# Deaths by age group
print("\nTotal Deaths by Age Group:")
for age in ['child', 'adult', 'elderly']:
    total = base_results[f'deaths_{age}'].sum()
    print(f"  {age.capitalize()}: {total:,}")

# Missed opportunities
print("\nTotal Missed Opportunities (HW absent):")
for age in ['child', 'adult', 'elderly']:
    total = base_results[f'missed_{age}'].sum()
    print(f"  {age.capitalize()}: {total:,}")

# Treatment rate
print("\nTreatment Rates:")
for age in ['child', 'adult', 'elderly']:
    treated = base_results[f'treated_{age}'].sum()
    total = base_results[f'total_{age}'].sum()
    rate = (treated / total * 100) if total > 0 else 0
    print(f"  {age.capitalize()}: {rate:.1f}% ({treated:,}/{total:,})")

## Part 7: Visualization Dashboard

Create a 4-panel dashboard showing key metrics over time.

In [None]:
def plot_dashboard(results_df: pd.DataFrame, title: str = "Supply Chain Dashboard"):
    """
    Create a 4-panel dashboard for simulation results.
    """
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    months = results_df['month']
    
    # Panel 1: Shortages by medicine type
    ax1 = axes[0, 0]
    for med in CONFIG['antibiotic_classes']:
        ax1.plot(months, results_df[f'shortage_{med}'], label=med, linewidth=1.5)
    ax1.set_xlabel('Month')
    ax1.set_ylabel('Shortage Events')
    ax1.set_title('Medicine Shortages by Type')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Panel 2: Wastage by medicine type
    ax2 = axes[0, 1]
    for med in CONFIG['antibiotic_classes']:
        ax2.plot(months, results_df[f'wastage_{med}'], label=med, linewidth=1.5)
    ax2.set_xlabel('Month')
    ax2.set_ylabel('Units Wasted (Expired)')
    ax2.set_title('Medicine Wastage by Type')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Panel 3: Deaths by age group
    ax3 = axes[1, 0]
    colors = {'child': 'blue', 'adult': 'green', 'elderly': 'red'}
    for age in ['child', 'adult', 'elderly']:
        ax3.plot(months, results_df[f'deaths_{age}'], label=age.capitalize(), 
                 color=colors[age], linewidth=1.5)
    ax3.set_xlabel('Month')
    ax3.set_ylabel('Deaths')
    ax3.set_title('Deaths by Age Group (Untreated)')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # Panel 4: Missed opportunities by age group
    ax4 = axes[1, 1]
    for age in ['child', 'adult', 'elderly']:
        ax4.plot(months, results_df[f'missed_{age}'], label=age.capitalize(), 
                 color=colors[age], linewidth=1.5)
    ax4.set_xlabel('Month')
    ax4.set_ylabel('Patients Missed')
    ax4.set_title('Missed Opportunities (Health Worker Absent)')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    plt.suptitle(title, fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()

# Plot base scenario dashboard
plot_dashboard(base_results, "Base Scenario Dashboard (60 months)")

## Part 8: Scenario Comparisons

Let's run all 8 scenarios and compare outcomes.

In [None]:
# Define all scenarios
SCENARIOS = [
    ('base', 'Base Case'),
    ('weather_delays', 'Weather Delays (4mo transit)'),
    ('disease_outbreak', 'Disease Outbreak (25 CHCs, 3x demand)'),
    ('advance_ordering', 'Advance Ordering (4mo lead time)'),
    ('manufacturer_failure', 'Manufacturer Failure (month 12)'),
    ('optimization_challenge', 'Optimized Policy'),
    ('amr_substitution', 'AMR (30% Penicillin resistance)'),
    ('private_sector', 'Private Sector (25% diversion)')
]

print(f"Running {len(SCENARIOS)} scenarios...")

In [None]:
# Run all scenarios
scenario_results = {}

for scenario_id, scenario_name in SCENARIOS:
    print(f"\n{'='*50}")
    print(f"Running: {scenario_name}")
    print('='*50)
    
    np.random.seed(42)  # Reset seed for fair comparison
    
    model = EthiopiaSupplyChainModel(scenario=scenario_id, verbose=False)
    model.run(60)
    
    results = model.get_results_df()
    scenario_results[scenario_id] = {
        'name': scenario_name,
        'results': results,
        'model': model
    }
    
    # Print summary
    total_shortages = sum(results[f'shortage_{med}'].sum() for med in CONFIG['antibiotic_classes'])
    total_wastage = sum(results[f'wastage_{med}'].sum() for med in CONFIG['antibiotic_classes'])
    total_deaths = sum(results[f'deaths_{age}'].sum() for age in ['child', 'adult', 'elderly'])
    
    print(f"  Total shortages: {total_shortages:,}")
    print(f"  Total wastage: {total_wastage:,}")
    print(f"  Total deaths: {total_deaths:,}")

print("\n" + "="*50)
print("All scenarios complete!")
print("="*50)

In [None]:
# Create comparison summary table
comparison_data = []

for scenario_id, data in scenario_results.items():
    results = data['results']
    
    row = {
        'Scenario': data['name'],
        'Total Shortages': sum(results[f'shortage_{med}'].sum() for med in CONFIG['antibiotic_classes']),
        'Total Wastage': sum(results[f'wastage_{med}'].sum() for med in CONFIG['antibiotic_classes']),
        'Deaths (Child)': results['deaths_child'].sum(),
        'Deaths (Adult)': results['deaths_adult'].sum(),
        'Deaths (Elderly)': results['deaths_elderly'].sum(),
        'Total Deaths': sum(results[f'deaths_{age}'].sum() for age in ['child', 'adult', 'elderly']),
        'Missed (Total)': sum(results[f'missed_{age}'].sum() for age in ['child', 'adult', 'elderly'])
    }
    comparison_data.append(row)

comparison_df = pd.DataFrame(comparison_data)

# Format with thousands separator
print("\nSCENARIO COMPARISON TABLE")
print("="*100)
display(comparison_df)

In [None]:
# Visualize scenario comparison
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Prepare data
scenarios = comparison_df['Scenario'].tolist()
x = np.arange(len(scenarios))
width = 0.6

# Panel 1: Total Shortages
ax1 = axes[0, 0]
bars = ax1.barh(x, comparison_df['Total Shortages'], color='steelblue', alpha=0.7)
ax1.set_yticks(x)
ax1.set_yticklabels(scenarios, fontsize=9)
ax1.set_xlabel('Total Shortage Events')
ax1.set_title('Total Shortages by Scenario')
ax1.invert_yaxis()

# Panel 2: Total Wastage
ax2 = axes[0, 1]
bars = ax2.barh(x, comparison_df['Total Wastage'], color='darkorange', alpha=0.7)
ax2.set_yticks(x)
ax2.set_yticklabels(scenarios, fontsize=9)
ax2.set_xlabel('Total Units Wasted')
ax2.set_title('Total Wastage by Scenario')
ax2.invert_yaxis()

# Panel 3: Deaths by age group (stacked)
ax3 = axes[1, 0]
bottom = np.zeros(len(scenarios))
colors = {'Child': 'blue', 'Adult': 'green', 'Elderly': 'red'}
for age in ['Child', 'Adult', 'Elderly']:
    values = comparison_df[f'Deaths ({age})'].values
    ax3.barh(x, values, left=bottom, label=age, color=colors[age], alpha=0.7)
    bottom += values
ax3.set_yticks(x)
ax3.set_yticklabels(scenarios, fontsize=9)
ax3.set_xlabel('Deaths')
ax3.set_title('Deaths by Scenario and Age Group')
ax3.legend()
ax3.invert_yaxis()

# Panel 4: Total Deaths (line for reference)
ax4 = axes[1, 1]
bars = ax4.barh(x, comparison_df['Total Deaths'], color='darkred', alpha=0.7)
ax4.set_yticks(x)
ax4.set_yticklabels(scenarios, fontsize=9)
ax4.set_xlabel('Total Deaths')
ax4.set_title('Total Deaths by Scenario')
ax4.invert_yaxis()

# Add base case reference line
base_deaths = comparison_df[comparison_df['Scenario'].str.contains('Base')]['Total Deaths'].values[0]
ax4.axvline(base_deaths, color='gray', linestyle='--', alpha=0.7, label='Base case')

plt.suptitle('Scenario Comparison', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

## Part 9: Individual Scenario Deep Dives

Let's examine specific scenarios in more detail.

In [None]:
# Weather Delays vs Base comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Shortages over time
ax1 = axes[0]
base_short = base_results[[f'shortage_{m}' for m in CONFIG['antibiotic_classes']]].sum(axis=1)
weather_short = scenario_results['weather_delays']['results'][[f'shortage_{m}' for m in CONFIG['antibiotic_classes']]].sum(axis=1)

ax1.plot(base_results['month'], base_short, label='Base (2mo transit)', linewidth=2)
ax1.plot(base_results['month'], weather_short, label='Weather Delays (4mo transit)', linewidth=2)
ax1.set_xlabel('Month')
ax1.set_ylabel('Total Shortage Events')
ax1.set_title('Impact of Weather Delays on Shortages')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Cumulative shortages
ax2 = axes[1]
ax2.plot(base_results['month'], base_short.cumsum(), label='Base', linewidth=2)
ax2.plot(base_results['month'], weather_short.cumsum(), label='Weather Delays', linewidth=2)
ax2.set_xlabel('Month')
ax2.set_ylabel('Cumulative Shortage Events')
ax2.set_title('Cumulative Shortages Over Time')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Manufacturer Failure scenario
mfr_results = scenario_results['manufacturer_failure']['results']

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Shortages over time
ax1 = axes[0]
mfr_short = mfr_results[[f'shortage_{m}' for m in CONFIG['antibiotic_classes']]].sum(axis=1)

ax1.plot(mfr_results['month'], mfr_short, color='red', linewidth=2)
ax1.axvline(12, color='gray', linestyle='--', label='Manufacturer fails')
ax1.axvline(24, color='green', linestyle='--', label='Manufacturer recovers')
ax1.fill_between([12, 24], 0, mfr_short.max() * 1.1, alpha=0.2, color='red')
ax1.set_xlabel('Month')
ax1.set_ylabel('Total Shortage Events')
ax1.set_title('Manufacturer Failure Impact on Shortages')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Deaths over time
ax2 = axes[1]
mfr_deaths = mfr_results[[f'deaths_{a}' for a in ['child', 'adult', 'elderly']]].sum(axis=1)

ax2.plot(mfr_results['month'], mfr_deaths, color='darkred', linewidth=2)
ax2.axvline(12, color='gray', linestyle='--', label='Manufacturer fails')
ax2.axvline(24, color='green', linestyle='--', label='Manufacturer recovers')
ax2.fill_between([12, 24], 0, mfr_deaths.max() * 1.1, alpha=0.2, color='red')
ax2.set_xlabel('Month')
ax2.set_ylabel('Deaths')
ax2.set_title('Manufacturer Failure Impact on Deaths')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.suptitle('Manufacturer Failure Scenario (Month 12-24)', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# AMR Substitution scenario
amr_results = scenario_results['amr_substitution']['results']

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Shortages by medicine type
ax1 = axes[0]
for med in CONFIG['antibiotic_classes']:
    ax1.plot(amr_results['month'], amr_results[f'shortage_{med}'], label=med, linewidth=1.5)
ax1.axvline(24, color='gray', linestyle='--', label='AMR emerges')
ax1.set_xlabel('Month')
ax1.set_ylabel('Shortage Events')
ax1.set_title('AMR Scenario: Shortages by Medicine Type')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Compare with base
ax2 = axes[1]
base_macro = base_results['shortage_Macrolides']
amr_macro = amr_results['shortage_Macrolides']

ax2.plot(base_results['month'], base_macro, label='Base - Macrolides', linewidth=2)
ax2.plot(amr_results['month'], amr_macro, label='AMR - Macrolides', linewidth=2, linestyle='--')
ax2.axvline(24, color='gray', linestyle=':', label='AMR emerges (month 24)')
ax2.set_xlabel('Month')
ax2.set_ylabel('Shortage Events')
ax2.set_title('Macrolide Shortages: Base vs AMR Scenario')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.suptitle('Antimicrobial Resistance (30% Penicillin resistance from month 24)', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

## Part 10: Interactive Parameter Exploration

Use ipywidgets to interactively explore model parameters (if available).

In [None]:
if WIDGETS_AVAILABLE:
    def run_interactive_simulation(transit_time, order_lead, absenteeism, shelf_life):
        """Run simulation with custom parameters."""
        custom_config = CONFIG.copy()
        custom_config['transit_time'] = transit_time
        custom_config['order_lead_time'] = order_lead
        custom_config['health_worker_absenteeism'] = absenteeism / 100
        custom_config['medicine_shelf_life'] = shelf_life
        
        np.random.seed(42)
        model = EthiopiaSupplyChainModel(config=custom_config, scenario='base', verbose=False)
        model.run(60)
        results = model.get_results_df()
        
        # Plot results
        clear_output(wait=True)
        plot_dashboard(results, f"Custom Parameters: Transit={transit_time}mo, Lead={order_lead}mo, Absent={absenteeism}%, Shelf={shelf_life}mo")
        
        # Summary
        total_short = sum(results[f'shortage_{m}'].sum() for m in CONFIG['antibiotic_classes'])
        total_waste = sum(results[f'wastage_{m}'].sum() for m in CONFIG['antibiotic_classes'])
        total_deaths = sum(results[f'deaths_{a}'].sum() for a in ['child', 'adult', 'elderly'])
        
        print(f"\nSummary: Shortages={total_short:,}, Wastage={total_waste:,}, Deaths={total_deaths:,}")
    
    # Create interactive widgets
    interactive_plot = widgets.interactive(
        run_interactive_simulation,
        transit_time=widgets.IntSlider(value=2, min=1, max=6, step=1, description='Transit (mo)'),
        order_lead=widgets.IntSlider(value=2, min=1, max=6, step=1, description='Order Lead (mo)'),
        absenteeism=widgets.IntSlider(value=10, min=0, max=30, step=5, description='Absent (%)'),
        shelf_life=widgets.IntSlider(value=12, min=6, max=24, step=3, description='Shelf Life (mo)')
    )
    
    print("Interactive Parameter Exploration")
    print("Adjust sliders to see impact on supply chain outcomes:\n")
    display(interactive_plot)
else:
    print("Interactive widgets not available in this environment.")
    print("To enable, install ipywidgets: pip install ipywidgets")

## Part 11: Supply Chain Network Visualization

Visualize the supply chain structure using NetworkX (if available).

In [None]:
if NETWORKX_AVAILABLE:
    def visualize_supply_chain():
        """Create a network visualization of the supply chain."""
        G = nx.DiGraph()
        
        # Add manufacturers
        for i in range(CONFIG['n_manufacturers']):
            G.add_node(f'MFR_{i}', layer=0, node_type='manufacturer')
        
        # Add central store
        G.add_node('CMS_0', layer=1, node_type='central_store')
        
        # Add hospitals
        for i in range(CONFIG['n_hospitals']):
            G.add_node(f'HOSP_{i}', layer=2, node_type='hospital')
        
        # Add edges: Manufacturer -> Central Store
        for i in range(CONFIG['n_manufacturers']):
            G.add_edge(f'MFR_{i}', 'CMS_0')
        
        # Add edges: Central Store -> Hospitals
        for i in range(CONFIG['n_hospitals']):
            G.add_edge('CMS_0', f'HOSP_{i}')
        
        # Position nodes
        pos = {}
        # Manufacturers at top
        for i in range(CONFIG['n_manufacturers']):
            pos[f'MFR_{i}'] = (i - (CONFIG['n_manufacturers']-1)/2, 3)
        
        # Central store
        pos['CMS_0'] = (0, 2)
        
        # Hospitals
        for i in range(CONFIG['n_hospitals']):
            pos[f'HOSP_{i}'] = (i - (CONFIG['n_hospitals']-1)/2, 1)
        
        # Create figure
        fig, ax = plt.subplots(figsize=(12, 8))
        
        # Draw nodes with different colors
        node_colors = []
        for node in G.nodes():
            if 'MFR' in node:
                node_colors.append('lightblue')
            elif 'CMS' in node:
                node_colors.append('lightgreen')
            elif 'HOSP' in node:
                node_colors.append('lightyellow')
        
        nx.draw(G, pos, ax=ax, with_labels=True, node_color=node_colors,
                node_size=2000, font_size=10, font_weight='bold',
                arrows=True, arrowsize=20, edge_color='gray')
        
        # Add legend
        legend_elements = [
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightblue', markersize=15, label='Manufacturers'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgreen', markersize=15, label='Central Store (EPSA)'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightyellow', markersize=15, label='Regional Hospitals'),
        ]
        ax.legend(handles=legend_elements, loc='upper right')
        
        # Add CHC counts annotation
        for i in range(CONFIG['n_hospitals']):
            n_chcs = len(CHC_HOSPITAL_ASSIGNMENTS[i])
            ax.annotate(f'{n_chcs} CHCs', 
                       xy=pos[f'HOSP_{i}'], 
                       xytext=(pos[f'HOSP_{i}'][0], 0.3),
                       ha='center', fontsize=9,
                       arrowprops=dict(arrowstyle='->', color='gray', alpha=0.5))
        
        ax.set_title('Ethiopia Antibiotic Supply Chain Structure\n(100 CHCs served by 3 Regional Hospitals)', 
                    fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.show()
    
    visualize_supply_chain()
else:
    print("NetworkX not available for supply chain visualization.")
    print("To enable, install networkx: pip install networkx")

## Part 12: Summary & Discussion Questions

### Key Findings by Scenario

In [None]:
# Calculate relative changes from base
base_row = comparison_df[comparison_df['Scenario'].str.contains('Base')].iloc[0]

print("="*70)
print("KEY FINDINGS: Relative to Base Scenario")
print("="*70)

for _, row in comparison_df.iterrows():
    if 'Base' in row['Scenario']:
        continue
    
    short_change = ((row['Total Shortages'] - base_row['Total Shortages']) / base_row['Total Shortages'] * 100) if base_row['Total Shortages'] > 0 else 0
    death_change = ((row['Total Deaths'] - base_row['Total Deaths']) / base_row['Total Deaths'] * 100) if base_row['Total Deaths'] > 0 else 0
    
    print(f"\n{row['Scenario']}:")
    print(f"  Shortages: {short_change:+.1f}%")
    print(f"  Deaths: {death_change:+.1f}%")

In [None]:
# Final summary
print("\n" + "="*70)
print("POLICY IMPLICATIONS")
print("="*70)

print("""
1. LOGISTICS INFRASTRUCTURE
   - Weather delays (doubled transit time) significantly increase shortages
   - Investment in all-weather roads could reduce supply disruptions
   - Consider strategic buffer stocks at regional hospitals

2. ORDERING POLICIES  
   - Advance ordering (4 months vs 2 months) can buffer against disruptions
   - ML demand forecasts enable proactive ordering
   - Balance between ordering ahead and medicine expiry risk

3. SUPPLY CHAIN RESILIENCE
   - Single manufacturer failure has cascading effects
   - Diversified sourcing reduces single-point-of-failure risk
   - Emergency stockpiles provide buffer during disruptions

4. ANTIMICROBIAL RESISTANCE (AMR)
   - Resistance emergence shifts demand to second-line antibiotics
   - Supply chains must adapt to changing treatment patterns
   - Surveillance data integration improves forecasting

5. HEALTH WORKFORCE
   - Health worker absenteeism creates missed care opportunities
   - Even with adequate supplies, staffing gaps affect outcomes
   - Integrated approach: supplies AND workforce planning

6. EQUITY CONSIDERATIONS
   - Elderly and children most vulnerable to supply disruptions
   - Prioritization policies may be needed during shortages
   - Geographic equity in distribution matters
""")

## Discussion Questions

1. **Which scenario had the largest impact on deaths? Why?**

2. **How do shortages and wastage trade off? Can you find a parameter setting that minimizes both?**

3. **The model assumes patients die if untreated. How might you modify this to include referral to higher-level facilities?**

4. **What additional data would make this model more realistic for Ethiopia's context?**

5. **How could the ML forecasts from Day 2 be improved to better support this ABM?**

6. **What policy interventions would you recommend to the Ministry of Health based on these results?**

## Optional Challenges

### Challenge 1: Implement your own optimization strategy
Modify the `optimization_challenge` scenario parameters to minimize deaths while keeping wastage below 50,000 units.

### Challenge 2: Add a new scenario
Implement a "Budget Cut" scenario where manufacturer capacity is reduced by 30%.

### Challenge 3: Enhance the model
Add a feature where CHCs can "borrow" medicines from neighboring CHCs during shortages.

In [None]:
# Space for your challenge solutions

# Challenge 1: Optimization
# Try different parameter combinations:
# custom_config = CONFIG.copy()
# custom_config['order_lead_time'] = ???
# custom_config['transit_time'] = ???
# ...

pass  # Your code here

---

## Connection to WISE Project

This ABM demonstrates how agent-based modeling can support the WISE project's goals:

1. **Evidence-based policy**: Test policy options virtually before implementation
2. **Capacity building**: Train local researchers in simulation methods
3. **Integration**: Combine ML forecasting with system dynamics
4. **Stakeholder engagement**: Visual results support communication with policymakers

The model structure can be adapted for:
- Other medicine classes (antimalarials, vaccines)
- Different geographic contexts (other regions, countries)
- Extended supply chains (including private sector)
- Integration with real EPSA data

---

**End of Notebook**

*For questions or feedback, contact the WISE Workshop team.*