# Discharge Decision-Making Model Training & Testing

In [None]:
import sys
import os
# Add parent directory to path (project root)
sys.path.append(os.path.abspath('..'))

In [None]:
from ConMedRL.conmedrl import *
from ConMedRL.data_loader import *

In [None]:
dm_configuration = RLConfigurator()

In [None]:
dm_configuration.choose_config_method()

In [None]:
dm_configuration.config.memory_capacity

In [None]:
outcome_table_train = pd.read_csv('discharge_sample_outcome_table_train.csv')
state_var_table = pd.read_csv('discharge_sample_state_var_table_train.csv')

outcome_table_val = pd.read_csv('discharge_sample_outcome_table_val.csv')
outcome_table_val_select = pd.read_csv('discharge_sample_outcome_table_val_select.csv')

state_var_table_val = pd.read_csv('discharge_sample_state_var_table_val.csv')
state_var_table_val_select = pd.read_csv('discharge_sample_state_var_table_val_select.csv')

In [None]:
terminal_state = np.zeros(state_var_table.shape[1])

In [None]:
len(terminal_state)

In [None]:
outcome_table_train.columns

In [None]:
state_var_table.columns

In [None]:
train_data_loader = TrainDataLoader(cfg = dm_configuration.config, 
                                    outcome_table = outcome_table_train, 
                                    state_var_table = state_var_table, 
                                    terminal_state = terminal_state)

In [None]:
train_data_loader.data_buffer_train(action_name = 'discharge_action', 
                                    done_condition = None, ### None: use action as done condition - only for illustration
                                    num_constraint = 2)

In [None]:
val_data_loader = ValTestDataLoader(cfg = dm_configuration.config, 
                                    outcome_table_select = outcome_table_val_select, 
                                    state_var_table_select = state_var_table_val_select, 
                                    outcome_table = outcome_table_val, 
                                    state_var_table = state_var_table_val, 
                                    terminal_state = terminal_state)

In [None]:
val_data_loader.data_buffer(action_name = 'discharge_action', 
                            done_condition = None, ### None: use action as done condition - only for illustration
                            num_constraint = 2)

In [None]:
ocrl_training = RLTraining(cfg = dm_configuration.config, 
                           state_dim = state_var_table.shape[1], 
                           action_dim = 2, 
                           train_data_loader = train_data_loader.data_torch_loader_train,
                           val_data_loader = val_data_loader.data_torch_loader)

In [None]:
# Building the FQI agent
fqi_agent = ocrl_training.fqi_agent_config(hidden_layers = [128, 128], 
                                           weight_decay = None, 
                                           seed = 1) 

# Building the FQE agents
fqe_agent_obj = ocrl_training.fqe_agent_config(eval_agent = fqi_agent, 
                                               hidden_layers = [1000], 
                                               weight_decay = None, 
                                               eval_target = 'obj', 
                                               seed = 1) 

fqe_agent_con_0 = ocrl_training.fqe_agent_config(eval_agent = fqi_agent, 
                                                 hidden_layers = [1000], 
                                                 weight_decay = None, 
                                                 eval_target = 0, 
                                                 seed = 1) 

fqe_agent_con_1 = ocrl_training.fqe_agent_config(eval_agent = fqi_agent, 
                                                 hidden_layers = [1000], 
                                                 weight_decay = None, 
                                                 eval_target = 1, 
                                                 seed = 1) 

In [None]:
ocrl_training.train(agent_fqi = fqi_agent, 
                    agent_fqe_obj = fqe_agent_obj, 
                    agent_fqe_con_list = [fqe_agent_con_0, fqe_agent_con_1], 
                    constraint = True,
                    save_num = 100,
                    z_value = 1.96)