In [None]:
library(tidyverse)
library(lubridate)
library(repr) # needed for figure size
options(repr.plot.width=6, repr.plot.height=6)

In [None]:
exp_name = 'ess_b32_h128_lr1e4_g99'
exp_grid = 1
Experiment = 'Extended State Space'
train_phy_WDR = 1.9362
train_phy_WIS = 1.9021
val_phy_WDR   = 2.3504
val_phy_WIS   = 1.8141
test_phy_WDR  = 4.6861
test_phy_WIS  = 1.832
trainfile = '../data/train_data_scaled_imputed.csv'
valfile = '../data/val_data_scaled_imputed.csv'
testfile = '../data/test_data_scaled_imputed.csv'
exp_data_path =  paste0('../models/', exp_name, '_' ,exp_grid, '/')
figuresdir = '../figures/'



In [None]:

########################
### MULTIPLOT
# SOURCE: http://www.cookbook-r.com/Graphs/Multiple_graphs_on_one_page_(ggplot2)/
multiplot <- function(..., plotlist=NULL, file, cols=1, layout=NULL) {
  library(grid)

  # Make a list from the ... arguments and plotlist
  plots <- c(list(...), plotlist)

  numPlots = length(plots)

  # If layout is NULL, then use 'cols' to determine layout
  if (is.null(layout)) {
    # Make the panel
    # ncol: Number of columns of plots
    # nrow: Number of rows needed, calculated from # of cols
    layout <- matrix(seq(1, cols * ceiling(numPlots/cols)),
                    ncol = cols, nrow = ceiling(numPlots/cols))
  }

 if (numPlots==1) {
    print(plots[[1]])

  } else {
    # Set up the page
    grid.newpage()
    pushViewport(viewport(layout = grid.layout(nrow(layout), ncol(layout))))

    # Make each plot, in the correct location
    for (i in 1:numPlots) {
      # Get the i,j matrix positions of the regions that contain this subplot
      matchidx <- as.data.frame(which(layout == i, arr.ind = TRUE))

      print(plots[[i]], vp = viewport(layout.pos.row = matchidx$row,
                                      layout.pos.col = matchidx$col))
    }
  }
}
print("done")

In [None]:
recquired_cols = c('icustay_id' ,'interval_start_time' ,
                    'interval_end_time',
                      'Discharge',
                      'discrete_action',
                      'Reward' ,
                      'discrete_action_original',
                      'row_id' ,
                      'row_id_next'
                      )

In [None]:
########################
### Get DATA

train_data <- read_csv(trainfile, , col_types=cols()) %>% select(recquired_cols)
val_data <- read_csv(valfile, , col_types=cols()) %>% select(recquired_cols)
test_data <- read_csv(testfile, , col_types=cols()) %>% select(recquired_cols)

modelcols = c('best_action', 'state_id')

 
train_opt <- read_csv(paste0(exp_data_path, 'DQN_Qvalues_traindata.csv'), col_types=cols()) %>% select(modelcols)
val_opt <- read_csv(paste0(exp_data_path, 'DQN_Qvalues_valdata.csv'), col_types=cols())  %>% select(modelcols)
test_opt <- read_csv(paste0(exp_data_path, 'DQN_Qvalues_testdata.csv'), col_types=cols()) %>% select(modelcols)

In [None]:
#action mapping

########################
# Action distribution
action_mappings <- expand.grid(discrete_IV = c(0, 1, 2, 3, 4), discrete_VP = c(0, 1, 2, 3, 4)) %>% mutate(real_discrete_action = 0:24)
model_actions     = c(0, 1, 2, 3, 4, NA, 5, 6, 7, 8, NA,  9, 10, 11, 12, NA, 13, 14, 15, 16, NA, 17, 18, 19, 20)
real_actions_ind =  c(0, 1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 21, 22, 23, 24)
action_mappings = action_mappings %>% mutate(model_action = model_actions[real_discrete_action+1])
DQN_action_mappings = action_mappings %>% rename(DQN_discrete_action = real_discrete_action)

# Apply mapping
train_data = train_data %>% mutate(real_discrete_action = real_actions_ind[discrete_action+1])
val_data = val_data %>% mutate(real_discrete_action = real_actions_ind[discrete_action+1])
test_data = test_data %>% mutate(real_discrete_action = real_actions_ind[discrete_action+1])
train_opt = train_opt %>% mutate(DQN_discrete_action = real_actions_ind[best_action+1])
val_opt = val_opt %>% mutate(DQN_discrete_action = real_actions_ind[best_action+1]) 
test_opt = test_opt %>% mutate(DQN_discrete_action = real_actions_ind[best_action+1]) 

In [None]:
########################
train = train_data %>% rename(state_id = row_id) %>% left_join(train_opt, by='state_id')  %>% 
        group_by(icustay_id) %>%
        mutate(sum_reward = sum(Reward)) %>%
        mutate(relative_time = difftime(interval_start_time, min(interval_start_time), units = 'hours')) %>% 
        mutate(discharge = case_when(sum(sum_reward) > 0 ~ 1, sum(sum_reward) < 0 ~ 0, TRUE ~ 9)) %>%
        ungroup() 

val =   val_data %>% rename(state_id = row_id) %>% left_join(val_opt, by='state_id') %>% 
        group_by(icustay_id) %>%
        mutate(sum_reward = sum(Reward)) %>%
        mutate(relative_time = difftime(interval_start_time, min(interval_start_time), units = 'hours')) %>% 
        mutate(discharge = case_when(sum(sum_reward) > 0 ~ 1, sum(sum_reward) < 0 ~ 0, TRUE ~ 9)) %>%
        ungroup() 

test =  test_data %>% rename(state_id = row_id) %>% left_join(test_opt, by='state_id')  %>% 
        group_by(icustay_id) %>%
        mutate(sum_reward = sum(Reward)) %>%
        mutate(relative_time = difftime(interval_start_time, min(interval_start_time), units = 'hours')) %>% 
        mutate(discharge = case_when(sum(sum_reward) > 0 ~ 1, sum(sum_reward) < 0 ~ 0, TRUE ~ 9)) %>%
        ungroup() 


