In [1]:
import os

import gym
import numpy as np
import pandas as pd
import importlib
import warnings

import hiive.mdptoolbox
import hiive.mdptoolbox.example
from matplotlib import pyplot as plt

from policy_iteration_custom import PolicyIterationCustom

import utils
import plots
from taxi_custom import TaxiCustomEnv

warnings.filterwarnings("ignore")



In [2]:
PATH_FIGURES = f'{os.getcwd()}/figures/taxi'
PATH_ARTEFACTS = f'{os.getcwd()}/artefacts/taxi'
VI_RUNS_CSV = 'vi_runs.csv'
PI_RUNS_CSV = 'pi_runs.csv'

VI_POLICIES_JSON = 'vi_policies.json'
PI_POLICIES_JSON = 'pi_policies.json'

regenerate_runs_mdp = True

In [3]:
gym.envs.register('TaxiCustom', TaxiCustomEnv)

In [4]:
plots.setup_plots()

In [5]:
ACTION_TO_LABEL = {
    0: '↓',
    1: '↑',
    2: '←',
    3: '→',
    4: 'P',
    5: 'D'
}
ACTION_TO_COLOR = {
    0: '#876161',
    1: '#d7e4f5',
    2: '#91ccff',
    3: '#f27272',
    4: '#02bf1e',
    5: '#ad031a'
}

In [24]:
MAIN_SIZE = 10

MAP_TO_ACTUAL_STATE_SIZE = {
    3: 3 * 3 * 5 * 4,
    4: 4 * 4 * 5 * 4,
    5: 5 * 5 * 5 * 4,
    6: 6 * 6 * 5 * 4,
    7: 7 * 7 * 5 * 4,
    8: 8 * 8 * 5 * 4,
    9: 9 * 9 * 5 * 4,
    10: 10 * 10 * 5 * 4,
}

importlib.reload(utils)


def problem_foo(sz: int):
    this_env = gym.make('TaxiCustom', size=sz)
    this_env.reset(seed=42)
    p, r = utils.convert_p_r_v2(this_env)
    return p, r


In [7]:
# from matplotlib import pyplot as plt
#
# env = gym.make('TaxiCustom', size=MAIN_SIZE, render_mode='rgb_array')
# env.reset()
# plt.imshow(env.render())


In [8]:
importlib.reload(utils)

if regenerate_runs_mdp:
    # print('Value iteration: starting...')
    grid_size = sorted(list(MAP_TO_ACTUAL_STATE_SIZE.keys()))
    taxi_df_vi, policies_vi = utils.write_stats_for_problem_sizes(
        algo_class=hiive.mdptoolbox.mdp.ValueIteration,
        main_size=MAIN_SIZE,
        state_sizes=grid_size,
        problem_function=problem_foo,
        not_use_span_vi=True,
    )
    print('Value iteration: finished runs!')
    utils.save_df_as_csv(taxi_df_vi, PATH_ARTEFACTS, VI_RUNS_CSV)
    utils.save_policies(policies_vi, PATH_ARTEFACTS, VI_POLICIES_JSON)
    print('Value iteration: saved!')

    print('Policy iteration: starting...')
    taxi_df_pi, policies_pi = utils.write_stats_for_problem_sizes(
        algo_class=PolicyIterationCustom,
        main_size=MAIN_SIZE,
        state_sizes=grid_size,
        problem_function=problem_foo,
        not_use_span_vi=False,
        collect_changes=True,
        max_iter=1e3,
    )
    print('Policy iteration: finished runs!')
    utils.save_df_as_csv(taxi_df_pi, PATH_ARTEFACTS, PI_RUNS_CSV)
    utils.save_policies(policies_pi, PATH_ARTEFACTS, PI_POLICIES_JSON)
    print('Policy iteration: saved!')

Processing size 3
Processing size 4
Processing size 5
Processing size 6
Processing size 7
Processing size 8
Processing size 9
Processing size 10
Value iteration: finished runs!
Value iteration: saved!
Policy iteration: starting...
Processing size 3
Processing size 4
Processing size 5
Processing size 6
Processing size 7
Processing size 8
Processing size 9
Processing size 10
Policy iteration: finished runs!
Policy iteration: saved!


