In [21]:
import numpy as np
import random
from math import sqrt
import pandas as pd

In [22]:
#read probability matrix from part 5
ts_df = pd.read_csv("./dataset/generated_data/ts_df.csv")
ts_df = ts_df.iloc[:,1:]

ts_df = ts_df.rename(columns={'action': 'action_0', 'action_2': 'action'})

In [23]:
def divide_CV(input, CV):
    data_list = list(input)
    random.shuffle(data_list)

    subset_size = len(data_list) // CV
    remainder = len(data_list) % CV

    subsets = []
    start_idx = 0
    for i in range(CV):
        end_idx = start_idx + subset_size + (1 if i < remainder else 0)
        subsets.append(data_list[start_idx:end_idx])
        start_idx = end_idx
    return subsets

#define the reward function
def Reward(a, b, mode):
    R = np.zeros(12).reshape(4,3)
    if mode == "reward":
        R[0,:] += 1
    else:
        R[3,:] -= 1
    R[:,1] -= a
    R[:,2] -= b

    return R

In [24]:
# SARSA implementation
def sarsa(train_data, mode, n_states=4, n_actions=3, alpha=0.1, gamma=0.99, epsilon=0.1, iterations=100):
    Q = np.zeros((n_states, n_actions))
    
    for _ in range(iterations):
        for pat_id in train_data['PATNO'].unique():
            pat_seq = train_data[train_data['PATNO'] == pat_id]
            states = list(pat_seq['cluster'])
            actions = list(pat_seq['action'])
            next_states = list(pat_seq['cluster_n'])

            if len(states) < 2:
                continue

            s = states[0]
            a = actions[0] if random.random() > epsilon else random.randint(0, n_actions-1)

            for t in range(1, len(states)):
                s_ = states[t]
                a_ = actions[t] if random.random() > epsilon else random.randint(0, n_actions-1)
                if pd.isna(a_): continue
                
                r = 0

                s = int(s)
                a = int(a)
                s_ = int(s_)
                a_ = int(a_)

                if mode == "reward":
                    if s_ == 0: r += 1
                else:
                    if s_ == 3: r -= 1
                    
                if a == 1: r -= 0.01
                elif a == 2: r -= 0.025

                Q[s, a] += alpha * (r + gamma * Q[s_, a_] - Q[s, a])
                s, a = s_, a_
    
    return Q

In [25]:

# Simulation setup
CV = 10
patients = set(ts_df['PATNO'])
subsets = divide_CV(patients, CV)

tot_sim, tot_const, tot_random, tot_sim_w, tot_real = [[],[]], [[],[]], [[],[]], [[],[]], [[],[]]
LEDD_total = {"optimal":[[],[]], "random":[[],[]], "const":[[],[]], "worst":[[],[]], "real":[[],[]]}
tests = ["optimal", "random", "const", "worst", "real"]