In [None]:
head(train,3)

In [None]:
########################
# Action distribution
action_mappings <- expand.grid(discrete_IV = c(0, 1, 2, 3, 4), discrete_VP = c(0, 1, 2, 3, 4)) %>% mutate(real_discrete_action = 0:24)
model_actions     = c(0, 1, 2, 3, 4, NA, 5, 6, 7, 8, NA,  9, 10, 11, 12, NA, 13, 14, 15, 16, NA, 17, 18, 19, 20)
real_actions_ind =  c(0, 1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 21, 22, 23, 24)
action_mappings = action_mappings %>% mutate(model_action = model_actions[real_discrete_action+1])

PHY_mapping = action_mappings %>% select(-model_action)
names(PHY_mapping) <- c("PHY_IV","PHY_VP","PHY_action")

DQN_mapping = action_mappings %>% select(-model_action)
names(DQN_mapping) <- c("DQN_IV","DQN_VP","DQN_action")

In [None]:
# Apply mapping again with additional mutations
dose_train = train %>%  mutate(PHY_action = real_actions_ind[discrete_action+1]) %>% 
                        mutate(DQN_action = real_actions_ind[best_action+1]) %>% 
                        full_join(PHY_mapping, by = "PHY_action") %>% group_by(PHY_IV, PHY_VP, PHY_action) %>% 
                        full_join(DQN_mapping, by = "DQN_action") %>% group_by(DQN_IV, DQN_VP, DQN_action) %>%
                        mutate(VP_diff = DQN_VP-PHY_VP) %>% mutate(IV_diff = DQN_IV-PHY_IV)

dose_val = val %>%      mutate(PHY_action = real_actions_ind[discrete_action+1]) %>% 
                        mutate(DQN_action = real_actions_ind[best_action+1]) %>% 
                        full_join(PHY_mapping, by = "PHY_action") %>% group_by(PHY_IV, PHY_VP, PHY_action) %>% 
                        full_join(DQN_mapping, by = "DQN_action") %>% group_by(DQN_IV, DQN_VP, DQN_action) %>%
                        mutate(VP_diff = DQN_VP-PHY_VP) %>% mutate(IV_diff = DQN_IV-PHY_IV)

dose_test = test %>%    mutate(PHY_action = real_actions_ind[discrete_action+1]) %>% 
                        mutate(DQN_action = real_actions_ind[best_action+1]) %>% 
                        full_join(PHY_mapping, by = "PHY_action") %>% group_by(PHY_IV, PHY_VP, PHY_action) %>% 
                        full_join(DQN_mapping, by = "DQN_action") %>% group_by(DQN_IV, DQN_VP, DQN_action) %>%
                        mutate(VP_diff = DQN_VP-PHY_VP) %>% mutate(IV_diff = DQN_IV-PHY_IV)

In [None]:
head(dose_val,2)

