# model-based

### Policy Iteration
- policy evaluation
- policy improvement


In [1]:
class Grid_MDP:
    
    def __init__(self):
        
        self.states = [1,2,3,4,5,6,7,8]
        
        self.terminal_states = dict()
        self.terminal_states[6] = 1
        self.terminal_states[7] = 1
        self.terminal_states[8] = 1
        
        self.actions = ['n', 'e', 's', 'w']
        
        self.rewards = dict()
        self.rewards['1_s'] = -1.0
        self.rewards['3_s'] = 1.0
        self.rewards['5_s'] = -1.0
        
        self.t = dict()
        self.t['1_s'] = 6
        self.t['1_e'] = 2
        self.t['2_w'] = 1
        self.t['2_e'] = 3
        self.t['3_s'] = 7
        self.t['3_w'] = 2
        self.t['3_e'] = 4
        self.t['4_w'] = 3
        self.t['4_e'] = 5
        self.t['5_s'] = 8
        self.t['5_w'] = 4
        
        self.gamma = 0.8
        
    def transform(self, state, action):
        
        if state in self.terminal_states:
            return True, state, 0
        
        key = '{}_{}'.format(state, action)
        
        if key in self.t.keys():
            next_state = self.t[key]
        else:
            next_state = state
            
        is_terminal = False
        if next_state in self.terminal_states.keys():
            is_terminal = True
        
        if key not in self.rewards.keys():
            r = 0.0
        else:
            r = self.rewards[key]
            
        return is_terminal, next_state, r

In [60]:
class policy_value:
    
    def __init__(self, grid_mdp):
        
        self.grid_mdp = grid_mdp
        
        self.v = [0.0] * (len(self.grid_mdp.states)+1)
        
        self.pi = dict()
        for state in grid_mdp.states:
            if state in grid_mdp.terminal_states:
                continue
            self.pi[state] = grid_mdp.actions[0]
            
        
    def policy_improve(self):
        
        for state in self.grid_mdp.states:
            
            if state in self.grid_mdp.terminal_states.keys(): 
                continue
                
            a1 = self.grid_mdp.actions[0]
            t, s, r = self.grid_mdp.transform(state, a1)
            v1 = r+ self.grid_mdp.gamma * self.v[s]
            
            for action in self.grid_mdp.actions:
                
                
                t,s,r = self.grid_mdp.transform(state, action)
                
                v_tmp = r + self.grid_mdp.gamma* self.v[s]
                
                if v1< v_tmp:
                    a1 = action
                    v1 = v_tmp
            
            self.pi[state] = a1
            
    def policy_evaluate(self):
        
#         for i in range(1000):
            
        delta = 0.0

        for state in self.grid_mdp.states:

            if state in self.grid_mdp.terminal_states:
                continue

            action = self.pi[state]
            t,s,r = self.grid_mdp.transform(state, action)
            new_v = r+ self.grid_mdp.gamma* self.v[s]
            delta += abs(self.v[state] - new_v)
            self.v[state] = new_v

#         if delta< 1e-6:
#             break
                
            
    def policy_iterate(self):
        
        for i in range(1000):
            self.policy_evaluate()
            self.policy_improve()
            
    
    def value_iteration(self):
        
        for i in range(100):
            
            delta = 0.0
            for state in self.grid_mdp.states:
                if state in self.grid_mdp.terminal_states: continue
                
                a1 = self.grid_mdp.actions[0]
                t,s,r = self.grid_mdp.transform(state, a1)
                v1 = r + self.grid_mdp.gamma * self.v[s]
            
                for action in self.grid_mdp.actions:
                    t,s,r = self.grid_mdp.transform(state, action)
                    v_tmp = self.grid_mdp.gamma * self.v[s] + r
                    if v1 < v_tmp:
                        a1 = action
                        v1 = v_tmp
                        
                delta += abs(v1 - self.v[state])
                
                self.pi[state] = a1
                self.v[state] = v1

In [53]:
grid_mdp_a = Grid_MDP()

In [54]:
policy_value_a = policy_value(grid_mdp_a)

In [55]:
policy_value_a.policy_iterate()

In [56]:
for i in range(1, 6):
    print i, policy_value_a.v[i]

1 0.64
2 0.8
3 1.0
4 0.8
5 0.64


In [57]:
for j in range(1,6):
    print j, policy_value_a.pi[j]

1 e
2 e
3 s
4 w
5 w


![](../result.png)

### value iteration

In [61]:
grid_mdp_b = Grid_MDP()
policy_value_b = policy_value(grid_mdp_b)

In [62]:
policy_value_b.value_iteration()

In [63]:
for i in range(1, 6):
    print i, policy_value_b.v[i]

1 0.64
2 0.8
3 1.0
4 0.8
5 0.64


In [64]:
for j in range(1,6):
    print j, policy_value_b.pi[j]

1 e
2 e
3 s
4 w
5 w


![](../result.png)