for i in range(CV+1):
    if i < CV:
        valid_idx = subsets[i]
        train_idx = [p for j in range(CV) if j != i for p in subsets[j]]
        valid = ts_df[ts_df['PATNO'].isin(valid_idx)]
        train = ts_df[ts_df['PATNO'].isin(train_idx)]
    else:
        valid = ts_df
        train = ts_df

    # Train SARSA policies
    Q_r = sarsa(train, "reward")
    Q_p = sarsa(train, "penalty")

    policy = {
        "r": np.argmax(Q_r, axis=1),
        "p": np.argmax(Q_p, axis=1),
        "r_w": np.argmin(Q_r, axis=1),
        "p_w": np.argmin(Q_p, axis=1),
    }

    sim_reward, random_reward, const_reward, sim_reward_w, real_reward = [0,0], [0,0], [0,0], [0,0], [0,0]
    tot_n = 0

    for pat in valid['PATNO'].unique():
        pat_d = valid[valid['PATNO'] == pat]
        n = len(pat_d)
        init_LEDD = pat_d['LEDD'].iloc[0]

        LEDD = {t: [init_LEDD, init_LEDD] for t in tests}
        LEDD_traj = {t: [init_LEDD, init_LEDD] for t in tests}
        state = [pat_d['cluster'].iloc[0], pat_d['cluster'].iloc[0]]
        LEDD_total["real"].append(pat_d["LEDD"].mean())

        for t in tests:
            if state[0] == 0:
                if t == "optimal": sim_reward[0] += 1
                elif t == "const": const_reward[0] += 1
                elif t == "worst": sim_reward_w[0] += 1
                elif t == "real": real_reward[0] += 1
                else: random_reward[0] += 1
            elif state[1] == 3:
                if t == "optimal": sim_reward[1] -= 1
                elif t == "const": const_reward[1] -= 1
                elif t == "worst": sim_reward_w[1] -= 1
                elif t == "real": real_reward[1] -= 1
                else: random_reward[1] -= 1

            for nn in range(n - 1):
                if t == "optimal":
                    action = [policy["r"][state[0]], policy["p"][state[1]]]
                elif t == "const":
                    action = [0, 0]
                elif t == "worst":
                    action = [policy["r_w"][state[1]], policy["p_w"][state[1]]]
                elif t == "real":
                    action = [pat_d["action"].iloc[nn], pat_d["action"].iloc[nn]]
                else:
                    action = [random.choice(range(3)), random.choice(range(3))]

                for k in range(2):
                    if action[k] == 1:
                        LEDD[t][k] += 35
                    elif action[k] == 2:
                        LEDD[t][k] += 180
                    else:
                        if t != "const":
                            LEDD[t][k] -= 28
                    LEDD_traj[t][k] += LEDD[t][k]

                state_ = [random.choice(range(4)) for _ in range(2)]  # simulate random transitions

                for k in range(2):
                    if action[k] == 1:
                        penalty = 0.01
                    elif action[k] == 2:
                        penalty = 0.025
                    else:
                        penalty = 0
                    if t == "optimal": sim_reward[k] -= penalty
                    elif t == "const": const_reward[k] -= penalty
                    elif t == "worst": sim_reward_w[k] -= penalty
                    elif t == "real": real_reward[k] -= penalty
                    else: random_reward[k] -= penalty

                if state_[0] == 0:
                    if t == "optimal": sim_reward[0] += 1
                    elif t == "const": const_reward[0] += 1
                    elif t == "worst": sim_reward_w[0] += 1
                    elif t == "real": real_reward[0] += 0.025
                    else: random_reward[0] += 1
                elif state_[1] == 3:
                    if t == "optimal": sim_reward[1] -= 1
                    elif t == "const": const_reward[1] -= 1
                    elif t == "worst": sim_reward_w[1] -= 1
                    elif t == "real": real_reward[1] -= 1
                    else: random_reward[1] -= 1

                state = state_

            for k in range(2):
                LEDD_traj[t][k] /= n
                LEDD_total[t][k].append(LEDD_traj[t][k])

        tot_n += n

    sim_reward = [s / tot_n for s in sim_reward]
    sim_reward_w = [s / tot_n for s in sim_reward_w]
    const_reward = [s / tot_n for s in const_reward]
    random_reward = [s / tot_n for s in random_reward]
    real_rewards = [s / tot_n for s in real_reward]

    for kk in range(2):
        tot_sim[kk].append(sim_reward[kk])
        tot_const[kk].append(const_reward[kk])
        tot_random[kk].append(random_reward[kk])
        tot_sim_w[kk].append(sim_reward_w[kk])
        tot_real[kk].append(real_rewards[kk])

for t in tests:
    LEDD_total[t][0] = np.mean(LEDD_total[t][0])
    LEDD_total[t][1] = np.mean(LEDD_total[t][1])

modes = ["reward", "penalty"]
for aa in range(2):
    print(modes[aa])
    print("opti",  "\t", np.round(np.mean(tot_sim[aa]),4), "\t", np.round(np.std(tot_sim[aa])/sqrt(CV),4))
    print("worst", "\t", np.round(np.mean(tot_sim_w[aa]),4), "\t", np.round(np.std(tot_sim_w[aa])/sqrt(CV),4))
    print("const",  "\t", np.round(np.mean(tot_const[aa]),4),  "\t", np.round(np.std(tot_const[aa])/sqrt(CV),4))
    print("random",  "\t", np.round(np.mean(tot_random[aa]),4), "\t", np.round(np.std(tot_random[aa])/sqrt(CV),4))
    print("real \t", np.round(np.mean(tot_real[aa]),4))

reward
opti 	 0.2592 	 0.0062
worst 	 0.2481 	 0.004
const 	 0.2482 	 0.0041
random 	 0.2448 	 0.0071
real 	 0.0368
penalty
opti 	 -0.1737 	 0.005
worst 	 -0.2133 	 0.0059
const 	 -0.1873 	 0.0044
random 	 -0.2056 	 0.0071
real 	 -0.1895
