## Config and Utils

In [15]:
import sys
import os

sys.path.append(os.path.dirname(os.getcwd()))

from learn.test_util import get_algo_config
test_config = get_algo_config

In [16]:
import collections
import datetime
import pandas as pd
import os
import json

def load_json(path):
    with open(path) as json_file:
        data = json.load(json_file)
        return data

In [17]:
task_types = ['thompson_sampling', 'epsilon_greedy', 'upper_confidence_bound']
task_name = []
for test_type in task_types:
    task_name.append('test' + '_' + test_type)
task_name[0:5]

info_map = collections.defaultdict(dict)
for task in task_name:
    res_folder = f"../test/learning_out"
    res_folder = os.path.join(res_folder, task)
    action_history_name = "action_hist.csv"
    knowledge_table_name = "knowledge.json"
    update_content_name = "update_content.csv"
    action_history = os.path.join(res_folder, action_history_name)
    knowledge_table = os.path.join(res_folder, knowledge_table_name)
    update_content = os.path.join(res_folder, update_content_name)
    
    info_map[task]['action_hist_df'] = pd.read_csv(action_history).reset_index(drop=True)
    info_map[task]['knowledge_table_df'] = load_json(knowledge_table)
    info_map[task]['update_content_df'] = pd.read_csv(update_content).reset_index(drop=True)

In [18]:
import matplotlib.pyplot as plt
import numpy as np

def get_result(task_name):
    identifier = f"test_{task_name}"
    print(len(info_map[identifier]['action_hist_df']))
    display(info_map[identifier]['action_hist_df'].tail(10))
    display(info_map[identifier]['update_content_df'])
    if task_name == 'epsilon_greedy' or task_name == 'upper_confidence_bound':
        display(pd.DataFrame.from_dict(info_map[identifier]['knowledge_table_df']))

In [25]:
from scipy.stats import norm
import numpy as np
import plotly.graph_objects as go
import math

def plot_for_task(task_name):
    identifier = f"test_{task_name}"
    possible_state = list(info_map[identifier]['knowledge_table_df'].keys())
    
    for state in possible_state:
        plot_distribution(identifier, state, skip=True)
        
def plot_distribution(identifier, state, skip=False):
    fig = go.Figure()
    params_dict = info_map[identifier]['knowledge_table_df']
    
    params = params_dict[str(state)]
    if skip:
        flag = True
        for key in params.keys():
            print(params[key])
            if not type(params[key]) is dict:
                continue
            mean = params[key]['a']
            sigma = params[key]['b']
            if mean != 0.0 or sigma != 10.0:
                flag = False
        if flag:
            return
    x_axis = np.arange(-2, 2, 0.01)
    for key in params.keys():
        if not type(params[key]) is dict:
            continue
        fig.add_trace(go.Line(x = x_axis, y = norm.pdf(x_axis, params[key]['a'], params[key]['b']), name=key))
    fig.update_layout(title_text=f'{identifier}_state={state}')
    fig.show()

## Visualization

In [28]:
get_result('upper_confidence_bound')

50


Unnamed: 0,timestamp,prev_state,prev_action,curr_state,curr_action,p1,p2,p3,reward
40,1665559602,s3,a4,s3,a4,0,0,1,1.0
41,1665559602,s3,a4,s3,a4,0,0,1,1.0
42,1665559602,s3,a4,s3,a4,0,0,1,1.0
43,1665559602,s3,a4,s3,a4,0,0,1,1.0
44,1665559602,s3,a4,s3,a4,0,0,1,1.0
45,1665559602,s3,a4,s3,a4,0,0,1,1.0
46,1665559602,s3,a4,s3,a4,0,0,1,1.0
47,1665559602,s3,a4,s3,a4,0,0,1,1.0
48,1665559602,s3,a4,s3,a4,0,0,1,1.0
49,1665559602,s3,a4,s3,a4,0,0,1,1.0


Unnamed: 0,timestamp,curr_state,curr_action,next_state,p1,p2,p3
0,1665559601,s1,a4,s1,1,0,0
1,1665559601,s1,a1,s2,0,1,0
2,1665559601,s2,a2,s3,0,0,1
3,1665559601,s3,a3,s1,1,0,0
4,1665559602,s1,a1,s2,0,1,0
5,1665559602,s2,a2,s3,0,0,1
6,1665559602,s3,a3,s1,1,0,0
7,1665559602,s1,a4,s1,1,0,0
8,1665559602,s1,a1,s2,0,1,0
9,1665559602,s2,a2,s3,0,0,1


Unnamed: 0,s1,s2,s3
a1,0.278217,,
a4,0.00936,0.0,2.823876
a2,,0.419247,
a3,,,0.036021


In [27]:
plot_for_task('thompson_sampling')

{'a': 0.5, 'b': 3.8742048900000006}
{'a': 0.0, 'b': 8.1}
nan
nan



plotly.graph_objs.Line is deprecated.
Please replace it with one of the following more specific types
  - plotly.graph_objs.scatter.Line
  - plotly.graph_objs.layout.shape.Line
  - etc.




nan
{'a': 0.2688888888888889, 'b': 6.561}
{'a': 1.0, 'b': 3.8742048900000006}
nan


nan
{'a': 1.0, 'b': 1.853020188851841}
nan
{'a': 0.0, 'b': 3.8742048900000006}
