In [1]:
import os

import numpy as np
import pandas as pd
import importlib
import warnings

import hiive.mdptoolbox
import hiive.mdptoolbox.mdp
import hiive.mdptoolbox.example
from matplotlib import pyplot as plt

import utils
import plots
from policy_iteration_custom import PolicyIterationCustom

warnings.filterwarnings("ignore")



In [2]:
PATH_FIGURES = f'{os.getcwd()}/figures/forest'
PATH_ARTEFACTS = f'{os.getcwd()}/artefacts/forest'
VI_RUNS_CSV = 'vi_runs.csv'
PI_RUNS_CSV = 'pi_runs.csv'

VI_POLICIES_JSON = 'vi_policies.json'
PI_POLICIES_JSON = 'pi_policies.json'

regenerate_runs_mdp = True

In [3]:
plots.setup_plots()

In [4]:
ACTION_TO_LABEL = {
    0: 'W',
    1: 'C',
    None: '',
}
ACTION_TO_COLOR = {
    0: 'g',
    1: 'b',
    None: 'b'
}

In [5]:
MAIN_SIZE = 100
STATES_SIZES = list(range(2, MAIN_SIZE + 1))

importlib.reload(utils)


def problem_foo(sz: int):
    return hiive.mdptoolbox.example.forest(sz)


In [6]:
if regenerate_runs_mdp:
    print('Value iteration: starting...')
    forest_df_vi, policies_vi = utils.write_stats_for_problem_sizes(
        algo_class=hiive.mdptoolbox.mdp.ValueIteration,
        main_size=MAIN_SIZE,
        state_sizes=STATES_SIZES,
        problem_function=problem_foo,
        not_use_span_vi=True,
    )
    print('Value iteration: finished runs!')
    utils.save_df_as_csv(forest_df_vi, PATH_ARTEFACTS, VI_RUNS_CSV)
    utils.save_policies(policies_vi, PATH_ARTEFACTS, VI_POLICIES_JSON)
    print('Value iteration: saved!')

    print('Policy iteration: starting...')
    forest_df_pi, policies_pi = utils.write_stats_for_problem_sizes(
        algo_class=PolicyIterationCustom,
        main_size=MAIN_SIZE,
        state_sizes=STATES_SIZES,
        problem_function=problem_foo,
        not_use_span_vi=False,
        collect_changes=True,
    )
    print('Policy iteration: finished runs!')
    utils.save_df_as_csv(forest_df_pi, PATH_ARTEFACTS, PI_RUNS_CSV)
    utils.save_policies(policies_pi, PATH_ARTEFACTS, PI_POLICIES_JSON)
    print('Policy iteration: saved!')

Value iteration: starting...
Processing size 2
Processing size 3
Processing size 4
Processing size 5
Processing size 6
Processing size 7
Processing size 8
Processing size 9
Processing size 10
Processing size 11
Processing size 12
Processing size 13
Processing size 14
Processing size 15
Processing size 16
Processing size 17
Processing size 18
Processing size 19
Processing size 20
Processing size 21
Processing size 22
Processing size 23
Processing size 24
Processing size 25
Processing size 26
Processing size 27
Processing size 28
Processing size 29
Processing size 30
Processing size 31
Processing size 32
Processing size 33
Processing size 34
Processing size 35
Processing size 36
Processing size 37
Processing size 38
Processing size 39
Processing size 40
Processing size 41
Processing size 42
Processing size 43
Processing size 44
Processing size 45
Processing size 46
Processing size 47
Processing size 48
Processing size 49
Processing size 50
Processing size 51
Processing size 52
Processing

In [7]:
forest_df_vi = utils.read_csv(PATH_ARTEFACTS, VI_RUNS_CSV)
forest_policies_vi = utils.read_policies(PATH_ARTEFACTS, VI_POLICIES_JSON)
forest_df_pi = utils.read_csv(PATH_ARTEFACTS, PI_RUNS_CSV)
forest_policies_pi = utils.read_policies(PATH_ARTEFACTS, PI_POLICIES_JSON)

In [8]:
importlib.reload(plots)
plots.create_convergence_and_state_plots(MAIN_SIZE, forest_df_vi, folder_path=PATH_FIGURES, algo="vi", marker_size=5)
plots.create_convergence_and_state_plots(MAIN_SIZE, forest_df_pi, folder_path=PATH_FIGURES, algo="pi", marker_size=5,
                                         plot_changes=True)

In [9]:
def plot_policy(policy, file_name: str, title: str, folder=PATH_FIGURES, color_mapping=ACTION_TO_COLOR,
                name_map=ACTION_TO_LABEL,
                num_columns=10):
    # Determine the number of rows based on the length of lst and num_columns
    num_rows = -(-len(policy) // num_columns)

    # Reshape the list into a grid
    grid = np.array(policy + [None] * (num_rows * num_columns - len(policy))).reshape(num_rows, num_columns)

    fig = plt.figure()
    ax = fig.add_subplot(111, xlim=(-.01, num_columns + 0.01), ylim=(-.01, num_rows + 0.01))

    for i in range(num_rows):
        for j in range(num_columns):
            y = i
            x = j
            p = plt.Rectangle([x, y], 1, 1, linewidth=1, edgecolor='k')
            p.set_facecolor(color_mapping.get(grid[i, j], 'b'))
            ax.add_patch(p)
            ax.text(x + 0.5, y + 0.5, name_map.get(grid[i, j], ''), ha='center', va='center', size=9, color='w')
    ax.set_xticks(np.arange(0, num_columns, 1))
    ax.set_yticks(np.arange(0, num_rows, 1))
    ax.set_xticklabels(np.arange(0, num_columns, 1))
    ax.set_yticklabels(np.arange(0, num_rows * num_columns, num_columns))
    ax.set_title(title)

    os.makedirs(folder, exist_ok=True)
    save_path = os.path.join(folder, file_name)
    plt.savefig(save_path)
    plt.close()

In [10]:
plot_policy(list(forest_policies_vi[str(0.99)]), file_name='vi_policy_g_099.png',
            title='Forest Management, VI Policy (gamma=0.99)')

plot_policy(list(forest_policies_vi[str(0.5)]), file_name='vi_policy_g_05.png',
            title='Forest Management, VI Policy (gamma=0.50)')

plot_policy(list(forest_policies_vi[str(0.1)]), file_name='vi_policy_g_01.png',
            title='Forest Management, VI Policy (gamma=0.1)')

In [11]:
plot_policy(list(forest_policies_pi[str(0.99)]), file_name='pi_policy_g_099.png',
            title='Forest Management, PI Policy (gamma=0.99)')

plot_policy(list(forest_policies_pi[str(0.5)]), file_name='pi_policy_g_05.png',
            title='Forest Management, PI Policy (gamma=0.50)')

plot_policy(list(forest_policies_pi[str(0.1)]), file_name='pi_policy_g_01.png',
            title='Forest Management, PI Policy (gamma=0.1)')