In [1]:
import numpy as np
import random

In [32]:
class GridWorld():
    def __init__(self,rows,cols, default_reward=-1, offgrid_reward=-10, win_states = {} , lose_states = {} , magic_states = {} ):
        self.cols = cols
        self.rows = rows
        self.default_reward = default_reward
        self.offgrid_reward = offgrid_reward
        self.magic_states = magic_states
        self.win_states = win_states
        self.lose_states = lose_states
        self.statesplus = [i for i in range(rows*cols)]
        self.states = list(set(self.statesplus)-(set(win_states.keys())|set(lose_states.keys())))
        self.state = 0
        self.actions = {'U':self.up,'D':self.down,'L':self.left,'R':self.right}
        self.action_labels = ['U','D','L','R']
        
        # Policy 
        self.pi = None

        
        # Setting the Dynamics
        self.pr = {}
        for s in self.states:
            for a in self.action_labels:
                new_s = self.actions[a](s)
                reward = self.default_reward
                
                
                if s in self.magic_states :
                    new_s,reward = self.magic_states[s]
                
                # reward for off-grid actions
                if new_s == s :
                    reward = self.offgrid_reward
                    
                
                if new_s in self.win_states:
                    reward = self.win_states[new_s]
                    
                if new_s in self.lose_states:
                    reward = self.lose_states[new_s]
                
                self.pr[(s,a,new_s,reward)]= 1
                
                
    
    def left(self, state):
        assert state in self.states ,ValueError(f'There is no state {state}')
        if state%self.cols !=0:
            new_state = state - 1
        else:
            new_state = state # off-grid move
        return new_state
        
    
    def right(self,state):
        assert state in self.states ,ValueError(f'There is no state {state}')
        if state%self.cols != (self.cols-1):
            new_state = state +1
        else:
            new_state = state # off-grid move
        return new_state
    
    def up(self,state):
        assert state in self.states ,ValueError(f'There is no state {state}')
        if state >= self.cols:
            new_state  = state -  self.cols
        else:
            new_state = state # off-grid move
        return new_state
    
    def down(self,state):
        assert state in self.states ,ValueError(f'There is no state {state}')
        if state < self.cols*(self.rows-1):
            new_state  = state + self.cols
        else:
            new_state = state # off-grid move
        return new_state
            
            
    def generate_random_experiment(self,l,show = 'list+grid'):
        assert self.pi , AttributeError('GridWorld has no policy yet, use set_policy method first')
        state_action_list =f'{self.state }'
        if 'grid' in show:
                self.show_grid(mode= 'cs')
        for i in range(l):
            action = random.choices(list(self.pi[self.state].keys()),weights=list(self.pi[self.state].values()))[0] 
            self.last_action = action
            new_state = self.actions[action](self.state)
            state_action_list += f'-> {action} -> {new_state} '
            self.state = new_state
            if 'grid' in show:
                self.show_grid(mode= 'cs')
        if 'list' in show:
            print(state_action_list)
            
            
    
                
                
    
    def set_policy(self, gamma = 1, pi= None):
        if pi:
            self.pi = {state : pi for state in self.states}
        else:
            self.pi = {state : {a:1/len(self.actions) for a in self.action_labels} for state in self.states}
        self.gamma = gamma
        
        # initialization
        self.v = [0 for state in self.statesplus]
        
    
    def select_action(self,state):
        pass
    
    def evaluate_policy(self,theta):
        '''
        This function gets a delta (tolerance) and runs policy evaluation algorithm
        until it converges. 
        
        Arguments:
        ----------
        delta (int) : the tolernce for convergence. if the difference between old value
        and new value for all states is less than delta, then the algorithm stops.
        
        Returns:
        --------
        None.
        '''
         
        converged = False
        while not converged:
            diff = []
            # Calculate new value for each state
            for state in self.states:         
                new_v = 0
                for action in self.action_labels:
                    for pair,prob in self.pr.items():
                        old_s,a,new_s,reward =pair
                        if action== a and state == old_s:
                            new_v += self.pi[state][action] * prob *(reward+ self.gamma*self.v[new_s])
                
                old_v = self.v[state]
                self.v[state] = new_v
                
                diff.append(abs(new_v-old_v))                
            
            if max(diff) <= theta:
                converged = True
                
                
    def improve_policy(self):
        
        # create new policy
        new_policy = {state : {action:0 for action in self.action_labels}  for state in self.states}
        
        # For each state we have to update our policy for that state
        for state in self.states:
            old_v = self.v[state]
            old_policy = self.pi[state]
            
            q = {action:0 for action in self.action_labels}
            for action in self.action_labels:
                for pair,prob in self.pr.items():
                    old_s,a,new_s,reward =pair
                    if a==action and old_s == state:
                        q[action] += prob * (reward + self.gamma * self.v[new_s])
                    
            
            greedy_actions = [action for action,value in q.items() if value == max(q.values())]
            
            for greedy_action in greedy_actions:
                new_policy[state][greedy_action] = 1/len(greedy_actions)
                
        converged = True if self.pi == new_policy else False
        
        self.pi = new_policy
        return converged
                    
    
    def policy_iteration(self, isshow = False , theta = 0):
        converged = False
        v_history = []
        policy_history = []
        
        while not converged:
            if isshow==True:
                self.show_grid('values')
                self.show_grid('policy')
                
            self.evaluate_policy(theta=theta)
            converged = self.improve_policy()
    
    
    def value_iteration(self,theta,isshow=False):
        
        while True:
            if isshow==True:
                self.show_grid('values')
                self.show_grid('policy')
                
            delta = 0
            for state in self.states:
                old_v = self.v[state]
                
                actions_v = []
                for action in self.action_labels:
                    total = 0 
                    for pair,prob in self.pr.items():
                        old_s,a,new_s,reward =pair
                        if a==action and old_s == state:
                            total += prob * (reward + self.gamma * self.v[new_s])
                    actions_v.append(total)

                new_v = max(actions_v)
                self.v[state] = new_v              
                delta = max(abs(new_v-old_v),delta)
                best_action = [a for v,a in  zip(actions_v,self.action_labels) if v == max(actions_v)]
                self.pi[state] = {a:1/len(best_action) if a in best_action else 0 for a in self.action_labels}  
            
            
            
                
            if delta <= theta:
                break
        
