In [1]:
import os

import numpy as np
import pandas as pd
import importlib
import warnings

import hiive.mdptoolbox
import hiive.mdptoolbox.mdp
import hiive.mdptoolbox.example
from matplotlib import pyplot as plt

import utils
import plots
from policy_iteration_custom import PolicyIterationCustom

warnings.filterwarnings("ignore")



In [2]:
PATH_FIGURES = f'{os.getcwd()}/figures/forest'
PATH_ARTEFACTS = f'{os.getcwd()}/artefacts/forest'
VI_RUNS_CSV = 'vi_runs.csv'
PI_RUNS_CSV = 'pi_runs.csv'

VI_POLICIES_JSON = 'vi_policies.json'
PI_POLICIES_JSON = 'pi_policies.json'

regenerate_runs_mdp = False

In [3]:
plots.setup_plots()

In [4]:
ACTION_TO_LABEL = {
    0: 'W',
    1: 'C',
    None: '',
}
ACTION_TO_COLOR = {
    0: 'g',
    1: 'b',
    None: 'b'
}

In [5]:
MAIN_SIZE = 100
STATES_SIZES = list(range(2, MAIN_SIZE + 1))

importlib.reload(utils)


def problem_foo(sz: int):
    return hiive.mdptoolbox.example.forest(sz)


In [6]:
if regenerate_runs_mdp:
    print('Value iteration: starting...')
    forest_df_vi, policies_vi = utils.write_stats_for_problem_sizes(
        algo_class=hiive.mdptoolbox.mdp.ValueIteration,
        main_size=MAIN_SIZE,
        state_sizes=STATES_SIZES,
        problem_function=problem_foo,
        not_use_span_vi=True,
    )
    print('Value iteration: finished runs!')
    utils.save_df_as_csv(forest_df_vi, PATH_ARTEFACTS, VI_RUNS_CSV)
    utils.save_policies(policies_vi, PATH_ARTEFACTS, VI_POLICIES_JSON)
    print('Value iteration: saved!')

    print('Policy iteration: starting...')
    forest_df_pi, policies_pi = utils.write_stats_for_problem_sizes(
        algo_class=PolicyIterationCustom,
        main_size=MAIN_SIZE,
        state_sizes=STATES_SIZES,
        problem_function=problem_foo,
        not_use_span_vi=False,
        collect_changes=True,
    )
    print('Policy iteration: finished runs!')
    utils.save_df_as_csv(forest_df_pi, PATH_ARTEFACTS, PI_RUNS_CSV)
    utils.save_policies(policies_pi, PATH_ARTEFACTS, PI_POLICIES_JSON)
    print('Policy iteration: saved!')

In [7]:
forest_df_vi = utils.read_csv(PATH_ARTEFACTS, VI_RUNS_CSV)
forest_policies_vi = utils.read_policies(PATH_ARTEFACTS, VI_POLICIES_JSON)
forest_df_pi = utils.read_csv(PATH_ARTEFACTS, PI_RUNS_CSV)
forest_policies_pi = utils.read_policies(PATH_ARTEFACTS, PI_POLICIES_JSON)

In [8]:
importlib.reload(plots)
plots.create_convergence_and_state_plots(MAIN_SIZE, forest_df_vi, folder_path=PATH_FIGURES, algo="vi", marker_size=5)
plots.create_convergence_and_state_plots(MAIN_SIZE, forest_df_pi, folder_path=PATH_FIGURES, algo="pi", marker_size=5,
                                         plot_changes=True)

