# Blood Bank Inventory Management

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pedronahum/stochastic-optimization/blob/master/notebooks/blood_management.ipynb)

## Problem Overview

This notebook demonstrates blood bank inventory optimization using JAX. The problem involves managing:
- **8 blood types** (O-, O+, A-, A+, B-, B+, AB-, AB+) with substitution rules
- **Age-dependent inventory** (blood expires after `max_age` days)
- **Stochastic demand** from urgent and elective surgeries
- **Stochastic supply** from blood donations
- **Surge events** (sudden increases in demand)

---

## Mathematical Formulation

### State Space
The state at time $t$ is:
$$s_t = [I_{t}, t]$$

where $I_t \in \mathbb{R}^{B \times A}$ is the inventory matrix:
- $B = 8$ blood types
- $A$ = maximum age (days before expiry)
- $I_{t}[b, a]$ = units of blood type $b$ with age $a$ days

### Blood Substitution Rules
The substitution matrix $S \in \{0,1\}^{B \times B}$ defines compatibility:
- $S[i,j] = 1$ if blood type $i$ can be used for demand type $j$
- **O- is universal donor**: can substitute for all types
- **AB+ is universal recipient**: can only receive from AB+
- Rh factor rules: Negative can give to positive, but not vice versa

### Dynamics
The inventory transitions as:
$$I_{t+1} = \text{Age}(I_t - X_t) + D_{t+1}$$

where:
- $X_t \in \mathbb{R}^{(B \times A) \times (B \times 2)}$ is the allocation decision
- $\text{Age}(\cdot)$ shifts inventory ages by 1 day (oldest expires)
- $D_{t+1} \in \mathbb{R}^B$ is new donations (age 0)

### Exogenous Information
At each time step:
$$W_t = (\text{demand}_t, \text{donation}_t)$$

- **Demand**: $\text{demand}_t \in \mathbb{R}^{B \times 2}$ (urgent + elective)
  $$\text{demand}_t[b, s] \sim \text{Poisson}(\lambda_s \cdot m_t)$$
  - $\lambda_0$ = urgent rate, $\lambda_1$ = elective rate
  - $m_t$ = surge multiplier (3.0 with prob 0.1, else 1.0)

- **Donation**: $\text{donation}_t[b] \sim \text{Poisson}(\lambda_d)$

### Reward Function
The single-step reward balances multiple objectives:

$$R(s_t, x_t, w_t) = R_{\text{fulfill}} + R_{\text{match}} + R_{\text{shortage}} + R_{\text{discard}}$$

**1. Fulfillment reward** (meeting demand):
$$R_{\text{fulfill}} = \sum_{b,s} \min(\text{allocated}[b,s], \text{demand}[b,s]) \cdot c_s$$
- $c_0 = 100$ (urgent bonus)
- $c_1 = 50$ (elective bonus)

**2. Exact match bonus** (no substitution needed):
$$R_{\text{match}} = \sum_{b,s} x_t[b,:,b,s] \cdot c_{\text{match}}$$
- $c_{\text{match}} = 10$

**3. Shortage penalty** (unmet demand):
$$R_{\text{shortage}} = -\sum_{b,s} \max(0, \text{demand}[b,s] - \text{allocated}[b,s]) \cdot p_s$$
- $p_0 = 400$ (urgent penalty, 2× base)
- $p_1 = 200$ (elective penalty)

**4. Discard penalty** (expired blood):
$$R_{\text{discard}} = -\sum_b I_t[b, A-1] \cdot p_{\text{discard}}$$
- $p_{\text{discard}} = 50$

### Objective
Maximize expected cumulative reward:
$$\max_{\pi} \mathbb{E}\left[\sum_{t=0}^{T-1} R(s_t, \pi(s_t), w_t)\right]$$

---

## Setup and Installation

First, let's install the required packages and clone the repository.

In [None]:
# Install JAX and dependencies
!pip install -q jax jaxlib jaxtyping chex numpy matplotlib

# Clone repository (force fresh clone for latest code)
import os
import shutil

if os.path.exists('stochastic-optimization'):
    shutil.rmtree('stochastic-optimization')

!git clone https://github.com/pedronahum/stochastic-optimization.git
os.chdir('stochastic-optimization')

# Clear Python import cache
import sys
for key in list(sys.modules.keys()):
    if key.startswith('problems'):
        del sys.modules[key]

print('✓ Setup complete!')

## Imports

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from typing import List, Dict

# Import problem components
from problems.blood_management import (
    BloodManagementConfig,
    BloodManagementModel,
    ExogenousInfo,
    GreedyPolicy,
    FIFOPolicy,
    RandomPolicy,
    BLOOD_TYPES,
)

