In [39]:
import pandas as pd
import pickle as pkl
import string
import numpy as np;
from sklearn.utils import shuffle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
# Note that if it not worthwhile to run the following code on GPU
# since even if we try to vectorize it as much as possible
# the objective function (computational heavy) still contains many loops

In [40]:
train_set = pd.read_csv('path/to/trainset')
test_set = pd.read_csv('path/to/testset')

In [41]:
train_actions = train_set['vaso_input'].values + 5 * train_set['iv_input'].values
test_actions = test_set['vaso_input'].values + 5 * test_set['iv_input'].values

In [42]:
# patient state characteristics, cirtical features for expert selection
attr_cols = ['age', 'elixhauser', 'SOFA', 'GCS', 'FiO2_1', 'BUN', 'Albumin']

In [43]:
# get ICU-stay block for each patient, say, row 1-4 are the records for patient 1, row 5-10 are for patient 2,
# then, fence_post returns [0, 4, 10 ...] so on and so forth, this is used for WDR estimate.
def get_fence_post(df):
    fence_posts = []
    bloc = df['bloc'].values
    for i, idx in enumerate(bloc):
        if idx == 0:
            fence_posts += [ i ]

    return np.array(fence_posts)

def get_traj_len(df):
    traj_len = []
    fence_posts = get_fence_post(df)
    for i in range(len(fence_posts)-1):
        traj_len += list(range(fence_posts[i+1] - fence_posts[i] ))
    
    traj_len += list(range(df.shape[0] - fence_posts[-1]))
    
    return np.array(traj_len) + 1

def softmax(x):
    e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
    return e_x / e_x.sum(axis=1, keepdims=True)

def actions_2_probs(ind, src_actions, num_actions=25, expert='kernel'):
    
    selected_actions = src_actions[ind]
    selected_actions[selected_actions == 0] = -1
    
    action_probs = np.zeros((selected_actions.shape[0], num_actions))
    
    if expert == 'kernel':
        actor_actions = selected_actions * np.isin(ind, train_survivors)
    else:
        actor_actions = selected_actions
    
    for i in range(actor_actions.shape[0]):
        actions = actor_actions[i]
        a, c = np.unique(actions[actions != 0], return_counts=True)
        a[a == -1] = 0
        action_probs[i, a] = c / np.sum(c)
    
    return action_probs

def restrict_actions(target, src, th):
    restricted_target = (target * (src > th))
    return restricted_target / np.sum(restricted_target, axis=1, keepdims=True)

def get_ir_diff(df):
    uids = np.unique(df['icustayid'])
    irs = np.zeros(df.shape[0])
    counter = 0
    for uid in uids:
        u_rewards = df[ df['icustayid'] == uid]['neg_mortality_logodds']
        irs[counter: counter + u_rewards.shape[0] - 1] = u_rewards[1:].values - u_rewards[:-1].values
        counter += u_rewards.shape[0]
    return irs

def get_expert_dist(pi_e, pi_k, pi_d):
    
    a_pi_e = np.argmax(pi_e, axis=1)
    a_pi_k = np.argmax(pi_k, axis=1)
    a_pi_d = np.argmax(pi_d, axis=1)
    
    s_k = np.where( a_pi_e == a_pi_k )[0]
    s_d = np.where( a_pi_e == a_pi_d )[0]
    
    ns = get_moe_unique_action(pi_e, pi_k, pi_d)
    
    print ( 's_k: ', s_k.shape[0] / pi_e.shape[0], 's_d:', s_d.shape[0] / pi_e.shape[0], 'ns: ', ns / pi_e.shape[0])

In [44]:
class MOE(nn.Module):
    
    def __init__(self, input_size, logits=2):
        
        super(MOE, self).__init__()
        
        self.logits = logits
        self.linear = nn.Linear(input_size, logits)

        if logits == 2:
            self.activation = nn.Softmax()
        elif logits == 1: 
            self.activation = nn.Sigmoid()

    def forward(self, x):
        wx = self.linear(x)
        return self.activation(wx)

In [45]:
def WDR(
    actions_sequence, rewards_sequence, fence_posts, gamma,
    pi_evaluation, pi_behavior, V = None, Q = None, num_of_actions = None ):

    num_of_trials = len( fence_posts )
    # get weight table
    whole_rho = Variable(torch.zeros((num_of_trials, 21)))
    
    for trial_i in range( num_of_trials ):
        
        rho = 1
        trial_rho = torch.zeros(21)
        trial_rho[0] = rho
        trial_rho = Variable(trial_rho) 
        
        if trial_i < num_of_trials - 1:
            steps_in_trial = fence_posts[ trial_i+1 ] -  fence_posts[ trial_i ]
        else:
            steps_in_trial = actions_sequence.shape[0] - fence_posts[-1]
        for t in range(
                fence_posts[ trial_i], fence_posts[ trial_i ] + steps_in_trial ):
            previous_rho = rho
            rho = rho * (pi_evaluation[ t, actions_sequence[t]] / \
                pi_behavior[ t, actions_sequence[t]])
            trial_aux = torch.zeros(21)
            trial_aux[t - fence_posts[ trial_i] + 1] = 1
            trial_aux = Variable(trial_aux)
            trial_rho = trial_rho + trial_aux * rho
        
        if steps_in_trial < 20:
            for t in range(fence_posts[ trial_i ] + steps_in_trial, fence_posts[ trial_i ] + 20):
                trial_aux = torch.zeros(21)
                trial_aux[t - fence_posts[ trial_i]+1] = 1
                trial_aux = Variable(trial_aux)
                trial_rho = trial_rho + trial_aux * rho
    
        whole_aux = torch.zeros((num_of_trials, 21))
        whole_aux[trial_i, :] = 1
        whole_rho = whole_rho + Variable(whole_aux) * trial_rho
        
    weight_table = whole_rho / torch.sum(whole_rho, dim = 0)
    
    estimator = 0
    # calculate the doubly robust estimator of the policy
    for trial_i in range(num_of_trials):
        discount = 1 / gamma
        if trial_i < num_of_trials - 1:
            steps_in_trial = fence_posts[ trial_i+1 ] -  fence_posts[ trial_i ]
        else:
            steps_in_trial = actions_sequence.shape[0] - fence_posts[-1]
        for t in range(
                fence_posts[ trial_i], fence_posts[ trial_i] + steps_in_trial ):
            previous_weight = weight_table[trial_i, t - fence_posts[ trial_i]]
            weight = weight_table[trial_i, t - fence_posts[ trial_i]+1]
            discount = discount * gamma
            r = rewards_sequence[ t ]
            Q_value = Q[ t, actions_sequence[t]]
            V_value = V[t]
            estimator = estimator + weight * discount * r - discount * ( weight * Q_value - previous_weight * V_value ) 
    
    return estimator

