# 101_6.1_データ理解.ipynb

## 流れ確認
- 「num_data」固定
- 「_」固定

In [1]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from pandas import DataFrame
from tqdm import tqdm
from sklearn.utils import check_random_state
import seaborn as sns
import matplotlib.pyplot as plt
import japanize_matplotlib
plt.style.use('ggplot')
y_label_dict = {"se": "平均二乗誤差", "bias": "二乗バイアス", "variance": "バリアンス", "selection": "方策選択"}

from dataset import generate_synthetic_data, calc_true_value
from estimators import calc_online, calc_ips, calc_new
from utils import eps_greedy_policy, softmax_policy, aggregate_simulation_results

In [2]:
## シミュレーション設定
# num_runs = 1000 # シミュレーションの繰り返し回数
dim_context = 10 # 特徴量xの次元
num_data = 500 # ログデータのサイズ
num_actions = 4 # 行動数, |A|
T = 12 # 総時点数
eps = 0.0 # データ収集方策のパラメータ, これは共通サポートの仮定を満たさない
beta = -5 # 評価方策のパラメータ
random_state = 12345
random_ = check_random_state(random_state)
num_data_list = [250, 500, 1000, 2000, 4000] # ログデータのサイズ

In [4]:
num_data = 500

## 期待報酬関数を定義するためのパラメータを抽出
random_ = check_random_state(random_state)
theta = random_.normal(size=(dim_context, num_actions))
M = random_.normal(size=(dim_context, num_actions))
b = random_.normal(size=(1, num_actions))
W = random_.uniform(0, 1, size=(T, T))
## データ収集方策と評価方策の真の性能(policy value)を近似
policy_value_of_pi0, policy_value_of_pi = calc_true_value(
    dim_context=dim_context, num_actions=num_actions,
    theta=theta, M=M, b=b, W=W, T=T, beta=beta, eps=eps,
)

In [8]:
_ = 0
## データ収集方策が形成する分布に従いログデータを生成
offline_logged_data = generate_synthetic_data(
    num_data=num_data, dim_context=dim_context, num_actions=num_actions,
    theta=theta, M=M, b=b, W=W, T=T, eps=eps, random_state=_
)
online_experiment_data = generate_synthetic_data(
    num_data=num_data, dim_context=dim_context, num_actions=num_actions,
    theta=theta, M=M, b=b, W=W, T=1, beta=beta, is_online=True, random_state=_
)

## ログデータ上における評価方策の行動選択確率を計算
pi = softmax_policy(beta * offline_logged_data["base_q_func"])

## ログデータを用いてオフ方策評価を実行する
estimated_policy_values, selection_result = dict(), dict()
V_hat_online, selection_result_online = calc_online(online_experiment_data)
estimated_policy_values["online"] = V_hat_online
selection_result["online"] = selection_result_online
V_hat_ips, selection_result_ips = calc_ips(offline_logged_data, pi)
estimated_policy_values["ips"] = V_hat_ips
selection_result["ips"] = selection_result_ips
V_hat_new, selection_result_new = calc_new(offline_logged_data, online_experiment_data, pi)
estimated_policy_values["new"] = V_hat_new
selection_result["new"] = selection_result_new
# estimated_policy_value_list.append(estimated_policy_values)
# selection_result_list.append(selection_result)

estimated_policy_values, selection_result

({'online': 0.2121460664550819,
  'ips': 0.0014043631786101057,
  'new': 0.17601989678339033},
 {'online': False, 'ips': False, 'new': True})

## offline_logged_data詳細確認

In [31]:
offline_logged_data.keys()

dict_keys(['num_data', 'T', 'num_actions', 'x', 'w', 'a_t', 'r_t', 'pi_0', 'q_t', 'base_q_func'])

In [32]:
offline_logged_data['num_data']

500

In [33]:
offline_logged_data['T']

12

In [34]:
offline_logged_data['num_actions']

4

In [35]:
offline_logged_data['x']

array([[ 1.76405235,  0.40015721,  0.97873798, ..., -0.15135721,
        -0.10321885,  0.4105985 ],
       [ 0.14404357,  1.45427351,  0.76103773, ..., -0.20515826,
         0.3130677 , -0.85409574],
       [-2.55298982,  0.6536186 ,  0.8644362 , ..., -0.18718385,
         1.53277921,  1.46935877],
       ...,
       [-1.30687164,  1.51959885,  0.21286139, ...,  0.16349452,
        -0.81311702, -0.60535458],
       [-1.3275238 , -0.64417161,  1.90888344, ..., -0.01728509,
         0.91228203,  1.2396585 ],
       [-0.5733674 ,  0.42488949, -0.27126002, ...,  0.92918181,
         0.22941801,  0.41440588]])

In [36]:
offline_logged_data['x'].shape

(500, 10)

In [37]:
offline_logged_data['w']

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [38]:
offline_logged_data['w'].shape

(500,)

In [39]:
offline_logged_data['a_t']

array([[1, 1, 1, ..., 1, 1, 1],
       [3, 3, 3, ..., 3, 3, 3],
       [1, 1, 1, ..., 1, 1, 1],
       ...,
       [1, 1, 1, ..., 1, 1, 1],
       [3, 3, 3, ..., 3, 3, 3],
       [2, 2, 2, ..., 2, 2, 2]])