In [85]:
if regenerate_runs_mdp:
    # print('Value iteration: starting...')
    grid_size = sorted(list(MAP_TO_ACTUAL_STATE_SIZE.keys()))
    _, policies_vi_3 = utils.write_stats_for_problem_sizes(
        algo_class=hiive.mdptoolbox.mdp.ValueIteration,
        main_size=3,
        state_sizes=grid_size,
        problem_function=problem_foo,
        not_use_span_vi=True,
    )
    print('Value iteration: finished runs!')
    utils.save_policies(policies_vi_3, PATH_ARTEFACTS, 'vi_policies_3.json')
    print('Value iteration: saved!')

    print('Policy iteration: starting...')
    _, policies_pi_3 = utils.write_stats_for_problem_sizes(
        algo_class=PolicyIterationCustom,
        main_size=MAIN_SIZE,
        state_sizes=grid_size,
        problem_function=problem_foo,
        not_use_span_vi=False,
        collect_changes=True,
        max_iter=1e3,
    )
    print('Policy iteration: finished runs!')
    utils.save_policies(policies_pi_3, PATH_ARTEFACTS, 'pi_policies_3.json')
    print('Policy iteration: saved!')

Processing size 3
Processing size 4
Processing size 5
Processing size 6
Processing size 7
Processing size 8
Processing size 9
Processing size 10
Value iteration: finished runs!
Value iteration: saved!
Policy iteration: starting...
Processing size 3
Processing size 4
Processing size 5
Processing size 6
Processing size 7
Processing size 8
Processing size 9
Processing size 10
Policy iteration: finished runs!
Policy iteration: saved!


In [9]:
taxi_df_vi = utils.read_csv(PATH_ARTEFACTS, VI_RUNS_CSV)
taxi_policies_vi = utils.read_policies(PATH_ARTEFACTS, VI_POLICIES_JSON)
taxi_df_pi = utils.read_csv(PATH_ARTEFACTS, PI_RUNS_CSV)
taxi_policies_pi = utils.read_policies(PATH_ARTEFACTS, PI_POLICIES_JSON)

In [10]:
importlib.reload(plots)
plots.create_convergence_and_state_plots(MAIN_SIZE, taxi_df_vi, folder_path=PATH_FIGURES, algo="vi",
                                         marker_size=5, actual_sizes=MAP_TO_ACTUAL_STATE_SIZE)
plots.create_convergence_and_state_plots(MAIN_SIZE, taxi_df_pi, folder_path=PATH_FIGURES, algo="pi",
                                         marker_size=5, actual_sizes=MAP_TO_ACTUAL_STATE_SIZE, plot_changes=True)

In [11]:
## Let's compare for the small size
if regenerate_runs_mdp:
    plots.create_convergence_and_state_plots(3, taxi_df_vi, folder_path=PATH_FIGURES, algo="vi_3",
                                             marker_size=5, actual_sizes=MAP_TO_ACTUAL_STATE_SIZE, plot_states=False)
    plots.create_convergence_and_state_plots(3, taxi_df_pi, folder_path=PATH_FIGURES, algo="pi_3",
                                             marker_size=5, actual_sizes=MAP_TO_ACTUAL_STATE_SIZE, plot_changes=False,
                                             plot_states=False)

