# OCRL Agent Set-up & Experiment

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import Variable

import numpy as np
import pandas as pd

import shap

import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
import math
import os
import copy
import random
from pathlib import Path
import sys
import datetime
from tqdm import tqdm
from collections import Counter

In [2]:
# torch.cuda.is_available()
# torch.cuda.device_count()

In [3]:
# torch.cuda.get_device_name(0)

In [4]:
print("torch.get_num_threads():", torch.get_num_threads())
print("torch.get_num_interop_threads():", torch.get_num_interop_threads())

torch.get_num_threads(): 128
torch.get_num_interop_threads(): 128


In [5]:
# device_id = 0 
# props = torch.cuda.get_device_properties(device_id)
# total_memory = props.total_memory  

# print("Total GPU memory (bytes):", total_memory)
# print("Total GPU memory (GB):", total_memory / 1024**3)

In [6]:
torch.cuda.empty_cache()

In [7]:
print(torch.version.cuda)

11.8


### Set-up OCRL Agent

In [8]:
curr_path = str(Path().absolute())
parent_path = str(Path().absolute().parent)
sys.path.append(parent_path) # add current terminal path to sys.path
curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")  # obtain current time

In [9]:
class FCN_fqe(nn.Module):
    def __init__(self, state_dim, action_dim):

        super(FCN_fqe, self).__init__()
        self.fc1 = nn.Linear(state_dim, action_dim)

    def forward(self, x):

        x = self.fc1(x)

        return x

In [10]:
class FCN_fqi(nn.Module):
    def __init__(self, state_dim, action_dim):

        super(FCN_fqi, self).__init__()
        self.fc1 = nn.Linear(state_dim, action_dim)

    def forward(self, x):

        x = self.fc1(x)

        return x

