In [None]:
%load_ext autoreload
%autoreload 2

# Learning Sepsis Environment MDP

Because the sepsis simulator is governed by fairly complex state transition dynamics, in this notebook we will approximate it by sampling. The approach is similar to one utilized in M. Oberst et al., “[Counterfactual Off-Policy Evaluation with Gumbel-Max Structural Causal Models](https://arxiv.org/abs/1905.05824)” paper.


In [2]:
import tqdm
import pickle
import itertools
import numpy as np

In [1]:
from pathlib import Path
from typing import Tuple

In [17]:
from ced.actors.sepsis import SepsisAction
from ced.envs.sepsis import Sepsis, State

To configure our approximation, we introduce a couple of variables:
- `SEED`: random seed used for reproducibility
- `NUM_ITERATIONS`: number of samples to draw from the simulator to approximate a single transition
- `NUM_ACTIONS`: total number of agent's actions
- `NUM_STATES`: total number of environment states
- `SAVE_PATH`: path to directory where resulting matrices will be saved

In [5]:
SEED = 5586
NUM_SAMPLES = 10000
NUM_ACTIONS = SepsisAction.NUM_TOTAL
NUM_STATES = State.NUM_TOTAL
SAVE_PATH = Path("./results/sepsis")
SAVE_PATH_ORIGINAL = SAVE_PATH / "mdp_original.pkl"
SAVE_PATH_AI = SAVE_PATH / "mdp_ai.pkl"
TRANSITIONS_ORIGINAL = Path("./assets/sepsis/sepsis_transition_probs_original.json")
TRANSITIONS_AI = Path("./assets/sepsis/sepsis_transition_probs_ai.json")

In [6]:
SAVE_PATH.mkdir(exist_ok=True, parents=True)

In [None]:
rng = np.random.default_rng(SEED)

In [7]:
def learn_mdp(env: Sepsis) -> Tuple[np.ndarray, ...]:
    transition_matrix = np.zeros((NUM_ACTIONS, NUM_STATES, NUM_STATES))
    reward_matrix = np.zeros((NUM_ACTIONS, NUM_STATES, NUM_STATES))
    initial_state_distribution = np.zeros((NUM_STATES, ))

    states = range(NUM_STATES)
    actions = range(NUM_ACTIONS)
    iterations = range(NUM_SAMPLES)

    # learn transition matrix
    for s_t, a_t, _ in tqdm.tqdm(itertools.product(states, actions, iterations), total=NUM_STATES * NUM_ACTIONS * NUM_SAMPLES, desc="Learning MDP transition matrix"):
        s_curr = State.from_index(s_t)
        s_next = env.step(state=s_curr, actions=[a_t], rng=rng)
        transition_matrix[a_t, s_curr.index, s_next.index] += 1

    # normalize transition matrix
    transition_matrix /= NUM_SAMPLES
    transition_matrix /= transition_matrix.sum(axis=-1, keepdims=True)
        
    # learn reward matrix
    for s_t in tqdm.tqdm(states, desc="Learning MDP reward matrix"):
        reward_matrix[:, :, s_t] = State.from_index(s_t).reward

    # learn initial state distribution
    for _ in tqdm.tqdm(range(NUM_SAMPLES), desc="Learning MDP initial state distribution"):
        state = env.reset(rng=rng)
        initial_state_distribution[state.index] += 1
    
    # normalize initial state distribution
    initial_state_distribution /= NUM_SAMPLES
    initial_state_distribution /= initial_state_distribution.sum(axis=-1, keepdims=True)

    return transition_matrix, reward_matrix, initial_state_distribution

## Ground-Truth Transition & Reward Matrices

We start our approximation with the transition and reward matrices, using the ground-truth underlying transition probabilities. Throughout this project, we will rely on [PyMDP Toolbox package](https://pymdptoolbox.readthedocs.io/en/latest/index.html) for tasks such as efficiently running policy iteration algorithm. The [format of the transition and reward matrices](https://pymdptoolbox.readthedocs.io/en/latest/api/mdp.html#mdptoolbox.mdp.MDP) expected by the PyMDP is `(A, S, S)`, where `A` is the number of actions and `S` is the number of states.

In [8]:
if not SAVE_PATH_ORIGINAL.exists():
    env = Sepsis(transition_probabilities=TRANSITIONS_ORIGINAL)
    transition_matrix, reward_matrix, initial_state_distribution = learn_mdp(env)

    with open(SAVE_PATH_ORIGINAL, "wb") as f:
        pickle.dump({
            "transition_matrix": transition_matrix,
            "reward_matrix": reward_matrix,
            "initial_state_distribution": initial_state_distribution}, f)

In [9]:
if SAVE_PATH_ORIGINAL.exists():
    with open(SAVE_PATH_ORIGINAL, "rb") as f: data = pickle.load(f)
    transition_matrix, reward_matrix = data["transition_matrix"], data["reward_matrix"]

We perform some sanity checks to ensure learned matrices make sense.

In [10]:
states = [State.from_index(i) for i in range(State.NUM_TOTAL)]
s_diabetic = [s for s in states if s.diabetes == 1]
s_non_diabetic = [s for s in states if s.diabetes == 0]

In [11]:
for s1, s2 in itertools.product(s_diabetic, s_non_diabetic):
    # ensures we cannot transition between diabetic and non-diabetic states
    assert (transition_matrix[:, s1.index, s2.index] == 0).all()
    assert (transition_matrix[:, s2.index, s1.index] == 0).all()

In [12]:
for s_index in range(State.NUM_TOTAL):
    # ensures states are correctly encoded and decoded
    assert s_index == State.from_index(s_index).index

In [13]:
# ensures we have proper probabilities
assert np.allclose(transition_matrix.sum(axis=-1), 1.0)

In [14]:
# ensures rewards are in the expected range
assert {-1.0, 0.0, 1.0} == set(np.unique(reward_matrix).tolist())

## AI Agent MDP

In our experiments, it is crucial that policy of the AI is (at least partially) different than the policy of the human. To achieve this, we will modify the underlying probabilities of the environment to obtain an updated MDP, which will then be used to learn the AI policy.

### Modifying Probabilities

The goal of this approach is to make AI policy **generally give higher doses of medications** than a clinician policy. To achieve this, we will increase the following probabilities:

- Probability of successful medication effect. This implies that, whenever AI policy prescribes a medication, that medication is expected to work with higher probability. 
- Probability of diverting from normal when removing medication. When taking patient off a medication, we increase the probability that the patient's state diverts from normal (i.e., becomes lower or higher).

Jointly, these two points should incentivise AI policy to overall prescribe more medications than the human policy. The probabilities are stored in `assets/sepsis_transition_probs.ai.json` file.

In [15]:
if not SAVE_PATH_AI.exists():
    env = Sepsis(transition_probabilities=TRANSITIONS_AI)
    transition_matrix, reward_matrix, initial_state_distribution = learn_mdp(env)

    with open(SAVE_PATH_AI, "wb") as f:
        pickle.dump({
            "transition_matrix": transition_matrix,
            "reward_matrix": reward_matrix,
            "initial_state_distribution": initial_state_distribution}, f)