In [12]:
def plot_policy(policy, file_name: str, title: str, folder=PATH_FIGURES, color_mapping=ACTION_TO_COLOR,
                name_map=ACTION_TO_LABEL,
                num_columns=5 * 4):
    # Determine the number of rows based on the length of lst and num_columns
    num_rows = -(-len(policy) // num_columns)

    # Reshape the list into a grid
    grid = np.array(policy + [None] * (num_rows * num_columns - len(policy))).reshape(num_rows, num_columns)

    fig = plt.figure()
    ax = fig.add_subplot(111, xlim=(-.01, num_columns + 0.01), ylim=(-.01, num_rows + 0.01))

    for i in range(num_rows):
        for j in range(num_columns):
            y = i
            x = j
            p = plt.Rectangle([x, y], 1, 1, linewidth=1, edgecolor='k')
            p.set_facecolor(color_mapping.get(grid[i, j], 'b'))
            ax.add_patch(p)
            ax.text(x + 0.5, y + 0.5, name_map.get(grid[i, j], ''), ha='center', va='center', size=2, color='w')
    ax.set_xticks(np.arange(0, num_columns, 1))
    ax.set_yticks(np.arange(0, num_rows, 10))
    ax.set_xticklabels(np.arange(0, num_columns, 1))
    ax.set_yticklabels(np.arange(0, num_rows, 10))
    ax.set_title(title)

    os.makedirs(folder, exist_ok=True)
    save_path = os.path.join(folder, file_name)
    plt.savefig(save_path)
    plt.close()

In [13]:
plot_policy(list(taxi_policies_vi[str(0.99)]), file_name='vi_policy_g_099.png',
            title='Taxi (size=2000), VI Policy (gamma=0.99)')

plot_policy(list(taxi_policies_vi[str(0.5)]), file_name='vi_policy_g_05.png',
            title='Taxi (size=2000), VI Policy (gamma=0.50)')

plot_policy(list(taxi_policies_vi[str(0.1)]), file_name='vi_policy_g_01.png',
            title='Taxi (size=2000), VI Policy (gamma=0.1)')

In [14]:
plot_policy(list(taxi_policies_pi[str(0.99)]), file_name='pi_policy_g_099.png',
            title='Taxi (size=2000), PI Policy (gamma=0.99)')

plot_policy(list(taxi_policies_pi[str(0.5)]), file_name='pi_policy_g_05.png',
            title='Taxi (size=2000), PI Policy (gamma=0.50)')

plot_policy(list(taxi_policies_pi[str(0.1)]), file_name='pi_policy_g_01.png',
            title='Taxi (size=2000), PI Policy (gamma=0.1)')

In [15]:
import qlearning_utils
import qlearning_plots
import qlearning_custom

importlib.reload(qlearning_utils)
importlib.reload(qlearning_custom)

p_3, r_3 = problem_foo(3)
np.random.seed(42)
df_ql_3, df_ql_3_2, policies_ql_3 = qlearning_utils.q_learning_stats_gammas(
    p=p_3,
    r=r_3,
    epsilon=0.5,
    epsilon_decay=0.9999,
    epsilon_min=0.01,
    alpha=1,
    alpha_decay=0.9999,
    alpha_min=1e-4,
    n_iter=100000,
    episode_length=10,
    td_error_threshold=5 * 1e-5,
    overall_stat_freq=1000,
    iter_callback=None,
)

Running for gamma=0.1
Running for gamma=0.3
Running for gamma=0.5
Running for gamma=0.7
Running for gamma=0.9
Running for gamma=0.95
Running for gamma=0.99


In [16]:
importlib.reload(qlearning_plots)
qlearning_plots.create_stat_plot(df=df_ql_3_2, folder_path=PATH_FIGURES, file_name='ql_td_error_3.png',
                                 log_scale_y=True, title_additional='(Taxi (size=180))')
qlearning_plots.create_stat_plot(df=df_ql_3, folder_path=PATH_FIGURES, file_name='ql_td_mean_v_3.png',
                                 y_axis='Mean V', title_additional='(Taxi (size=180))')

In [147]:
p_10, r_10 = problem_foo(10)
np.random.seed(42)

df_ql_10_bad, df_ql_10_2_bad, policies_ql_10_bad = qlearning_utils.q_learning_stats_gammas(
    p=p_10,
    r=r_10,
    epsilon=0.01,
    epsilon_decay=0.9999,
    epsilon_min=0.01,
    alpha=1,
    alpha_decay=0.9999,
    alpha_min=1e-4,
    n_iter=100000,
    episode_length=10,
    td_error_threshold=5 * 1e-5,
    overall_stat_freq=1000,
    iter_callback=None,
    # gammas=[0.99]
)
importlib.reload(qlearning_plots)
qlearning_plots.create_stat_plot(df=df_ql_10_2_bad, folder_path=PATH_FIGURES, file_name='ql_td_error_10_bad.png',
                                 log_scale_y=True, title_additional='(Taxi (size=2000))')
qlearning_plots.create_stat_plot(df=df_ql_10_bad, folder_path=PATH_FIGURES, file_name='ql_td_mean_v_10_bad.png',
                                 y_axis='Mean V', title_additional='(Taxi (size=2000))', )

Running for gamma=0.1
Running for gamma=0.3
Running for gamma=0.5
Running for gamma=0.7
Running for gamma=0.9
Running for gamma=0.95
Running for gamma=0.99


In [162]:
p_10, r_10 = problem_foo(10)
np.random.seed(42)

df_ql_10, df_ql_10_2, policies_ql_10 = qlearning_utils.q_learning_stats_gammas(
    p=p_10,
    r=r_10,
    epsilon=0.01,
    epsilon_decay=0.9999,
    epsilon_min=0.0001,
    alpha=1,
    alpha_decay=0.9999,
    alpha_min=1,
    n_iter=500000,
    episode_length=100,
    td_error_threshold=1e-2,
    overall_stat_freq=1000,
    iter_callback=None,
    # gammas=[0.3, 0.7, 0.9, 0.95, 0.99]
)
importlib.reload(qlearning_plots)
qlearning_plots.create_stat_plot(df=df_ql_10_2, folder_path=PATH_FIGURES, file_name='ql_td_error_10.png',
                                 log_scale_y=True, title_additional='(Taxi (size=2000))')
qlearning_plots.create_stat_plot(df=df_ql_10, folder_path=PATH_FIGURES, file_name='ql_td_mean_v_10.png',
                                 y_axis='Mean V', title_additional='(Taxi (size=2000))', )

Running for gamma=0.1
Running for gamma=0.3
Running for gamma=0.5
Running for gamma=0.7
Running for gamma=0.9
Running for gamma=0.95
Running for gamma=0.99


In [19]:
plot_policy(list(policies_ql_10[0.99]), file_name='ql_policy_10_g_099.png',
            title='Taxi (size=2000), QL Policy (gamma=0.99)')

plot_policy(list(policies_ql_10[0.50]), file_name='ql_policy_10_g_050.png',
            title='Taxi (size=2000), QL Policy (gamma=0.5)')

plot_policy(list(policies_ql_10[0.1]), file_name='ql_policy_10_g_01.png',
            title='Taxi (size=2000), QL Policy (gamma=0.1)')

In [20]:
importlib.reload(qlearning_utils)
df_ql_10_eps, df_ql_10_2_eps, policies_ql_10_eps = qlearning_utils.q_learning_stats_epsilons(
    p=p_10,
    r=r_10,
    epsilons=[0.001, 0.05, 0.1, 0.3, 0.5, 0.8, 0.9],
    alpha=1,
    alpha_decay=0.9999,
    alpha_min=1e-4,
    n_iter=100000,
    episode_length=10,
    td_error_threshold=5 * 1e-5,
    overall_stat_freq=1000,
    gamma=0.99,
    iter_callback=None,
)

Running for epsilon=0.001
Running for epsilon=0.05
Running for epsilon=0.1
Running for epsilon=0.3
Running for epsilon=0.5
Running for epsilon=0.8
Running for epsilon=0.9


In [21]:
importlib.reload(qlearning_plots)
qlearning_plots.create_stat_plot(df=df_ql_10_2_eps, folder_path=PATH_FIGURES, hue_col='Epsilon',
                                 file_name='ql_td_error_10_epsilon_g099.png',
                                 title_additional='(Taxi (size=2000, g=0.99))')
qlearning_plots.create_stat_plot(df=df_ql_10_eps, folder_path=PATH_FIGURES, hue_col='Epsilon',
                                 file_name='ql_td_mean_v_10_epsilon_g099.png', y_axis='Mean V',
                                 title_additional='(Taxi (size=2000, g=0.99))')

In [22]:
importlib.reload(qlearning_utils)
df_ql_3_eps, df_ql_3_2_eps, policies_ql_3_eps = qlearning_utils.q_learning_stats_epsilons(
    p=p_3,
    r=r_3,
    epsilons=[0.001, 0.05, 0.1, 0.3, 0.5, 0.8, 0.9],
    alpha=1,
    alpha_decay=0.9999,
    alpha_min=1e-4,
    n_iter=100000,
    episode_length=10,
    td_error_threshold=5 * 1e-5,
    overall_stat_freq=1000,
    gamma=0.99,
    iter_callback=None,
)

Running for epsilon=0.001
Running for epsilon=0.05
Running for epsilon=0.1
Running for epsilon=0.3
Running for epsilon=0.5
Running for epsilon=0.8
Running for epsilon=0.9


In [23]:
importlib.reload(qlearning_plots)
qlearning_plots.create_stat_plot(df=df_ql_3_2_eps, folder_path=PATH_FIGURES, hue_col='Epsilon',
                                 file_name='ql_td_error_3_epsilon_g099.png',
                                 title_additional='(Taxi (size=180, g=0.99))')
qlearning_plots.create_stat_plot(df=df_ql_3_eps, folder_path=PATH_FIGURES, hue_col='Epsilon',
                                 file_name='ql_td_mean_v_3_epsilon_g099.png', y_axis='Mean V',
                                 title_additional='(Taxi (size=180, g=0.99))')

In [68]:
SEEDS = [45, 1337, 42, 2991, 10232, 100, 23, 19999, 2935, 2323, 11, 13213, 2323, 4424, 211, 33, 5053, 100, 1320, 213,
         240231, 3012, 2424, 23293, 2424, 44, 123, 9483, 933, 11112, 98]


def run_episode(policy, env, seed):
    s, _ = env.reset(seed=seed)
    total_reward = 0
    terminated = False
    i = 0
    while not terminated:
        action = policy[s]
        tup = env.step(action)
        s_n, r, t, _a, _b = tup
        total_reward += r
        i += 1
        terminated = t
        s = s_n
        if i == 100:
            terminated = True
    return total_reward


def evaluate_policy(policy, size):
    rewards = []
    for seed in SEEDS:
        cur_env = gym.make('TaxiCustom', size=size)
        reward = run_episode(policy, cur_env, seed)
        rewards.append(reward)
    return rewards




In [69]:
# total_rewards_ql = evaluate_policy(list(policies_ql_10[0.99]), 10)

In [173]:
taxi_policies_vi_3 = utils.read_policies(PATH_ARTEFACTS, 'vi_policies_3.json')
taxi_policies_pi_3 = utils.read_policies(PATH_ARTEFACTS, 'pi_policies_3.json')
total_rewards_pi = {}
total_rewards_vi = {}
total_rewards_ql = {}
difference_vi_pi = {}
difference_vi_ql = {}
difference_pi_ql = {}
total_rewards_pi_3 = {}
total_rewards_vi_3 = {}
total_rewards_ql_3 = {}
difference_vi_pi_3 = {}
difference_vi_ql_3 = {}
for gamma in taxi_policies_pi.keys():
    # for gamma in [str(0.99)]:
    total_rewards_vi[gamma] = evaluate_policy(list(taxi_policies_vi[gamma]), 10)
    total_rewards_pi[gamma] = evaluate_policy(list(taxi_policies_pi[gamma]), 10)
    total_rewards_ql[gamma] = evaluate_policy(list(policies_ql_10[float(gamma)]), 10)
    total_rewards_vi_3[gamma] = evaluate_policy(list(taxi_policies_vi_3[gamma]), 3)
    total_rewards_pi_3[gamma] = evaluate_policy(list(taxi_policies_pi_3[gamma]), 3)
    total_rewards_ql_3[gamma] = evaluate_policy(list(policies_ql_3[float(gamma)]), 3)
    difference_vi_pi[gamma] = [np.abs(list(total_rewards_vi[gamma])[i] - list(total_rewards_pi[gamma])[i]) for i in
                               range(len(total_rewards_vi[gamma]))]
    difference_vi_ql[gamma] = [np.abs(list(total_rewards_vi[gamma])[i] - list(total_rewards_ql[gamma])[i]) for i in
                               range(len(total_rewards_vi[gamma]))]

    difference_pi_ql[gamma] = [np.abs(list(total_rewards_pi[gamma])[i] - list(total_rewards_ql[gamma])[i]) for i in
                               range(len(total_rewards_ql[gamma]))]

    difference_vi_pi_3[gamma] = [np.abs(list(total_rewards_vi_3[gamma])[i] - list(total_rewards_pi_3[gamma])[i]) for i
                                 in
                                 range(len(total_rewards_vi_3[gamma]))]
    difference_vi_ql_3[gamma] = [np.abs(list(total_rewards_vi_3[gamma])[i] - list(total_rewards_ql_3[gamma])[i]) for i
                                 in
                                 range(len(total_rewards_vi_3[gamma]))]

In [124]:
print(difference_vi_ql_3)

{'0.99': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0]}


