<a href="https://colab.research.google.com/github/rickyhan24/RL_Linear_Proofs_Policy_Evaluation/blob/main/Linear_Proofs_Iterative_Policy_Evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import matplotlib.pyplot as plt
import numpy as np

In [12]:
#Defining the action space as the properties we can apply to each equation
ACTION_SPACE = [
    'Subtraction Property of Equality', 
    'Additive Inverse', 
    'Additive Identity', 
    'Division Property of Equality', 
    'Division Property of Equality 2',
    'Multiplicative Inverses', 
    'Multiplicative Identity',
    'Distributing the Denominator',
    'Subtraction Property of Equality 2',
    'Additive Inverse 2',
    'Combining Under Common Denominator'
    ]
#Defining the class of mathworld objects; a mathworld is like a grid in grid world
class Mathworld:
  def __init__(self, start_state):
    self.state = start_state
    
  def current_state(self):
    return self.state

  def all_states(self):
    # a way to get all the states
    # either a position that has possible next actions
    # or a position that yields a reward
    return set(self.actions.keys()) | set(self.rewards.keys())
  
  def is_terminal(self, s):
    return s not in self.actions

  def reset(self):
    # put agent back in start position
    self.state = 'a * x + b = d'
    return self.state
    
  def game_over(self):
    return self.state not in self.actions

  def set(self, rewards, actions):
    self.rewards = rewards
    self.actions = actions

  def set_state(self, s):
    self.state = s
    
  def get_next_state(self, s, a):
    # this answers: where would I end up if I perform action 'a' in state 's'?
    state = s

    # defining the next state after performing action a at state s; each property amounts to some form of string manipulation
    if a in self.actions[state]:
      if a == 'Subtraction Property of Equality':
        state = ' ='.join([state.split('=')[0]+'- b',state.split('=')[1]+' - b'])
      elif a == 'Additive Inverse':
        state = state.replace('b - b','0')
      elif a == 'Additive Identity':
        state = state.replace(' + 0','')
      elif a == 'Division Property of Equality':
        state = ' ='.join(['( '+state.split('=')[0]+')'+' / a', ' ('+state.split('=')[1]+' )'+' / a'])
      elif a == 'Division Property of Equality 2':
        state = ' ='.join(['( '+state.split('=')[0]+')'+' / a', state.split('=')[1]+' / a'])
      elif a == 'Multiplicative Inverses':
        state = state.replace('( a * x ) / a', '1 * x')
      elif a == 'Multiplicative Identity':
        state = state.replace('1 * x','x')
      elif a == 'Distributing the Denominator':
        state = state.replace('( a * x + b ) / a', '( a * x ) / a + b / a')
      elif a == 'Subtraction Property of Equality 2':
        state = ' ='.join([state.split('=')[0]+'- b / a',state.split('=')[1]+' - b / a'])
      elif a == 'Additive Inverse 2':
        state = state.replace('b / a - b / a','0')
      elif a == 'Combining Under Common Denominator':
        state = state.replace('d / a - b / a','( d - b ) / a')
      return state
    else:
       return None

In [13]:
#defining a grid, an instance of the Mathworld class, with a given start position, rewards, and actions
def grid():
  start_position = 'a * x + b = d'
  g = Mathworld(start_position)
  #terminal_state = 'x = ( d - b ) / a'

  #defining the dictionary of states and their possible actions; the possible actions are not exhaustive for all states
  actions = {
      'a * x + b = d':['Subtraction Property of Equality','Division Property of Equality 2'], 
      'a * x + b - b = d - b':['Additive Inverse'], 
      'a * x + 0 = d - b':['Additive Identity'], 
      'a * x = d - b':['Division Property of Equality'], 
      '( a * x ) / a = ( d - b ) / a':['Multiplicative Inverses'], 
      '1 * x = ( d - b ) / a':['Multiplicative Identity'],
      '( a * x + b ) / a = d / a': ['Distributing the Denominator'],
      '( a * x ) / a + b / a = d / a': ['Multiplicative Inverses', 'Subtraction Property of Equality 2'],
      '1 * x + b / a = d / a': ['Multiplicative Identity'],
      'x + b / a = d / a': ['Subtraction Property of Equality 2'],
      'x + b / a - b / a = d / a - b / a': ['Additive Inverse 2'],
      'x + 0 = d / a - b / a': ['Additive Identity'],
      'x = d / a - b / a': ['Combining Under Common Denominator'],
      '( a * x ) / a + b / a - b / a = d / a - b / a':['Additive Inverse 2'],
      '( a * x ) / a + 0 = d / a - b / a': ['Additive Identity'],
      '( a * x ) / a = d / a - b / a': ['Multiplicative Inverses'],
      '1 * x = d / a - b / a': ['Multiplicative Identity','Combining Under Common Denominator']
      }

  rewards = {'x = ( d - b ) / a':1} #setting the reward for the terminal state to be 1 and 0 for all other states
  for state in actions.keys():
    rewards[state]=0

  g.set(rewards,actions)
  return g

In [14]:
SMALL_ENOUGH = 1e-3 # threshold for convergence
gamma = .9 # discount factor

#print function that prints the value function V sorted by the values
def print_values(V, g):
  print("The Value Function")
  list_of_values = []
  for state in g.all_states():
    v = V.get(state, 0)
    list_of_values.append((state,v))
  list_of_values.sort(key=lambda x:x[1])
  for item in list_of_values:
    print(item[0] + ": %.2f|" % item[1], end="")
    print("")
  print('\n')

