In [1]:
from IPython.display import display
import ipywidgets as widgets
from mdp import *

In [2]:
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = (6, 4)
# plt.rcParams['figure.dpi'] = 200

plt.rcParams['figure.dpi'] = 120

In [3]:

iterations = range(25)
discounts = [0, 0.9, 0.99, 1]
step_penalties = [-0.01,-0.04,-0.4,-4]

In [4]:
def value_iteration_instru():
    results = {p: dict() for p in step_penalties}
    for pen in step_penalties: 
        mdp = GridMDP([[pen, pen, pen, +1],[pen, None, pen, -1],[pen, pen, pen, pen]], terminals=[(3, 2), (3, 1)])
        for gamma in discounts:
            U_over_time = []
            R, T = mdp.R, mdp.T
            U1 = {s: 0 for s in mdp.states}
            for _ in iterations:
                U = U1.copy()
                for s in mdp.states:
                    U1[s] = R(s) + gamma * max([sum([p * U[s1] for (p, s1) in T(s, a)])
                                                for a in mdp.actions(s)])
                U_over_time.append(U)
            results[pen][gamma] = U_over_time
    return results
vs = value_iteration_instru()

In [5]:


def make_visualize(slider):
    """Takes an input a sliderand returns callback function
    for timer and animation."""

    def visualize_callback(visualize, time_step):
        if visualize is True:
            for i in range(slider.min, slider.max + 1):
                slider.value = i
                time.sleep(float(time_step))

    return visualize_callback



def make_plot_grid_step_function(columns, rows, U_over_time):
    """ipywidgets interactive function supports single parameter as input.
    This function creates and return such a function by taking as input
    other parameters."""

    def plot_grid_step(iteration, discount, penalty):
        data = U_over_time[penalty][discount][iteration]
        data = defaultdict(lambda: 0, data)
        grid = []
        for row in range(rows):
            current_row = []
            for column in range(columns):
                current_row.append(data[(column, row)])
            grid.append(current_row)
        grid.reverse()  # output like book
        fig = plt.imshow(grid, cmap='summer',#plt.cm.bwr, #vmin=-20, vmax=1,
                         interpolation='nearest')

        plt.axis('off')
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)

        for col in range(len(grid)):
            for row in range(len(grid[0])):
                magic = grid[col][row]
                fig.axes.text(row, col, "{0:.2f}".format(magic), va='center', ha='center', fontsize=20)

        
        fig.figure.set_size_inches(8, 6)
        plt.show()
        plt.rcParams['figure.dpi'] = 120

    return plot_grid_step

plot_grid_step = make_plot_grid_step_function(4, 3, vs)


iteration_slider = widgets.IntSlider(min=1, max=20, step=1, value=0)
discount_chooser = widgets.RadioButtons(
    options=discounts,
    value=discounts[-2],
    description='step penalty',
    disabled=False
)

# widgets.FloatSlider(min=0, max=1, step=0.1, value=0.9)
penalty_chooser = widgets.RadioButtons(
    options=step_penalties,
    value=step_penalties[0],
    description='step penalty',
    disabled=False
)


w=widgets.interactive(plot_grid_step,iteration=iteration_slider, discount=discount_chooser, penalty=penalty_chooser)




display(w)

interactive(children=(IntSlider(value=1, description='iteration', max=20, min=1), RadioButtons(description='st…