In [46]:
# WDR in numpy version
def np_WDR(
    actions_sequence, rewards_sequence, fence_posts, gamma,
    pi_evaluation, pi_behavior, V = None, Q = None, num_of_actions = None, return_weights=False):

    num_of_trials = len( fence_posts )
    # get weight table
    whole_rho = np.zeros((num_of_trials, 21))
    for trial_i in range( num_of_trials ):
        rho = 1
        trial_rho = np.zeros(21)
        trial_rho[0] = rho
        if trial_i < num_of_trials - 1:
            steps_in_trial = fence_posts[ trial_i+1 ] -  fence_posts[ trial_i ]
        else:
            steps_in_trial = len( actions_sequence) - fence_posts[-1]
        for t in range(
                fence_posts[ trial_i], fence_posts[ trial_i] + steps_in_trial ):
            previous_rho = rho
            rho *= pi_evaluation[ t, actions_sequence[t]] / \
                pi_behavior[ t, actions_sequence[ t]]
            trial_aux = np.zeros(21)
            trial_aux[t - fence_posts[ trial_i]+1] = 1
            trial_rho = trial_rho + trial_aux*rho
        
        if steps_in_trial < 20:
            for t in range(fence_posts[ trial_i] + steps_in_trial, fence_posts[trial_i] + 20):
                
                trial_aux = np.zeros(21)
                trial_aux[t - fence_posts[ trial_i]+1] = 1
                trial_rho = trial_rho + trial_aux*rho
    
        whole_aux = np.zeros((num_of_trials, 21))
        whole_aux[trial_i, :] = 1
        whole_rho += whole_aux*trial_rho
        
    weight_table = whole_rho/np.sum(whole_rho, axis = 0)
    
    estimator = 0
    #pa, pb = 0, 0
    # calculate the doubly robust estimator of the policy
    for trial_i in range(num_of_trials):
        discount = 1/gamma
        if trial_i < num_of_trials - 1:
            steps_in_trial = fence_posts[ trial_i+1 ] -  fence_posts[ trial_i ]
        else:
            steps_in_trial = len(actions_sequence) - fence_posts[-1]
        for t in range(
                fence_posts[ trial_i], fence_posts[ trial_i] + steps_in_trial ):
            previous_weight = weight_table[trial_i, t - fence_posts[ trial_i]]
            weight = weight_table[trial_i, t - fence_posts[ trial_i]+1]
            discount *= gamma
            r =  rewards_sequence[ t ]
            Q_value=  Q[ t, actions_sequence[ t ] ] 
            V_value =  V[t]
            estimator =  estimator + weight * discount * r - discount * ( weight * Q_value - previous_weight * V_value )
    
    if return_weights:
        return estimator, whole_rho
    else:
        return estimator

In [47]:
def evaluate(pi_e, phase='train', VQ='phy'):
    
    
    if phase == 'train':
        
        train_fence_posts = get_fence_post(train_df)
        
        if VQ == 'phy':
            V, Q = phy_train_V, phy_train_Q
        elif VQ == 'dqn':
            V, Q = dqn_train_V, dqn_train_Q
            
        wdr = np_WDR(train_actions, train_rewards, train_fence_posts, .99, pi_e, train_pi_b, V, Q)

    elif phase == 'test':
        
        test_fence_posts = get_fence_post(test_df)
        
        if VQ == 'phy':
            V, Q = phy_test_V, phy_test_Q
        elif VQ == 'dqn':
            V, Q = dqn_test_V, dqn_test_Q
        
        wdr = np_WDR(test_actions, test_rewards, test_fence_posts, .99, pi_e, test_pi_b, V, Q)
        
    return wdr

In [48]:
def objective(action_seq, rewards, fence_posts, pi_e, pi_b, V, Q):
    return -WDR(action_seq, rewards, fence_posts, .99, pi_e, pi_b, V, Q)

In [49]:
def do_eval_test(moe=None, pi_e_type='moe'):
    
    x = Variable(torch.FloatTensor(test_df[attr_cols + ['traj_len', 'dist']].values))

    action_prob_k = Variable(torch.FloatTensor(test_df.values[:,:25]))
    action_prob_d = Variable(torch.FloatTensor(test_df.values[:,25:50]))

    if pi_e_type == 'moe':
        probs = moe(x)
        # print ('expert dist:', np.unique(np.argmax(probs.data.numpy(), axis=1), return_counts=True))
        if moe.logits == 2:
            pi_e = torch.unsqueeze(probs[:,0], 1) * action_prob_k + torch.unsqueeze(probs[:,1], 1) * action_prob_d
        elif moe.logits == 1:
            pi_e = probs * action_prob_k + (1-probs) * action_prob_d
    elif pi_e_type == 'kernel':
        pi_e = action_prob_k
    else:
        pi_e = action_prob_d
    
    pi_e = pi_e.data.numpy()
    fence_posts = get_fence_post(test_df)
    
    return np_WDR(test_actions, test_rewards, fence_posts, .99, pi_e, test_pi_b, phy_test_V, phy_test_Q)

In [50]:
def get_moe_policies(moe, df):
    
    x = Variable(torch.FloatTensor(df[attr_cols + ['traj_len', 'dist']].values))
    action_prob_k = Variable(torch.FloatTensor(df.values[:,:25]))
    action_prob_d = Variable(torch.FloatTensor(df.values[:,25:50]))

    probs = moe(x)
    if moe.logits == 2:
        pi_e = torch.unsqueeze(probs[:,0], 1) * action_prob_k + torch.unsqueeze(probs[:,1], 1) * action_prob_d
    elif moe.logits == 1:
        pi_e = probs * action_prob_k + (1-probs) * action_prob_d
    
    return probs.data.numpy(), pi_e.data.numpy()