In [125]:
print(total_rewards_vi_3)

{'0.99': [12, 16, 15, 12, 13, 16, 13, 14, 8, 15, 12, 13, 15, 11, 9, 14, 13, 16, 15, 16, 8, 14, 12, 15, 12, 12, 15, 15, 13, 17, 15]}


In [138]:
print(total_rewards_ql_3)

{'0.99': [12, 16, 15, 12, 13, 16, 13, 14, 8, 15, 12, 13, 15, 11, 9, 14, 13, 16, 15, 16, 8, 12, 12, 15, 12, 12, 15, 15, 13, 17, 15]}


In [135]:
print(difference_vi_pi)

{'0.99': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}


In [164]:
print(total_rewards_ql)

{'0.1': [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 13, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], '0.3': [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], '0.5': [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 13, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], '0.7': [-100, -100, -100, -100, -100, -100, 7, -100, -100, -100, -100, 10, -100, -100, 13, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], '0.9': [-1, 0, 4, -2, 2, -6, 7, 6, -100, -3, -100, 8, -3, 1, 13, 4, -100, -6, -3, -3, -2, 2, 6, -100, 6, -100, -11, 6, 4, -1, -100], '0.95': [-1, 0, 4, -2, 2, -100, 7, 6, -11, -3, -100, 10, -3, 1, 13, 4, -100, -100, -3, -1, -2, 2,

In [169]:
print(difference_vi_ql)

{'0.1': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 113, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], '0.3': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 113, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], '0.5': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 110, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 103, 0, 0, 0, 0, 0, 0, 0], '0.7': [99, 100, 104, 98, 102, 0, 0, 106, 0, 97, 97, 0, 97, 101, 0, 104, 0, 0, 97, 99, 98, 102, 106, 103, 106, 0, 0, 106, 104, 99, 99], '0.9': [0, 0, 0, 0, 0, 0, 0, 0, 91, 0, 97, 2, 0, 0, 0, 0, 89, 0, 0, 2, 0, 0, 0, 103, 0, 94, 4, 0, 0, 0, 99], '0.95': [0, 0, 0, 0, 0, 94, 0, 0, 2, 0, 97, 0, 0, 0, 0, 0, 89, 94, 0, 0, 0, 0, 2, 0, 2, 2, 2, 0, 0, 99, 99], '0.99': [0, 0, 0, 0, 0, 0, 0, 0, 91, 0, 97, 0, 0, 0, 0, 0, 89, 0, 0, 0, 0, 102, 0, 103, 0, 94, 0, 0, 104, 99, 99]}


