## Multi-State Model Data Preparation
### Romen Samuel Wabina, MSc

In [150]:
import pandas as pd
import numpy as np
import warnings 
warnings.filterwarnings('ignore')

np.random.seed(0)

### Consider a recurrent multi-state model with five states and seven transitions. Assume patients can have at most three hospitalizations


This code is designed to generate synthetic datasets representing patient transition data in a multi-state model. It simulates the pathways patients take through various health states, such as hospitalization or death, over time. The model comprises five states, and each patient can move between these states in accordance with a set of allowed transitions.

Key components of the code include:

Dataset and pathway generation: Functions such as <code>generate_pathway</code> and <code>apply_censoring</code> create pathways for each patient, marking transitions as observed or censored based on whether a state was reached. The <code>apply_censoring</code> function also modifies transition times, ensuring that if a state is not observed, it is marked as censored and reflects the last observed state's time.

Pathway Analysis: The function <code>extract_max_state</code> extracts the maximum state reached by each patient from their pathway, while <code>row_count</code> calculates the expected number of rows based on predefined pathways. This is essential for organizing the dataset according to specific transition rules.

In [154]:
num_patients = 100 # Let's only generate synthetic dataset for 100 patients 
num_states = 5

patient_ids = np.arange(1, num_patients + 1)
state_columns = ['time_state1', 'status_state1', 
                 'time_state2', 'status_state2', 
                 'time_state3', 'status_state3', 
                 'time_state4', 'status_state4', 
                 'time_state5', 'status_state5']

def generate_pathway(row):
    pathway = ['1']  # Start with the initial state
    for i in range(2, num_states + 1):
        if pd.notna(row[f'time_state{i}']) and row[f'status_state{i}'] == 1:
            pathway.append(str(i))  # Add observed states to pathway
    return ' -> '.join(pathway)

def apply_censoring(row):
    # Find the last observed time and state
    last_time = np.nan
    for i in range(2, num_states + 1):  # Start from state 2 as state 1 is always observed with time 0
        if pd.notna(row[f'time_state{i}']) and row[f'status_state{i}'] == 1:
            last_time = row[f'time_state{i}']
        else:
            # Censor this state's time based on the last observed time
            row[f'time_state{i}'] = last_time
            row[f'status_state{i}'] = 0  # Set as censored
    return row

def extract_max_state(row):
    return max(int(state) for state in row['pathway'].split(" -> "))

def row_count(pathway):
    if pathway == "1 -> 2 -> 3 -> 4":
        return 7
    elif pathway == "1 -> 2":
        return 4
    elif pathway == "1 -> 2 -> 3 -> 4 -> 5":
        return 7
    elif pathway == "1 -> 2 -> 3":
        return 6
    else:
        return None
    
def adjust_tstop(df):
    df = df.sort_values(by = ['PID', 'FROM'])
    for i in range(len(df)):
        row = df.iloc[i]        
        if row['status'] == 0:
            pid = row['PID']
            from_state = row['FROM']
            to_state = row['TO']
            
            if from_state == 1 and to_state == 5:
                tstop_ref = df[(df['PID'] == pid) & (df['FROM'] == 1) & (df['TO'] == 2)]['tstop']
            elif from_state == 1 and to_state == 2:
                tstop_ref = df[(df['PID'] == pid) & (df['FROM'] == 1) & (df['TO'] == 5)]['tstop']
            elif from_state == 2 and to_state == 5:
                tstop_ref = df[(df['PID'] == pid) & (df['FROM'] == 2) & (df['TO'] == 3)]['tstop']
            elif from_state == 2 and to_state == 3:
                tstop_ref = df[(df['PID'] == pid) & (df['FROM'] == 2) & (df['TO'] == 5)]['tstop']
            elif from_state == 3 and to_state == 5:
                tstop_ref = df[(df['PID'] == pid) & (df['FROM'] == 3) & (df['TO'] == 4)]['tstop']
            elif from_state == 3 and to_state == 4:
                tstop_ref = df[(df['PID'] == pid) & (df['FROM'] == 3) & (df['TO'] == 5)]['tstop']
            else:
                tstop_ref = pd.Series()
            
            if not tstop_ref.empty and pd.notna(tstop_ref.values[0]):
                df.at[i, 'tstop'] = tstop_ref.values[0]
            else:
                df.at[i, 'tstop'] = row['tstart']
    
    df = df.dropna()
    return df

#### Let's create a synthetic dataset for our recurrent multi-state model. 

In [155]:
data = {col: [] for col in state_columns}
data['patient_id'] = patient_ids

# transition matrix
allowed_transitions = {
    1: [2, 5],    # State 1 can go to State 2 or State 5
    2: [3, 5],    # State 2 can go to State 3 or State 5
    3: [4, 5],    # State 3 can go to State 4 or State 5
    4: [5],       # State 4 can go to State 5
}