In [51]:
def pretrained(moe, pretrained_weight, pretrained_bias):
    moe.linear.weight.data = torch.FloatTensor(pretrained_weight)
    moe.linear.bias.data = torch.FloatTensor(pretrained_bias)

In [52]:
def train(df, moe, batch_size=128, lr=0.001, num_epoch=10):
    
    uids = np.unique(df['icustayid'].values)
    np.random.shuffle(uids)
    num_batch = uids.shape[0] // batch_size
    
    optimizer = torch.optim.Adam(moe.parameters(), lr=lr)
    prev_obj = 0
    prev_obj_train = 0
    stop_counter = 0
    
    for epoch in range(num_epoch):
        
        for batch_idx in range(num_batch):
            
            batch_uids = uids[batch_idx*batch_size: (batch_idx+1)*batch_size]
            batch_user = df[df['icustayid'].isin(batch_uids)]
            batch_user_idx = batch_user.index.values
            
            x = Variable(torch.FloatTensor(batch_user[attr_cols + ['traj_len', 'dist']].values))
            
            action_prob_k = Variable(torch.FloatTensor(batch_user.values[:,:25]))
            action_prob_d = Variable(torch.FloatTensor(batch_user.values[:,25:50]))
            
            probs = moe(x)
            if moe.logits == 2:
                pi_e = torch.unsqueeze(probs[:,0], 1) * action_prob_k + \
                    torch.unsqueeze(probs[:,1], 1) * action_prob_d
            elif moe.logits == 1:
                pi_e = probs * action_prob_k + (1-probs) * action_prob_d
            
            fence_posts = get_fence_post(batch_user)
            
            action_seq = train_actions[batch_user_idx]
            rewards = batch_user['reward'].values
            
            pi_b = Variable(torch.FloatTensor(train_pi_b[batch_user_idx]))

            
            Q = Variable(torch.FloatTensor(phy_train_Q[batch_user_idx]))
            V = torch.max(Q, dim=1)[0]
            
            obj = objective(action_seq, rewards, fence_posts, pi_e, pi_b, V, Q)
            
            if np.isnan(obj.data[0]):
                return 0
            
            optimizer.zero_grad()
            obj.backward()
            optimizer.step()

        print ('********************')
        print('Epoch:{}/{}, wdr:{}'.format(epoch + 1, num_epoch, -obj.data[0]))
        wdr = do_eval_test(moe)
        print ('********************')
        print('Eval: epoch:{}/{}, wdr:{}'.format(epoch + 1, num_epoch, wdr))
        for param in moe.linear.parameters():
            print(param.data.numpy().tolist())
        print ('********************')
        
#         if prev_obj > wdr or prev_obj_train > -obj.data[0]:
#             stop_counter += 1
#         else:
#             stop_counter = 0
#         prev_obj = wdr
#         prev_obj_train = -obj.data[0]
        
#         if stop_counter == 1:
#             return 1

In [53]:
'''
    Note that when deriving **kernel policy** on test set, 
    one should look at the similar patient states from trainset.

    Note that when deriving **physician policy** for test set, 
    one should look at the similar patient states within the test set
'''
# indices of neighbors can be obtained through "kernel.ipynb"
# or call the code below
'''
##############################################################

for 1) deriving kernel policy over train and test sets; 
    2) deriving physician policy over train set;

knn = KNN(300)
knn.fit(train_embeddings)

train_dist, train_ind = knn.kneighbors(train_embeddings)
test_from_train_dist, test_from_train_ind = knn.kneighbors(test_embeddings)

##############################################################

for 1) deriving physician policy over test sets; 
knn_phy = KNN(300)
knn_phy.fit(test_embeddings)
phy_test_dist, phy_test_ind = knn_phy.kneighbors(test_embeddings)

##############################################################
'''
dist_train, ind_train = pkl.load(open('kernel_knn_train.pkl', 'rb'))
# for kernel policy derviation over testset, indices of neighbors shall from the trainset
# however, for obtaining the physician policy over testset,
# need to find neighbors from the testset instead of trainset.!!!!
dist_test, ind_test = pkl.load(open('kernel_knn_test_from_train.pkl', 'rb'))

# physician policy over test sets
_, ind_test_from_test_for_pi_b = pkl.load(open('kernel_knn_test.pkl', 'rb'))
# survivors in train set
train_survivors = np.where(train_set['died_in_hosp'].values == 0)[0]

# collect kernel and dqn policy, pi_e
## kernel
# action_probs_k_train = actions_2_probs(ind_train, train_actions)
# action_probs_k_test = actions_2_probs(ind_test, train_actions)
action_probs_k_train, action_probs_k_test = pkl.load(open('kernel_actions.pkl','rb'))
## dqn, restrict actions
### dqn training results are dumped as tuple (actions, q-values)
dqn_res_train = pkl.load(open('../outcome_le/results_train.pkl','rb'))
dqn_res_test = pkl.load(open('../outcome_le/results_test.pkl','rb'))

# collect physician policy, pi_b
train_pi_b = actions_2_probs(ind_train, train_actions, expert='phy')
# to compute ind_test_from_test_for_pi_b, train knn and predict using test set solely.
test_pi_b = actions_2_probs(ind_test_from_test_for_pi_b, test_actions, expert='phy')
# test phy examined by trainset
test_pi_b_from_train = actions_2_probs(ind_test, train_actions, expert='phy')

# restrict dqn policies by physician polcies
action_probs_d_train = restrict_actions(softmax(dqn_res_train[0]), train_pi_b, 3/300)
action_probs_d_test = restrict_actions(softmax(dqn_res_test[0]), test_pi_b_from_train, 3/300)

In [54]:
# V, Q
## obtained from qnetwork_solve_phy_VQ.py
phy_train_Q = pkl.load(open('../outcome_phy/results_train.pkl','rb'))[0]
phy_test_Q = pkl.load(open('../outcome_phy/results_test.pkl','rb'))[0]

phy_train_V = phy_train_Q.max(axis = 1)
phy_test_V = phy_test_Q.max(axis = 1)

## obtained from qnetwork.py
dqn_train_Q = dqn_res_train[0]
dqn_test_Q = dqn_res_test[0]