In [None]:
########################
### preprocess data
dose_time_train = dose_train %>% mutate(PHY_VP_OnOFF_0 = case_when(PHY_VP == 0 ~ 1, PHY_VP != 0  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_0 = case_when(PHY_IV == 0 ~ 1, PHY_IV != 0  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_0 = case_when(DQN_VP == 0 ~ 1, DQN_VP != 0  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_0 = case_when(DQN_IV == 0 ~ 1, DQN_IV != 0  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_1 = case_when(PHY_VP == 1 ~ 1, PHY_VP != 1  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_1 = case_when(PHY_IV == 1 ~ 1, PHY_IV != 1  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_1 = case_when(DQN_VP == 1 ~ 1, DQN_VP != 1  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_1 = case_when(DQN_IV == 1 ~ 1, DQN_IV != 1  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_2 = case_when(PHY_VP == 2 ~ 1, PHY_VP != 2  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_2 = case_when(PHY_IV == 2 ~ 1, PHY_IV != 2  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_2 = case_when(DQN_VP == 2 ~ 1, DQN_VP != 2  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_2 = case_when(DQN_IV == 2 ~ 1, DQN_IV != 2  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_3 = case_when(PHY_VP == 3 ~ 1, PHY_VP != 3  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_3 = case_when(PHY_IV == 3 ~ 1, PHY_IV != 3  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_3 = case_when(DQN_VP == 3 ~ 1, DQN_VP != 3  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_3 = case_when(DQN_IV == 3 ~ 1, DQN_IV != 3  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_4 = case_when(PHY_VP == 4 ~ 1, PHY_VP != 4  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_4 = case_when(PHY_IV == 4 ~ 1, PHY_IV != 4  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_4 = case_when(DQN_VP == 4 ~ 1, DQN_VP != 4  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_4 = case_when(DQN_IV == 4 ~ 1, DQN_IV != 4  ~ 0)) 

dose_time_val = dose_val     %>% mutate(PHY_VP_OnOFF_0 = case_when(PHY_VP == 0 ~ 1, PHY_VP != 0  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_0 = case_when(PHY_IV == 0 ~ 1, PHY_IV != 0  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_0 = case_when(DQN_VP == 0 ~ 1, DQN_VP != 0  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_0 = case_when(DQN_IV == 0 ~ 1, DQN_IV != 0  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_1 = case_when(PHY_VP == 1 ~ 1, PHY_VP != 1  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_1 = case_when(PHY_IV == 1 ~ 1, PHY_IV != 1  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_1 = case_when(DQN_VP == 1 ~ 1, DQN_VP != 1  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_1 = case_when(DQN_IV == 1 ~ 1, DQN_IV != 1  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_2 = case_when(PHY_VP == 2 ~ 1, PHY_VP != 2  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_2 = case_when(PHY_IV == 2 ~ 1, PHY_IV != 2  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_2 = case_when(DQN_VP == 2 ~ 1, DQN_VP != 2  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_2 = case_when(DQN_IV == 2 ~ 1, DQN_IV != 2  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_3 = case_when(PHY_VP == 3 ~ 1, PHY_VP != 3  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_3 = case_when(PHY_IV == 3 ~ 1, PHY_IV != 3  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_3 = case_when(DQN_VP == 3 ~ 1, DQN_VP != 3  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_3 = case_when(DQN_IV == 3 ~ 1, DQN_IV != 3  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_4 = case_when(PHY_VP == 4 ~ 1, PHY_VP != 4  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_4 = case_when(PHY_IV == 4 ~ 1, PHY_IV != 4  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_4 = case_when(DQN_VP == 4 ~ 1, DQN_VP != 4  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_4 = case_when(DQN_IV == 4 ~ 1, DQN_IV != 4  ~ 0)) 

dose_time_test = dose_test   %>% mutate(PHY_VP_OnOFF_0 = case_when(PHY_VP == 0 ~ 1, PHY_VP != 0  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_0 = case_when(PHY_IV == 0 ~ 1, PHY_IV != 0  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_0 = case_when(DQN_VP == 0 ~ 1, DQN_VP != 0  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_0 = case_when(DQN_IV == 0 ~ 1, DQN_IV != 0  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_1 = case_when(PHY_VP == 1 ~ 1, PHY_VP != 1  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_1 = case_when(PHY_IV == 1 ~ 1, PHY_IV != 1  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_1 = case_when(DQN_VP == 1 ~ 1, DQN_VP != 1  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_1 = case_when(DQN_IV == 1 ~ 1, DQN_IV != 1  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_2 = case_when(PHY_VP == 2 ~ 1, PHY_VP != 2  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_2 = case_when(PHY_IV == 2 ~ 1, PHY_IV != 2  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_2 = case_when(DQN_VP == 2 ~ 1, DQN_VP != 2  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_2 = case_when(DQN_IV == 2 ~ 1, DQN_IV != 2  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_3 = case_when(PHY_VP == 3 ~ 1, PHY_VP != 3  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_3 = case_when(PHY_IV == 3 ~ 1, PHY_IV != 3  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_3 = case_when(DQN_VP == 3 ~ 1, DQN_VP != 3  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_3 = case_when(DQN_IV == 3 ~ 1, DQN_IV != 3  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_4 = case_when(PHY_VP == 4 ~ 1, PHY_VP != 4  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_4 = case_when(PHY_IV == 4 ~ 1, PHY_IV != 4  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_4 = case_when(DQN_VP == 4 ~ 1, DQN_VP != 4  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_4 = case_when(DQN_IV == 4 ~ 1, DQN_IV != 4  ~ 0)) 


In [None]:
head(dose_time_test,2)

In [None]:
########################
### SUMMARISE DATA for Physician
PHY_IV_time_train_melt = dose_time_train %>% group_by(relative_time) %>% 
                                           summarise(PHY_prop_IV_0 = mean(PHY_IV_OnOFF_0)*100,
                                                     PHY_prop_IV_1 = mean(PHY_IV_OnOFF_1)*100, 
                                                     PHY_prop_IV_2 = mean(PHY_IV_OnOFF_2)*100, 
                                                     PHY_prop_IV_3 = mean(PHY_IV_OnOFF_3)*100,
                                                     PHY_prop_IV_4 = mean(PHY_IV_OnOFF_4)*100,                                                                    
                                                  ) %>% ungroup() 

PHY_IV_time_val_melt = dose_time_val %>% group_by(relative_time) %>% 
                                           summarise(PHY_prop_IV_0 = mean(PHY_IV_OnOFF_0)*100,
                                                     PHY_prop_IV_1 = mean(PHY_IV_OnOFF_1)*100, 
                                                     PHY_prop_IV_2 = mean(PHY_IV_OnOFF_2)*100, 
                                                     PHY_prop_IV_3 = mean(PHY_IV_OnOFF_3)*100,
                                                     PHY_prop_IV_4 = mean(PHY_IV_OnOFF_4)*100                                                              
                                                  ) %>% ungroup() 

PHY_IV_time_test_melt = dose_time_test %>% group_by(relative_time) %>% 
                                           summarise(PHY_prop_IV_0 = mean(PHY_IV_OnOFF_0)*100,
                                                     PHY_prop_IV_1 = mean(PHY_IV_OnOFF_1)*100, 
                                                     PHY_prop_IV_2 = mean(PHY_IV_OnOFF_2)*100, 
                                                     PHY_prop_IV_3 = mean(PHY_IV_OnOFF_3)*100,
                                                     PHY_prop_IV_4 = mean(PHY_IV_OnOFF_4)*100                                                                  
                                                  ) %>% ungroup() 

PHY_VP_time_train_melt = dose_time_train %>% group_by(relative_time) %>% 
                                           summarise(PHY_prop_VP_0 = mean(PHY_VP_OnOFF_0)*100,
                                                     PHY_prop_VP_1 = mean(PHY_VP_OnOFF_1)*100, 
                                                     PHY_prop_VP_2 = mean(PHY_VP_OnOFF_2)*100, 
                                                     PHY_prop_VP_3 = mean(PHY_VP_OnOFF_3)*100,
                                                     PHY_prop_VP_4 = mean(PHY_VP_OnOFF_4)*100                                                                    
                                                  ) %>% ungroup() 

PHY_VP_time_val_melt = dose_time_val %>% group_by(relative_time) %>% 
                                           summarise(PHY_prop_VP_0 = mean(PHY_VP_OnOFF_0)*100,
                                                     PHY_prop_VP_1 = mean(PHY_VP_OnOFF_1)*100, 
                                                     PHY_prop_VP_2 = mean(PHY_VP_OnOFF_2)*100, 
                                                     PHY_prop_VP_3 = mean(PHY_VP_OnOFF_3)*100,
                                                     PHY_prop_VP_4 = mean(PHY_VP_OnOFF_4)*100                                                                     
                                                  ) %>% ungroup() 

PHY_VP_time_test_melt = dose_time_test %>% group_by(relative_time) %>% 
                                           summarise(PHY_prop_VP_0 = mean(PHY_VP_OnOFF_0)*100,
                                                     PHY_prop_VP_1 = mean(PHY_VP_OnOFF_1)*100, 
                                                     PHY_prop_VP_2 = mean(PHY_VP_OnOFF_2)*100, 
                                                     PHY_prop_VP_3 = mean(PHY_VP_OnOFF_3)*100,
                                                     PHY_prop_VP_4 = mean(PHY_VP_OnOFF_4)*100                                                                    
                                                  ) %>% ungroup()  


In [None]:
head(PHY_IV_time_val_melt)

In [None]:
########################
### SUMMARISE DATA for DQN
DQN_IV_time_train_melt = dose_time_train %>% group_by(relative_time) %>% 
                                           summarise(DQN_prop_IV_0 = mean(DQN_IV_OnOFF_0)*100,
                                                     DQN_prop_IV_1 = mean(DQN_IV_OnOFF_1)*100,
                                                     DQN_prop_IV_2 = mean(DQN_IV_OnOFF_2)*100,
                                                     DQN_prop_IV_3 = mean(DQN_IV_OnOFF_3)*100, 
                                                     DQN_prop_IV_4 = mean(DQN_IV_OnOFF_4)*100                                                                      
                                                  ) %>% ungroup() 

DQN_IV_time_val_melt = dose_time_val %>% group_by(relative_time) %>% 
                                           summarise(DQN_prop_IV_0 = mean(DQN_IV_OnOFF_0)*100,
                                                     DQN_prop_IV_1 = mean(DQN_IV_OnOFF_1)*100,
                                                     DQN_prop_IV_2 = mean(DQN_IV_OnOFF_2)*100,
                                                     DQN_prop_IV_3 = mean(DQN_IV_OnOFF_3)*100, 
                                                     DQN_prop_IV_4 = mean(DQN_IV_OnOFF_4)*100                                                                       
                                                  ) %>% ungroup() 

DQN_IV_time_test_melt = dose_time_test %>% group_by(relative_time) %>% 
                                           summarise(DQN_prop_IV_0 = mean(DQN_IV_OnOFF_0)*100,
                                                     DQN_prop_IV_1 = mean(DQN_IV_OnOFF_1)*100,
                                                     DQN_prop_IV_2 = mean(DQN_IV_OnOFF_2)*100,
                                                     DQN_prop_IV_3 = mean(DQN_IV_OnOFF_3)*100, 
                                                     DQN_prop_IV_4 = mean(DQN_IV_OnOFF_4)*100                                                                       
                                                  ) %>% ungroup() 

DQN_VP_time_train_melt = dose_time_train %>% group_by(relative_time) %>% 
                                           summarise(DQN_prop_VP_0 = mean(DQN_VP_OnOFF_0)*100,
                                                     DQN_prop_VP_1 = mean(DQN_VP_OnOFF_1)*100,
                                                     DQN_prop_VP_2 = mean(DQN_VP_OnOFF_2)*100,
                                                     DQN_prop_VP_3 = mean(DQN_VP_OnOFF_3)*100, 
                                                     DQN_prop_VP_4 = mean(DQN_VP_OnOFF_4)*100                                                                       
                                                  ) %>% ungroup() 

DQN_VP_time_val_melt = dose_time_val %>% group_by(relative_time) %>% 
                                           summarise(DQN_prop_VP_0 = mean(DQN_VP_OnOFF_0)*100,
                                                     DQN_prop_VP_1 = mean(DQN_VP_OnOFF_1)*100,
                                                     DQN_prop_VP_2 = mean(DQN_VP_OnOFF_2)*100,
                                                     DQN_prop_VP_3 = mean(DQN_VP_OnOFF_3)*100, 
                                                     DQN_prop_VP_4 = mean(DQN_VP_OnOFF_4)*100                                                                      
                                                  ) %>% ungroup() 
DQN_VP_time_test_melt = dose_time_test %>% group_by(relative_time) %>% 
                                           summarise(DQN_prop_VP_0 = mean(DQN_VP_OnOFF_0)*100,
                                                     DQN_prop_VP_1 = mean(DQN_VP_OnOFF_1)*100,
                                                     DQN_prop_VP_2 = mean(DQN_VP_OnOFF_2)*100,
                                                     DQN_prop_VP_3 = mean(DQN_VP_OnOFF_3)*100, 
                                                     DQN_prop_VP_4 = mean(DQN_VP_OnOFF_4)*100                                                                       
                                                  ) %>% ungroup() 

print("Data preprocessing done")

In [None]:
########################
options(repr.plot.width=10, repr.plot.height=4)
action_max = 44000
head(DQN_VP_time_test_melt)

In [None]:
# action matrix Physicians
train_PHY_actionmatrix <- train %>% full_join(action_mappings, by = "real_discrete_action") %>% group_by(discrete_IV, discrete_VP, real_discrete_action) %>% summarise(action_count = n())
AM_TRAIN_PHY = ggplot(train_PHY_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + #, limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='Training - Physician action matrix')

#
# action matrix DQN
train_DQN_actionmatrix <- train %>% full_join(DQN_action_mappings, by = "DQN_discrete_action") %>% group_by(discrete_IV, discrete_VP, DQN_discrete_action) %>% summarise(action_count = n())
AM_TRAIN_DQN = ggplot(train_DQN_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "orange", name = 'Action Count') + #, limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='Training - Optimal policy action matrix')


# action matrix Physicians
val_PHY_actionmatrix <- val %>% full_join(action_mappings, by = "real_discrete_action") %>% group_by(discrete_IV, discrete_VP, real_discrete_action) %>% summarise(action_count = n())
AM_VAL_PHY = ggplot(val_PHY_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + #, limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='Validation - Physician action matrix')

# action matrix DQN
val_DQN_actionmatrix <- val %>% full_join(DQN_action_mappings, by = "DQN_discrete_action") %>% group_by(discrete_IV, discrete_VP, DQN_discrete_action) %>% summarise(action_count = n())
AM_VAL_DQN = ggplot(val_DQN_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "orange", name = 'Action Count') + #, limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='Validation - Optimal policy action matrix')


# action matrix Physicians
test_PHY_actionmatrix <- test %>% full_join(action_mappings, by = "real_discrete_action") %>% group_by(discrete_IV, discrete_VP, real_discrete_action) %>% summarise(action_count = n())
AM_TEST_PHY = ggplot(test_PHY_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + #, limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='Test - Physician action matrix')

# action matrix DQN
test_DQN_actionmatrix <- test %>% full_join(DQN_action_mappings, by = "DQN_discrete_action") %>% group_by(discrete_IV, discrete_VP, DQN_discrete_action) %>% summarise(action_count = n())
AM_TEST_DQN = ggplot(test_DQN_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "orange", name = 'Action Count') + # , limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='Test - Optimal policy action matrix')




### Calculate action frequency % ( better readable than numbers)

In [None]:
# Physician action % 
train_PHY_actionmatrix$action_frequency = round(train_PHY_actionmatrix$action_count/sum(train_PHY_actionmatrix$action_count) *100, 2)
val_PHY_actionmatrix$action_frequency = round(val_PHY_actionmatrix$action_count/sum(val_PHY_actionmatrix$action_count) *100, 2)
test_PHY_actionmatrix$action_frequency = round(test_PHY_actionmatrix$action_count/sum(test_PHY_actionmatrix$action_count) *100, 2)

#DQN action frequency % 
train_DQN_actionmatrix$action_frequency = round(train_DQN_actionmatrix$action_count/sum(train_DQN_actionmatrix$action_count) *100, 2)
val_DQN_actionmatrix$action_frequency = round(val_DQN_actionmatrix$action_count/sum(val_DQN_actionmatrix$action_count) *100, 2)
test_DQN_actionmatrix$action_frequency = round(test_DQN_actionmatrix$action_count/sum(test_DQN_actionmatrix$action_count) *100, 2)



In [None]:
### Plotting frequencies instead of counts
# action matrix Physicians
#train_PHY_actionmatrix <- train %>% full_join(action_mappings, by = "real_discrete_action") %>% group_by(discrete_IV, discrete_VP, real_discrete_action) %>% summarise(action_count = n())
AM_TRAIN_PHY_FREQ = ggplot(train_PHY_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_frequency)) + 
    geom_text(aes(label = round(action_frequency, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count %') + #, limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='Training - Physician action matrix')

#
# action matrix DQN
#train_DQN_actionmatrix <- train %>% full_join(DQN_action_mappings, by = "DQN_discrete_action") %>% group_by(discrete_IV, discrete_VP, DQN_discrete_action) %>% summarise(action_count = n())
AM_TRAIN_DQN_FREQ = ggplot(train_DQN_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_frequency)) + 
    geom_text(aes(label = round(action_frequency, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "orange", name = 'Action Count %') + #, limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='Training - Optimal policy action matrix')


# action matrix Physicians
#val_PHY_actionmatrix <- val %>% full_join(action_mappings, by = "real_discrete_action") %>% group_by(discrete_IV, discrete_VP, real_discrete_action) %>% summarise(action_count = n())
AM_VAL_PHY_FREQ = ggplot(val_PHY_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_frequency)) + 
    geom_text(aes(label = round(action_frequency, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count %') + #, limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='Validation - Physician action matrix')

# action matrix DQN
#val_DQN_actionmatrix <- val %>% full_join(DQN_action_mappings, by = "DQN_discrete_action") %>% group_by(discrete_IV, discrete_VP, DQN_discrete_action) %>% summarise(action_count = n())
AM_VAL_DQN_FREQ = ggplot(val_DQN_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_frequency)) + 
    geom_text(aes(label = round(action_frequency, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "orange", name = 'Action Count %') + #, limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='Validation - Optimal policy action matrix')


# action matrix Physicians
#test_PHY_actionmatrix <- test %>% full_join(action_mappings, by = "real_discrete_action") %>% group_by(discrete_IV, discrete_VP, real_discrete_action) %>% summarise(action_count = n())
AM_TEST_PHY_FREQ = ggplot(test_PHY_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_frequency)) + 
    geom_text(aes(label = round(action_frequency, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count %') + #, limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='Test - Physician action matrix')

# action matrix DQN
#test_DQN_actionmatrix <- test %>% full_join(DQN_action_mappings, by = "DQN_discrete_action") %>% group_by(discrete_IV, discrete_VP, DQN_discrete_action) %>% summarise(action_count = n())
AM_TEST_DQN_FREQ = ggplot(test_DQN_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_frequency)) + 
    geom_text(aes(label = round(action_frequency, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "orange", name = 'Action Count %') + # , limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='Test - Optimal policy action matrix')




### Filter actions 48 hour after onset of sepsis 

In [None]:
#################################################
### FILTERED

# action matrix Physicians
filt_train_PHY_actionmatrix <- train %>% filter(relative_time >= 24) %>% full_join(action_mappings, by = "real_discrete_action") %>% group_by(discrete_IV, discrete_VP, real_discrete_action) %>% summarise(action_count = n())
filt_AM_TRAIN_PHY = ggplot(filt_train_PHY_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='Training - Physician action matrix', subtitle = 'First 48H after sepsis onset')

# action matrix DQN
filt_train_DQN_actionmatrix <- train %>% filter(relative_time >= 24) %>% full_join(DQN_action_mappings, by = "DQN_discrete_action") %>% group_by(discrete_IV, discrete_VP, DQN_discrete_action) %>% summarise(action_count = n())
filt_AM_TRAIN_DQN = ggplot(filt_train_DQN_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "orange", name = 'Action Count') + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='Training - Optimal policy action matrix', subtitle = 'First 48H after sepsis onset')

# action matrix Physicians
filt_val_PHY_actionmatrix <- val %>% filter(relative_time >= 24) %>% full_join(action_mappings, by = "real_discrete_action") %>% group_by(discrete_IV, discrete_VP, real_discrete_action) %>% summarise(action_count = n())
filt_AM_VAL_PHY = ggplot(filt_val_PHY_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='Validation - Physician action matrix', subtitle = 'First 48H after sepsis onset')

# action matrix DQN
filt_val_DQN_actionmatrix <- val %>% filter(relative_time >= 24) %>% full_join(DQN_action_mappings, by = "DQN_discrete_action") %>% group_by(discrete_IV, discrete_VP, DQN_discrete_action) %>% summarise(action_count = n())
filt_AM_VAL_DQN = ggplot(filt_val_DQN_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "orange", name = 'Action Count') + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='Validation - Optimal policy action matrix', subtitle = 'First 48H after sepsis onset')

# action matrix Physicians
filt_test_PHY_actionmatrix <- test %>% filter(relative_time <= 60) %>% full_join(action_mappings, by = "real_discrete_action") %>% group_by(discrete_IV, discrete_VP, real_discrete_action) %>% summarise(action_count = n())
filt_AM_TEST_PHY = ggplot(filt_test_PHY_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label = 'Test - Physician action matrix', subtitle = 'First 48H of admission')

# action matrix DQN
filt_test_DQN_actionmatrix <- test %>% filter(relative_time <= 60) %>% full_join(DQN_action_mappings, by = "DQN_discrete_action") %>% group_by(discrete_IV, discrete_VP, DQN_discrete_action) %>% summarise(action_count = n())
filt_AM_TEST_DQN = ggplot(filt_test_DQN_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "orange", name = 'Action Count') + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label = 'Test - Optimal policy action matrix', subtitle = 'First 48H of admission')


### save plots

In [None]:
    ########################
    ### SAFE ACTION MATRIX PNG
    options(repr.plot.width=18, repr.plot.height=8)
    #png(filename = paste0(figuresdir, exp_name, '_', "Multiplot_ActionMatrix.png"), width = 18, height = 8, units = "in", res = 300, pointsize=6)
    suppressWarnings(multiplot(AM_TRAIN_PHY, AM_TRAIN_DQN, AM_VAL_PHY, AM_VAL_DQN, AM_TEST_PHY, AM_TEST_DQN, cols=3))
    #dev.off()

    ########################
    ### SAFE FILTERED ACTION MATRIX PNG
   # options(repr.plot.width=18, repr.plot.height=8)
   # png(filename = paste(figuresdir, exp_name, '_', "Multiplot_ActionMatrix_FILT.png"), width = 18, height = 8, units = "in", res = 400, pointsize=6)
   # suppressWarnings(multiplot(filt_AM_TRAIN_PHY, filt_AM_TRAIN_DQN, filt_AM_VAL_PHY, filt_AM_VAL_DQN, filt_AM_TEST_PHY, filt_AM_TEST_DQN, cols=3))
    #dev.off()
    
    ### Plot frequencies instead of counts 
    options(repr.plot.width=18, repr.plot.height=8)
    #png(filename = paste0(figuresdir, exp_name, '_', "Multiplot_ActionMatrix_FREQ.png"), width = 18, height = 8, units = "in", res = 300, pointsize=6)
    suppressWarnings(multiplot(AM_TRAIN_PHY_FREQ, AM_TRAIN_DQN_FREQ, AM_VAL_PHY_FREQ, AM_VAL_DQN_FREQ, AM_TEST_PHY_FREQ, AM_TEST_DQN_FREQ, cols=3))
    #dev.off()



   

In [None]:
### Plot frequencies instead of counts 
options(repr.plot.width=12, repr.plot.height=8)
png(filename = paste0(figuresdir, exp_name, '_', "Multiplot_ActionMatrix_percent.png"), width = 18, height = 8, units = "in", res = 300, pointsize=6)
suppressWarnings(multiplot(AM_VAL_PHY_FREQ, AM_VAL_DQN_FREQ, AM_TEST_PHY_FREQ, AM_TEST_DQN_FREQ, cols=2))
dev.off()


#### Plot Mortality 

In [None]:
train_opt <- read_csv(paste0(exp_data_path, 'DQN_Qvalues_traindata.csv'), col_types=cols()) %>% select(phy_action_Qvalue, best_action_Qvalue)
val_opt <- read_csv(paste0(exp_data_path, 'DQN_Qvalues_valdata.csv'), col_types=cols())  %>% select(phy_action_Qvalue, best_action_Qvalue)
test_opt <- read_csv(paste0(exp_data_path, 'DQN_Qvalues_testdata.csv'), col_types=cols()) %>% select(phy_action_Qvalue, best_action_Qvalue)

train_opt$Discharge = train_data$Discharge
train_opt$interval_start_time = train_data$interval_start_time
train_opt$interval_end_time = train_data$interval_end_time
train_opt$icustay_id = train_data$icustay_id

test_opt$Discharge = test_data$Discharge
test_opt$interval_start_time = test_data$interval_start_time
test_opt$interval_end_time = test_data$interval_end_time
test_opt$icustay_id = test_data$icustay_id

val_opt$Discharge = val_data$Discharge
val_opt$interval_start_time = val_data$interval_start_time
val_opt$interval_end_time = val_data$interval_end_time
val_opt$icustay_id = val_data$icustay_id



#### Mortality plot ( only for validation set )

In [None]:
temp <- val_opt %>%
        group_by(icustay_id) %>%
        mutate(relative_time = difftime(interval_start_time, min(interval_start_time) + hours(24), units = 'hours')) %>%
        ungroup() 


PHY_mortality <- temp %>%
        ggplot(aes(relative_time, phy_action_Qvalue, color = as.factor(Discharge))) +
        geom_smooth() + 
        xlab('Relative Hours') +
        ylab('Predicted Q for Physician Policy') + 
        scale_color_discrete('Died in Hospital', labels = c('No', 'Yes')) + 
        theme_bw() + 
        theme(legend.position = c(0.1, 0.2),
              legend.key = element_blank()) +
        scale_x_continuous(breaks = c(-24, -12, 0, 12, 24, 36, 48), limits = c(-24, 48))

optmal_mortality <- temp %>%
        ggplot(aes(relative_time, best_action_Qvalue, color = as.factor(Discharge))) +
        geom_smooth() + 
        xlab('Relative Hours') +
        ylab('Predicted Q for Optimal Policy') + 
        scale_color_discrete('Died in Hospital', labels = c('No', 'Yes')) + 
        theme_bw() + 
        theme(legend.position = c(0.1, 0.2),
              legend.key = element_blank()) +
        scale_x_continuous(breaks = c(-24, -12, 0, 12, 24, 36, 48), limits = c(-24, 48))
temp2 <- temp %>% gather(Q_type, value = Q, phy_action_Qvalue, best_action_Qvalue)

combine_plot <- temp2 %>%
        ggplot(aes(relative_time, Q, 
                   color = as.factor(Discharge),
                   shape = Q_type)) +
        geom_smooth() + 
        xlab('Relative Hours') +
        ylab('Predicted Q') + 
        scale_color_discrete('Died in Hospital', labels = c('No', 'Yes')) + 
        theme_bw() + 
        theme(legend.position = c(0.1, .2),
              legend.key = element_blank()) +
        scale_x_continuous(breaks = c(-24, -12, 0, 12, 24, 36, 48), limits = c(-24, 48)) #+
        #annotate('text', label = 'Optimal', x = 47, y = 10) + 
        #annotate('text', label = 'Optimal', x = 47, y = 10) + 
        #annotate('text', label = 'Physician', x = 47, y = 10) + 
        #annotate('text', label = 'Physician', x = 47, y = 9) #+ 
#         ylim(10, 15)
### Plot frequencies instead of counts 
    options(repr.plot.width=18, repr.plot.height=8)
    #png(filename = paste0(figuresdir, exp_name, '_', "Multiplot_Mortality_Validation.png"), width = 18, height = 8, units = "in", res = 300, pointsize=6)
    suppressWarnings(multiplot(PHY_mortality,  combine_plot, optmal_mortality, cols=2))
    #dev.off()

Q Calibrartion plotm

In [None]:
### TO DO: looks like cumulative lineplot (change aesthetics)

# set X min and max for plot axis
xlim_min = -15
xlim_max = 15

# set the granularity of the binning
floor_dec <- function(x, level=1) round(x - 5*10^(-level-1), level)

PHYQ_train_plot_df <- train_opt %>% 
                mutate(bin_Q = floor_dec(phy_action_Qvalue)) %>%
                group_by(bin_Q) %>%
                summarise(prop_dead = mean(((Discharge*-1)+1)),sd_dead = sd(Discharge)/sqrt(n()))
PHYQ_train_plot_df$count = data.frame(table(floor_dec(train_opt$phy_action_Qvalue)))$Freq


PHYQ_val_plot_df <- val_opt %>% 
                mutate(bin_Q = floor_dec(phy_action_Qvalue)) %>%
                group_by(bin_Q) %>%
                summarise(prop_dead = mean(((Discharge*-1)+1)),sd_dead = sd(Discharge)/sqrt(n()))
PHYQ_val_plot_df$count = data.frame(table(floor_dec(val_opt$phy_action_Qvalue)))$Freq


PHYQ_test_plot_df <- test_opt %>% 
                mutate(bin_Q = floor_dec(phy_action_Qvalue)) %>%
                group_by(bin_Q) %>%
                summarise(prop_dead = mean(((Discharge*-1)+1)),sd_dead = sd(Discharge)/sqrt(n()))
PHYQ_test_plot_df$count = data.frame(table(floor_dec(test_opt$phy_action_Qvalue)))$Freq

PHYQ_train_survival = ggplot(data=PHYQ_train_plot_df, aes(x = bin_Q, y = prop_dead*100)) + 
                geom_line(aes(y=prop_dead*100), alpha=0.3) +
                geom_smooth(span = 0.4,method = 'loess', formula = y ~ x, se = FALSE, level=0.5) +
                geom_bar(aes(y=count/60), stat="identity") + 
                scale_y_continuous(sec.axis = sec_axis(~.*60, name = "")) +
                xlim(-5,10) + theme_bw() +
                ggtitle(subtitle='MIMIC training dataset', label = 'Physician Q value - mortality calibration') + 
                ylab('% Patient survival') + xlab('Q value')

PHYQ_val_survival = ggplot(PHYQ_val_plot_df, aes(x = bin_Q, y = prop_dead*100)) + 
                geom_line(aes(y=prop_dead*100), alpha=0.3) +
                geom_smooth(span = 0.4,method = 'loess', formula = y ~ x, se = FALSE, level=0.5) +
                geom_bar(aes(y=count/30), stat="identity") + 
                scale_y_continuous(sec.axis = sec_axis(~.*30, name = "")) +
                xlim(-5,10) + theme_bw() +
                ggtitle(subtitle='MIMIC validation dataset', label = 'Physician Q value - mortality calibration') + 
                ylab('') + xlab('Q value')

PHYQ_test_survival = ggplot(PHYQ_test_plot_df, aes(x = bin_Q, y = prop_dead*100)) + 
                geom_line(aes(y=prop_dead*100), alpha=0.3) +
                geom_smooth(span = 0.4,method = 'loess', formula = y ~ x, se = FALSE, level=0.5) +
                geom_bar(aes(y=count/30), stat="identity") + 
                scale_y_continuous(sec.axis = sec_axis(~.*30, name = "Q Value count")) +
                xlim(-5,10) + theme_bw() +
                ggtitle(subtitle='MIMIC CareVue Test dataset', label = 'Physician Q value - mortality calibration') + 
                ylab('') + xlab('Q value')

options(repr.plot.width=15, repr.plot.height=5)
#png(filename = paste0(figuresdir, exp_name, '_', "Multiplot_Calibration_PHY_Qvalue.png"), width = 12, height = 5, units = "in", res = 400, pointsize=6)
suppressWarnings(multiplot(PHYQ_train_survival, PHYQ_val_survival, PHYQ_test_survival, cols=3))
#dev.off()

### Dose initiation

In [None]:
library('RColorBrewer')

In [None]:
########################
### create dataframes with % of actions above 0 (VP>0 and IV>0), except for test set, use IV>mode(IV)
NOdose_time_train = dose_train %>% mutate(PHY_VP_OnOFF = case_when(PHY_VP > 0 ~ 1, PHY_VP == 0  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF = case_when(PHY_IV > 0  ~ 1, PHY_IV == 0  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF = case_when(DQN_VP > 0 ~ 1, DQN_VP == 0  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF = case_when(DQN_IV > 0 ~ 1, DQN_IV == 0  ~ 0)) %>%
                                 group_by(relative_time) %>% summarise(  PHY_prop_VP = mean(PHY_VP_OnOFF)*100, 
                                                                         PHY_prop_IV = mean(PHY_IV_OnOFF)*100,
                                                                         DQN_prop_VP = mean(DQN_VP_OnOFF)*100, 
                                                                         DQN_prop_IV = mean(DQN_IV_OnOFF)*100,
                                                                      ) %>% ungroup() 

NOdose_time_val   = dose_val   %>% mutate(PHY_VP_OnOFF = case_when(PHY_VP > 0 ~ 1, PHY_VP == 0  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF = case_when(PHY_IV > 0  ~ 1, PHY_IV == 0  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF = case_when(DQN_VP > 0 ~ 1, DQN_VP == 0  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF = case_when(DQN_IV > 0 ~ 1, DQN_IV == 0  ~ 0)) %>%
                                 group_by(relative_time) %>% summarise(  PHY_prop_VP = mean(PHY_VP_OnOFF)*100, 
                                                                         PHY_prop_IV = mean(PHY_IV_OnOFF)*100,
                                                                         DQN_prop_VP = mean(DQN_VP_OnOFF)*100, 
                                                                         DQN_prop_IV = mean(DQN_IV_OnOFF)*100,
                                                                      ) %>% ungroup() 

NOdose_time_test  = dose_test  %>% mutate(PHY_VP_OnOFF = case_when(PHY_VP > 0 ~ 1, PHY_VP == 0  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF = case_when(PHY_IV >= 3  ~ 1, PHY_IV < 3  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF = case_when(DQN_VP > 0 ~ 1, DQN_VP == 0  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF = case_when(DQN_IV >= 3 ~ 1, DQN_IV < 3  ~ 0)) %>% 
                                 group_by(relative_time) %>% summarise(  PHY_prop_VP = mean(PHY_VP_OnOFF)*100, 
                                                                         PHY_prop_IV = mean(PHY_IV_OnOFF)*100,
                                                                         DQN_prop_VP = mean(DQN_VP_OnOFF)*100, 
                                                                         DQN_prop_IV = mean(DQN_IV_OnOFF)*100,
                                                                      ) %>% ungroup() 

### quick and dirty fix
NOdose_time_train = NOdose_time_train[complete.cases(NOdose_time_train), ] %>% gather(key,value, 2:5)
NOdose_time_val = NOdose_time_val[complete.cases(NOdose_time_val), ] %>% gather(key,value, 2:5)
NOdose_time_test = NOdose_time_test[complete.cases(NOdose_time_test), ] %>% gather(key,value, 2:5)

########################
### CREATE PLOTS
myColors <- brewer.pal(9,"Greens")
myColors2 <- brewer.pal(9,"Reds")
myColors3 = rbind(myColors[8],myColors2[8],myColors[3],myColors2[3])
names(myColors3) <- levels(as.factor(NOdose_time_train$key))
colScale <- scale_colour_manual(name = "Actions:",values = myColors3, labels = c('Optimal policy Fluids', 'Optimal policy Vasopressors', 'Physician Fluids ', 'Physician Vasopressors'))
dose_time_train_plot = NOdose_time_train %>% ggplot(aes(relative_time-24, value, color = as.factor(key))) +
                                            geom_smooth(method="loess", se=TRUE, span=0.4) + colScale + scale_x_continuous() +
                                            xlab('Relative Hours') +
                                            ylab('% patients on treatment') + ylim(0,100) +
                                            ggtitle(subtitle='MIMIC Training dataset', label='Treatment progression') + theme_bw() +
                                            scale_x_continuous(breaks = c(-24,-12, 0, 12, 24, 36, 48, 60, 72), limits = c(-24, 72))  

dose_time_val_plot = NOdose_time_val %>%     ggplot(aes(relative_time-24, value, color = as.factor(key))) + 
                                            geom_smooth(method="loess", se=TRUE, span=0.4) + colScale + scale_x_continuous() +
                                            xlab('Relative Hours') +
                                            ylab('') + ylim(0,100) +
                                            ggtitle(subtitle='MIMIC Validation dataset', label='Treatment progression') + theme_bw() +
                                            scale_x_continuous(breaks = c(-24,-12, 0, 12, 24, 36, 48, 60, 72), limits = c(-24, 72))  

colScale <- scale_colour_manual(name = "Actions:",values = myColors3, labels = c('Optimal policy Fluids', 'Optimal policy Vasopressors', 'Physician Fluids ', 'Physician Vasopressors'))
dose_time_test_plot = NOdose_time_test %>%   ggplot(aes(relative_time, value, color = as.factor(key))) +
                                            geom_smooth(method="loess", se=TRUE, span=0.4) + colScale + scale_x_continuous() +
                                            xlab('Relative Hours') +
                                            ylab('') + ylim(0,105) +
                                            ggtitle(subtitle='MIMIC test dataset', label='Treatment progression') + theme_bw() +
                                            scale_x_continuous(breaks = c(-24,-12, 0, 12, 24, 36, 48, 60, 72), limits = c(-24, 72))  


In [None]:
options(repr.plot.width=15, repr.plot.height=5)
#png(filename = paste0(figuresdir, exp_name, '_', "Multiplot_Calibration_PHY_Qvalue.png"), width = 12, height = 5, units = "in", res = 400, pointsize=6)
suppressWarnings(multiplot(dose_time_train_plot, dose_time_val_plot, dose_time_test_plot, cols=3))