for patient in patient_ids:
    # Define the initial time and transition status
    times = [0]     # Start time for state 1 is always 0
    statuses = [1]  # Start state 1 as observed for all patients

    # Transition pathway for each patient
    pathway = ['1']
    current_state = 1  # Start from state 1

    # Generate observed transitions
    while current_state < num_states:
        if current_state == 1:
            next_state = np.random.choice([2, 5], p=[0.6, 0.4])
        elif current_state == 2:
            next_state = np.random.choice([3, 5], p=[0.7, 0.3])
        elif current_state == 3:
            next_state = np.random.choice([4, 5], p=[0.7, 0.3])
        elif current_state == 4:
            next_state = 5

        # If transitioned to the next state, update time and status
        time_to_next = times[-1] + np.random.randint(10, 100)  
        if time_to_next > 400:  # Cap the maximum time at 400
            break  # Stop if maximum time reached or exceeded
        
        times.append(time_to_next)
        statuses.append(1)  # Mark this as observed
        pathway.append(str(next_state))
        current_state = next_state
        if current_state == 5:
            break  # End at death

    # Fill the data dictionary for this patient
    last_observed_time = times[-1]  # Store the last observed time
    if len(times) < num_states:
        # Calculate a consistent censored time for all remaining states
        censored_time = min(last_observed_time + np.random.randint(10, 30), 400)  

    for i in range(num_states):
        if i < len(times):  # If the state was reached, use the observed times and status
            data[f'time_state{i+1}']  .append(times[i])
            data[f'status_state{i+1}'].append(statuses[i])
        else:
            # Apply the consistent censored time to all unobserved states
            data[f'time_state{i+1}'].append(censored_time)
            data[f'status_state{i+1}'].append(0)  # Mark as censored


wide_format = pd.DataFrame(data)
wide_format['pathway'] = wide_format.apply(generate_pathway, axis=1)
wide_format['time_state1'] = 0
wide_format = wide_format[['patient_id',
                           'time_state1', 'status_state1', 
                           'time_state2', 'status_state2',
                           'time_state3', 'status_state3', 
                           'time_state4', 'status_state4',
                           'time_state5', 'status_state5',
                           'pathway']]

In [162]:
transitions = {
    1: [2, 5],   # state 1 can transition to state 2 or state 5
    2: [3, 5],   # state 2 can transition to state 3 or state 5
    3: [4, 5],   # state 3 can transition to state 4 or state 5
    4: [5]       # state 4 can only transition to state 5
}

long_format = []
for idx, row in wide_format.iterrows():
    pid = row['patient_id']
    pathway = row['pathway']

    for from_state in transitions.keys():
        # Get the start time for the current state
        tstart = row[f"time_state{from_state}"]
        
        # Ensure the start time exists (patient reached this state)
        if pd.notna(tstart):
            # Loop over possible 'to' states from the current 'from_state'
            for to_state in transitions[from_state]:
                # Get the stop time and status for the 'to_state'
                tstop = row[f"time_state{to_state}"]
                status = row[f"status_state{to_state}"] if pd.notna(tstop) else 0

                # Ensure the tstop is meaningful; if not observed, use tstart as tstop for censoring
                tstop = tstop if pd.notna(tstop) else tstart
                
                # Add a row to the long format data including the pathway
                long_format.append({
                    "PID": pid,
                    "FROM": from_state,
                    "TO": to_state,
                    "status": status,
                    "tstart": tstart,
                    "tstop": tstop,
                    "pathway": pathway
                })

# Convert the revised list to a DataFrame
long_format = pd.DataFrame(long_format)
long_format['max_state'] = long_format.apply(extract_max_state, axis=1)
long_format = long_format[long_format['FROM'] <= long_format['max_state']]
long_format = long_format.drop(columns = ['max_state'])
long_format = adjust_tstop(long_format)
long_format

Unnamed: 0,PID,FROM,TO,status,tstart,tstop,pathway
0,1.0,1.0,2.0,1.0,0.0,12.0,1 -> 2
1,1.0,1.0,5.0,0.0,0.0,12.0,1 -> 2
2,1.0,2.0,3.0,0.0,12.0,33.0,1 -> 2
3,1.0,2.0,5.0,0.0,12.0,33.0,1 -> 2
7,2.0,1.0,2.0,1.0,0.0,127.0,1 -> 2 -> 3
...,...,...,...,...,...,...,...
695,100.0,2.0,3.0,1.0,24.0,36.0,1 -> 2 -> 3 -> 4
696,100.0,2.0,5.0,0.0,24.0,155.0,1 -> 2 -> 3 -> 4
697,100.0,3.0,4.0,1.0,36.0,131.0,1 -> 2 -> 3 -> 4
698,100.0,3.0,5.0,0.0,36.0,155.0,1 -> 2 -> 3 -> 4


Patient 1 went from state 1 to state 2 only. From state 1, patient 1 us at risk to state 2 (hospitalization 1) or state 5 (death). At time = 12, patient 1 entered state 2 which is at risk of entering to state 3 (hospitalization 2) or state 5 (death).

In [161]:
long_format[long_format['PID'] == 1]

Unnamed: 0,PID,FROM,TO,status,tstart,tstop,pathway
0,1.0,1.0,2.0,1.0,0.0,12.0,1 -> 2
1,1.0,1.0,5.0,0.0,0.0,12.0,1 -> 2
2,1.0,2.0,3.0,0.0,12.0,33.0,1 -> 2
3,1.0,2.0,5.0,0.0,12.0,33.0,1 -> 2


In [163]:
long_format['expected_rows'] = long_format['pathway'].apply(row_count)

actual_row_counts = long_format.groupby('PID').size().reset_index(name='actual_rows')
expected_vs_actual_counts = pd.merge(
    long_format[['PID', 'expected_rows']].drop_duplicates(),
    actual_row_counts,
    on = 'PID')

In [149]:
expected_vs_actual_counts['expected_rows'].sum()

56269

In [147]:
expected_vs_actual_counts['actual_rows'].sum()

56269

Always put the multi-state model diagram or transition matrix aside. Think about competing risk. What may happen next for a specific state? What are the risk sets?

Thorough data preparation is the foundation of accurate multi-state modeling 
- Keep risk sets in mind
- Anticipate transitions 
- Double-check your data for accuracy

Be meticulous! Details matter – small inaccuracies can cause model misfits