dqn_train_V = dqn_train_Q.max(axis = 1)
dqn_test_V = dqn_test_Q.max(axis = 1)

# rewards, R_(t) - R_{t-1}
train_rewards = get_ir_diff(train_set)
test_rewards = get_ir_diff(test_set)

In [55]:
# prepare for training MoE
## train set 
train_df = pd.DataFrame(np.hstack((action_probs_k_train, action_probs_d_train)))
train_df[attr_cols] = train_set[attr_cols]
train_df['reward'] = train_rewards
train_df['traj_len'] = get_traj_len(train_set)
train_df['dist'] = np.max(dist_train[:,1:], axis=1)
train_df['icustayid'] = train_set['icustayid']
train_df['bloc'] = train_set['bloc']

## test set
test_df = pd.DataFrame(np.hstack((action_probs_k_test, action_probs_d_test)))
test_df[attr_cols] = test_set[attr_cols]
test_df['reward'] = test_rewards
test_df['traj_len'] = get_traj_len(test_set) 
test_df['dist'] = np.max(dist_test[:,1:], axis=1)
test_df['icustayid'] = test_set['icustayid']
test_df['bloc'] = test_set['bloc']

In [56]:
# examine the phy, kernel, and dqn
## Train
print ('pi_b on train:', evaluate(train_pi_b))
print ('kernel policy on train:', evaluate(action_probs_k_train))
print ('dqn policy on train:', evaluate(action_probs_d_train, VQ='dqn'))

pi_b on train: 3.485433343630809
kernel policy on train: 3.8046235618168307
dqn policy on train: 3.0369187346310356


In [57]:
## Test
print ('pi_b on test:', evaluate(test_pi_b, phase='test'))
print ('kernel policy on test:', evaluate(action_probs_k_test, phase='test'))
print ('dqn policy on test:', evaluate(action_probs_d_test, phase='test', VQ='dqn'))

pi_b on test: 3.7617322682172523
kernel policy on test: 3.7315519237828756
dqn policy on test: 3.21056799416379


In [111]:
moe = MOE(input_size=9, logits=1)

In [113]:
# moe = MOE(input_size=9, logits=1)
train(train_df, moe, batch_size=256, lr=0.0001, num_epoch=500)
# Epoch:37/100, wdr:4.2721781730651855
# ********************
# Eval: epoch:37/100, wdr:3.9994751696702746
# [[7.278773409780115e-05, -0.3777479827404022, 0.08824396878480911, 0.013657244853675365, 0.11063016206026077, -0.3569730818271637, 0.17868568003177643, -0.15885134041309357, -0.34712275862693787]]
# [0.24018655717372894]
# Epoch:52/100, wdr:4.322579860687256
# ********************
# Eval: epoch:52/100, wdr:4.063299313263504
# [[0.17618609964847565, -0.7344371676445007, 0.11212635040283203, 0.4373891353607178, 0.5129995346069336, -0.522505521774292, 0.6723778247833252, -0.2311212420463562, -0.5206314325332642]]
# [0.34241819381713867]

# Epoch:100/100, wdr:4.509039402008057
# ********************
# Eval: epoch:100/100, wdr:4.169190939357189
# [[1.2300101518630981, -1.352432131767273, -0.3300890624523163, 0.9182170033454895, 1.240336298942566, -0.5786391496658325, 1.7389116287231445, -0.2638624310493469, -0.8485985994338989]]
# [0.7569210529327393]
# ********************

# Epoch:29/100, wdr:4.870425224304199
# ********************
# Eval: epoch:29/100, wdr:4.24502183234042
# [[1.2412420511245728, -0.8054572939872742, -0.4281856119632721, 0.7711164355278015, 0.3763531744480133, -0.10058163851499557, 1.7026106119155884, -0.7935341596603394, -0.26787468791007996]]
# [0.7318970561027527]
# dqn:3.799225081366078, phy:4.24502183234042

# Epoch:23/100, wdr:5.080012321472168
# ********************
# Eval: epoch:23/100, wdr:3.882480572462679
# [[-0.000850573880597949, -0.3292688727378845, -0.008397337049245834, 0.14008520543575287, -0.4858025014400482, -0.006549003068357706, -0.3450920879840851, -0.07146909832954407, -0.07110099494457245]]
# [0.018589148297905922]

# Epoch:36/100, wdr:3.622654676437378
# ********************
# Eval: epoch:36/100, wdr:4.314914156055653
# [[1.3287606239318848, -0.7961893677711487, -0.5077809691429138, 0.8471158742904663, 0.3134307861328125, -0.07832232117652893, 1.8573822975158691, -0.6942882537841797, -0.3460248112678528]]
# [0.7956991195678711]