#         for state in self.states: 
        
        
    
    def show_grid(self, mode = 'states'):
        
        up = [f'{6*" "}' for v in self.v]
        down = [f'{6*" "}' for v in self.v]
        left = [f'{"   "}' for v in self.v]
        right = [f'{"   "}' for v in self.v]
        
        if mode == 'states':

            center = [f'{s:^5d}' for s in self.states]
            left = [f'{" "}' for v in self.v]
            right = [f'{" "}' for v in self.v]
            
        elif mode == 'cs':
            center = [' ' if self.state!=i else '●' for i in self.states  ]
        
        elif mode == 'values':
            center = [f'{v:^5.1f}' for v in self.v]
            left = [f'{" "}' for v in self.v]
            right = [f'{" "}' for v in self.v]
            
        elif mode == 'policy':        
            center = [f' ' for v in self.v]
            up = [f'{"↑"}' if p['U']>0 else ' ' for state,p in self.pi.items() ]
            down = [f'{"↓"}'  if p['D']>0 else ' ' for state,p in self.pi.items() ]
            right = [f'{"  →"}'  if p['R']>0 else '   ' for state,p in self.pi.items() ]
            left = [f'{"←  "}'  if p['L']>0 else '   ' for state,p in self.pi.items() ]
        
        else:
            raise ValueError('mode should be either states|cs|values')
            # ↑ ← ↓ →
            
            
        print('┌' + '┬'.join([9*'─' for i in range(self.cols)])+ '┐')
        for r in range(self.rows):
            # Print Up position
            print('│ ' + ' │ '.join([f'{up[s]:^7}' for s in range(self.cols*r,self.cols*(r+1))]) + ' │')
            
            #print Left, Center , Right position
            print('│ ' + ' │ '.join([f'{left[s]}{center[s]}{right[s]}' for s in range(self.cols*r,self.cols*(r+1))]) + ' │')
            
            # print Down position
            print('│ ' + ' │ '.join([f'{down[s]:^7}' for s in range(self.cols*r,self.cols*(r+1))]) + ' │')
            if r<self.rows-1:
                print('├' + '┼'.join([9*'─' for i in range(self.cols)])+ '┤')
            else:
                print('└' + '┴'.join([9*'─' for i in range(self.cols)])+ '┘')
        

In [33]:
gw = GridWorld(5,5,default_reward=0 , offgrid_reward=-1, magic_states={1:(21,10),3:(13,5)})
gw.set_policy(gamma =0.9)

In [34]:
gw.policy_iteration(isshow=True,theta = 0)

┌─────────┬─────────┬─────────┬─────────┬─────────┐
│         │         │         │         │         │
│   0.0   │   0.0   │   0.0   │   0.0   │   0.0   │
│         │         │         │         │         │
├─────────┼─────────┼─────────┼─────────┼─────────┤
│         │         │         │         │         │
│   0.0   │   0.0   │   0.0   │   0.0   │   0.0   │
│         │         │         │         │         │
├─────────┼─────────┼─────────┼─────────┼─────────┤
│         │         │         │         │         │
│   0.0   │   0.0   │   0.0   │   0.0   │   0.0   │
│         │         │         │         │         │
├─────────┼─────────┼─────────┼─────────┼─────────┤
│         │         │         │         │         │
│   0.0   │   0.0   │   0.0   │   0.0   │   0.0   │
│         │         │         │         │         │
├─────────┼─────────┼─────────┼─────────┼─────────┤
│         │         │         │         │         │
│   0.0   │   0.0   │   0.0   │   0.0   │   0.0   │
│         │ 

In [35]:
gw = GridWorld(5,5,default_reward=0 , offgrid_reward=-1, magic_states={1:(21,10),3:(13,5)})
gw.set_policy(gamma =0.9)

In [36]:
gw.value_iteration(theta = 0.001,isshow=True)

┌─────────┬─────────┬─────────┬─────────┬─────────┐
│         │         │         │         │         │
│   0.0   │   0.0   │   0.0   │   0.0   │   0.0   │
│         │         │         │         │         │
├─────────┼─────────┼─────────┼─────────┼─────────┤
│         │         │         │         │         │
│   0.0   │   0.0   │   0.0   │   0.0   │   0.0   │
│         │         │         │         │         │
├─────────┼─────────┼─────────┼─────────┼─────────┤
│         │         │         │         │         │
│   0.0   │   0.0   │   0.0   │   0.0   │   0.0   │
│         │         │         │         │         │
├─────────┼─────────┼─────────┼─────────┼─────────┤
│         │         │         │         │         │
│   0.0   │   0.0   │   0.0   │   0.0   │   0.0   │
│         │         │         │         │         │
├─────────┼─────────┼─────────┼─────────┼─────────┤
│         │         │         │         │         │
│   0.0   │   0.0   │   0.0   │   0.0   │   0.0   │
│         │ 