In [1]:
transition_probs = {
  's0':{
    'a0': {'s0': 0.5, 's2': 0.5},
    'a1': {'s2': 1}
  },
  's1':{
    'a0': {'s0': 0.7, 's1': 0.1, 's2': 0.2},
    'a1': {'s1': 0.95, 's2': 0.05}
  },
  's2':{
    'a0': {'s0': 0.4, 's1': 0.6},
    'a1': {'s0': 0.3, 's1': 0.3, 's2':0.4}
  }
}
rewards = {
  's1': {'a0': {'s0': +5}},
  's2': {'a1': {'s0': -1}}
}

from mdp import MDP
mdp = MDP(transition_probs, rewards, initial_state='s0')

<img src='https://upload.wikimedia.org/wikipedia/commons/thumb/a/ad/Markov_Decision_Process.svg/800px-Markov_Decision_Process.svg.png' width=300px>

In [2]:
print('inital state =', mdp.reset())
next_state,reward,done,info = mdp.step('a1')
print('next_state = %s, reward=%s,done=%s'%(next_state,reward,done))

inital state = s0
next_state = s2, reward=0.0,done=False


In [3]:
print('mdp.get_all_states =',mdp.get_all_states())
print('mdp.get_possible_actions("s1") =',mdp.get_possible_actions('s1'))
print('mdp.get_next_states("s1","a0") =',mdp.get_next_states('s1','a0'))
print('mdp.get_transition_prob("s1","a0","s0") =', mdp.get_transition_prob('s1','a0','s0'))

mdp.get_all_states = ('s0', 's1', 's2')
mdp.get_possible_actions("s1") = ('a0', 'a1')
mdp.get_next_states("s1","a0") = {'s0': 0.7, 's1': 0.1, 's2': 0.2}
mdp.get_transition_prob("s1","a0","s0") = 0.7


In [4]:
def get_action_value(mdp,state_values,state,action,gamma):
    
    Q = 0
    next_states = mdp.get_next_states(state,action)
   
    for next_state in next_states:
        probability = next_states[next_state]
        reward = mdp.get_reward(state,action,next_state)
        Q += probability*(reward + gamma*state_values[next_state])
   
    return Q

In [5]:
def get_new_state_value(mdp, state_values, state, gamma):
    if mdp.is_terminal(state): return 0
    
    actions = mdp.get_possible_actions(state)
    
    state_value = 0
    for action in actions:
        action_value = get_action_value(mdp, state_values, state, action, gamma)
        if action_value > state_value:
            state_value = action_value
    
    return  state_value

In [None]:
gamma = 0.9
num_iter = 100
min_difference = 0.001

state_values = {s : 0 for s in mdp.get_all_states()}

new_state_values = {}
for i in range(num_iter):
    for s in state_values:
        new_state_values[s] = get_new_state_value(mdp,state_values,s,gamma)
    assert isinstance(new_state_values,dict)
    diffs = [abs(new_state_values[s] - state_values[s]) for s in mdp.get_all_states()]
    diff = max(abs(new_state_values[s]-state_values[s]) for s in mdp.get_all_states())
    print('iter %4i | diff:%6.5f | '%(i,diff),end="")
    print('   '.join("V(%s) = %.3f"%(s, v) for s,v in state_values.items()), end='\n\n')
    state_values = dict(new_state_values)
    if diff < min_difference:
        print('Terminated'); break

In [13]:
def get_optimal_action(mdp, state_values, state, gamma=0.9):
   
    if mdp.is_terminal(state): return None
     
    actions = mdp.get_possible_actions(state)
    
    optimal_action = None
    optimal_action_value = - float("inf")
    for action in actions:
        action_value = get_action_value(mdp, state_values, state, action, gamma)
        if action_value >= optimal_action_value:
            optimal_action_value = action_value
            optimal_action = action
    
    return optimal_action

In [17]:
optimal_action = get_optimal_action(mdp, state_values, 's1', gamma=0.9)