In [39]:
from pathlib import Path
import pickle
import argparse
from omegaconf import OmegaConf, DictConfig

import numpy as np
import xarray as xr

import plotly.express as px
import plotly.graph_objects as go

from components import reward_tensor, prob_tensor, value_iteration, get_optimal_policy

In [40]:
problem_config_path = Path("problem_config.yaml")
with open(problem_config_path) as f:
    config: DictConfig = OmegaConf.load(f)
tol = 1e-5

In [41]:
with open("reward.pkl", "br") as f:
    reward_t: xr.DataArray = pickle.load(f)

In [82]:
reward_t.loc[20, 20, 1].max()

In [42]:
with open("prob.pkl", "br") as f:
    prob_t: xr.DataArray = pickle.load(f)

In [49]:
# initial value function
v0 = np.zeros([config.max_cars_location + 1, config.max_cars_location + 1])

In [51]:
v = value_iteration(
    reward_t.values, prob_t.values, config.gamma, v0,
    tol, max_steps=500
)

Output()

In [52]:
policy = get_optimal_policy(reward_t.values, prob_t.values, config.gamma, v)

In [53]:
# make and save figures
img_folder = Path("assets")
img_folder.mkdir(exist_ok=True)

In [54]:
value_fig = go.Figure(
    data=go.Contour(
        z=v,
        x=np.arange(0, config.max_cars_location + 1),
        y=np.arange(0, config.max_cars_location + 1),
        colorbar=dict(
            title=dict(text='Value function', side='right'),
        ),
    ),
    layout={
        "xaxis_title": "num_cars0",
        "yaxis_title": "num_cars1",
    }
)
value_fig.write_html(img_folder / "value.html")
value_fig.show()

In [57]:
px.imshow(
    policy,
    #text_auto=True,
    labels=dict(x="num_cars0", y="num_cars1", color="Cars moved"),
    x=np.arange(0, config.max_cars_location + 1),
    y=np.arange(0, config.max_cars_location + 1),
    origin="lower"
)

In [55]:
policy_fig = go.Figure(
    data=go.Contour(
        z=policy,
        contours=dict(
            start=-config.max_cars_moved,
            end=config.max_cars_moved,
            size=1,
        ),
        x=np.arange(0, config.max_cars_location + 1),
        y=np.arange(0, config.max_cars_location + 1),
        colorbar=dict(
            title=dict(text='Cars moved', side='right'),
        ),
    ),
    layout={
        "xaxis_title": "num_cars0",
        "yaxis_title": "num_cars1",
    }
)
policy_fig.write_html(img_folder / "policy.html")
policy_fig.show()