In [1]:
import matplotlib.pyplot as plt
import ipywidgets as widgets
import time
from IPython.display import display

In [2]:
def make_plot_grid_step_function(cols, rows, V_over_time, P_over_time, lr, gamma):
  """ipywidgets interactive function supports single parameter as input.
  This function creates and return such a function by taking as input
  other parameters."""

  def plot_grid_step(iteration):
    data = V_over_time[iteration]
    data[(3,1)] = -1
    data[(3,2)] = 1
    data[(1,1)] = 0
    policy = P_over_time[iteration]
    policy[(3,1)] = -1
    policy[(3,2)] = 1
    policy[(1,1)] = 0  

    Vgrid = []
    Pgrid = []
    for row in range(rows):
      Vrow = []
      Prow = []
      for column in range(cols):
        Vrow.append(data[(column, row)])
        Prow.append(policy[(column, row)])
      Vgrid.append(Vrow)
      Pgrid.append(Prow)
    Vgrid.reverse()
    Pgrid.reverse()

    fig, (ax1, ax2) = plt.subplots(1,2, figsize=(15,15))
    im = ax1.imshow(Vgrid, cmap=plt.cm.Spectral, interpolation='nearest')
    ax1.axis('off')
    ax1.axes.get_xaxis().set_visible(False)
    ax1.axes.get_yaxis().set_visible(False)
    for col in range(len(Vgrid)):
      for row in range(len(Vgrid[0])):
        value = Vgrid[col][row]
        ax1.axes.text(row, col, "{0:.3f}".format(value), fontsize=14, va='center', ha='center')

    im = ax2.imshow(Vgrid, cmap=plt.cm.Spectral, interpolation='nearest')
    ax2.axis('off')
    ax2.axes.get_xaxis().set_visible(False)
    ax2.axes.get_yaxis().set_visible(False)
    for col in range(len(Pgrid)):
      for row in range(len(Pgrid[0])):
        action = Pgrid[col][row]
        ax2.axes.text(row, col, action, fontsize=14, va='center', ha='center')                

    plt.show()
  return plot_grid_step


def make_visualize(slider):
  """Takes an input a slider and returns callback function
  for timer and animation."""
  def visualize_callback(visualize, time_step):
    if visualize is True:
      for i in range(slider.min, slider.max + 1):
        slider.value = i
        time.sleep(float(time_step))
  return visualize_callback

In [3]:
def Qvalue(s, a, V, lr, g):
  """
  Computes the Q-value Q(s,a) given by \sum_{s'} T(s,a,s') [R(s,a,s')+g*V(s')]
  """
  wall = (1,1)
  i, j = s

  # Stay put if moving into walls
  if i-1 < 0 or (i-1,j) == wall:
    left = s
  else: left = (i-1,j)
  if i+1 > 3 or (i+1,j) == wall:
    right = s
  else: right = (i+1,j)
  if j+1 > 2 or (i,j+1) == wall:
    up = s
  else: up = (i,j+1)
  if j-1 < 0 or (i,j-1) == wall:
    down = s
  else: down = (i,j-1)

  # Living and terminal state rewards
  if s == (2,2):
    r = 1
    u = lr
  elif s == (2,1): 
    r = -1
    u = lr
  elif s == (3,0):
    r = lr
    u = -1
  else:
    r = lr
    u = lr

  # Q-value computation
  if a == '<':
    return 0.8*(lr+g*V[left]) + 0.1*(u+g*V[up]) + 0.1*(lr+g*V[down])
  if a == '>':
    return 0.8*(r+g*V[right]) + 0.1*(u+g*V[up]) + 0.1*(lr+g*V[down])
  if a == '^':
    return 0.8*(u+g*V[up]) + 0.1*(lr+g*V[left]) + 0.1*(r+g*V[right])
  if a == 'v':
    return 0.8*(lr+g*V[down]) + 0.1*(lr+g*V[left]) + 0.1*(r+g*V[right]) 