#prints the policy P
def print_policy(P, g):
  print("The Policy")
  for state in g.all_states():
    a = P.get(state, ' ')
    print(state+":  %s  |" % a, end="")
    print("")
  print('\n')

def get_transition_probs_and_rewards(grid):
  ### define transition probabilities and grid ###
  # the key is (s, a, s'), the value is the probability
  # that is, transition_probs[(s, a, s')] = p(s' | s, a)
  # any key NOT present will be considered to be impossible (i.e. probability 0)
  transition_probs = {}

  # to reduce the dimensionality of the dictionary, we'll use deterministic
  # rewards, r(s, a, s')
  # note: you could make it simpler by using r(s') since the reward doesn't
  # actually depend on (s, a)
  rewards = {}

  #creating the transition probabilities dictionary and the rewards dictionary
  for s in grid.all_states():
      if not grid.is_terminal(s):
        for a in ACTION_SPACE:
          s2 = grid.get_next_state(s, a)
          if s2 != None:
            transition_probs[(s, a, s2)] = 1
            rewards[(s, a, s2)] = grid.rewards[s2]
  return transition_probs, rewards

#a function that gives the value function for a given policy (policy evaluation)  
def evaluate_deterministic_policy(grid, policy, initV=None):
  # initialize V(s) = 0
  if initV is None:
    V = {}
    for s in grid.all_states():
      V[s] = 0
  else:
    # it's faster to use the existing V(s) since the value won't change
    # that much from one policy to the next
    V = initV
  
  # repeat until convergence
  it = 0
  while True:
    biggest_change = 0
    for s in grid.all_states():
      if not grid.is_terminal(s):
        old_v = V[s]
        new_v = 0 # we will accumulate the answer
        for a in ACTION_SPACE:
          for s2 in grid.all_states():

            # action probability is deterministic
            action_prob = 1 if policy.get(s) == a else 0
            
            # reward is a function of (s, a, s'), 0 if not specified
            r = rewards.get((s, a, s2), 0)
            new_v += action_prob * transition_probs.get((s, a, s2), 0) * (r + gamma * V[s2])

        # after done getting the new value, update the value table
        V[s] = new_v
        biggest_change = max(biggest_change, np.abs(old_v - V[s]))
    
    print("iter:", it, "biggest_change:", biggest_change)
    print_values(V, grid)
    it += 1

    if biggest_change < SMALL_ENOUGH:
      break
  print("\n\n")
  print_values(V,grid)
  


In [15]:
#creating a grid and getting transition_probs and rewards for that grid
grid = grid()
transition_probs, rewards = get_transition_probs_and_rewards(grid)
#defining a given policy
policy = {
    'a * x + b = d': 'Subtraction Property of Equality',
    'a * x + b - b = d - b': 'Additive Inverse',
    'a * x + 0 = d - b': 'Additive Identity',
    'a * x = d - b': 'Division Property of Equality',
    '( a * x ) / a = ( d - b ) / a': 'Multiplicative Inverses',
    '1 * x = ( d - b ) / a': 'Multiplicative Identity',
    '( a * x + b ) / a = d / a': 'Distributing the Denominator',
    '( a * x ) / a + b / a = d / a': 'Multiplicative Inverses',
    '1 * x + b / a = d / a': 'Multiplicative Identity',
    'x + b / a = d / a': 'Subtraction Property of Equality 2',
    'x + b / a - b / a = d / a - b / a': 'Additive Inverse 2',
    'x + 0 = d / a - b / a': 'Additive Identity',
    'x = d / a - b / a': 'Combining Under Common Denominator',
    '( a * x ) / a + b / a - b / a = d / a - b / a':'Additive Inverse 2',
    '( a * x ) / a + 0 = d / a - b / a': 'Additive Identity',
    '( a * x ) / a = d / a - b / a': 'Multiplicative Inverses',
    '1 * x = d / a - b / a': 'Multiplicative Identity'
  }
print_policy(policy, grid)
#evaluating the given policy
evaluate_deterministic_policy(grid, policy, initV=None)

The Policy
a * x + b = d:  Subtraction Property of Equality  |
( a * x ) / a + 0 = d / a - b / a:  Additive Identity  |
1 * x = d / a - b / a:  Multiplicative Identity  |
x = ( d - b ) / a:     |
x + b / a - b / a = d / a - b / a:  Additive Inverse 2  |
a * x + 0 = d - b:  Additive Identity  |
a * x + b - b = d - b:  Additive Inverse  |
1 * x + b / a = d / a:  Multiplicative Identity  |
a * x = d - b:  Division Property of Equality  |
( a * x + b ) / a = d / a:  Distributing the Denominator  |
x = d / a - b / a:  Combining Under Common Denominator  |
( a * x ) / a + b / a - b / a = d / a - b / a:  Additive Inverse 2  |
( a * x ) / a = d / a - b / a:  Multiplicative Inverses  |
x + b / a = d / a:  Subtraction Property of Equality 2  |
( a * x ) / a + b / a = d / a:  Multiplicative Inverses  |
x + 0 = d / a - b / a:  Additive Identity  |
1 * x = ( d - b ) / a:  Multiplicative Identity  |
( a * x ) / a = ( d - b ) / a:  Multiplicative Inverses  |


iter: 0 biggest_change: 1.0
The Value Fu