# Medical Decision Making - Diabetes

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pedronahum/stochastic-optimization/blob/master/notebooks/medical_decision_diabetes.ipynb)

In [None]:
# Install JAX and dependencies
!pip install -q jax jaxlib jaxtyping chex numpy matplotlib

# Clone repository (force fresh clone for latest code)
import os
import shutil

if os.path.exists('stochastic-optimization'):
    shutil.rmtree('stochastic-optimization')

!git clone https://github.com/pedronahum/stochastic-optimization.git
os.chdir('stochastic-optimization')

# Clear Python import cache
import sys
for key in list(sys.modules.keys()):
    if key.startswith('problems'):
        del sys.modules[key]

print('✓ Setup complete!')

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

# Import problem components
from problems.medical_decision_diabetes import (
    MedicalDecisionDiabetesConfig,
    MedicalDecisionDiabetesModel,
    UCBPolicy,
    ThompsonSamplingPolicy,
)

print("✓ Imports successful")
print(f"JAX version: {jax.__version__}")
print(f"JAX backend: {jax.default_backend()}")

In [None]:
# Create model configuration
config = MedicalDecisionDiabetesConfig(
    n_drugs=5,
    initial_mu=0.5,
    measurement_sigma=0.05
)
model = MedicalDecisionDiabetesModel(config)
key = jax.random.PRNGKey(42)
state = model.init_state(key)
print('✓ Medical decision diabetes model ready')

In [None]:
config = DiabetesConfig(horizon=48)
model = DiabetesModel(config)
key = jax.random.PRNGKey(42)
state = model.init_state(key)
print('✓ Diabetes model ready')

In [None]:
# Track glucose levels
glucose_levels, insulin_doses, rewards = [], [], []
for t in range(48):
    key, k1, k2 = jax.random.split(key, 3)
    decision = jnp.array([2.0])  # Insulin dose
    exog = model.sample_exogenous(k2, state, t)
    reward = model.reward(state, decision, exog)
    glucose_levels.append(float(state[0]))
    insulin_doses.append(float(decision[0]))
    rewards.append(float(reward))
    state = model.transition(state, decision, exog)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(glucose_levels)
plt.axhline(100, color='green', linestyle='--', label='Target')
plt.title('Glucose Level Over Time')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(insulin_doses)
plt.title('Insulin Doses')
plt.tight_layout()
plt.show()