In [9]:
import os

import pandas as pd
import importlib
import warnings

import hiive.mdptoolbox
import hiive.mdptoolbox.mdp
import hiive.mdptoolbox.example

import utils
import plots
from policy_iteration_custom import PolicyIterationCustom

warnings.filterwarnings("ignore")

In [10]:
PATH_FIGURES = f'{os.getcwd()}/figures/forest'
PATH_ARTEFACTS = f'{os.getcwd()}/artefacts/forest'
VI_RUNS_CSV = 'vi_runs.csv'
PI_RUNS_CSV = 'pi_runs.csv'

VI_POLICIES_JSON = 'vi_policies.json'
PI_POLICIES_JSON = 'pi_policies.json'

regenerate_runs_mdp = True

In [11]:
plots.setup_plots()

In [12]:
ACTION_TO_LABEL = {
    0: 'W',
    1: 'C'
}

In [13]:
MAIN_SIZE = 30
STATES_SIZES = list(range(2, 75))

importlib.reload(utils)


def problem_foo(sz: int):
    return hiive.mdptoolbox.example.forest(sz)


In [14]:
if regenerate_runs_mdp:
    print('Value iteration: starting...')
    forest_df_vi, policies_vi = utils.write_stats_for_problem_sizes(
        algo_class=hiive.mdptoolbox.mdp.ValueIteration,
        main_size=MAIN_SIZE,
        state_sizes=STATES_SIZES,
        problem_function=problem_foo,
        not_use_span_vi=True,
    )
    print('Value iteration: finished runs!')
    utils.save_df_as_csv(forest_df_vi, PATH_ARTEFACTS, VI_RUNS_CSV)
    utils.save_policies(policies_vi, PATH_ARTEFACTS, VI_POLICIES_JSON)
    print('Value iteration: saved!')

    print('Policy iteration: starting...')
    forest_df_pi, policies_pi = utils.write_stats_for_problem_sizes(
        algo_class=PolicyIterationCustom,
        main_size=MAIN_SIZE,
        state_sizes=STATES_SIZES,
        problem_function=problem_foo,
        not_use_span_vi=False,
        collect_changes=True,
    )
    print('Policy iteration: finished runs!')
    utils.save_df_as_csv(forest_df_pi, PATH_ARTEFACTS, PI_RUNS_CSV)
    utils.save_policies(policies_pi, PATH_ARTEFACTS, PI_POLICIES_JSON)
    print('Policy iteration: saved!')

Value iteration: starting...
Value iteration: finished runs!
Value iteration: saved!
Policy iteration: starting...
Policy iteration: finished runs!
Policy iteration: saved!


In [15]:
forest_df_vi = utils.read_csv(PATH_ARTEFACTS, VI_RUNS_CSV)
forest_policies_vi = utils.read_policies(PATH_ARTEFACTS, VI_POLICIES_JSON)
forest_df_pi = utils.read_csv(PATH_ARTEFACTS, PI_RUNS_CSV)
forest_policies_pi = utils.read_policies(PATH_ARTEFACTS, PI_POLICIES_JSON)

In [16]:
importlib.reload(plots)
plots.create_convergence_and_state_plots(MAIN_SIZE, forest_df_vi, folder_path=PATH_FIGURES, algo="vi", marker_size=5)
plots.create_convergence_and_state_plots(MAIN_SIZE, forest_df_pi, folder_path=PATH_FIGURES, algo="pi", marker_size=5, plot_changes=True)