print("✓ Imports successful")
print(f"JAX version: {jax.__version__}")
print(f"JAX backend: {jax.default_backend()}")

## Problem Configuration

Let's set up the blood bank parameters.

In [None]:
# Create configuration
config = BloodManagementConfig(
    max_age=5,                  # Blood expires after 5 days
    max_demand_urgent=10.0,     # Average urgent demand per type
    max_demand_elective=5.0,    # Average elective demand per type
    max_donation=15.0,          # Average donation per type
    surge_prob=0.1,             # 10% chance of surge
    surge_factor=3.0,           # 3x demand during surge
    urgent_bonus=100.0,         # Reward for urgent fulfillment
    elective_bonus=50.0,        # Reward for elective fulfillment
    no_substitution_bonus=10.0, # Bonus for exact match
    discard_penalty=-50.0,      # Penalty for expired blood
    shortage_penalty=-200.0,    # Penalty for shortage
    seed=42
)

print("Configuration:")
print(f"  Blood types: {len(BLOOD_TYPES)} ({', '.join(BLOOD_TYPES)})")
print(f"  Max age: {config.max_age} days")
print(f"  Surge probability: {config.surge_prob:.1%}")
print(f"  State size: {8 * config.max_age + 1} (inventory + time)")

## Initialize Model and Policies

In [None]:
# Create model
model = BloodManagementModel(config)

# Create policies
greedy_policy = GreedyPolicy()
fifo_policy = FIFOPolicy()
random_policy = RandomPolicy()

print(f"✓ Model initialized")
print(f"  Inventory slots: {model.n_inventory_slots}")
print(f"  Demand types: {model.n_demand_types} (8 blood types × 2 surgery types)")
print(f"\n✓ Policies created:")
print(f"  - Greedy (oldest blood first, exact matches preferred)")
print(f"  - FIFO (first-in-first-out)")
print(f"  - Random (baseline)")

## Blood Type Substitution Matrix

Let's visualize which blood types can substitute for others.

In [None]:
# Visualize substitution matrix
fig, ax = plt.subplots(figsize=(8, 7))
im = ax.imshow(model.substitution_matrix, cmap='RdYlGn', aspect='auto')

# Set ticks and labels
ax.set_xticks(range(len(BLOOD_TYPES)))
ax.set_yticks(range(len(BLOOD_TYPES)))
ax.set_xticklabels(BLOOD_TYPES)
ax.set_yticklabels(BLOOD_TYPES)

ax.set_xlabel('Demand Type (recipient)', fontsize=12)
ax.set_ylabel('Supply Type (donor)', fontsize=12)
ax.set_title('Blood Type Substitution Matrix\n(Green = Compatible)', fontsize=14)

# Add text annotations
for i in range(len(BLOOD_TYPES)):
    for j in range(len(BLOOD_TYPES)):
        text = ax.text(j, i, '✓' if model.substitution_matrix[i, j] else '✗',
                      ha="center", va="center", color="black", fontsize=14)

plt.colorbar(im, ax=ax, label='Compatible')
plt.tight_layout()
plt.show()

print("Key observations:")
print("  • O- (row 0) can donate to all types (universal donor)")
print("  • AB+ (column 7) can receive from all types (universal recipient)")
print("  • Negative types can donate to corresponding positive types")

## Run Simulation

Let's simulate blood bank operations for 30 days.

In [None]:
def run_episode(model, policy, horizon=30, key=None):
    """Run a single episode."""
    if key is None:
        key = jax.random.PRNGKey(42)
    
    # Initialize
    state = model.init_state(key)
    
    # Track history
    history = {
        'rewards': [],
        'inventory_total': [],
        'demands': [],
        'donations': [],
        'allocations': [],
    }
    
    # Run simulation
    for t in range(horizon):
        # Get decision
        key, subkey = jax.random.split(key)
        decision = policy(None, state, subkey, model)
        
        # Sample exogenous events
        key, subkey = jax.random.split(key)
        exog = model.sample_exogenous(subkey, state, t)
        
        # Compute reward
        reward = model.reward(state, decision, exog)
        
        # Record
        inventory = model.get_inventory(state)
        history['rewards'].append(float(reward))
        history['inventory_total'].append(float(jnp.sum(inventory)))
        history['demands'].append(float(jnp.sum(exog.demand)))
        history['donations'].append(float(jnp.sum(exog.donation)))
        history['allocations'].append(float(jnp.sum(decision)))
        
        # Transition
        state = model.transition(state, decision, exog)
    
    return history

# Run simulation with greedy policy
print("Running 30-day simulation with Greedy policy...")
history = run_episode(model, greedy_policy, horizon=30)

