In [1]:
import os

import gym
import numpy as np
import pandas as pd
import importlib
import warnings

import hiive.mdptoolbox
import hiive.mdptoolbox.example
from frozen_lake_custom import generate_random_map, FrozenLakeCustomEnv
from policy_iteration_custom import PolicyIterationCustom

import utils
import plots

warnings.filterwarnings("ignore")



In [2]:
PATH_FIGURES = f'{os.getcwd()}/figures/frozen_lake'
PATH_ARTEFACTS = f'{os.getcwd()}/artefacts/frozen_lake'
VI_RUNS_CSV = 'vi_runs.csv'
PI_RUNS_CSV = 'pi_runs.csv'

VI_POLICIES_JSON = 'vi_policies.json'
PI_POLICIES_JSON = 'pi_policies.json'

regenerate_runs_mdp = True

In [3]:
gym.envs.register('FrozenLakeCustom', FrozenLakeCustomEnv)

In [4]:
plots.setup_plots()

In [5]:
ACTION_TO_LABEL = {
    0: '←',
    1: '↓',
    2: '→',
    3: '↑'
}

In [6]:
MAIN_SIZE = 7
STATES_SIZES = list(range(20, 21))

importlib.reload(utils)


def problem_foo(sz: int):
    np.random.seed(1220)
    env = gym.make('FrozenLakeCustom', desc=generate_random_map(size=sz, p=0.9, seed=122))
    p, r = utils.convert_p_r(env)
    return p, r


In [7]:
generate_random_map(size=8, p=0.9, seed=122)

['SFFFFFFF',
 'FFFFFFFF',
 'FFFFFFFH',
 'FFFFFFFF',
 'FFFFHFFF',
 'FFFFFHFF',
 'FHFHFFFF',
 'HFFFFFFG']

In [9]:
importlib.reload(utils)

if regenerate_runs_mdp:
    print('Value iteration: starting...')
    frozen_lake_df_vi, policies_vi = utils.write_stats_for_problem_sizes(
        algo_class=hiive.mdptoolbox.mdp.ValueIteration,
        main_size=MAIN_SIZE,
        state_sizes=STATES_SIZES,
        problem_function=problem_foo,
        not_use_span_vi=False,
    )
    print('Value iteration: finished runs!')
    utils.save_df_as_csv(frozen_lake_df_vi, PATH_ARTEFACTS, VI_RUNS_CSV)
    utils.save_policies(policies_vi, PATH_ARTEFACTS, VI_POLICIES_JSON)
    print('Value iteration: saved!')

    print('Policy iteration: starting...')
    frozen_lake_df_pi, policies_pi = utils.write_stats_for_problem_sizes(
        algo_class=PolicyIterationCustom,
        main_size=MAIN_SIZE,
        state_sizes=STATES_SIZES,
        problem_function=problem_foo,
        not_use_span_vi=False,
        collect_changes=True,
        policies0={size: [(1 if (i + j) % 2 == 0 else 2) for i in range(size) for j in range(size)] for size in
                   STATES_SIZES},
        max_iter=1e3,
    )
    print('Policy iteration: finished runs!')
    utils.save_df_as_csv(frozen_lake_df_pi, PATH_ARTEFACTS, PI_RUNS_CSV)
    utils.save_policies(policies_pi, PATH_ARTEFACTS, PI_POLICIES_JSON)
    print('Policy iteration: saved!')

Value iteration: starting...
Value iteration: finished runs!
Value iteration: saved!
Policy iteration: starting...
Policy iteration: finished runs!
Policy iteration: saved!


In [10]:
frozen_lake_df_vi = utils.read_csv(PATH_ARTEFACTS, VI_RUNS_CSV)
frozen_lake_policies_vi = utils.read_policies(PATH_ARTEFACTS, VI_POLICIES_JSON)
frozen_lake_df_pi = utils.read_csv(PATH_ARTEFACTS, PI_RUNS_CSV)
frozen_lake_policies_pi = utils.read_policies(PATH_ARTEFACTS, PI_POLICIES_JSON)

In [13]:
importlib.reload(plots)
plots.create_convergence_and_state_plots(MAIN_SIZE, frozen_lake_df_vi, folder_path=PATH_FIGURES, algo="vi",
                                         marker_size=5)
plots.create_convergence_and_state_plots(20, frozen_lake_df_pi, folder_path=PATH_FIGURES, algo="pi",
                                         marker_size=5, plot_changes=True, log_scale_y_one=True, log_scale_x_one=True)

No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
