# [Sutton and Barto Notebooks](https://github.com/seungjaeryanlee/sutton-barto-notebooks): Figure 4.1

[ModuAI](https://www.modu.ai)  
Author: Seung Jae (Ryan) Lee  

![Figure 4.1](figure_4_1.png)

In [1]:
from enum import IntEnum

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

## Environment

![Example 4.1](example_4_1.png)

In [2]:
class Action(IntEnum):
    """
    All possible actions in a 4x4 gridworld.
    """
    UP = -4
    DOWN = 4
    RIGHT = 1
    LEFT = -1

In [3]:
class Environment:
    """
    The 4x4 gridworld shown in Example 4.1.
    """
    state_space = [i for i in range(16)]
    action_space = list(Action)

    def peek(self, state, action):
        """
        Returns the result of taking given action on the given state.
        The result consists of next state and reward.
        """
        if self.is_done(state):
            return state, 0

        if not ((state // 4 == 0 and action == Action.UP)
            or (state // 4 == 3 and action == Action.DOWN)
            or (state % 4 == 3 and action == Action.RIGHT)
            or (state % 4 == 0 and action == Action.LEFT)):
            state += action

        return state, -1

    def is_done(self, state):
        """
        Returns True if given state is the terminal state.
        """
        return state == 0 or state == 15

## Agent

![Pseudocode](policy_evaluation.png)

In [4]:
class RandomAgent:
    """
    A random-policy agent.
    """

    def __init__(self, env):
        self.env = env
        self.v = np.zeros(len(env.state_space))

    def update(self):
        """
        Update agent's state values based on random policy.
        """
        new_v = np.zeros(len(self.env.state_space))
        for state in self.env.state_space:
            for action in self.env.action_space:
                next_state, reward = self.env.peek(state, action)
                new_v[state] += 1/4 * (reward + self.v[next_state])
        self.v = new_v

## Plots

In [5]:
import tkinter as tk

In [6]:
def keymax(dict_):
    """
    Get list of keys in a dictionary with maximum value.
    """
    max_v = max(dict_.values())
    return [key for key in dict_.keys() if dict_[key] == max_v]

In [7]:
class GUI(tk.Tk):
    """
    GUI for Figure 4.1
    """

    def __init__(self, env, agent, n_steps=500):
        self.env = env
        self.agent = agent

        self.step = 0
        self.cache = []

        # Precompute state values
        for _ in range(n_steps):
            self.cache.append(agent.v)
            agent.update()

        tk.Tk.__init__(self)
        self.canvas = self._build_canvas()
        self._update_canvas()

    def _build_canvas(self):
        """
        Builds a tkinter canvas.
        """
        canvas_w, canvas_h = (480, 300) # Canvas Dimensions
        g1_x, g1_y = (60, 100)          # Center of first square in grid 1
        g2_x, g2_y = (300, 100)         # Center of first square in grid 2
        r = 20                          # Radius of grid cell

        canvas = tk.Canvas(self, width=canvas_w, height=canvas_h, bg='white')

        # Info Text
        self.text_step = canvas.create_text(canvas_w/2, 20)

        # Grid 1
        self.grid1 = []
        for j in range(4):
            for i in range(4):
                x = g1_x + 2*r*i
                y = g1_y + 2*r*j
                canvas.create_rectangle(x-r, y-r, x+r, y+r)
                self.grid1.append({'text': canvas.create_text(x, y)})

        # Grid 2
        self.grid2 = []
        for j in range(4):
            for i in range(4):
                cell = {}

                x = g2_x + 2*r*i
                y = g2_y + 2*r*j
                canvas.create_rectangle(x-r, y-r, x+r, y+r)
                cell['text'] = canvas.create_text(x, y)

                cell['arrows'] = {}
                cell['arrows'][Action.UP] = canvas.create_line(
                    x, y, x, y-r, arrow="last")
                cell['arrows'][Action.DOWN] = canvas.create_line(
                    x, y, x, y+r, arrow="last")
                cell['arrows'][Action.LEFT] = canvas.create_line(
                    x, y, x-r, y, arrow="last")
                cell['arrows'][Action.RIGHT] = canvas.create_line(
                    x, y, x+r, y, arrow="last")

                self.grid2.append(cell)

        # Back Button
        back_button = tk.Button(
            self, text="Back", command=self.back_step, width=30, height=3)
        canvas.create_window(canvas_w/4, canvas_h-25, window=back_button)

        # Step Button
        step_button = tk.Button(
            self, text="Step", command=self.run_step, width=30, height=3)
        canvas.create_window(3*canvas_w/4, canvas_h-25, window=step_button)

        canvas.pack(side="top", fill="both", expand=True)

        return canvas

    def run_step(self):
        """
        Updates canvas with state values of the next step.
        """
        if self.step >= len(self.cache) - 1:
            return

        self.step += 1
        self._update_canvas()

    def back_step(self):
        """
        Updates canvas with state values of the previous step.
        """
        if self.step <= 0:
            return

        self.step -= 1
        self._update_canvas()

    def _update_canvas(self):
        """
        Updates the canvas.
        """
        v = self.cache[self.step]

        # Update Info
        self.canvas.itemconfig(self.text_step, text='k = %d' % self.step)

        # Update V
        for state, value in enumerate(v):
            self.canvas.itemconfig(self.grid1[state]['text'],
                                   text='{:.1f}'.format(v[state]))

        # Update greedy policy
        for state, _ in enumerate(v):
            for action in self.env.action_space:
                self.canvas.itemconfig(self.grid2[state]['arrows'][action],
                                       fill='white')
        for state, value in enumerate(v):
            if state == 0 or state == 15:
                continue

            nearby_values = {}
            for action in self.env.action_space:
                if state == 1 and action == Action.UP:
                    next_state, reward = self.env.peek(state, action)
                next_state, _ = self.env.peek(state, action)

                nearby_values[action] = v[next_state]

            best_actions = keymax(nearby_values)
            for action in best_actions:
                self.canvas.itemconfig(self.grid2[state]['arrows'][action],
                                       fill='black')

In [8]:
env = Environment()
agent = RandomAgent(env)
gui = GUI(env, agent)
gui.mainloop()