In [40]:
offline_logged_data['a_t'].shape

(500, 12)

In [41]:
offline_logged_data['r_t']

array([[ 0.03384737, -0.14253272, -0.20123437, ..., -0.76607143,
        -1.04159561,  0.98451425],
       [ 0.08845666,  0.02901178,  0.50365334, ...,  0.20655862,
        -0.61174681, -0.08205286],
       [-0.21301248, -0.10871527,  0.26601941, ...,  0.12612565,
        -0.11850354,  0.0709828 ],
       ...,
       [ 0.42600294,  1.08689243,  0.15139377, ...,  0.43172204,
         0.05595903,  0.30307042],
       [ 0.61595745,  0.96222077, -0.01579352, ...,  1.48495505,
         0.07534321, -0.12602453],
       [ 0.06615431,  0.10524836,  0.70567442, ...,  0.16450106,
         0.31683646,  0.17143409]])

In [42]:
offline_logged_data['r_t'].shape

(500, 12)

In [43]:
offline_logged_data['pi_0']

array([[0., 1., 0., 0.],
       [0., 0., 0., 1.],
       [0., 1., 0., 0.],
       ...,
       [0., 1., 0., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.]])

In [44]:
offline_logged_data['pi_0'].shape

(500, 4)

In [45]:
offline_logged_data['q_t']

array([[6.40299904e-02, 1.23353147e-01, 1.32163397e-01, ...,
        1.28803379e-01, 1.04514370e-01, 1.50071661e-01],
       [8.26665776e-02, 1.59256349e-01, 1.70630913e-01, ...,
        1.66292927e-01, 1.34934352e-01, 1.93751561e-01],
       [2.05208422e-08, 3.18332338e-08, 2.95195202e-08, ...,
        3.86569376e-08, 2.31252357e-08, 2.41360443e-08],
       ...,
       [8.33323462e-02, 1.60538946e-01, 1.72005116e-01, ...,
        1.67632194e-01, 1.36021068e-01, 1.95311971e-01],
       [1.10474563e-01, 1.71375158e-01, 1.58919213e-01, ...,
        2.08110770e-01, 1.24495393e-01, 1.29937110e-01],
       [3.01348247e-02, 5.80544436e-02, 6.22008651e-02, ...,
        6.06195193e-02, 4.91882350e-02, 7.06291407e-02]])

In [46]:
offline_logged_data['q_t'].shape

(500, 12)

In [47]:
offline_logged_data['base_q_func']

array([[1.70712800e-06, 6.40299904e-02, 8.33333333e-02, 8.32419959e-02],
       [8.31260713e-02, 8.33126526e-02, 7.57541521e-02, 8.26665776e-02],
       [1.98508260e-10, 1.54782971e-08, 8.33333333e-02, 8.33314478e-02],
       ...,
       [8.30334060e-02, 8.33323462e-02, 8.33333333e-02, 8.33333220e-02],
       [3.49925417e-05, 8.33328773e-02, 8.33330729e-02, 8.33278717e-02],
       [8.32951516e-02, 2.21363717e-02, 3.01348247e-02, 8.00677255e-02]])

In [48]:
offline_logged_data['base_q_func'].shape

(500, 4)

## online_experiment_data詳細確認

In [50]:
online_experiment_data.keys()

dict_keys(['num_data', 'T', 'num_actions', 'x', 'w', 'a_t', 'r_t', 'pi_0', 'q_t', 'base_q_func'])

In [51]:
online_experiment_data['num_data']

500

In [52]:
online_experiment_data['T']

1

In [53]:
online_experiment_data['num_actions']

4

In [54]:
online_experiment_data['x']

array([[ 1.76405235,  0.40015721,  0.97873798, ..., -0.15135721,
        -0.10321885,  0.4105985 ],
       [ 0.14404357,  1.45427351,  0.76103773, ..., -0.20515826,
         0.3130677 , -0.85409574],
       [-2.55298982,  0.6536186 ,  0.8644362 , ..., -0.18718385,
         1.53277921,  1.46935877],
       ...,
       [-1.30687164,  1.51959885,  0.21286139, ...,  0.16349452,
        -0.81311702, -0.60535458],
       [-1.3275238 , -0.64417161,  1.90888344, ..., -0.01728509,
         0.91228203,  1.2396585 ],
       [-0.5733674 ,  0.42488949, -0.27126002, ...,  0.92918181,
         0.22941801,  0.41440588]])

In [55]:
online_experiment_data['x'].shape

(500, 10)

In [56]:
online_experiment_data['w']