In [9]:
def plot_policy(policy, file_name: str, title: str, folder=PATH_FIGURES, color_mapping=ACTION_TO_COLOR,
                name_map=ACTION_TO_LABEL,
                num_columns=10):
    # Determine the number of rows based on the length of lst and num_columns
    num_rows = -(-len(policy) // num_columns)

    # Reshape the list into a grid
    grid = np.array(policy + [None] * (num_rows * num_columns - len(policy))).reshape(num_rows, num_columns)

    fig = plt.figure()
    ax = fig.add_subplot(111, xlim=(-.01, num_columns + 0.01), ylim=(-.01, num_rows + 0.01))

    for i in range(num_rows):
        for j in range(num_columns):
            y = i
            x = j
            p = plt.Rectangle([x, y], 1, 1, linewidth=1, edgecolor='k')
            p.set_facecolor(color_mapping.get(grid[i, j], 'b'))
            ax.add_patch(p)
            ax.text(x + 0.5, y + 0.5, name_map.get(grid[i, j], ''), ha='center', va='center', size=9, color='w')
    ax.set_xticks(np.arange(0, num_columns, 1))
    ax.set_yticks(np.arange(0, num_rows, 1))
    ax.set_xticklabels(np.arange(0, num_columns, 1))
    ax.set_yticklabels(np.arange(0, num_rows * num_columns, num_columns))
    ax.set_title(title)

    os.makedirs(folder, exist_ok=True)
    save_path = os.path.join(folder, file_name)
    plt.savefig(save_path)
    plt.close()

In [10]:
plot_policy(list(forest_policies_vi[str(0.99)]), file_name='vi_policy_g_099.png',
            title='Forest Management (size=100), VI Policy (gamma=0.99)')

plot_policy(list(forest_policies_vi[str(0.5)]), file_name='vi_policy_g_05.png',
            title='Forest Management (size=100), VI Policy (gamma=0.50)')

plot_policy(list(forest_policies_vi[str(0.1)]), file_name='vi_policy_g_01.png',
            title='Forest Management (size=100), VI Policy (gamma=0.1)')

In [11]:
plot_policy(list(forest_policies_pi[str(0.99)]), file_name='pi_policy_g_099.png',
            title='Forest Management (size=100), PI Policy (gamma=0.99)')

plot_policy(list(forest_policies_pi[str(0.5)]), file_name='pi_policy_g_05.png',
            title='Forest Management (size=100), PI Policy (gamma=0.50)')

plot_policy(list(forest_policies_pi[str(0.1)]), file_name='pi_policy_g_01.png',
            title='Forest Management (size=100), PI Policy (gamma=0.1)')

In [37]:
import qlearning_utils
import qlearning_plots
import qlearning_custom

importlib.reload(qlearning_utils)
importlib.reload(qlearning_custom)
importlib.reload(qlearning_plots)


def forest_iter_callback(s, a, s_new):
    if s == 0:
        return False

    if s_new == 0 and a == 0:
        return True
    else:
        return False


p_100, r_100 = problem_foo(100)
np.random.seed(42)
df_ql_100, df_ql_100_2, policies_ql_100 = qlearning_utils.q_learning_stats_gammas(
    p=p_100,
    r=r_100,
    epsilon=0.1,
    epsilon_decay=0.9999,
    epsilon_min=0.1,
    alpha=1,
    alpha_decay=0.9999,
    alpha_min=1e-4,
    n_iter=100000,
    episode_length=10,
    td_error_threshold=5 * 1e-5,
    overall_stat_freq=1000,
    iter_callback=forest_iter_callback,
)

qlearning_plots.create_stat_plot(df=df_ql_100_2, folder_path=PATH_FIGURES, file_name='ql_td_error_100.png',
                                 log_scale_y=True, title_additional='(Forest Management (size=100)), v0')

qlearning_plots.create_stat_plot(df=df_ql_100, folder_path=PATH_FIGURES, file_name='ql_td_mean_v_100.png',
                                 y_axis='Mean V', title_additional='(Forest Management (size=100)), v0')

plot_policy(list(policies_ql_100[0.99]), file_name='ql_policy_100_g_099.png',
            title='Forest Management (size=100), QL Policy (gamma=0.99), v0')

plot_policy(list(policies_ql_100[0.50]), file_name='ql_policy_100_g_050.png',
            title='Forest Management (size=100), QL Policy (gamma=0.5), v0')

plot_policy(list(policies_ql_100[0.1]), file_name='ql_policy_100_g_01.png',
            title='Forest Management (size=100), QL Policy (gamma=0.1), v0')


Running for gamma=0.1
Running for gamma=0.3
Running for gamma=0.5
Running for gamma=0.7
Running for gamma=0.9
Running for gamma=0.95
Running for gamma=0.99


In [26]:
import qlearning_utils
import qlearning_plots
import qlearning_custom

importlib.reload(qlearning_utils)
importlib.reload(qlearning_custom)
importlib.reload(qlearning_plots)


def forest_iter_callback_good(s, a, s_new):
    if s == 0:
        return False
    if s_new == 0:
        return True
    else:
        return False


p_100, r_100 = problem_foo(100)
np.random.seed(42)
df_ql_100_good, df_ql_100_2_good, policies_ql_100_good = qlearning_utils.q_learning_stats_gammas(
    p=p_100,
    r=r_100,
    epsilon=0.5,
    epsilon_decay=0.9999,
    epsilon_min=0.5,
    alpha=1,
    alpha_decay=0.9999,
    alpha_min=1e-4,
    n_iter=100000,
    episode_length=10,
    td_error_threshold=5 * 1e-5,
    overall_stat_freq=1000,
    iter_callback=forest_iter_callback_good,
)

qlearning_plots.create_stat_plot(df=df_ql_100_2_good, folder_path=PATH_FIGURES, file_name='ql_td_error_100_good.png',
                                 log_scale_y=True, title_additional='(Forest Management (size=100))')

qlearning_plots.create_stat_plot(df=df_ql_100_good, folder_path=PATH_FIGURES, file_name='ql_td_mean_v_100_good.png',
                                 y_axis='Mean V', title_additional='(Forest Management (size=100))')

plot_policy(list(policies_ql_100_good[0.99]), file_name='ql_policy_100_g_099_good.png',
            title='Forest Management (size=100), QL Policy (gamma=0.99)')

plot_policy(list(policies_ql_100_good[0.50]), file_name='ql_policy_100_g_050_good.png',
            title='Forest Management (size=100), QL Policy (gamma=0.5)')

plot_policy(list(policies_ql_100_good[0.1]), file_name='ql_policy_100_g_01_good.png',
            title='Forest Management (size=100), QL Policy (gamma=0.1)')


Running for gamma=0.1
Running for gamma=0.3
Running for gamma=0.5
Running for gamma=0.7
Running for gamma=0.9
Running for gamma=0.95
Running for gamma=0.99


In [38]:
importlib.reload(qlearning_utils)
df_ql_100_eps, df_ql_100_2_eps, policies_ql_100_eps = qlearning_utils.q_learning_stats_epsilons(
    p=p_100,
    r=r_100,
    epsilons=[0.001, 0.05, 0.1, 0.3, 0.5, 0.8, 0.9],
    alpha=1,
    alpha_decay=0.9999,
    alpha_min=1e-4,
    n_iter=100000,
    episode_length=10,
    td_error_threshold=5 * 1e-5,
    overall_stat_freq=1000,
    iter_callback=forest_iter_callback,
    gamma=0.99,
)

importlib.reload(qlearning_plots)
qlearning_plots.create_stat_plot(df=df_ql_100_2_eps, folder_path=PATH_FIGURES, hue_col='Epsilon',
                                 file_name='ql_td_error_100_epsilon_g099.png')
qlearning_plots.create_stat_plot(df=df_ql_100_eps, folder_path=PATH_FIGURES, hue_col='Epsilon',
                                 file_name='ql_td_mean_v_100_epsilon_g099.png', y_axis='Mean V',
                                 title_additional='(Forest Management (size=100, g=0.99))')

Running for epsilon=0.001
Running for epsilon=0.05
Running for epsilon=0.1
Running for epsilon=0.3
Running for epsilon=0.5
Running for epsilon=0.8
Running for epsilon=0.9


In [33]:
importlib.reload(qlearning_utils)
states_ql_df = qlearning_utils.q_learning_stats_gammas_problem_sizes(
    problem_function=problem_foo,
    state_sizes=(list(range(5, 100, 10))),
    epsilon=0.5,
    epsilon_decay=1,
    epsilon_min=0.5,
    alpha=1,
    alpha_decay=0.9999,
    alpha_min=1e-4,
    n_iter=100000,
    episode_length=10,
    td_error_threshold=5 * 1e-5,
    overall_stat_freq=1000,
    iter_callback=forest_iter_callback_good,
    gammas=[0.99],
)

Processing size 5
Running for gamma=0.99
Processing size 15
Running for gamma=0.99
Processing size 25
Running for gamma=0.99
Processing size 35
Running for gamma=0.99
Processing size 45
Running for gamma=0.99
Processing size 55
Running for gamma=0.99
Processing size 65
Running for gamma=0.99
Processing size 75
Running for gamma=0.99
Processing size 85
Running for gamma=0.99
Processing size 95
Running for gamma=0.99


In [34]:
importlib.reload(qlearning_plots)
qlearning_plots.create_stat_plot(df=states_ql_df, folder_path=PATH_FIGURES, hue_col='Gamma', x_axis='States',
                                 y_axis='Total Time',
                                 file_name='ql_time_by_states.png')

In [35]:
## Bad alpha choice
df_ql_100_bad_alpha, df_ql_100_2_bad_alpha, _ = qlearning_utils.q_learning_stats_gammas(
    p=p_100,
    r=r_100,
    epsilon=0.5,
    epsilon_decay=0.9999,
    epsilon_min=0.5,
    alpha=1,
    alpha_decay=0.99,
    alpha_min=0.01,
    n_iter=100000,
    episode_length=10,
    td_error_threshold=5 * 1e-5,
    overall_stat_freq=1000,
    iter_callback=forest_iter_callback_good,
)

Running for gamma=0.1
Running for gamma=0.3
Running for gamma=0.5
Running for gamma=0.7
Running for gamma=0.9
Running for gamma=0.95
Running for gamma=0.99


In [36]:
importlib.reload(qlearning_plots)
qlearning_plots.create_stat_plot(df=df_ql_100_2_bad_alpha, folder_path=PATH_FIGURES,
                                 file_name='ql_td_error_100_bad_alpha.png',
                                 log_scale_y=True, title_additional='(Forest Management (size=100), bad alpha setup)')
qlearning_plots.create_stat_plot(df=df_ql_100_bad_alpha, folder_path=PATH_FIGURES,
                                 file_name='ql_td_mean_v_100_bad_alpha.png',
                                 y_axis='Mean V', title_additional='(Forest Management (size=100), bad alpha setup)')