In [4]:
def value_iteration(lr, gamma, iter):
  """
  Value iteration for a fixed number of sweeps (no threshold)
  Also returns associated policy for each value function V_i
  """
  V = {s:0 for s in states}
  V[(3,1)] = 0
  V[(3,2)] = 0
  V_over_time = [V]
  P_over_time = [extract_policy(V, lr, gamma)]

  for i in range(iter):
    V_new = {(3,1):0, (3,2):0}
    for s in states:
      V_new[s] = -float("inf")
      for a in actions:
        V_new[s] = max(V_new[s], Qvalue(s,a,V,lr,gamma))
    V = V_new
    V_over_time.append(V)
    P_over_time.append(extract_policy(V, lr, gamma))

  return V_over_time, P_over_time


def policy_iteration(lr, gamma, iter):
  """
  Policy iteration using iterative evaluation and greedy improvement
  Evaluation done using a fixed number of sweeps (no threshold)
  Also returns associated value function for each policy P_i
  """
  P = {s:'<' for s in states}
  V = {s:0 for s in states}
  V[(3,1)] = 0
  V[(3,2)] = 0
  P_over_time = [P]
  V_over_time = []

  while True:
    V = eval_policy(P, V, lr, gamma, iter)
    V_over_time.append(V)
    P_new = extract_policy(V, lr, gamma)
    if P_new == P: break
    P_over_time.append(P_new)
    P = P_new
  
  return P_over_time, V_over_time


def eval_policy(P, V, lr, gamma, iter):
  V_new = V.copy()
  for i in range(iter):
    for s in states:
      V_new[s] = Qvalue(s,P[s],V_new,lr,gamma)
  return V_new 

def extract_policy(V, lr, gamma):
  policy = V.copy()
  for s in states:
    Qmax = -float("inf")
    for a in actions:
      Q = Qvalue(s, a, V, lr, gamma)
      if Q > Qmax:
        Qmax = Q
        policy[s] = a
  return policy

In [5]:
states = [(0,0),(0,1),(0,2),(1,0),(1,2),(2,0),(2,1),(2,2),(3,0)]
actions = ['<', '>', '^', 'v']

# Living reward, discount factor, number of sweeps
lr = -.04
gamma = 1
iter = 20

In [6]:
V_over_time, P_over_time = value_iteration(lr, gamma, iter)

print("VALUE ITERATION")
print("Living reward", lr, ", Discount factor", gamma)
plot_grid_step = make_plot_grid_step_function(4, 3, V_over_time, P_over_time, lr, gamma)
iteration_slider = widgets.IntSlider(min=0, max=iter, step=1, value=0)
w=widgets.interactive(plot_grid_step,iteration=iteration_slider)
display(w)

visualize_callback = make_visualize(iteration_slider)
visualize_button = widgets.ToggleButton(description = "Visualize", value = False)
time_select = widgets.ToggleButtons(description='Extra Delay:',options=['0', '0.1', '0.2', '0.5', '0.7', '1.0'])
a = widgets.interactive(visualize_callback, visualize=visualize_button, time_step=time_select)
display(a)

VALUE ITERATION
Living reward -0.04 , Discount factor 1


interactive(children=(IntSlider(value=0, description='iteration', max=20), Output()), _dom_classes=('widget-in…

interactive(children=(ToggleButton(value=False, description='Visualize'), ToggleButtons(description='Extra Del…

In [None]:
P_over_time, V_over_time = policy_iteration(lr, gamma, iter)
print("POLICY ITERATION")
print("Living reward", lr, ", Discount factor", gamma)

plot_grid_step = make_plot_grid_step_function(4, 3, V_over_time, P_over_time, lr, gamma)
iteration_slider = widgets.IntSlider(min=0, max=len(P_over_time)-1, step=1, value=0)
w=widgets.interactive(plot_grid_step,iteration=iteration_slider)
display(w)

visualize_callback = make_visualize(iteration_slider)
visualize_button = widgets.ToggleButton(description = "Visualize", value = False)
time_select = widgets.ToggleButtons(description='Extra Delay:',options=['0', '0.1', '0.2', '0.5', '0.7', '1.0'])
a = widgets.interactive(visualize_callback, visualize=visualize_button, time_step=time_select)
display(a)

POLICY ITERATION
Living reward -0.04 , Discount factor 1


interactive(children=(IntSlider(value=0, description='iteration', max=6), Output()), _dom_classes=('widget-int…

interactive(children=(ToggleButton(value=False, description='Visualize'), ToggleButtons(description='Extra Del…