print(f"\n✓ Simulation complete!")
print(f"  Total reward: {sum(history['rewards']):.1f}")
print(f"  Average daily reward: {np.mean(history['rewards']):.1f}")
print(f"  Average inventory: {np.mean(history['inventory_total']):.1f} units")

## Visualize Results

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

days = range(len(history['rewards']))

# Rewards
axes[0, 0].plot(days, history['rewards'], 'o-', linewidth=2)
axes[0, 0].axhline(0, color='red', linestyle='--', alpha=0.3)
axes[0, 0].set_title('Daily Reward', fontsize=12, fontweight='bold')
axes[0, 0].set_xlabel('Day')
axes[0, 0].set_ylabel('Reward')
axes[0, 0].grid(alpha=0.3)

# Cumulative reward
cumulative = np.cumsum(history['rewards'])
axes[0, 1].plot(days, cumulative, 'o-', linewidth=2, color='green')
axes[0, 1].set_title('Cumulative Reward', fontsize=12, fontweight='bold')
axes[0, 1].set_xlabel('Day')
axes[0, 1].set_ylabel('Cumulative Reward')
axes[0, 1].grid(alpha=0.3)

# Inventory
axes[1, 0].plot(days, history['inventory_total'], 'o-', linewidth=2, color='purple')
axes[1, 0].set_title('Total Inventory', fontsize=12, fontweight='bold')
axes[1, 0].set_xlabel('Day')
axes[1, 0].set_ylabel('Units')
axes[1, 0].grid(alpha=0.3)

# Supply and Demand
axes[1, 1].plot(days, history['demands'], 'o-', label='Demand', linewidth=2, color='red')
axes[1, 1].plot(days, history['donations'], 's-', label='Donations', linewidth=2, color='blue')
axes[1, 1].plot(days, history['allocations'], '^-', label='Allocated', linewidth=2, color='green')
axes[1, 1].set_title('Supply and Demand', fontsize=12, fontweight='bold')
axes[1, 1].set_xlabel('Day')
axes[1, 1].set_ylabel('Units')
axes[1, 1].legend()
axes[1, 1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

## Policy Comparison

Let's compare different allocation policies.

In [None]:
policies = {
    'Greedy': greedy_policy,
    'FIFO': fifo_policy,
    'Random': random_policy,
}

# Run each policy
results = {}
for name, policy in policies.items():
    print(f"Running {name} policy...")
    key = jax.random.PRNGKey(42)  # Same seed for fair comparison
    history = run_episode(model, policy, horizon=30, key=key)
    results[name] = history

# Compare total rewards
print("\n" + "="*50)
print("Policy Comparison (30 days):")
print("="*50)
for name, history in results.items():
    total = sum(history['rewards'])
    avg = np.mean(history['rewards'])
    print(f"{name:10s}: Total = {total:8.1f}  |  Avg/day = {avg:6.1f}")

# Visualize comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Daily rewards
for name, history in results.items():
    ax1.plot(history['rewards'], '-', label=name, linewidth=2, alpha=0.7)
ax1.set_title('Daily Rewards Comparison', fontsize=12, fontweight='bold')
ax1.set_xlabel('Day')
ax1.set_ylabel('Reward')
ax1.legend()
ax1.grid(alpha=0.3)

# Cumulative rewards
for name, history in results.items():
    cumulative = np.cumsum(history['rewards'])
    ax2.plot(cumulative, '-', label=name, linewidth=2, alpha=0.7)
ax2.set_title('Cumulative Rewards Comparison', fontsize=12, fontweight='bold')
ax2.set_xlabel('Day')
ax2.set_ylabel('Cumulative Reward')
ax2.legend()
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.show()

## Key Insights

The simulation demonstrates:

1. **Greedy policy** typically performs best by:
   - Using oldest blood first (FIFO)
   - Preferring exact blood type matches
   - Minimizing blood expiry

2. **Inventory dynamics**:
   - Fluctuates based on stochastic donations and demands
   - Must balance holding costs vs. shortage penalties
   - Age structure matters - older blood should be used first

3. **Substitution tradeoffs**:
   - Using substitutions (e.g., O- for A+) provides flexibility
   - But exact matches are preferred for better outcomes
   - Universal donor (O-) is critical for system resilience

4. **Surge events**:
   - Occasional demand spikes test inventory robustness
   - Policies must handle both normal and high-demand periods

---

## Extensions

Try modifying:
- `max_age`: See how expiry time affects performance
- `surge_prob`: Test resilience to demand variability
- `shortage_penalty`: Balance shortage vs. excess inventory costs
- Implement your own policy!

## References

- Repository: https://github.com/pedronahum/stochastic-optimization
- JAX Documentation: https://jax.readthedocs.io/
