# Clinical Trials - Adaptive Dose Optimization

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pedronahum/stochastic-optimization/blob/master/notebooks/clinical_trials.ipynb)

## Problem Overview

Adaptive dose optimization for clinical trials where we must balance:
- **Efficacy**: Finding the right dose for patient outcomes
- **Safety**: Avoiding adverse effects from incorrect dosing
- **Learning**: Adapting based on patient responses

## Mathematical Formulation

### State Space
$$s_t = (t, x_t)$$
where $x_t \in \mathbb{R}$ is the patient health metric

### Dynamics
$$x_{t+1} = x_t + a_t + \epsilon_t$$
where:
- $a_t \in \mathbb{R}$ is the dose decision
- $\epsilon_t \sim \mathcal{N}(\mu, \sigma^2)$ is stochastic response

### Reward
$$r_t = -|x_t|$$
Penalizes deviation from healthy state (x=0)

### Objective
$$\max_{\pi} \mathbb{E}\left[\sum_{t=0}^{T-1} -|x_t|\right]$$

In [None]:
# Setup
!pip install -q jax jaxlib jaxtyping chex flax matplotlib
import os
if 'COLAB_GPU' in os.environ or not os.path.exists('problems'):
    !git clone https://github.com/pedronahum/stochastic-optimization.git
    os.chdir('stochastic-optimization')

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from problems.clinical_trials import ClinicalTrialsConfig, ClinicalTrialsModel, LinearDosePolicy

# Create model
config = ClinicalTrialsConfig(horizon=50, mu=0.1, sigma=0.5)
model = ClinicalTrialsModel(config)
policy = LinearDosePolicy(weight=-0.5)

print(f"âœ“ Setup complete")

In [None]:
# Run simulation
key = jax.random.PRNGKey(42)
state = model.reset(key=key)
states, actions, rewards = [], [], []

for t in range(50):
    key, subkey = jax.random.split(key)
    action = policy.act(state, key=subkey)
    
    key, subkey = jax.random.split(key)
    new_state, reward = model.step(state, action, key=subkey)
    
    states.append(float(state.x))
    actions.append(float(action))
    rewards.append(float(reward))
    state = new_state

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(states, 'o-')
axes[0].axhline(0, color='red', linestyle='--', alpha=0.3)
axes[0].set_title('Health Metric Over Time')
axes[0].set_xlabel('Time Step')
axes[0].set_ylabel('Health State (x)')

axes[1].plot(actions, 's-', color='green')
axes[1].axhline(0, color='red', linestyle='--', alpha=0.3)
axes[1].set_title('Dose Decisions')
axes[1].set_xlabel('Time Step')
axes[1].set_ylabel('Dose (action)')

axes[2].plot(rewards, '^-', color='purple')
axes[2].set_title('Rewards (Negative Deviation)')
axes[2].set_xlabel('Time Step')
axes[2].set_ylabel('Reward')

plt.tight_layout()
plt.show()

print(f"Total reward: {sum(rewards):.2f}")
print(f"Final health state: {states[-1]:.3f}")