In [1]:
import os
import sys

import pandas as pd
import matplotlib.pyplot as plt

In [2]:
sys.path.append('F:/time_step/OfflineRL_FactoredActions')
from RL_mimic_sepsis.utils.timestep_util import (
    action_space_global,
    timestep_list, 
    action_space_list,
    action_space_name_mapping,
    )
# Config area.
action_space = 'NormThreshold'
metrics_name = 'metrics_100multiple.csv'
# 'metrics_step1p438.csv'
clipping_method = 'PerStepClipping=1.438'
timestep = 1
save = False

In [3]:
def extract_param_from_yaml(path, key):
    with open(path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith('#'):
                continue
            if line.startswith(f"{key}:"):
                _, raw = line.split(':', 1)
                raw = raw.strip()
                
                try:
                    if '.' in raw or 'e' in raw.lower():
                        num = float(raw)
                        
                        if num.is_integer():
                            num = int(num)
                        return num
                    else:
                        return int(raw)
                except ValueError:
                    return raw  
    raise KeyError(f"{key} not found in {path}")

In [4]:
version_list = list(range(40))
seed_list = [0, 1, 2, 3, 4]
threshold_list  = [0.0, 0.01, 0.05, 0.1, 0.3, 0.5, 0.75, 0.9999]
metas = []
param_name = 'threshold'

for threshold in threshold_list:
    for seed in seed_list:
        folder_name = f'BCQ_as{action_space}_dt{timestep}h_grid'
        hparams_file = (f'F:/time_step/OfflineRL_FactoredActions/RL_mimic_sepsis'
                        f'/d_BCQ/logs/{folder_name}/dt{timestep}_threshold{threshold}seed{seed}/hparams.yaml')
        val = extract_param_from_yaml(hparams_file, param_name)
        meta = pd.Series({param_name: val}, name='value')
        meta.index.name = 'key'
        metas.append(meta)

In [5]:
metas[:5]

[key
 threshold    0
 Name: value, dtype: int64,
 key
 threshold    0
 Name: value, dtype: int64,
 key
 threshold    0
 Name: value, dtype: int64,
 key
 threshold    0
 Name: value, dtype: int64,
 key
 threshold    0
 Name: value, dtype: int64]

In [6]:
tau_list = [meta.loc['threshold'] for meta in metas]

dfs = []

for threshold in threshold_list:
    for seed in seed_list:
        metrics_file = (f'F:/time_step/OfflineRL_FactoredActions/RL_mimic_sepsis'
                        f'/d_BCQ/logs/{folder_name}/dt{timestep}_threshold{threshold}seed{seed}/{metrics_name}')
        df = pd.read_csv(metrics_file).iloc[:100]
        df['threshold'] = threshold 
        df['seed'] = seed
        dfs.append(df)

tau_list[:10]

[0, 0, 0, 0, 0, 0.01, 0.01, 0.01, 0.01, 0.01]

## Model selection

In [32]:
ESS_cutoff = 180

In [33]:
results = []
for ver, df in enumerate(dfs):
    df_tmp = df[df['val_ess'] >= ESS_cutoff]['val_wis']
    if len(df_tmp) > 0:
        results.append((ver, df.loc[df_tmp.idxmax()]))

In [28]:
len(dfs)

40

In [34]:
len(results)

29

In [35]:
results

[(0,
  iteration      5900.000000
  step           5899.000000
  val_wis          97.813367
  val_qvalues       0.710929
  val_ess         198.515172
  threshold         0.000000
  seed              0.000000
  Name: 58, dtype: float64),
 (1,
  iteration      6300.000000
  step           6299.000000
  val_wis          97.809673
  val_qvalues       0.788054
  val_ess         198.516628
  threshold         0.000000
  seed              1.000000
  Name: 62, dtype: float64),
 (2,
  iteration      8100.000000
  step           8099.000000
  val_wis          97.801577
  val_qvalues       0.875793
  val_ess         197.535351
  threshold         0.000000
  seed              2.000000
  Name: 80, dtype: float64),
 (3,
  iteration      9300.000000
  step           9299.000000
  val_wis          97.649250
  val_qvalues       1.116220
  val_ess         184.024695
  threshold         0.000000
  seed              3.000000
  Name: 92, dtype: float64),
 (5,
  iteration      200.000000
  step           19

In [10]:
val_scores = [row['val_wis'] for ver, row in results]
val_ess = [row['val_ess'] for ver, row in results]
hparams = [(ver, row['iteration']) for ver, row in results]
thresholds = [row['threshold'] for ver, row in results]
seeds = [row['seed'] for ver, row in results]

In [11]:
import numpy as np
best_idx = np.argmax(val_scores)
best_ver, best_iter = hparams[best_idx]
best_seed = seeds[best_idx]
best_threshold = thresholds[best_idx]
val_scores[best_idx], val_ess[best_idx], hparams[best_idx], best_seed, best_threshold

(97.33840335271378, 202.4322922262388, (0, 1500.0), 0.0, 0.0)

In [12]:
from types import SimpleNamespace
import torch
from torch.utils.data import DataLoader
from RL_mimic_sepsis.d_BCQ.src.data import EpisodicBuffer, SASRBuffer, remap_rewards
from RL_mimic_sepsis.d_BCQ.src.model import BCQ
from RL_mimic_sepsis.utils.timestep_util import get_state_dim, get_horizon

In [13]:
bcq_model = BCQ.load_from_checkpoint(checkpoint_path=rf'F:\time_step\OfflineRL_FactoredActions\RL_mimic_sepsis\d_BCQ\logs'
                                                     rf'/BCQ_as{action_space}_dt{timestep}h_grid/dt{timestep}_threshold{best_threshold}seed{int(best_seed)}'
                                                     rf'/checkpoints/step={int(best_iter):04}.ckpt', 
                                                     map_location=None)
bcq_model.eval()

KeyboardInterrupt: 

In [None]:
num_actions = 25
state_dim = get_state_dim(timestep, 'NormThreshold')
horizon = get_horizon(timestep)
data_folder_name = (rf'F:\time_step\OfflineRL_FactoredActions\RL_mimic_sepsis\data'
                    rf'\data_as{action_space}_dt{timestep}h\episodes+encoded_state+knn_pibs')

In [None]:
test_episodes = EpisodicBuffer(state_dim, num_actions, horizon)
test_episodes.load(f'{data_folder_name}/test_data.pt')
test_episodes.reward = remap_rewards(test_episodes.reward, SimpleNamespace(**{'R_immed': 0.0, 'R_death': 0.0, 'R_disch': 100.0}))

tmp_test_episodes_loader = DataLoader(test_episodes, batch_size=len(test_episodes), shuffle=False)
test_batch = next(iter(tmp_test_episodes_loader))

# Move test_batch tensors to the same device as the model
device = next(bcq_model.parameters()).device
test_batch = [x.to(device) if hasattr(x, 'to') else x for x in test_batch]

test_wis, test_ess = bcq_model.offline_evaluation(test_batch, eps=0.01)
test_wis, test_ess

Episodic Buffer loaded with 2757 episides.


(97.63008315880062, 249.17406311999247)