array([0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0,
       1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1,
       0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1,
       1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0,
       0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,
       0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1,
       1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1,
       1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0,
       1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0,
       0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1,
       1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0,
       1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1,
       1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0,
       1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0,

In [57]:
online_experiment_data['w'].shape

(500,)

In [60]:
online_experiment_data['a_t']

array([[1],
       [2],
       [1],
       [1],
       [0],
       [1],
       [1],
       [1],
       [3],
       [3],
       [1],
       [0],
       [3],
       [1],
       [0],
       [0],
       [2],
       [0],
       [0],
       [1],
       [3],
       [1],
       [1],
       [3],
       [3],
       [0],
       [1],
       [3],
       [2],
       [3],
       [1],
       [1],
       [2],
       [1],
       [2],
       [2],
       [0],
       [3],
       [0],
       [1],
       [1],
       [0],
       [0],
       [1],
       [1],
       [0],
       [0],
       [0],
       [0],
       [0],
       [2],
       [1],
       [3],
       [3],
       [3],
       [0],
       [1],
       [0],
       [3],
       [0],
       [3],
       [1],
       [3],
       [2],
       [2],
       [1],
       [2],
       [2],
       [0],
       [0],
       [0],
       [1],
       [3],
       [3],
       [1],
       [3],
       [1],
       [0],
       [1],
       [0],
       [1],
       [0],
       [2],
    

In [61]:
online_experiment_data['a_t'].shape

(500, 1)

In [62]:
online_experiment_data['r_t']

array([[ 1.56369539e+00],
       [ 2.46039546e-01],
       [ 1.59204655e-01],
       [ 2.61394028e-01],
       [ 6.08946191e-01],
       [ 9.25907786e-01],
       [ 4.27616416e-01],
       [ 4.87725271e-01],
       [ 6.39892030e-01],
       [ 7.88081457e-01],
       [ 6.87377440e-01],
       [ 7.22508670e-01],
       [ 3.61376241e-01],
       [-2.99686061e-01],
       [-6.37247170e-01],
       [ 8.10821300e-01],
       [ 7.44844414e-01],
       [ 7.92043037e-01],
       [ 1.20669795e-01],
       [ 1.14641420e-01],
       [-3.48986308e-01],
       [-4.38074270e-01],
       [ 7.02461980e-01],
       [ 4.85626823e-01],
       [ 3.22067073e-01],
       [-5.46700680e-01],
       [ 2.28160422e-01],
       [-2.21810852e-01],
       [ 8.88508066e-01],
       [-7.65846418e-02],
       [ 1.44254903e+00],
       [-2.67450960e-01],
       [ 9.19461552e-01],
       [ 6.02976967e-01],
       [ 6.98285032e-01],
       [ 1.37259319e+00],
       [ 1.13592069e+00],
       [ 1.37014650e+00],
       [ 4.3

In [63]:
online_experiment_data['r_t'].shape

(500, 1)

In [66]:
online_experiment_data['pi_0']

array([[0.        , 1.        , 0.        , 0.        ],
       [0.21866876, 0.21623445, 0.34031554, 0.22478125],
       [0.        , 1.        , 0.        , 0.        ],
       ...,
       [0.        , 1.        , 0.        , 0.        ],
       [0.98014344, 0.00661822, 0.00661814, 0.0066202 ],
       [0.        , 0.        , 1.        , 0.        ]])

In [67]:
online_experiment_data['pi_0'].shape

(500, 4)

In [64]:
online_experiment_data['q_t']

array([[7.68359885e-01],
       [9.09049825e-01],
       [1.85739565e-07],
       [1.61581614e-03],
       [9.84192114e-01],
       [4.16786925e-03],
       [7.36815340e-03],
       [7.72779007e-01],
       [1.90263302e-01],
       [5.99077122e-03],
       [1.85743136e-04],
       [9.91687303e-01],
       [2.66478795e-01],
       [8.64840710e-01],
       [1.49198783e-12],
       [9.99907008e-01],
       [9.99220944e-01],
       [1.56780249e-01],
       [1.45124115e-04],
       [2.24745610e-12],
       [1.80900974e-03],
       [6.78399677e-14],
       [8.70667814e-02],
       [9.39100913e-01],
       [5.10143978e-01],
       [2.47690181e-01],
       [2.52969249e-01],
       [1.48044131e-02],
       [9.68649785e-01],
       [4.03219394e-08],
       [9.99999570e-01],
       [1.63600842e-01],
       [5.14131614e-04],
       [9.85184355e-01],
       [5.89951251e-01],
       [9.94964127e-01],
       [3.87641447e-01],
       [1.06697139e-01],
       [3.90147113e-06],
       [3.30953032e-05],


In [65]:
online_experiment_data['q_t'].shape

(500, 1)

In [68]:
online_experiment_data['base_q_func']

array([[2.04855360e-05, 7.68359885e-01, 1.00000000e+00, 9.98903951e-01],
       [9.97512856e-01, 9.99751831e-01, 9.09049825e-01, 9.91998931e-01],
       [2.38209913e-09, 1.85739565e-07, 1.00000000e+00, 9.99977373e-01],
       ...,
       [9.96400872e-01, 9.99988154e-01, 1.00000000e+00, 9.99999864e-01],
       [4.19910500e-04, 9.99994527e-01, 9.99996874e-01, 9.99934460e-01],
       [9.99541819e-01, 2.65636460e-01, 3.61617896e-01, 9.60812705e-01]])

In [69]:
online_experiment_data['base_q_func'].shape

(500, 4)