In [99]:
print(total_rewards_vi)

{'0.1': [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], '0.3': [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 13, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], '0.5': [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 10, -100, -100, 13, -100, -100, -100, -100, -100, -100, -100, -100, 3, -100, -100, -100, -100, -100, -100, -100], '0.7': [-1, 0, 4, -2, 2, -100, 7, 6, -100, -3, -3, 10, -3, 1, 13, 4, -100, -100, -3, -1, -2, 2, 6, 3, 6, -100, -100, 6, 4, -1, -1], '0.9': [-1, 0, 4, -2, 2, -6, 7, 6, -9, -3, -3, 10, -3, 1, 13, 4, -11, -6, -3, -1, -2, 2, 6, 3, 6, -6, -7, 6, 4, -1, -1], '0.95': [-1, 0, 4, -2, 2, -6, 7, 6, -9, -3, -3, 10, -3, 1, 13, 4, -11, -6, -3, -1, -2, 2, 6, 3, 6, -6, -7, 6, 4, -1, -1], '0.99': [-1, 0, 4, -2, 2, -6, 7, 6, -9, -3, -3, 

In [174]:
print({k: len(list(filter(lambda x: x != 0, v))) / 31.0 for k, v in difference_vi_ql.items()})

{'0.1': 0.03225806451612903, '0.3': 0.03225806451612903, '0.5': 0.06451612903225806, '0.7': 0.7096774193548387, '0.9': 0.2903225806451613, '0.95': 0.3548387096774194, '0.99': 0.2903225806451613}


In [175]:
print({k: len(list(filter(lambda x: x != 0, v))) /31.0 for k, v in difference_pi_ql.items()})

{'0.1': 0.0, '0.3': 0.0967741935483871, '0.5': 0.5161290322580645, '0.7': 0.9032258064516129, '0.9': 0.2903225806451613, '0.95': 0.3548387096774194, '0.99': 0.2903225806451613}


In [102]:
len(SEEDS)

31

In [111]:
print({k: np.mean(list(filter(lambda x: x != 0, v))) for k, v in difference_vi_pi.items()})

{'0.1': 113.0, '0.3': 106.5, '0.5': 102.85714285714286, '0.7': 92.5, '0.9': nan, '0.95': nan, '0.99': nan}