In [11]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, obj_cost, con_cost, next_state, done):

        if not isinstance(con_cost, list) and not isinstance(con_cost, tuple):
            con_cost = [con_cost]

        if len(self.buffer) < self.capacity:
            self.buffer.append(None)

        self.buffer[self.position] = (state, action, obj_cost, con_cost, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, obj_cost, con_cost, next_state, done = zip(*batch)

        con_cost = [list(costs) for costs in zip(*con_cost)]

        return state, action, obj_cost, con_cost, next_state, done

    def extract(self):
        batch = self.buffer
        state, action, obj_cost, con_cost, next_state, done = zip(*batch)

        con_cost = [list(costs) for costs in zip(*con_cost)]

        return state, action, obj_cost, con_cost, next_state, done

    def clear(self):
        self.buffer = []
        self.position = 0

    def __len__(self):
        return len(self.buffer)

In [12]:
class FQE:
    def __init__(self, cfg, state_dim, action_dim, id_stop, eval_agent, weight_decay, eval_target = 'obj'):

        self.device = cfg.device

        self.gamma = cfg.gamma

        ### indicate optimal stopping structure or not
        self.id_stop = id_stop

        ### For constraint cost, specify which constraint to evaluate
        if eval_target == 'obj':
            self.lr_fqe = cfg.lr_fqe_obj
        else:
            self.lr_fqe = cfg.lr_fqe_con[eval_target] 

        # define policy Q-Estimator
        self.policy_net = FCN_fqe(state_dim, action_dim).to(self.device)
        # define target Q-Estimator
        self.target_net = FCN_fqe(state_dim, action_dim).to(self.device)

        # initialize target Q-Estimator with policy Q-Estimator
        for target_param, param in zip(self.target_net.parameters(), self.policy_net.parameters()):
            target_param.data.copy_(param.data)

        self.weight_decay = weight_decay
        self.optimizer = optim.SGD(self.policy_net.parameters(), lr = self.lr_fqe)
        # self.optimizer = optim.SGD([
        #     {'params': self.policy_net.fc1.weight, 'weight_decay': self.weight_decay}, 
        #     {'params': self.policy_net.fc1.bias,   'weight_decay': 0.0}
        #     ], lr = self.lr_fqe)
        # self.optimizer = optim.Adam(self.policy_net.parameters(), lr = self.lr_fqe)
        
        # define loss function
        self.loss = cfg.loss_fqe
        
        # input the evaluation agent
        self.eval_agent = eval_agent

    def update(self, state_batch, action_batch, cost_batch, next_state_batch, done_batch, disch_batch):

        # We need to evaluate the parameterized policy
        policy_action_batch = self.eval_agent.rl_policy(next_state_batch)

        # predicted Q-value using policy Q-network
        q_values = self.policy_net(state_batch).gather(dim = 1, index = action_batch)

        # target Q-value calculated by target Q-network
        next_q_values = self.target_net(next_state_batch).gather(dim = 1, index = policy_action_batch).squeeze(1).detach()
        
        if self.id_stop == 0:
            expected_q_values = cost_batch + self.gamma * next_q_values * (1 - done_batch)
        else:
            expected_q_values = cost_batch + self.gamma * next_q_values * (1 - disch_batch)
            
        loss = self.loss(q_values, expected_q_values.unsqueeze(1))

        # Update reward Q-network by minimizing the above loss function
        self.optimizer.zero_grad()
        loss.backward()
        
        for param in self.policy_net.parameters():
            param.grad.data.clamp_(-1, 1)

        self.optimizer.step()

        return loss.item()

    def avg_Q_value_est(self, state_batch):
        policy_action_batch = self.eval_agent.rl_policy(state_batch)
        q_values = self.policy_net(state_batch).gather(dim = 1, index = policy_action_batch).squeeze(1)

        q_mean = q_values.mean()
        q_std = q_values.std()
        n = q_values.shape[0]
    
        if n <= 1 or q_std == 0:
            return q_mean.item(), q_mean.item()
    
        z = 2.33  
        q_upper_bound = q_mean + z * (q_std / math.sqrt(n))
    
        return q_mean.item(), q_upper_bound.item()

    def save(self, path):
        torch.save(self.policy_net.state_dict(), path + 'FQE_policy_network.pth')
        torch.save(self.target_net.state_dict(), path + 'FQE_target_network.pth')

In [13]:
class FQI:
    def __init__(self, cfg, state_dim, action_dim):
        
        self.device = cfg.device
        self.gamma = cfg.gamma
        
        self.lr = cfg.lr_fqi

        self.policy_net = FCN_fqi(state_dim, action_dim).to(self.device)
        self.target_net = FCN_fqi(state_dim, action_dim).to(self.device)

        for target_param, param in zip(self.target_net.parameters(), self.policy_net.parameters()):
            target_param.data.copy_(param.data)

        # self.optimizer = optim.SGD(self.policy_net.parameters(), lr = self.lr)
        self.optimizer = optim.SGD([
            {'params': self.policy_net.fc1.weight, 'weight_decay': 1e-2}, 
            {'params': self.policy_net.fc1.bias,   'weight_decay': 0.0}
            ], lr = self.lr)
        # self.optimizer = optim.Adam(self.policy_net.parameters(), lr = self.lr)

        self.loss = cfg.loss_fqi

    def update(self, lambda_t_list, state_batch, action_batch, obj_cost_batch, con_cost_batch, next_state_batch, done_batch, disch_batch):

        q_values = self.policy_net(state_batch).gather(dim = 1, index = action_batch)
        policy_action_batch = self.policy_net(next_state_batch).min(1)[1].unsqueeze(1)
        next_q_values = self.target_net(next_state_batch).gather(dim = 1, index = policy_action_batch).squeeze(1).detach()

        sum_con_cost = 0
        for i in range(len(lambda_t_list)):
            lambda_t = lambda_t_list[i]
            sum_con_cost += lambda_t * con_cost_batch[i]

        expected_q_values = (obj_cost_batch + sum_con_cost) + self.gamma * next_q_values * (1 - done_batch)
        # expected_q_values = (obj_cost_batch + sum_con_cost) + self.gamma * next_q_values * (1 - disch_batch)

        loss = self.loss(q_values, expected_q_values.unsqueeze(1))

        self.optimizer.zero_grad()
        loss.backward()
        for param in self.policy_net.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

        return loss.item()

    def avg_Q_value_est(self, state_batch):

        q_values = self.policy_net(state_batch)
        avg_q_values = q_values.min(1)[0].unsqueeze(1).mean().item()

        return avg_q_values

    def rl_policy(self, state_batch):

        q_values = self.policy_net(state_batch)
        policy_action_batch = q_values.min(1)[1].unsqueeze(1)

        return policy_action_batch

    def save(self, path):
        torch.save(self.policy_net.state_dict(), path + 'Offline_FQI_policy_network.pth')
        torch.save(self.target_net.state_dict(), path + 'Offline_FQI_target_network.pth')

In [14]:
class RLConfig:
    def __init__(self, algo_name, train_eps, gamma, lr_fqi, lr_fqe_obj, constraint_num, lr_fqe_con_list, lr_lambda_list, threshold_list):
        
        self.algo = algo_name  # name of algorithm

        self.train_eps = train_eps  #the number of trainng episodes

        self.gamma = gamma # discount factor
        
        self.constraint_num = constraint_num

        # learning rates
        self.lr_fqi = lr_fqi
        self.lr_fqe_obj = lr_fqe_obj
        self.lr_fqe_con = [0 for i in range(constraint_num)]
        self.lr_lam = [0 for i in range(constraint_num)]

        # constraint threshold
        self.constraint_limit = [0 for i in range(constraint_num)]
        for i in range(constraint_num):
            self.lr_fqe_con[i] = lr_fqe_con_list[i]
            self.lr_lam[i] = lr_lambda_list[i]
            self.constraint_limit[i] = threshold_list[i]

        self.train_eps_steps = int(1e3)  # the number of steps in each training episode

        self.batch_size = 256

        self.loss_fqi = nn.MSELoss()
        self.loss_fqe = nn.MSELoss()

        self.memory_capacity = int(2e6)  # capacity of Replay Memory

        self.target_update = 100 # update frequency of target net
        self.tau = 0.01

        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # check GPU
        self.device = torch.device("cpu")

In [15]:
class DataLoader:
    def __init__(self, cfg, state_id_table, rl_cont_state_table, rl_cont_state_table_scaled, terminal_state):
        self.cfg = cfg
    
        # Load datasets
        self.state_df_id = state_id_table
        self.rl_cont_state_table = rl_cont_state_table
        self.rl_cont_state_table_scaled = rl_cont_state_table_scaled

        self.terminal_state = terminal_state

    def data_buffer_train(self, num_constraint = 2):
        self.train_memory = ReplayBuffer(self.cfg.memory_capacity)

        for i in range(len(self.state_df_id)):
            state = self.rl_cont_state_table_scaled.values[i]
            action = self.state_df_id['discharge_action'].values[i]
            
            if action == 1.0:
                if self.state_df_id['death'].values[i] == 1.0:
                    if self.state_df_id['discharge_fail'].values[i] == 1.0:
                        done = 0.0
                    else:
                        done = 1.0
                else:
                    if self.state_df_id['discharge_fail'].values[i] == 0.0:
                        done = 1.0
                    else:
                        done = 0.0
            else:
                done = 0.0
            
            obj_cost = self.state_df_id['mortality_costs_md'].values[i]
            con_cost = []
            
            for j in range(num_constraint):
                cost_col = f'con_cost_{j}'  
                if cost_col in self.state_df_id.columns:
                    con_cost.append(self.state_df_id[cost_col].values[i])
                else:
                    con_cost.append(0.0) 

            if done == 0.0:
                idx = self.state_df_id.index[i]
                next_state = self.rl_cont_state_table_scaled.loc[idx + 1].values
            else:
                next_state = self.terminal_state

            self.train_memory.push(state, action, obj_cost, con_cost, next_state, done)

    def data_torch_loader_train(self):
        state_batch, action_batch, obj_cost_batch, con_cost_batch, next_state_batch, done_batch = self.train_memory.sample(self.cfg.batch_size)

        state_batch = torch.tensor(np.array(state_batch), device = self.cfg.device, dtype = torch.float)

        disch_batch = list(action_batch)
        action_batch = torch.tensor(np.array(action_batch), device = self.cfg.device, dtype = torch.long).unsqueeze(1)
        
        obj_cost_batch = torch.tensor(np.array(obj_cost_batch), device = self.cfg.device, dtype = torch.float)
        con_cost_batch = [torch.tensor(np.array(cost), device = self.cfg.device, dtype=torch.float) for cost in con_cost_batch]
        next_state_batch = torch.tensor(np.array(next_state_batch), device = self.cfg.device, dtype = torch.float)
        
        done_batch = torch.tensor(np.array(done_batch), device = self.cfg.device, dtype = torch.float)
        disch_batch = torch.tensor(np.array(disch_batch), device = self.cfg.device, dtype = torch.float)

        return state_batch, action_batch, obj_cost_batch, con_cost_batch, next_state_batch, done_batch, disch_batch

In [16]:
class ValDataLoader:
    def __init__(self, cfg, state_id_table, rl_cont_state_table, rl_cont_state_table_scaled, state_id_table_1, rl_cont_state_table_scaled_1, terminal_state):
        self.cfg = cfg
    
        # Load datasets
        self.state_df_id = state_id_table
        self.rl_cont_state_table = rl_cont_state_table
        self.rl_cont_state_table_scaled = rl_cont_state_table_scaled

        self.state_df_id_1 = state_id_table_1
        self.rl_cont_state_table_scaled_1 = rl_cont_state_table_scaled_1

        self.terminal_state = terminal_state

    def data_buffer_val(self, num_constraint = 2):
        self.val_memory = ReplayBuffer(self.cfg.memory_capacity)

        for i in range(len(self.state_df_id)):
            state = self.rl_cont_state_table_scaled.values[i]
            action = self.state_df_id['discharge_action'].values[i]
            
            if action == 1.0:
                if self.state_df_id['death'].values[i] == 1.0:
                    if self.state_df_id['discharge_fail'].values[i] == 1.0:
                        done = 0.0
                    else:
                        done = 1.0
                else:
                    if self.state_df_id['discharge_fail'].values[i] == 0.0:
                        done = 1.0
                    else:
                        done = 0.0
            else:
                done = 0.0
            
            obj_cost = self.state_df_id['mortality_costs_md'].values[i]
            con_cost = []
            
            for j in range(num_constraint):
                cost_col = f'con_cost_{j}'  
                if cost_col in self.state_df_id.columns:
                    con_cost.append(self.state_df_id[cost_col].values[i])
                else:
                    con_cost.append(0.0) 

            if done == 0.0:
                idx = self.state_df_id.index[i]
                next_state = self.rl_cont_state_table_scaled_1.loc[idx + 1].values
            else:
                next_state = self.terminal_state

            self.val_memory.push(state, action, obj_cost, con_cost, next_state, done)

    def data_torch_loader_val(self):
        state_batch, action_batch, obj_cost_batch, con_cost_batch, next_state_batch, done_batch = self.val_memory.extract()

        state_batch = torch.tensor(np.array(state_batch), device = self.cfg.device, dtype = torch.float)

        return state_batch

In [17]:
class RLTraining:
    def __init__(self, cfg, state_dim, action_dim, val_data_loader, data_loader):
        self.cfg = cfg
        self.state_dim = state_dim
        self.action_dim = action_dim

        self.val_data_loader = val_data_loader
        self.data_loader = data_loader

        # Store for FQI models history
        self.fqi_models_history = []
        
        # Store for FQE models history
        self.fqe_obj_models_history = []
        self.fqe_con_models_history = {}

    def fqi_agent_config(self, seed = 1):
        agent_fqi = FQI(self.cfg, self.state_dim, self.action_dim)
        torch.manual_seed(seed)
        return agent_fqi

    def fqe_agent_config(self, id_stop, eval_agent, weight_decay, eval_target, seed = 2):
        agent_fqe = FQE(self.cfg, self.state_dim, self.action_dim, id_stop, eval_agent, weight_decay, eval_target)
        torch.manual_seed(seed)
        return agent_fqe

    def train(self, agent_fqi, agent_fqe_obj, agent_fqe_con_list, constraint = None):
        print('Start to train!')
        print(f'Algorithm:{self.cfg.algo}, Device:{self.cfg.device}')

        self.FQI_loss = []
        self.FQI_est_values = []

        self.FQE_loss_obj = []
        self.FQE_loss_con = {i: [] for i in range(len(agent_fqe_con_list))}

        self.FQE_est_obj_costs = []
        self.FQE_est_con_costs = {i: [] for i in range(len(agent_fqe_con_list))}

        self.lambda_dict = {i: [] for i in range(len(agent_fqe_con_list))}
        
        # Initialize the model history dictionaries for constraint agents
        for i in range(len(agent_fqe_con_list)):
            self.fqe_con_models_history[i] = []

        lambda_t_list = [0 for i in range(len(agent_fqe_con_list))]
        lambda_update_list = [0 for i in range(len(agent_fqe_con_list))]

        state_batch_val = self.val_data_loader()
        
        model_update_counter = 0  # Counter to track model updates

        for k in range(self.cfg.train_eps):
            loss_list_fqi = []
            loss_list_fqe_obj = []
            loss_list_fqe_con = {i: [] for i in range(len(agent_fqe_con_list))}

            fqi_est_list = []
            fqe_est_obj = []
            fqe_est_con = {i: [] for i in range(len(agent_fqe_con_list))}

            for j in tqdm(range(self.cfg.train_eps_steps)):

                state_batch, action_batch, obj_cost_batch, con_cost_batch, next_state_batch, done_batch, disch_batch = self.data_loader()
                
                # update the policy agent for learning agent (FQI) and evaluation agent (FQE)
                loss_rl = agent_fqi.update(lambda_t_list, state_batch, action_batch, obj_cost_batch, con_cost_batch, next_state_batch, done_batch, disch_batch)
                loss_ev_obj = agent_fqe_obj.update(state_batch, action_batch, obj_cost_batch, next_state_batch, done_batch, disch_batch)
                
                # Save FQI model state
                model_update_counter += 1
                if len(self.fqi_models_history) >= 2000:
                    self.fqi_models_history.pop(0)  
                self.fqi_models_history.append({
                    'update_num': model_update_counter,
                    'epoch': k,
                    'step': j,
                    'model_state': self._get_model_state(agent_fqi)
                })

                # Save FQE objective model state
                if len(self.fqe_obj_models_history) >= 2000:
                    self.fqe_obj_models_history.pop(0) 
                self.fqe_obj_models_history.append({
                    'update_num': model_update_counter,
                    'epoch': k,
                    'step': j,
                    'model_state': self._get_model_state(agent_fqe_obj)
                })
                
                ##############################################################################################################
                loss_list_fqi.append(loss_rl)
                loss_list_fqe_obj.append(loss_ev_obj)

                if constraint == None:
                    for m in range(len(agent_fqe_con_list)):
                        loss_con = agent_fqe_con_list[m].update(state_batch, action_batch, con_cost_batch[m], next_state_batch, done_batch, disch_batch)
                        loss_list_fqe_con[m].append(loss_con)
                        
                        # Save FQE constraint model state
                        if len(self.fqe_con_models_history[m]) >= 2000:
                            self.fqe_con_models_history[m].pop(0)  
                        self.fqe_con_models_history[m].append({
                            'update_num': model_update_counter,
                            'epoch': k,
                            'step': j,
                            'model_state': self._get_model_state(agent_fqe_con_list[m])
                        })
                        
                        con_est_value, con_est_value_up = agent_fqe_con_list[m].avg_Q_value_est(state_batch_val)
                        fqe_est_con[m].append(con_est_value)
                    
                    fqi_est_value = agent_fqi.avg_Q_value_est(state_batch_val)
                    avg_q_value_obj, avg_q_value_obj_up = agent_fqe_obj.avg_Q_value_est(state_batch_val)

                    fqi_est_list.append(fqi_est_value)
                    fqe_est_obj.append(avg_q_value_obj)

                    lambda_update_list = [0 for i in range(len(agent_fqe_con_list))]
                    lambda_t_list = [0 for i in range(len(agent_fqe_con_list))]
                
                else:
                    for m in range(len(agent_fqe_con_list)):
                        loss_con = agent_fqe_con_list[m].update(state_batch, action_batch, con_cost_batch[m], next_state_batch, done_batch, disch_batch)
                        loss_list_fqe_con[m].append(loss_con)
                        
                        # Save FQE constraint model state
                        if len(self.fqe_con_models_history[m]) >= 2000:
                            self.fqe_con_models_history[m].pop(0)
                            
                        self.fqe_con_models_history[m].append({
                            'update_num': model_update_counter,
                            'epoch': k,
                            'step': j,
                            'model_state': self._get_model_state(agent_fqe_con_list[m])
                        })
                        
                        con_est_value, con_est_value_up = agent_fqe_con_list[m].avg_Q_value_est(state_batch_val)
                        fqe_est_con[m].append(con_est_value)
                        
                        lambda_update_list[m] = con_est_value_up - self.cfg.constraint_limit[m]
                        lambda_t_list[m] = lambda_t_list[m] + (self.cfg.lr_lam[m] * lambda_update_list[m])
                        lambda_t_list[m] = max(0, lambda_t_list[m])
                    
                    fqi_est_value = agent_fqi.avg_Q_value_est(state_batch_val)
                    avg_q_value_obj, avg_q_value_obj_up = agent_fqe_obj.avg_Q_value_est(state_batch_val)

                    fqi_est_list.append(fqi_est_value)
                    fqe_est_obj.append(avg_q_value_obj)
                ######################################################################################
                if j % self.cfg.target_update == 0:

                    ### update the target agent for learning agent (FQI)
                    for target_param, policy_param in zip(agent_fqi.target_net.parameters(), agent_fqi.policy_net.parameters()):
                        target_param.data.copy_(self.cfg.tau * policy_param.data + (1 - self.cfg.tau) * target_param.data)

                    ### update the target agent for evaluation agent (FQE objective cost)
                    for target_param, policy_param in zip(agent_fqe_obj.target_net.parameters(), agent_fqe_obj.policy_net.parameters()):
                        target_param.data.copy_(self.cfg.tau * policy_param.data + (1 - self.cfg.tau) * target_param.data)

                    ### update the target agent for evaluation agent (FQE constraint cost)
                    for agent_fqe_con in agent_fqe_con_list:
                        for target_param, policy_param in zip(agent_fqe_con.target_net.parameters(), agent_fqe_con.policy_net.parameters()):
                            target_param.data.copy_(self.cfg.tau * policy_param.data + (1 - self.cfg.tau) * target_param.data)
                #########################################################################################
            print(f"Epoch {k + 1}/{self.cfg.train_eps}")
            print(f"Average FQE estimated objective cost after epoch {k + 1}: {np.mean(fqe_est_obj)}")        

            for m in range(len(agent_fqe_con_list)):
                print(f"Average FQE estimated constraint cost of constraint {m} after epoch {k + 1}: {np.mean(fqe_est_con[m])}")
                print(f"Dual variable of constraint {m} after epoch {k + 1}: {lambda_t_list[m]}")
                print(f"Dual variable update of constraint {m} after epoch {k + 1}: {lambda_update_list[m]}")

                self.lambda_dict[m].append(lambda_t_list[m])
                self.FQE_loss_con[m].append(np.mean(loss_list_fqe_con[m]))
                self.FQE_est_con_costs[m].append(np.mean(fqe_est_con[m]))

            self.FQI_loss.append(np.mean(loss_list_fqi))
            self.FQE_loss_obj.append(np.mean(loss_list_fqe_obj))

            self.FQI_est_values.append(np.mean(fqi_est_list))
            self.FQE_est_obj_costs.append(np.mean(fqe_est_obj))

        print("Complete Training!")

        return self.FQI_loss, self.FQE_loss_obj, self.FQE_loss_con, self.FQI_est_values, self.FQE_est_obj_costs, self.FQE_est_con_costs, self.lambda_dict
    
    def _get_model_state(self, agent):
        """
        Extract the model state from an agent.
        Returns a deep copy of both policy_net and target_net states.
        """
        return {
            'policy_net': copy.deepcopy(agent.policy_net.state_dict()),
            'target_net': copy.deepcopy(agent.target_net.state_dict())
        }
    
    def _save_models_to_disk(self):
        """
        Save the models to disk.
        This method can be customized based on your specific requirements.
        """       
        # Create directory for saved models if it doesn't exist
        os.makedirs('saved_models/fqe_obj', exist_ok = True)
        
        # Save objective FQE models
        for idx, model_data in enumerate(self.fqe_obj_models_history[-2000:]):
            torch.save(
                model_data['model_state'], 
                f'saved_models/fqe_obj/model_{model_data["update_num"]}.pt'
            )
        
        # Save constraint FQE models
        for con_idx in self.fqe_con_models_history.keys():
            os.makedirs(f'saved_models/fqe_con_{con_idx}', exist_ok = True)
            
            for idx, model_data in enumerate(self.fqe_con_models_history[con_idx][-2000:]):
                torch.save(
                    model_data['model_state'], 
                    f'saved_models/fqe_con_{con_idx}/model_{model_data["update_num"]}.pt'
                )
        
        print("Models saved to disk in 'saved_models/' directory")
    
    def load_fqe_model(self, agent, model_path):
        """
        Load a saved FQE model into an agent.
        
        Args:
            agent: The FQE agent to load the model into
            model_path: Path to the saved model state
        
        Returns:
            The agent with loaded model
        """
        model_state = torch.load(model_path)
        agent.policy_net.load_state_dict(model_state['policy_net'])
        agent.target_net.load_state_dict(model_state['target_net'])
        return agent

### Load Data

In [18]:
id_table_train = pd.read_csv('../data_output/id_table_train_v13.csv')
rl_table_train = pd.read_csv('../data_output/rl_table_train_v13.csv')
rl_table_train_scaled = pd.read_csv('../data_output/rl_table_train_scaled_v13.csv')

In [19]:
id_table_val = pd.read_csv('../data_output/id_table_val_v13.csv')
rl_table_val = pd.read_csv('../data_output/rl_table_val_v13.csv')
rl_table_val_scaled = pd.read_csv('../data_output/rl_table_val_scaled_v13.csv')

In [20]:
# rl_table_train_scaled.columns

In [21]:
# rl_table_train.drop(columns = ['readmission_count'], inplace = True)
# rl_table_train_scaled.drop(columns = ['readmission_count'], inplace = True)
# rl_table_val.drop(columns = ['readmission_count'], inplace = True)
# rl_table_val_scaled.drop(columns = ['readmission_count'], inplace = True)

In [22]:
# rl_table_train_scaled.columns

In [23]:
# Create a copy of the 'mortality_costs' column
id_table_train['mortality_costs_md'] = id_table_train['mortality_costs'].copy()

# Apply the conditions using vectorized operations
discharge_action_mask = id_table_train['discharge_action'] == 1
death_mask = id_table_train['death'] == 0
discharge_fail_mask = id_table_train['discharge_fail'] == 1

# Update 'mortality_costs_md' based on the conditions
id_table_train.loc[discharge_action_mask & death_mask, 'mortality_costs_md'] = 0
id_table_train.loc[discharge_action_mask & ~death_mask & discharge_fail_mask, 'mortality_costs_md'] = 0

In [24]:
# Create a copy of the 'mortality_costs' column
id_table_val['mortality_costs_md'] = id_table_val['mortality_costs'].copy()

# Apply the conditions using vectorized operations
discharge_action_mask = id_table_val['discharge_action'] == 1
death_mask = id_table_val['death'] == 0
discharge_fail_mask = id_table_val['discharge_fail'] == 1

# Update 'mortality_costs_md' based on the conditions
id_table_val.loc[discharge_action_mask & death_mask, 'mortality_costs_md'] = 0
id_table_val.loc[discharge_action_mask & ~death_mask & discharge_fail_mask, 'mortality_costs_md'] = 0

In [25]:
id_table_train['con_cost_0'] = id_table_train['discharge_fail_costs'].copy()
id_table_train['con_cost_1'] = id_table_train['los_costs_scaled'].copy()

In [26]:
id_table_val['con_cost_0'] = id_table_val['discharge_fail_costs'].copy()
id_table_val['con_cost_1'] = id_table_val['los_costs_scaled'].copy()

In [27]:
cfg = RLConfig(algo_name = 'OCRL_v2_trainset_104', train_eps = int(9e3), gamma = 1.0, lr_fqi = 2e-3, lr_fqe_obj = 5e-4, 
               constraint_num = 2, lr_fqe_con_list = [5e-4, 5e-4], lr_lambda_list = [3e-4, 1e-6], threshold_list = [0.14, 4.5])

In [28]:
cfg.device

device(type='cpu')

In [29]:
terminal_state = np.zeros(rl_table_train_scaled.shape[1])

In [30]:
cfg.batch_size

256

In [31]:
train_data_loader = DataLoader(cfg, id_table_train, rl_table_train, rl_table_train_scaled, terminal_state)

In [32]:
train_data_loader.data_buffer_train(num_constraint = 2)

In [33]:
id_table_val_initial = id_table_val[id_table_val['readmission_count'] == 0].copy()
id_table_val_initial = id_table_val_initial[id_table_val_initial['epoch'] == 1].copy()

mv_val_initial_index = id_table_val_initial.index
id_index_list_initial_val = mv_val_initial_index.tolist()

rl_table_val_initial = rl_table_val.loc[id_index_list_initial_val].copy()

rl_table_val_scaled_initial = rl_table_val_scaled.loc[id_index_list_initial_val].copy()

id_table_val_initial['con_cost_0'] = id_table_val_initial['discharge_fail_costs'].copy()
id_table_val_initial['con_cost_1'] = id_table_val_initial['los_costs_scaled'].copy()

val_data_loader = ValDataLoader(cfg, 
                                id_table_val_initial, 
                                rl_table_val_initial, 
                                rl_table_val_scaled_initial, 
                                id_table_val, 
                                rl_table_val_scaled, 
                                terminal_state)

val_data_loader.data_buffer_val(num_constraint = 2)

### Training OCRL Agent

In [34]:
torch.set_num_threads(16)
torch.set_num_interop_threads(16)

In [35]:
ocrl_training_s1 = RLTraining(cfg, state_dim = rl_table_train_scaled.shape[1], 
                              action_dim = 2, 
                              val_data_loader = val_data_loader.data_torch_loader_val, 
                              data_loader = train_data_loader.data_torch_loader_train)

In [36]:
agent_fqi_s1 = ocrl_training_s1.fqi_agent_config(seed = 1000)

### objective cost: mortality risk
agent_fqe_obj_s1 = ocrl_training_s1.fqe_agent_config(id_stop = 0, 
                                                     eval_agent = agent_fqi_s1, 
                                                     weight_decay = 0.0,
                                                     eval_target = 'obj', 
                                                     seed = 2000) 

### constraint cost 1: readmission risk
agent_fqe_con_rr_s1 = ocrl_training_s1.fqe_agent_config(id_stop = 0, 
                                                        eval_agent = agent_fqi_s1, 
                                                        weight_decay = 0.0,
                                                        eval_target = 0, 
                                                        seed = 3000) 

### constraint cost 2: length-of-stay
agent_fqe_con_los_s1 = ocrl_training_s1.fqe_agent_config(id_stop = 0, 
                                                         eval_agent = agent_fqi_s1, 
                                                         weight_decay = 0.0,
                                                         eval_target = 1, 
                                                         seed = 4000)

In [None]:
ocrl_training_s1.train(agent_fqi_s1, 
                       agent_fqe_obj_s1, [agent_fqe_con_rr_s1, agent_fqe_con_los_s1], 
                       constraint = True)

In [None]:
def save_ocrl_models_and_data(agent_fqi_s1,
                              agent_fqe_obj_s1,
                              agent_fqe_con_rr_s1,
                              agent_fqe_con_los_s1,
                              ocrl_training_s1,
                              approx_model = "linear",
                              version = "v1"):
    """
    Automatically generate the current date and version number, and save the model and the corresponding training data.
    
    Parameters:
    -------
    agent_fqi_s1 : trained FQI model object
    agent_fqe_obj_s1 : trained FQE model object (objective function)
    agent_fqe_con_rr_s1 : trained FQE model object (a constraint, e.g., readmission risk)
    agent_fqe_con_los_s1 : trained FQE model object (another constraint, e.g., length-of-stay)
    ocrl_training_s1 : training process object, including FQI_loss, FQI_est_values, etc.
    version : model version number string (can be customized, default is "v1")
    """
    
    # Automatically generate the current date
    date_str = datetime.datetime.now().strftime("%Y%m%d")
    
    # Folder prefix (can be modified as needed)
    model_prefix = "../model_output"
    data_prefix = "../model_output/data_output"
    
    # ========= Save models =========
    torch.save(agent_fqi_s1, f"{model_prefix}/ocrl_agent_s1_fqi_{date_str}_{version}.pth")
    torch.save(agent_fqe_obj_s1, f"{model_prefix}/ocrl_agent_s1_fqe_obj_{date_str}_{version}.pth")
    torch.save(agent_fqe_con_rr_s1, f"{model_prefix}/ocrl_agent_s1_fqe_con_rr_{date_str}_{version}.pth")
    torch.save(agent_fqe_con_los_s1, f"{model_prefix}/ocrl_agent_s1_fqe_con_los_{date_str}_{version}.pth")

    # ========= Save the results of the training process =========
    np.save(f"{data_prefix}/{approx_model}_fqi_loss_{date_str}_{version}.npy", np.array(ocrl_training_s1.FQI_loss))
    np.save(f"{data_prefix}/{approx_model}_fqi_est_value_{date_str}_{version}.npy", np.array(ocrl_training_s1.FQI_est_values))

    np.save(f"{data_prefix}/{approx_model}_fqe_obj_loss_{date_str}_{version}.npy", np.array(ocrl_training_s1.FQE_loss_obj))
    np.save(f"{data_prefix}/{approx_model}_fqe_con_rr_loss_{date_str}_{version}.npy", np.array(ocrl_training_s1.FQE_loss_con[0]))
    np.save(f"{data_prefix}/{approx_model}_fqe_con_los_loss_{date_str}_{version}.npy", np.array(ocrl_training_s1.FQE_loss_con[1]))

    np.save(f"{data_prefix}/{approx_model}_fqe_est_obj_{date_str}_{version}.npy", np.array(ocrl_training_s1.FQE_est_obj_costs))
    np.save(f"{data_prefix}/{approx_model}_fqe_est_con_rr_{date_str}_{version}.npy", np.array(ocrl_training_s1.FQE_est_con_costs[0]))
    np.save(f"{data_prefix}/{approx_model}_fqe_est_con_los_{date_str}_{version}.npy", np.array(ocrl_training_s1.FQE_est_con_costs[1]))

    np.save(f"{data_prefix}/{approx_model}_lambda_rr_{date_str}_{version}.npy", np.array(ocrl_training_s1.lambda_dict[0]))
    np.save(f"{data_prefix}/{approx_model}_lambda_los_{date_str}_{version}.npy", np.array(ocrl_training_s1.lambda_dict[1]))

    print(f"model and data saved, date: {date_str}, approximation model: {approx_model}, version: {version}.")

In [None]:
save_ocrl_models_and_data(
    agent_fqi_s1 = agent_fqi_s1,
    agent_fqe_obj_s1 = agent_fqe_obj_s1,
    agent_fqe_con_rr_s1 = agent_fqe_con_rr_s1,
    agent_fqe_con_los_s1 = agent_fqe_con_los_s1,
    ocrl_training_s1 = ocrl_training_s1,
    approx_model = "linear",
    version = "v0" 
)

- Visualization

In [None]:
plt.plot(ocrl_training_s1.FQI_loss)

plt.title("The Loss of Reinforcement Learning Agent")
plt.xlabel("Training Epoch")
plt.ylabel("Training Loss")

# plt.savefig("../Experiment Figure/ocrl_fqi_loss_20250331.png", 
#             dpi = 300, 
#             bbox_inches = 'tight')

plt.show()

In [None]:
plt.plot(ocrl_training_s1.FQI_est_values)

plt.title("The estimated Q value of FQI agent")
plt.xlabel("Training Epoch")
plt.ylabel("Average estimated Q value")
# plt.grid(True)

# plt.savefig("../Experiment Figure/ocrl_fqi_est_Q_20250331.png", 
#             dpi = 300,
#             bbox_inches = 'tight')


plt.show()

In [None]:
plt.plot(ocrl_training_s1.FQE_loss_obj)

plt.title("The Loss of FQE (Objective Cost) Agent")
plt.xlabel("Training Epoch")
plt.ylabel("Training Loss")

# plt.savefig("../Experiment Figure/ocrl_fqe_obj_loss_20250331.png", 
#             dpi = 300, 
#             bbox_inches='tight')

plt.show()

In [None]:
plt.plot(ocrl_training_s1.FQE_loss_con[0])

plt.title("The Loss of FQE (Constrained Costs) Agent")
plt.xlabel("Training Epoch")
plt.ylabel("Training Loss")

# plt.savefig("../Experiment Figure/ocrl_fqe_con_rdm_loss_20250331.png", 
#             dpi = 300, 
#             bbox_inches='tight')

plt.show()

In [None]:
plt.plot(ocrl_training_s1.FQE_loss_con[1])

plt.title("The Loss of FQE (Constrained Costs) Agent")
plt.xlabel("Training Epoch")
plt.ylabel("Training Loss")

# plt.savefig("../Experiment Figure/ocrl_fqe_con_los_loss_20250331.png", 
#             dpi = 300, 
#             bbox_inches='tight')

plt.show()

In [None]:
plt.plot(ocrl_training_s1.FQE_est_obj_costs, label = r'OCRL, $l = (0.14, 60)$')

last_epoch = len(ocrl_training_s1.FQE_est_obj_costs) - 1
last_value = ocrl_training_s1.FQE_est_obj_costs[-1]

plt.text(last_epoch + 1600, last_value, f'{last_value:.4f}', ha = 'right', va = 'center', fontsize = 12, color = 'red')

plt.axhline(y = last_value, color = 'r', linestyle = '--')

# plt.title("The Average Estimated Value of Objective Costs by FQE")
plt.xlabel("Training Epoch")
# plt.ylabel(r"Mortality Risk ($\%$)")
plt.ylabel(r"Mortality Risk")
plt.legend(prop = {'size': 8})
plt.grid(True)

# plt.savefig("../Experiment Figure/ocrl_fqe_obj_est_value_20250331.png", 
#             dpi = 300, 
#             bbox_inches = 'tight')

plt.show()

In [None]:
plt.plot(ocrl_training_s1.FQE_est_con_costs[0], label = r'OCRL, $l = (0.14, 60)$')

last_epoch = len(ocrl_training_s1.FQE_est_con_costs[0]) - 1
last_value = ocrl_training_s1.FQE_est_con_costs[0][-1]

# plt.text(last_epoch + 1800, last_value, f'{last_value:.2f}%', ha = 'right', va = 'center', fontsize = 12, color = 'red')
plt.text(last_epoch + 1600, last_value, f'{last_value:.4f}', ha = 'right', va = 'center', fontsize = 12, color = 'red')

plt.axhline(y = last_value, color = 'r', linestyle = '--')

# plt.title("The Average Estimated Value of Constrained Costs by FQE")
plt.xlabel("Training Epoch")
# plt.ylabel(r"Readmission Risk ($\%$)")
plt.ylabel(r"Readmission Risk")
plt.legend(prop = {'size': 8})
plt.grid(True)

# plt.savefig("../Experiment Figure/ocrl_fqe_con_disch_fail_est_value_20250331.png", 
#             dpi = 300, 
#             bbox_inches = 'tight')

plt.show()

In [None]:
plt.plot(np.array(ocrl_training_s1.FQE_est_con_costs[1])*12, label = r'OCRL, $l = (0.14, 60)$')

last_epoch = len(ocrl_training_s1.FQE_est_con_costs[1]) - 1
last_value = ocrl_training_s1.FQE_est_con_costs[1][-1] * 12

# plt.text(last_epoch + 1600, last_value, f'{last_value:.2f}', ha = 'right', va = 'center', fontsize = 12, color = 'red')

# plt.axhline(y = last_value, color = 'r', linestyle = '--')
plt.axhline(y = 60, color = 'g', linestyle = '-')


# plt.title("The Average Estimated Value of Constrained Costs by FQE")
plt.xlabel("Training Epoch")
plt.ylabel(r"Length-of-Stay (hrs)")
plt.legend(prop = {'size': 8})
plt.grid(True)

# plt.savefig("../Experiment Figure/ocrl_fqe_con_los_value_20250331.png", 
#             dpi = 300, 
#             bbox_inches = 'tight')

plt.show()

In [None]:
plt.plot(ocrl_training_s1.lambda_dict[0], label = r'OCRL, $l = (0.14, 60)$')

plt.title(r"Training Process of Lagrangian Multiplier $\lambda$")
plt.xlabel("Training Epoch")
ylabel = plt.ylabel(r"$\lambda$")
ylabel.set_rotation(0)

plt.legend(prop={'size': 8})
plt.grid(True)

# plt.savefig("../Experiment Figure/ocrl_lambda_disch_fail_20250331.png", 
#             dpi = 300, 
#             bbox_inches = 'tight')

plt.show()

In [None]:
plt.plot(ocrl_training_s1.lambda_dict[1], label = r'OCRL, $l = (0.14, 60)$')

plt.title(r"Training Process of Lagrangian Multiplier $\lambda$")
plt.xlabel("Training Epoch")
ylabel = plt.ylabel(r"$\lambda$")
ylabel.set_rotation(0)

plt.legend(prop={'size': 8})
plt.grid(True)

# plt.savefig("../Experiment Figure/ocrl_lambda_los_20250331.png", 
#             dpi = 300, 
#             bbox_inches = 'tight')

plt.show()

## Testing dataset

In [124]:
class TestDataLoader:
    def __init__(self, cfg, state_id_table, rl_cont_state_table, rl_cont_state_table_scaled, state_id_table_1, rl_cont_state_table_scaled_1, terminal_state):
        self.cfg = cfg
    
        # Load datasets
        self.state_df_id = state_id_table
        self.rl_cont_state_table = rl_cont_state_table
        self.rl_cont_state_table_scaled = rl_cont_state_table_scaled

        self.state_df_id_1 = state_id_table_1
        self.rl_cont_state_table_scaled_1 = rl_cont_state_table_scaled_1

        self.terminal_state = terminal_state

    def data_buffer_test(self, num_constraint = 2):
        self.test_memory = ReplayBuffer(self.cfg.memory_capacity)

        for i in range(len(self.state_df_id)):
            state = self.rl_cont_state_table_scaled.values[i]
            action = self.state_df_id['discharge_action'].values[i]
            
            if action == 1.0:
                if self.state_df_id['death'].values[i] == 1.0:
                    if self.state_df_id['discharge_fail'].values[i] == 1.0:
                        done = 0.0
                    else:
                        done = 1.0
                else:
                    if self.state_df_id['discharge_fail'].values[i] == 0.0:
                        done = 1.0
                    else:
                        done = 0.0
            else:
                done = 0.0
            
            obj_cost = self.state_df_id['mortality_costs_md'].values[i]
            con_cost = []
            
            for j in range(num_constraint):
                cost_col = f'con_cost_{j}'  
                if cost_col in self.state_df_id.columns:
                    con_cost.append(self.state_df_id[cost_col].values[i])
                else:
                    con_cost.append(0.0) 

            if done == 0.0:
                idx = self.state_df_id.index[i]
                next_state = self.rl_cont_state_table_scaled_1.loc[idx + 1].values
            else:
                next_state = self.terminal_state

            self.test_memory.push(state, action, obj_cost, con_cost, next_state, done)

    def data_torch_loader_test(self):
        state_batch, action_batch, obj_cost_batch, con_cost_batch, next_state_batch, done_batch = self.test_memory.extract()

        state_batch = torch.tensor(np.array(state_batch), device = self.cfg.device, dtype = torch.float)

        disch_batch = list(action_batch)
        action_batch = torch.tensor(np.array(action_batch), device = self.cfg.device, dtype = torch.long).unsqueeze(1)
        
        obj_cost_batch = torch.tensor(np.array(obj_cost_batch), device = self.cfg.device, dtype = torch.float)
        con_cost_batch = [torch.tensor(np.array(cost), device = self.cfg.device, dtype=torch.float) for cost in con_cost_batch]
        next_state_batch = torch.tensor(np.array(next_state_batch), device = self.cfg.device, dtype = torch.float)

        done_batch = torch.tensor(np.array(done_batch), device = self.cfg.device, dtype = torch.float)
        disch_batch = torch.tensor(np.array(disch_batch), device = self.cfg.device, dtype = torch.float)

        return state_batch, action_batch, obj_cost_batch, con_cost_batch, next_state_batch, done_batch, disch_batch

In [125]:
class TestConfig:
    def __init__(self, constraint_num):
        
        self.constraint_num = constraint_num

        self.memory_capacity = int(2e6)  # capacity of Replay Memory

        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # check GPU
        self.device = torch.device("cpu")

In [149]:
id_table_test = pd.read_csv('../data_output/id_table_test_v13.csv')
rl_table_test = pd.read_csv('../data_output/rl_table_test_v13.csv')
rl_table_test_scaled = pd.read_csv('../data_output/rl_table_test_scaled_v13.csv')

In [150]:
# rl_table_test.drop(columns = ['readmission_count'], inplace = True)
# rl_table_test_scaled.drop(columns = ['readmission_count'], inplace = True)

In [151]:
# Create a copy of the 'mortality_costs' column
id_table_test['mortality_costs_md'] = id_table_test['mortality_costs'].copy()

# Apply the conditions using vectorized operations
discharge_action_mask = id_table_test['discharge_action'] == 1
death_mask = id_table_test['death'] == 0
discharge_fail_mask = id_table_test['discharge_fail'] == 1

# Update 'mortality_costs_md' based on the conditions
id_table_test.loc[discharge_action_mask & death_mask, 'mortality_costs_md'] = 0
id_table_test.loc[discharge_action_mask & ~death_mask & discharge_fail_mask, 'mortality_costs_md'] = 0

In [152]:
# id_table_test_initial = id_table_test.copy()
id_table_test_initial = id_table_test[id_table_test['readmission_count'] == 0].copy()
# id_table_test_initial = id_table_test[(id_table_test['readmission_count'] == 4) | (id_table_test['readmission_count'] == 5)].copy()
# id_table_test_initial = id_table_test_initial[id_table_test_initial['epoch'] == 1].copy()
# id_table_test_initial = id_table_test[id_table_test['discharge_action'] == 1].copy()
# id_table_test_initial = id_table_test_initial[id_table_test_initial['discharge_fail'] == 1].copy()

mv_test_initial_index = id_table_test_initial.index
id_index_list_initial_test = mv_test_initial_index.tolist()

rl_table_test_initial = rl_table_test.loc[id_index_list_initial_test].copy()

rl_table_test_scaled_initial = rl_table_test_scaled.loc[id_index_list_initial_test].copy()

id_table_test_initial['con_cost_0'] = id_table_test_initial['discharge_fail_costs'].copy()
id_table_test_initial['con_cost_1'] = id_table_test_initial['los_costs_scaled'].copy()

test_cfg = TestConfig(constraint_num = 2)
terminal_state = np.zeros(rl_table_test_scaled.shape[1])
test_data_loader = TestDataLoader(test_cfg, 
                                  id_table_test_initial, 
                                  rl_table_test_initial, 
                                  rl_table_test_scaled_initial, 
                                  id_table_test, 
                                  rl_table_test_scaled, 
                                  terminal_state)

test_data_loader.data_buffer_test(num_constraint = 2)
state_batch, action_batch, obj_cost_batch, con_cost_batch, next_state_batch, done_batch, disch_batch = test_data_loader.data_torch_loader_test()

In [153]:
agent_fqi_s1 = torch.load('../model_output/ocrl_agent_s1_fqi_20250331_v0.pth', weights_only = False)

agent_fqe_obj_s1 = torch.load('../model_output/ocrl_agent_s1_fqe_obj_20250331_v0.pth', weights_only = False)
agent_fqe_con_rr_s1 = torch.load('../model_output/ocrl_agent_s1_fqe_con_rr_20250331_v0.pth', weights_only = False)
agent_fqe_con_los_s1 = torch.load('../model_output/ocrl_agent_s1_fqe_con_los_20250331_v0.pth', weights_only = False)

In [154]:
ocrl_policy_set_1 = agent_fqi_s1.rl_policy(state_batch)
matches_1 = sum(1 for a, b in zip(ocrl_policy_set_1, action_batch) if a == b)
percentage_1 = matches_1/len(ocrl_policy_set_1)

In [155]:
percentage_1

0.598563022973892

In [156]:
action_batch

tensor([[0],
        [0],
        [1],
        ...,
        [0],
        [0],
        [1]])

In [157]:
ocrl_policy_set_1

tensor([[1],
        [1],
        [1],
        ...,
        [1],
        [1],
        [1]])

In [158]:
zeros = (action_batch == 0).sum().item()
ones = (action_batch == 1).sum().item()
total = action_batch.numel()  

zero_ratio = zeros / total
one_ratio = ones / total

print(f"0的数量: {zeros}, 比例: {zero_ratio:.4f} ({zeros}/{total})")
print(f"1的数量: {ones}, 比例: {one_ratio:.4f} ({ones}/{total})")

0的数量: 27825, 比例: 0.8227 (27825/33821)
1的数量: 5996, 比例: 0.1773 (5996/33821)


In [159]:
zeros = (ocrl_policy_set_1 == 0).sum().item()
ones = (ocrl_policy_set_1 == 1).sum().item()
total = ocrl_policy_set_1.numel()  

zero_ratio = zeros / total
one_ratio = ones / total

print(f"0的数量: {zeros}, 比例: {zero_ratio:.3f} ({zeros}/{total})")
print(f"1的数量: {ones}, 比例: {one_ratio:.3f} ({ones}/{total})")

0的数量: 18644, 比例: 0.551 (18644/33821)
1的数量: 15177, 比例: 0.449 (15177/33821)


In [144]:
fqe_test_obj_costs_ocrl_1, fqe_test_obj_costs_ocrl_1_ub = agent_fqe_obj_s1.avg_Q_value_est(state_batch)
fqe_test_con_disch_costs_ocrl_1, fqe_test_con_disch_costs_ocrl_1_ub = agent_fqe_con_rr_s1.avg_Q_value_est(state_batch)
fqe_test_con_los_costs_ocrl_1, fqe_test_con_los_costs_ocrl_1_ub = agent_fqe_con_los_s1.avg_Q_value_est(state_batch)

In [145]:
fqe_test_obj_costs_ocrl_1

0.06513223052024841

In [146]:
fqe_test_con_disch_costs_ocrl_1

0.1295393407344818

In [147]:
fqe_test_con_los_costs_ocrl_1

4.800172328948975

In [148]:
fqe_test_con_los_costs_ocrl_1 * 12

57.602067947387695

In [None]:
agent_fqi_test = ocrl_training_s1.fqi_agent_config(seed = 1)

agent_fqe_obj_test = ocrl_training_s1.fqe_agent_config(id_stop = 0, 
                                                       eval_agent = agent_fqi_test, 
                                                       weight_decay = 0.0,
                                                       eval_target = 'obj', 
                                                       seed = 1) 

agent_fqe_con_rr_test = ocrl_training_s1.fqe_agent_config(id_stop = 0, 
                                                          eval_agent = agent_fqi_test, 
                                                          weight_decay = 0.0,
                                                          eval_target = 0, 
                                                          seed = 2) 

agent_fqe_con_los_test = ocrl_training_s1.fqe_agent_config(id_stop = 0, 
                                                           eval_agent = agent_fqi_test, 
                                                           weight_decay = 0.0,
                                                           eval_target = 1, 
                                                           seed = 3) 

cum_fqe_obj = 0

for i in range(len(ocrl_training_s1.fqe_obj_models_history)):
    agent_fqi_test.policy_net.load_state_dict(ocrl_training_s1.fqi_models_history[i]['model_state']['policy_net'])
    agent_fqe_obj_test.policy_net.load_state_dict(ocrl_training_s1.fqe_obj_models_history[i]['model_state']['policy_net'])
    fqe_est_obj_test, fqe_est_obj_test_ub = agent_fqe_obj_test.avg_Q_value_est(state_batch)
    print(fqe_est_obj_test)
    cum_fqe_obj += fqe_est_obj_test

cum_fqe_con_rr = 0

for i in range(len(ocrl_training_s1.fqe_obj_models_history)):
    agent_fqi_test.policy_net.load_state_dict(ocrl_training_s1.fqi_models_history[i]['model_state']['policy_net'])
    agent_fqe_con_rr_test.policy_net.load_state_dict(ocrl_training_s1.fqe_con_models_history[0][i]['model_state']['policy_net'])
    fqe_est_con_rr_test, fqe_est_con_rr_test_ub = agent_fqe_con_rr_test.avg_Q_value_est(state_batch)
    print(fqe_est_con_rr_test)
    cum_fqe_con_rr += fqe_est_con_rr_test

cum_fqe_con_los = 0

for i in range(len(ocrl_training_s1.fqe_obj_models_history)):
    agent_fqi_test.policy_net.load_state_dict(ocrl_training_s1.fqi_models_history[i]['model_state']['policy_net'])
    agent_fqe_con_los_test.policy_net.load_state_dict(ocrl_training_s1.fqe_con_models_history[1][i]['model_state']['policy_net'])
    fqe_est_con_los_test, fqe_est_con_los_test_ub = agent_fqe_con_los_test.avg_Q_value_est(state_batch)
    print(fqe_est_con_los_test)
    cum_fqe_con_los += fqe_est_con_los_test

In [None]:
L = len(ocrl_training_s1.fqe_obj_models_history)
print(L)

In [None]:
[cum_fqe_obj/L, 
 cum_fqe_con_rr/L,
 (cum_fqe_con_los/L)*12]

## Shap Values Analysis

In [None]:
state_batch, action_batch, obj_cost_batch, con_cost_batch, next_state_batch, done_batch = train_data_loader.train_memory.extract()
state_batch = np.array(state_batch)

In [None]:
# agent_fqi_s1 = torch.load('../model_output/ocrl_safe_agent_s1_fqi_20250321.pth', weights_only = False)

In [None]:
idx = np.random.choice(state_batch.shape[0], 30000, replace = False)
background_data = state_batch[idx, :]

In [None]:
model_1 = agent_fqi_s1.policy_net
model_1.eval()

weights = model_1.fc1.weight.data.cpu().numpy()
intercept = model_1.fc1.bias.data.cpu().numpy()

e = shap.LinearExplainer((weights, intercept), background_data)

In [None]:
rl_table_train_scaled.columns

In [None]:
weights

In [None]:
selected = np.zeros(state_batch.shape[0], dtype = bool)
selected[idx] = True

X_data = state_batch[~selected]

In [None]:
X_data

In [None]:
shap_values = e.shap_values(X_data)

shap.summary_plot(shap_values, X_data, 
                  feature_names = rl_table_train_scaled.columns, 
                  plot_size = (8, 6), 
                  plot_type = 'bar',
                  show = False)

# plt.savefig("../Experiment Figure/shap_summ_plot_20250415.png", 
#             dpi = 400, 
#             bbox_inches = 'tight')

plt.show()

In [None]:
shap_values

In [None]:
if isinstance(shap_values, list):
    num_outputs = len(shap_values)
    print(f"Found {num_outputs} outputs in the SHAP values")
    # We'll focus on the first output for visualization
    selected_output = 0  # Change this to visualize different outputs
    shap_values_selected = shap_values[selected_output]
else:
    shap_values_selected = shap_values

In [None]:
shap_values_selected

In [None]:
mean_abs_shap = np.abs(shap_values_selected).mean(axis=0)
top_features_idx = np.argsort(mean_abs_shap)[-10:]
top_features = [rl_table_train_scaled.columns[i] for i in top_features_idx]

# Calculate mean positive and negative contributions for top features
pos_contrib = np.zeros(len(top_features_idx))
neg_contrib = np.zeros(len(top_features_idx))

for i, idx in enumerate(top_features_idx):
    pos_values = shap_values_selected[:, idx].copy()
    pos_values[pos_values < 0] = 0
    pos_contrib[i] = pos_values.mean()
    
    neg_values = shap_values_selected[:, idx].copy()
    neg_values[neg_values > 0] = 0
    neg_contrib[i] = abs(neg_values.mean())

# Create a DataFrame for visualization
contrib_df = pd.DataFrame({
    'Feature': top_features,
    'Positive Impact': pos_contrib,
    'Negative Impact': neg_contrib
})

In [None]:
contrib_df

In [None]:
plt.figure(figsize=(12, 8))
contrib_df = contrib_df.sort_values('Positive Impact', ascending=True)
contrib_df.plot(
    x='Feature',
    y=['Positive Impact', 'Negative Impact'],
    kind='barh',
    color=['green', 'red'],
    figsize=(12, 8)
)
plt.title('Positive vs Negative Feature Impact', fontsize=16)
plt.xlabel('Mean Absolute SHAP Value', fontsize=14)
plt.ylabel('Features', fontsize=14)
plt.grid(axis='x', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig("../Experiment Figure/shap_pos_neg_impact_20250415.png", dpi=400, bbox_inches='tight')
plt.show()

In [None]:
# Function to create statistics for one output
def create_output_stats(shap_values_output, output_idx):
    # 2.1 Table of feature importance statistics
    feature_importance = pd.DataFrame({
        'Feature': rl_table_train_scaled.columns,
        'Mean Absolute SHAP': np.abs(shap_values_output).mean(axis=0),
        'Mean SHAP': shap_values_output.mean(axis=0),
        'Min SHAP': shap_values_output.min(axis=0),
        'Max SHAP': shap_values_output.max(axis=0),
        'Std SHAP': shap_values_output.std(axis=0)
    })

    feature_importance = feature_importance.sort_values('Mean Absolute SHAP', ascending=False)
    feature_importance.to_csv(f"../Experiment Figure/shap_feature_importance_output{output_idx}_20250415.csv", index=False)

    # 2.2 Table with percentage of positive and negative contributions
    feature_contribution = pd.DataFrame({
        'Feature': rl_table_train_scaled.columns,
        'Positive Contribution (%)': [(shap_values_output[:, i] > 0).mean() * 100 for i in range(shap_values_output.shape[1])],
        'Negative Contribution (%)': [(shap_values_output[:, i] < 0).mean() * 100 for i in range(shap_values_output.shape[1])],
        'Mean Positive SHAP': [shap_values_output[:, i][shap_values_output[:, i] > 0].mean() if any(shap_values_output[:, i] > 0) else 0 
                              for i in range(shap_values_output.shape[1])],
        'Mean Negative SHAP': [shap_values_output[:, i][shap_values_output[:, i] < 0].mean() if any(shap_values_output[:, i] < 0) else 0 
                              for i in range(shap_values_output.shape[1])]
    })

    feature_contribution = feature_contribution.sort_values('Mean Positive SHAP', ascending=False)
    feature_contribution.to_csv(f"../Experiment Figure/shap_feature_contribution_output{output_idx}_20250415.csv", index=False)
    
    return feature_importance, feature_contribution

In [None]:
# If we have multiple outputs, we can analyze each one
if isinstance(shap_values, list):
    # Create statistics for each output
    for i, shap_vals in enumerate(shap_values):
        print(f"Creating statistics for output {i}")
        create_output_stats(shap_vals, i)
        
    # Additionally, create a summary table that compares top features across outputs
    feature_importance_all = pd.DataFrame({'Feature': rl_table_train_scaled.columns})
    for i, shap_vals in enumerate(shap_values):
        importance = np.abs(shap_vals).mean(axis=0)
        feature_importance_all[f'Importance_Output_{i}'] = importance
    
    # Add average importance across all outputs
    feature_importance_all['Average_Importance'] = feature_importance_all.iloc[:, 1:].mean(axis=1)
    feature_importance_all = feature_importance_all.sort_values('Average_Importance', ascending=False)
    feature_importance_all.to_csv("../Experiment Figure/shap_feature_importance_all_outputs_20250415.csv", index=False)
else:
    # Just one output
    create_output_stats(shap_values_selected, selected_output)

In [None]:
# 4. Create a correlation heatmap of SHAP values for the selected output
shap_corr = np.corrcoef(shap_values_selected.T)
plt.figure(figsize=(12, 10))
mask = np.triu(np.ones_like(shap_corr, dtype=bool))
sns.heatmap(
    shap_corr,
    mask=mask,
    cmap='coolwarm',
    annot=False,
    fmt='.2f',
    linewidths=0.5,
    xticklabels=rl_table_train_scaled.columns,
    yticklabels=rl_table_train_scaled.columns
)
plt.title(f'Correlation Between Feature SHAP Values - Output {selected_output}', fontsize=16)
plt.tight_layout()
plt.savefig(f"../Experiment Figure/shap_correlation_output{selected_output}_20250415.png", dpi=400, bbox_inches='tight')
plt.show()

In [None]:
if isinstance(shap_values, list) and len(shap_values) > 1:
    # Get top 10 features based on average importance across all outputs
    avg_importance = np.zeros(len(rl_table_train_scaled.columns))
    for shap_vals in shap_values:
        avg_importance += np.abs(shap_vals).mean(axis=0)
    avg_importance /= len(shap_values)
    
    top_features_overall_idx = np.argsort(avg_importance)[-10:]
    top_features_overall = [rl_table_train_scaled.columns[i] for i in top_features_overall_idx]
    
    # Create a comparison bar chart
    comparison_data = {
        'Feature': top_features_overall
    }
    
    for i, shap_vals in enumerate(shap_values):
        importance = np.abs(shap_vals).mean(axis=0)
        comparison_data[f'Output {i}'] = [importance[idx] for idx in top_features_overall_idx]
    
    comparison_df = pd.DataFrame(comparison_data)
    comparison_df = comparison_df.melt(id_vars=['Feature'], var_name='Output', value_name='Importance')
    
    plt.figure(figsize=(12, 8))
    sns.barplot(x='Importance', y='Feature', hue='Output', data=comparison_df)
    plt.title('Feature Importance Comparison Across Outputs', fontsize=16)
    plt.tight_layout()
    plt.savefig("../Experiment Figure/shap_importance_comparison_20250415.png", dpi=400, bbox_inches='tight')
    plt.show()
    
    # 5.2 Create a heatmap showing importance across outputs
    heatmap_data = np.zeros((len(top_features_overall), len(shap_values)))
    for i, shap_vals in enumerate(shap_values):
        importance = np.abs(shap_vals).mean(axis=0)
        for j, idx in enumerate(top_features_overall_idx):
            heatmap_data[j, i] = importance[idx]
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        heatmap_data,
        cmap='viridis',
        annot=True,
        fmt='.3f',
        xticklabels=[f'Output {i}' for i in range(len(shap_values))],
        yticklabels=top_features_overall
    )
    plt.title('Feature Importance Heatmap Across Outputs', fontsize=16)
    plt.tight_layout()
    plt.savefig("../Experiment Figure/shap_importance_heatmap_20250415.png", dpi=400, bbox_inches='tight')
    plt.show()

print("All SHAP visualizations and tables have been created successfully!")

In [None]:
# Set a modern style for better aesthetics
plt.style.use('seaborn-whitegrid')  # Use seaborn style for cleaner look
plt.rcParams['font.family'] = 'Arial'  # Set font to Arial
plt.rcParams['font.size'] = 10  # Increase base font size

if isinstance(shap_values, list):
    for output_idx in range(len(shap_values)):
        plt.figure(figsize = (12, 10))  # Larger figure size for better spacing
        shap_plot = shap.plots.beeswarm(
            shap.Explanation(
                values = shap_values[output_idx],
                base_values = e.expected_value[output_idx] if isinstance(e.expected_value, list) else e.expected_value,
                data = X_data,
                feature_names = rl_table_train_scaled.columns
            ),
            alpha = 0.5,  # Add transparency to points
            color = plt.get_cmap("coolwarm"),  # Use a more distinct colormap
            show = False,
            color_bar = False
        )
        
        # Customize plot elements
        plt.title(f"SHAP Value Impact - Action {output_idx}", fontsize = 12, pad=  20)
        plt.xlabel("SHAP Value (Impact on Model Output)", fontsize = 10)
        plt.ylabel("Features", fontsize = 10)
        
        # Add gridlines
        plt.grid(True, axis = 'x', linestyle = '--', alpha = 0.3, color = 'gray')
        
        # Adjust tick labels
        plt.xticks(fontsize = 10)
        plt.yticks(fontsize = 10)
        
        # Add a colorbar to show feature value mapping
        sm = plt.cm.ScalarMappable(cmap = "coolwarm", norm = plt.Normalize(vmin = 0, vmax = 1))
        sm.set_array([])  
        cbar = plt.colorbar(sm, fraction = 0.03, pad = 0.02, aspect = 30)
        cbar.set_label("Feature Value (Low to High)", fontsize = 10, rotation = 270, labelpad = 20)
        cbar.ax.tick_params(labelsize = 10)
        
        # Tight layout to prevent clipping
        plt.tight_layout()
        
        # Save as both PNG and PDF for flexibility
        plt.savefig(f"../Experiment Figure/shap_beeswarm_output{output_idx}_20250415_1.png", 
                    dpi = 400, 
                    bbox_inches = 'tight')
        # plt.savefig(f"../Experiment Figure/shap_beeswarm_output{output_idx}_20250415.pdf", bbox_inches='tight')
        plt.show()
else:
    plt.figure(figsize = (12, 10))
    shap_plot = shap.plots.beeswarm(
        shap.Explanation(
            values = shap_values,
            base_values = e.expected_value,
            data = X_data,
            feature_names = rl_table_train_scaled.columns
        ),
        alpha = 0.7,
        color = plt.get_cmap("coolwarm")
    )
    
    # Customize plot elements
    plt.title("SHAP Value Impact", fontsize = 16, pad = 20)
    plt.xlabel("SHAP Value (Impact on Model Output)", fontsize = 14)
    plt.ylabel("Features", fontsize = 14)
    
    # Add gridlines
    plt.grid(True, axis = 'x', linestyle = '--', alpha = 0.3, color = 'gray')
    
    # Adjust tick labels
    plt.xticks(fontsize = 12)
    plt.yticks(fontsize = 12)
    
    # Add a colorbar
    sm = plt.cm.ScalarMappable(cmap = "coolwarm", norm = plt.Normalize(vmin = 0, vmax = 1))
    plt.colorbar(sm, label = "Feature Value (Low to High)", fraction = 0.046, pad = 0.04)
    
    # Tight layout
    plt.tight_layout()
    
    # Save as both PNG and PDF
    plt.savefig("../Experiment Figure/shap_beeswarm_20250415.png", 
                dpi = 300, 
                bbox_inches = 'tight')
    # plt.savefig("../Experiment Figure/shap_beeswarm_20250415.pdf", bbox_inches='tight')
    plt.show()