In [1]:
import os

import gym
import numpy as np
import pandas as pd
import importlib
import warnings

import hiive.mdptoolbox
import hiive.mdptoolbox.example
from matplotlib import pyplot as plt

from policy_iteration_custom import PolicyIterationCustom

import utils
import plots
from taxi_custom import TaxiCustomEnv

warnings.filterwarnings("ignore")



In [2]:
PATH_FIGURES = f'{os.getcwd()}/figures/taxi'
PATH_ARTEFACTS = f'{os.getcwd()}/artefacts/taxi'
VI_RUNS_CSV = 'vi_runs.csv'
PI_RUNS_CSV = 'pi_runs.csv'

VI_POLICIES_JSON = 'vi_policies.json'
PI_POLICIES_JSON = 'pi_policies.json'

regenerate_runs_mdp = False

In [3]:
gym.envs.register('TaxiCustom', TaxiCustomEnv)

In [4]:
plots.setup_plots()

In [5]:
ACTION_TO_LABEL = {
    0: '↓',
    1: '↑',
    2: '←',
    3: '→',
    4: 'P',
    5: 'D'
}
ACTION_TO_COLOR = {
    0: '#876161',
    1: '#d7e4f5',
    2: '#91ccff',
    3: '#f27272',
    4: '#02bf1e',
    5: '#ad031a'
}

In [6]:
MAIN_SIZE = 10

MAP_TO_ACTUAL_STATE_SIZE = {
    3: 3 * 3 * 5 * 4,
    4: 4 * 4 * 5 * 4,
    5: 5 * 5 * 5 * 4,
    6: 6 * 6 * 5 * 4,
    7: 7 * 7 * 5 * 4,
    8: 8 * 8 * 5 * 4,
    9: 9 * 9 * 5 * 4,
    10: 10 * 10 * 5 * 4,
}

importlib.reload(utils)


def problem_foo(sz: int):
    this_env = gym.make('TaxiCustom', size=sz)
    this_env.reset(seed=42)
    p, r = utils.convert_p_r_v2(this_env)
    return p, r


In [7]:
# from matplotlib import pyplot as plt
#
# env = gym.make('TaxiCustom', size=MAIN_SIZE, render_mode='rgb_array')
# env.reset()
# plt.imshow(env.render())


In [8]:
importlib.reload(utils)

if regenerate_runs_mdp:
    # print('Value iteration: starting...')
    grid_size = sorted(list(MAP_TO_ACTUAL_STATE_SIZE.keys()))
    # taxi_df_vi, policies_vi = utils.write_stats_for_problem_sizes(
    #     algo_class=hiive.mdptoolbox.mdp.ValueIteration,
    #     main_size=MAIN_SIZE,
    #     state_sizes=grid_size,
    #     problem_function=problem_foo,
    #     not_use_span_vi=True,
    # )
    # print('Value iteration: finished runs!')
    # utils.save_df_as_csv(taxi_df_vi, PATH_ARTEFACTS, VI_RUNS_CSV)
    # utils.save_policies(policies_vi, PATH_ARTEFACTS, VI_POLICIES_JSON)
    # print('Value iteration: saved!')

    print('Policy iteration: starting...')
    taxi_df_pi, policies_pi = utils.write_stats_for_problem_sizes(
        algo_class=PolicyIterationCustom,
        main_size=MAIN_SIZE,
        state_sizes=grid_size,
        problem_function=problem_foo,
        not_use_span_vi=False,
        collect_changes=True,
        max_iter=1e3,
    )
    print('Policy iteration: finished runs!')
    utils.save_df_as_csv(taxi_df_pi, PATH_ARTEFACTS, PI_RUNS_CSV)
    utils.save_policies(policies_pi, PATH_ARTEFACTS, PI_POLICIES_JSON)
    print('Policy iteration: saved!')

Policy iteration: starting...
Processing size 3
Processing size 4
Processing size 5
Processing size 6
Processing size 7
Processing size 8
Processing size 9
Processing size 10
Policy iteration: finished runs!
Policy iteration: saved!


In [9]:
taxi_df_vi = utils.read_csv(PATH_ARTEFACTS, VI_RUNS_CSV)
taxi_policies_vi = utils.read_policies(PATH_ARTEFACTS, VI_POLICIES_JSON)
taxi_df_pi = utils.read_csv(PATH_ARTEFACTS, PI_RUNS_CSV)
taxi_policies_pi = utils.read_policies(PATH_ARTEFACTS, PI_POLICIES_JSON)

In [10]:
importlib.reload(plots)
plots.create_convergence_and_state_plots(MAIN_SIZE, taxi_df_vi, folder_path=PATH_FIGURES, algo="vi",
                                         marker_size=5, actual_sizes=MAP_TO_ACTUAL_STATE_SIZE)
plots.create_convergence_and_state_plots(MAIN_SIZE, taxi_df_pi, folder_path=PATH_FIGURES, algo="pi",
                                         marker_size=5, actual_sizes=MAP_TO_ACTUAL_STATE_SIZE, plot_changes=True)

In [11]:
def plot_policy(policy, file_name: str, title: str, folder=PATH_FIGURES, color_mapping=ACTION_TO_COLOR,
                name_map=ACTION_TO_LABEL,
                num_columns=5 * 4):
    # Determine the number of rows based on the length of lst and num_columns
    num_rows = -(-len(policy) // num_columns)

    # Reshape the list into a grid
    grid = np.array(policy + [None] * (num_rows * num_columns - len(policy))).reshape(num_rows, num_columns)

    fig = plt.figure()
    ax = fig.add_subplot(111, xlim=(-.01, num_columns + 0.01), ylim=(-.01, num_rows + 0.01))

    for i in range(num_rows):
        for j in range(num_columns):
            y = i
            x = j
            p = plt.Rectangle([x, y], 1, 1, linewidth=1, edgecolor='k')
            p.set_facecolor(color_mapping.get(grid[i, j], 'b'))
            ax.add_patch(p)
            ax.text(x + 0.5, y + 0.5, name_map.get(grid[i, j], ''), ha='center', va='center', size=2, color='w')
    ax.set_xticks(np.arange(0, num_columns, 1))
    ax.set_yticks(np.arange(0, num_rows, 10))
    ax.set_xticklabels(np.arange(0, num_columns, 1))
    ax.set_yticklabels(np.arange(0, num_rows, 10))
    ax.set_title(title)

    os.makedirs(folder, exist_ok=True)
    save_path = os.path.join(folder, file_name)
    plt.savefig(save_path)
    plt.close()

In [12]:
plot_policy(list(taxi_policies_vi[str(0.99)]), file_name='vi_policy_g_099.png',
            title='Taxi, VI Policy (gamma=0.99)')

plot_policy(list(taxi_policies_vi[str(0.5)]), file_name='vi_policy_g_05.png',
            title='Taxi, VI Policy (gamma=0.50)')

plot_policy(list(taxi_policies_vi[str(0.1)]), file_name='vi_policy_g_01.png',
            title='Taxi, VI Policy (gamma=0.1)')

In [13]:
plot_policy(list(taxi_policies_pi[str(0.99)]), file_name='pi_policy_g_099.png',
            title='Taxi, PI Policy (gamma=0.99)')

plot_policy(list(taxi_policies_pi[str(0.5)]), file_name='pi_policy_g_05.png',
            title='Taxi, PI Policy (gamma=0.50)')

plot_policy(list(taxi_policies_pi[str(0.1)]), file_name='pi_policy_g_01.png',
            title='Taxi, PI Policy (gamma=0.1)')