********************
Epoch:1/500, wdr:5.225149154663086
********************
Eval: epoch:1/500, wdr:4.018893399354986
[[0.22461552917957306, 0.30547332763671875, -0.0681275948882103, -0.058487746864557266, -0.055541906505823135, -0.24932356178760529, -0.013985078781843185, -0.30881044268608093, -0.0553886778652668]]
[0.10962872207164764]
********************
********************
Epoch:2/500, wdr:5.225090503692627
********************
Eval: epoch:2/500, wdr:4.019045933361099
[[0.22528821229934692, 0.30504274368286133, -0.06840132176876068, -0.0579751580953598, -0.05460672453045845, -0.24926427006721497, -0.013234032317996025, -0.3090510964393616, -0.0555289201438427]]
[0.10985824465751648]
********************
********************
Epoch:3/500, wdr:5.225017547607422
********************
Eval: epoch:3/500, wdr:4.01920018914247
[[0.22596406936645508, 0.3046094477176666, -0.06867147982120514, -0.057459089905023575, -0.053672485053539276, -0.2492019236087799, -0.012482036836445332, -0.309293

********************
Epoch:24/500, wdr:5.223484992980957
********************
Eval: epoch:24/500, wdr:4.02246091002575
[[0.24029645323753357, 0.2954186201095581, -0.07424765080213547, -0.046528127044439316, -0.03412693366408348, -0.24783611297607422, 0.0033323627430945635, -0.3143521845340729, -0.05859560891985893]]
[0.11496812105178833]
********************
********************
Epoch:25/500, wdr:5.223399639129639
********************
Eval: epoch:25/500, wdr:4.022615807383333
[[0.2409806251525879, 0.2949783504009247, -0.0745123028755188, -0.046007800847291946, -0.033201321959495544, -0.24777208268642426, 0.004084598273038864, -0.3145885467529297, -0.058735720813274384]]
[0.11520054191350937]
********************
********************
Epoch:26/500, wdr:5.2233171463012695
********************
Eval: epoch:26/500, wdr:4.02277056461354
[[0.2416648119688034, 0.2945377826690674, -0.07477693259716034, -0.04548760876059532, -0.03227623924612999, -0.24770823121070862, 0.004836706444621086, -0.314

********************
Epoch:47/500, wdr:5.221655368804932
********************
Eval: epoch:47/500, wdr:4.026007043607035
[[0.2560308575630188, 0.28522828221321106, -0.08034209907054901, -0.03459985926747322, -0.012978026643395424, -0.2464139759540558, 0.02059476636350155, -0.3196861147880554, -0.061852190643548965]]
[0.12029346078634262]
********************
********************
Epoch:48/500, wdr:5.221574306488037
********************
Eval: epoch:48/500, wdr:4.026160455406603
[[0.25671449303627014, 0.28478217124938965, -0.08060774952173233, -0.03408342972397804, -0.012065453454852104, -0.2463548183441162, 0.021343158558011055, -0.31991341710090637, -0.06199566647410393]]
[0.12052370607852936]
********************
********************
Epoch:49/500, wdr:5.221506118774414
********************
Eval: epoch:49/500, wdr:4.026314036552181
[[0.25739797949790955, 0.2843357026576996, -0.08087345212697983, -0.033567216247320175, -0.011153480038046837, -0.24629594385623932, 0.02209133841097355, -0.3

********************
Epoch:70/500, wdr:5.2197699546813965
********************
Eval: epoch:70/500, wdr:4.029523149540841
[[0.27172958850860596, 0.27490171790122986, -0.08647574484348297, -0.022778701037168503, 0.00785498134791851, -0.24512222409248352, 0.03775249049067497, -0.3248293399810791, -0.06520125269889832]]
[0.12555097043514252]
********************
********************
Epoch:71/500, wdr:5.219684600830078
********************
Eval: epoch:71/500, wdr:4.029675211673876
[[0.2724107503890991, 0.27444955706596375, -0.08674374222755432, -0.022267622873187065, 0.008753180503845215, -0.2450695037841797, 0.03849567472934723, -0.3250490427017212, -0.06534934788942337]]
[0.12577757239341736]
********************
********************
Epoch:72/500, wdr:5.2196044921875
********************
Eval: epoch:72/500, wdr:4.0298272988675405
[[0.27309170365333557, 0.27399715781211853, -0.08701186627149582, -0.021756792441010475, 0.00965073611587286, -0.24501702189445496, 0.03923860937356949, -0.32526

********************
Epoch:93/500, wdr:5.2178568840026855
********************
Eval: epoch:93/500, wdr:4.033006339467799
[[0.28735798597335815, 0.2644333243370056, -0.09267411381006241, -0.011090964078903198, 0.02834819070994854, -0.24398629367351532, 0.05478042736649513, -0.32981017231941223, -0.06866573542356491]]
[0.13071346282958984]
********************
********************
Epoch:94/500, wdr:5.217773914337158
********************
Eval: epoch:94/500, wdr:4.0331569890908385
[[0.2880355417728424, 0.26397499442100525, -0.09294535964727402, -0.010586102493107319, 0.02923125959932804, -0.24394069612026215, 0.055517569184303284, -0.33002349734306335, -0.06881926208734512]]
[0.13093550503253937]
********************
********************
Epoch:95/500, wdr:5.217692852020264
********************
Eval: epoch:95/500, wdr:4.033307518978328
[[0.28871291875839233, 0.263516366481781, -0.09321673959493637, -0.010081523098051548, 0.03011365607380867, -0.24389545619487762, 0.05625443533062935, -0.330

********************
Epoch:116/500, wdr:5.215969562530518
********************
Eval: epoch:116/500, wdr:4.036450735631479
[[0.30289575457572937, 0.2538236677646637, -0.09895312786102295, 0.0004476938920561224, 0.04848857596516609, -0.24302157759666443, 0.07166381180286407, -0.33465325832366943, -0.07226052135229111]]
[0.13576363027095795]
********************
********************
Epoch:117/500, wdr:5.215877532958984
********************
Eval: epoch:117/500, wdr:4.036599499593904
[[0.30356913805007935, 0.2533591389656067, -0.09922812879085541, 0.0009458395070396364, 0.049356136471033096, -0.24298368394374847, 0.0723944678902626, -0.3348609507083893, -0.0724199041724205]]
[0.13598045706748962]
********************
********************
Epoch:118/500, wdr:5.2158050537109375
********************
Eval: epoch:118/500, wdr:4.036748278655122
[[0.3042423129081726, 0.25289440155029297, -0.09950333833694458, 0.0014436861965805292, 0.05022301897406578, -0.2429460883140564, 0.07312481105327606, -0.3

********************
Epoch:139/500, wdr:5.214118003845215
********************
Eval: epoch:139/500, wdr:4.039847116895468
[[0.3183300495147705, 0.24307619035243988, -0.10532347112894058, 0.011828657239675522, 0.06827055662870407, -0.24223673343658447, 0.08839534223079681, -0.3393756151199341, -0.07599322497844696]]
[0.1406908482313156]
********************
********************
Epoch:140/500, wdr:5.21403169631958
********************
Eval: epoch:140/500, wdr:4.03999355472881
[[0.31899869441986084, 0.24260587990283966, -0.10560262948274612, 0.012319833040237427, 0.06912250071763992, -0.24220678210258484, 0.0891193151473999, -0.33957841992378235, -0.07615870237350464]]
[0.14090219140052795]
********************
********************
Epoch:141/500, wdr:5.213955402374268
********************
Eval: epoch:141/500, wdr:4.04013984707368
[[0.3196670114994049, 0.24213537573814392, -0.10588197410106659, 0.012810702435672283, 0.06997374445199966, -0.24217712879180908, 0.08984299749135971, -0.3397810

********************
Epoch:162/500, wdr:5.212315559387207
********************
Eval: epoch:162/500, wdr:4.0431860449024315
[[0.3336547911167145, 0.2321990579366684, -0.11179079115390778, 0.023048145696520805, 0.08769388496875763, -0.24163547158241272, 0.10497285425662994, -0.3439924418926239, -0.07986677438020706]]
[0.14548952877521515]
********************
********************
Epoch:163/500, wdr:5.212246417999268
********************
Eval: epoch:163/500, wdr:4.043329766674208
[[0.3343184292316437, 0.23172329366207123, -0.11207424849271774, 0.023532269522547722, 0.0885302796959877, -0.24161356687545776, 0.1056901067495346, -0.34419092535972595, -0.08003838360309601]]
[0.14569523930549622]
********************
********************
Epoch:164/500, wdr:5.212179183959961
********************
Eval: epoch:164/500, wdr:4.043473333284125
[[0.3349819779396057, 0.23124732077121735, -0.11235786974430084, 0.024016091600060463, 0.08936601132154465, -0.24159201979637146, 0.10640708357095718, -0.34438

********************
Epoch:185/500, wdr:5.2105913162231445
********************
Eval: epoch:185/500, wdr:4.04645805276771
[[0.34886667132377625, 0.22120080888271332, -0.11835721880197525, 0.03410595655441284, 0.10676199942827225, -0.24121835827827454, 0.12139814347028732, -0.3485144078731537, -0.08388053625822067]]
[0.15015803277492523]
********************
********************
Epoch:186/500, wdr:5.210526466369629
********************
Eval: epoch:186/500, wdr:4.046598727206434
[[0.3495253920555115, 0.22072003781795502, -0.11864494532346725, 0.03458307683467865, 0.10758305341005325, -0.2412044107913971, 0.12210888415575027, -0.3487090468406677, -0.08405818045139313]]
[0.15035806596279144]
********************
********************
Epoch:187/500, wdr:5.210448265075684
********************
Eval: epoch:187/500, wdr:4.046739175824991
[[0.35018402338027954, 0.22023910284042358, -0.11893288046121597, 0.03505991771817207, 0.10840345174074173, -0.24119076132774353, 0.12281936407089233, -0.348903

********************
Epoch:208/500, wdr:5.208939075469971
********************
Eval: epoch:208/500, wdr:4.049655223164398
[[0.36396554112434387, 0.21009288728237152, -0.12502217292785645, 0.04500434920191765, 0.1254805624485016, -0.2409830391407013, 0.13767564296722412, -0.3529520332813263, -0.08803077042102814]]
[0.15469752252101898]
********************
********************
Epoch:209/500, wdr:5.208877086639404
********************
Eval: epoch:209/500, wdr:4.049792463431482
[[0.3646194338798523, 0.20960763096809387, -0.1253141611814499, 0.04547462984919548, 0.126286581158638, -0.24097688496112823, 0.13838008046150208, -0.35314321517944336, -0.08821422606706619]]
[0.1548919826745987]
********************
********************
Epoch:210/500, wdr:5.208820819854736
********************
Eval: epoch:210/500, wdr:4.049929442914449
[[0.3652731776237488, 0.20912222564220428, -0.12560635805130005, 0.04594461992383003, 0.12709195911884308, -0.24097110331058502, 0.13908426463603973, -0.35333430767

********************
Epoch:231/500, wdr:5.207382678985596
********************
Eval: epoch:231/500, wdr:4.052770437051617
[[0.3789527714252472, 0.19888684153556824, -0.13178318738937378, 0.05574723333120346, 0.14385706186294556, -0.24092508852481842, 0.15381121635437012, -0.357313334941864, -0.09231196343898773]]
[0.15911121666431427]
********************
********************
Epoch:232/500, wdr:5.207308769226074
********************
Eval: epoch:232/500, wdr:4.052903888457538
[[0.37960195541381836, 0.19839760661125183, -0.13207925856113434, 0.056210849434137344, 0.14464835822582245, -0.24092644453048706, 0.15450966358184814, -0.35750123858451843, -0.0925009474158287]]
[0.15930034220218658]
********************
********************
Epoch:233/500, wdr:5.207248687744141
********************
Eval: epoch:233/500, wdr:4.053037255062198
[[0.38025104999542236, 0.19790823757648468, -0.13237550854682922, 0.05667419731616974, 0.1454390585422516, -0.24092811346054077, 0.15520785748958588, -0.357689

********************
Epoch:254/500, wdr:5.205883979797363
********************
Eval: epoch:254/500, wdr:4.055797826905502
[[0.39383170008659363, 0.18759484589099884, -0.13863666355609894, 0.06633985042572021, 0.16189990937709808, -0.24103793501853943, 0.16981257498264313, -0.3616037368774414, -0.09671735018491745]]
[0.1634039431810379]
********************
********************
Epoch:255/500, wdr:5.205832481384277
********************
Eval: epoch:255/500, wdr:4.0559273343518285
[[0.3944762945175171, 0.18710209429264069, -0.13893665373325348, 0.06679708510637283, 0.16267696022987366, -0.24104665219783783, 0.1705053150653839, -0.3617887794971466, -0.09691162407398224]]
[0.16358794271945953]
********************
********************
Epoch:256/500, wdr:5.205764293670654
********************
Eval: epoch:256/500, wdr:4.056056785072766
[[0.3951205611228943, 0.1866092085838318, -0.13923679292201996, 0.06725404411554337, 0.16345340013504028, -0.2410556972026825, 0.1711978316307068, -0.3619737327

********************
Epoch:277/500, wdr:5.204474925994873
********************
Eval: epoch:277/500, wdr:4.058733588372438
[[0.4086061418056488, 0.17622849345207214, -0.14557674527168274, 0.07678806781768799, 0.1796179562807083, -0.24131529033184052, 0.18568669259548187, -0.36583033204078674, -0.10123882442712784]]
[0.16758184134960175]
********************
********************
Epoch:278/500, wdr:5.20440673828125
********************
Eval: epoch:278/500, wdr:4.058859124336951
[[0.4092462956905365, 0.1757328063249588, -0.14588035643100739, 0.07723915576934814, 0.18038107454776764, -0.24133098125457764, 0.18637411296367645, -0.36601272225379944, -0.10143789649009705]]
[0.1677609533071518]
********************
********************
Epoch:279/500, wdr:5.204357624053955
********************
Eval: epoch:279/500, wdr:4.0589845039270545
[[0.4098862409591675, 0.1752369999885559, -0.14618416130542755, 0.07768997550010681, 0.1811436414718628, -0.24134689569473267, 0.18706129491329193, -0.3661950230

********************
Epoch:300/500, wdr:5.2031331062316895
********************
Eval: epoch:300/500, wdr:4.0615747398651285
[[0.4232799708843231, 0.16479970514774323, -0.15259817242622375, 0.08709847927093506, 0.19702085852622986, -0.24174953997135162, 0.20144231617450714, -0.3699973523616791, -0.10586808621883392]]
[0.17165029048919678]
********************
********************
Epoch:301/500, wdr:5.20306921005249
********************
Eval: epoch:301/500, wdr:4.061696038100804
[[0.4239158034324646, 0.16430158913135529, -0.1529051959514618, 0.08754374086856842, 0.19777052104473114, -0.2417718470096588, 0.2021247297525406, -0.3701772391796112, -0.10607166588306427]]
[0.17182482779026031]
********************
********************
Epoch:302/500, wdr:5.203006267547607
********************
Eval: epoch:302/500, wdr:4.061817244915444
[[0.42455148696899414, 0.1638033539056778, -0.15321242809295654, 0.08798873424530029, 0.19851957261562347, -0.24179445207118988, 0.20280705392360687, -0.370357006

********************
Epoch:323/500, wdr:5.201852798461914
********************
Eval: epoch:323/500, wdr:4.064319671578964
[[0.4378582835197449, 0.1533195674419403, -0.15969473123550415, 0.0972777009010315, 0.21411757171154022, -0.2423332929611206, 0.21708796918392181, -0.37410908937454224, -0.11059637367725372]]
[0.1756163388490677]
********************
********************
Epoch:324/500, wdr:5.201798439025879
********************
Eval: epoch:324/500, wdr:4.0644368443302055
[[0.43849003314971924, 0.1528194397687912, -0.16000494360923767, 0.09771738201379776, 0.2148541659116745, -0.24236196279525757, 0.21776585280895233, -0.3742867112159729, -0.11080407351255417]]
[0.175786554813385]
********************
********************
Epoch:325/500, wdr:5.201754093170166
********************
Eval: epoch:325/500, wdr:4.06455379816259
[[0.4391215443611145, 0.15231925249099731, -0.16031521558761597, 0.09815684705972672, 0.21559014916419983, -0.24239082634449005, 0.21844355762004852, -0.3744642436504

********************
Epoch:346/500, wdr:5.20065450668335
********************
Eval: epoch:346/500, wdr:4.066968014737016
[[0.4523445963859558, 0.14179863035678864, -0.16686002910137177, 0.10733244568109512, 0.23091736435890198, -0.24305802583694458, 0.23263229429721832, -0.378169447183609, -0.11541502177715302]]
[0.17948687076568604]
********************
********************
Epoch:347/500, wdr:5.200611114501953
********************
Eval: epoch:347/500, wdr:4.0670808888365295
[[0.4529724717140198, 0.14129695296287537, -0.16717299818992615, 0.10776691138744354, 0.23164118826389313, -0.2430925965309143, 0.23330597579479218, -0.37834489345550537, -0.11562644690275192]]
[0.17965304851531982]
********************
********************
Epoch:348/500, wdr:5.2005462646484375
********************
Eval: epoch:348/500, wdr:4.067193627076082
[[0.45360007882118225, 0.1407952457666397, -0.16748614609241486, 0.10820114612579346, 0.2323644608259201, -0.2431274652481079, 0.23397943377494812, -0.378520220

********************
Epoch:369/500, wdr:5.199491500854492
********************
Eval: epoch:369/500, wdr:4.069520312431792
[[0.4667436182498932, 0.13024647533893585, -0.1740875542163849, 0.11726918071508408, 0.24742959439754486, -0.24391622841358185, 0.24808309972286224, -0.3821823298931122, -0.12031543999910355]]
[0.18326880037784576]
********************
********************
Epoch:370/500, wdr:5.199441909790039
********************
Eval: epoch:370/500, wdr:4.069629088315717
[[0.4673677384853363, 0.12974365055561066, -0.1744031012058258, 0.11769863218069077, 0.2481410801410675, -0.24395641684532166, 0.24875286221504211, -0.3823557496070862, -0.12053021788597107]]
[0.1834312379360199]
********************
********************
Epoch:371/500, wdr:5.199390888214111
********************
Eval: epoch:371/500, wdr:4.06973770537471
[[0.4679917097091675, 0.12924079596996307, -0.17471878230571747, 0.11812786012887955, 0.24885207414627075, -0.24399688839912415, 0.24942244589328766, -0.382529169321

********************
Epoch:392/500, wdr:5.198394298553467
********************
Eval: epoch:392/500, wdr:4.071977504848825
[[0.481060266494751, 0.11867270618677139, -0.18137086927890778, 0.12709437310695648, 0.2636622488498688, -0.2448999434709549, 0.26344794034957886, -0.3861505091190338, -0.1252892166376114]]
[0.18696823716163635]
********************
********************
Epoch:393/500, wdr:5.19835090637207
********************
Eval: epoch:393/500, wdr:4.072082250780272
[[0.4816809296607971, 0.11816912889480591, -0.18168869614601135, 0.12751905620098114, 0.26436179876327515, -0.24494542181491852, 0.26411423087120056, -0.3863220810890198, -0.12550698220729828]]
[0.18712732195854187]
********************
********************
Epoch:394/500, wdr:5.198315620422363
********************
Eval: epoch:394/500, wdr:4.072186854342461
[[0.4823013246059418, 0.11766554415225983, -0.18200665712356567, 0.12794363498687744, 0.2650608718395233, -0.24499112367630005, 0.2647802531719208, -0.38649368286132

********************
Epoch:415/500, wdr:5.1973466873168945
********************
Eval: epoch:415/500, wdr:4.074341871480857
[[0.4952980577945709, 0.10708541423082352, -0.18870416283607483, 0.13681453466415405, 0.2796231508255005, -0.24600189924240112, 0.2787337899208069, -0.3900780975818634, -0.1303282231092453]]
[0.19059200584888458]
********************
********************
Epoch:416/500, wdr:5.197306156158447
********************
Eval: epoch:416/500, wdr:4.0744425085345926
[[0.49591538310050964, 0.10658146440982819, -0.18902404606342316, 0.1372348815202713, 0.28031104803085327, -0.24605238437652588, 0.2793967127799988, -0.3902479112148285, -0.13054867088794708]]
[0.19074797630310059]
********************
********************
Epoch:417/500, wdr:5.197259426116943
********************
Eval: epoch:417/500, wdr:4.074542944947142
[[0.49653249979019165, 0.10607748478651047, -0.18934395909309387, 0.13765496015548706, 0.28099849820137024, -0.24610310792922974, 0.2800595462322235, -0.390417695

********************
Epoch:438/500, wdr:5.19635009765625
********************
Eval: epoch:438/500, wdr:4.076615032629285
[[0.5094597339630127, 0.09549258649349213, -0.1960812658071518, 0.14643533527851105, 0.2953203320503235, -0.24721471965312958, 0.2939487099647522, -0.39396682381629944, -0.13542518019676208]]
[0.19414643943309784]
********************
********************
Epoch:439/500, wdr:5.196314334869385
********************
Eval: epoch:439/500, wdr:4.076711946135351
[[0.5100739002227783, 0.09498852491378784, -0.1964029222726822, 0.14685145020484924, 0.2959970235824585, -0.24726983904838562, 0.2946086823940277, -0.3941349983215332, -0.1356479823589325]]
[0.19429948925971985]
********************
********************
Epoch:440/500, wdr:5.196265697479248
********************
Eval: epoch:440/500, wdr:4.076808535641946
[[0.5106878876686096, 0.09448449313640594, -0.19672462344169617, 0.1472673863172531, 0.2966732084751129, -0.24732519686222076, 0.29526853561401367, -0.3943031430244446

********************
Epoch:461/500, wdr:5.1953816413879395
********************
Eval: epoch:461/500, wdr:4.078799743445704
[[0.5235505104064941, 0.08390121906995773, -0.20349657535552979, 0.1559622883796692, 0.3107607662677765, -0.24853186309337616, 0.30909737944602966, -0.3978198170661926, -0.1405726671218872]]
[0.19763658940792084]
********************
********************
Epoch:462/500, wdr:5.19534158706665
********************
Eval: epoch:462/500, wdr:4.078892916210933
[[0.524161696434021, 0.08339739590883255, -0.20381973683834076, 0.15637445449829102, 0.31142643094062805, -0.2485913634300232, 0.3097545802593231, -0.3979865610599518, -0.14079752564430237]]
[0.19778695702552795]
********************
********************
Epoch:463/500, wdr:5.195303440093994
********************
Eval: epoch:463/500, wdr:4.078985688351696
[[0.5247728824615479, 0.08289356529712677, -0.20414303243160248, 0.1567865014076233, 0.31209149956703186, -0.24865107238292694, 0.3104117512702942, -0.398153185844421

********************
Epoch:484/500, wdr:5.194459438323975
********************
Eval: epoch:484/500, wdr:4.080899099513705
[[0.5375730991363525, 0.0723176896572113, -0.21094471216201782, 0.16540078818798065, 0.32595095038414, -0.24994680285453796, 0.3241878151893616, -0.40163907408714294, -0.14576409757137299]]
[0.2010682374238968]
********************
********************
Epoch:485/500, wdr:5.194417476654053
********************
Eval: epoch:485/500, wdr:4.080988450385744
[[0.5381811857223511, 0.07181432098150253, -0.21126927435398102, 0.1658092588186264, 0.32660576701164246, -0.2500104308128357, 0.3248426616191864, -0.4018043875694275, -0.14599066972732544]]
[0.20121616125106812]
********************
********************
Epoch:486/500, wdr:5.194382190704346
********************
Eval: epoch:486/500, wdr:4.081077679747006
[[0.5387890934944153, 0.07131098955869675, -0.21159382164478302, 0.1662175953388214, 0.3272601366043091, -0.25007420778274536, 0.32549747824668884, -0.40196967124938965

In [109]:
# pre_w = [[-0.12163831293582916, 0.08448589593172073, -0.37054142355918884, 0.09086832404136658, -0.3181394934654236, 0.16706600785255432, -0.2878364026546478, 0.13344427943229675, -0.18734200298786163]]
# pre_b = [-0.35878944396972656]
# pre_w = [[0.18300466239452362, -0.28264403343200684, -0.16560478508472443, -0.3295769989490509, -0.07599536329507828, -0.23663748800754547, -0.06759240478277206, 0.21431511640548706, -0.2873593866825104]]
# pre_b = [-0.11591021716594696]
# pre_w = [[1.3287606239318848, -0.7961893677711487, -0.5077809691429138, 0.8471158742904663, 0.3134307861328125, -0.07832232117652893, 1.8573822975158691, -0.6942882537841797, -0.3460248112678528]]
# pre_b = [0.7956991195678711]
pre_w = [[1.3282166719436646, -0.7970498204231262, -0.5085808038711548, 0.8464760184288025, 0.3131825625896454, -0.07899399101734161, 1.8568652868270874, -0.6949451565742493, -0.3468712866306305]]
pre_b = [0.7950364351272583]
pretrained(moe, pre_w, pre_b)

In [106]:
pi_expert, pi_moe = get_moe_policies(moe, test_df)

In [93]:
evaluate(pi_moe, phase='test', VQ='dqn')

3.934232727560059

In [94]:
evaluate(pi_moe, phase='test', VQ='phy')

4.312659789548697

In [None]:
# objective function is complicated, need randomize start points and simulate many times
# global maxima is not guaranteed. 
# Maybe use "basinhopping" to find global maxima, but very slow.
for i in range(1000):
    print ('sim. n. ', i)
    moe = MOE(input_size=9, logits=1)
    train(train_df, moe, batch_size=128, lr=0.0001, num_epoch=50)

In [None]:
do_eval_test(moe)

In [None]:
_, pi_moe = get_moe_policies(moe, test_df)

In [154]:
train_fence_post = get_fence_post(train_df)

In [159]:
test_fence_post = get_fence_post(test_df)