In [256]:
import numpy as np
import random

In [257]:
class RecylingRobot():
    def __init__(self,alpha,beta, r_search, r_wait):
        self.alpha = alpha
        self.beta = beta
        self.r_search = r_search
        self.r_wait = r_wait
        
        self.statesplus = ['high','low']    
        self.states =  ['high','low']    
        
        # initial state
        self.state = 'high'
                
        self.actions = {'search':self.search,'wait':self.wait,'recharge':self.recharge}
        self.action_labels = ['search','wait','recharge']
        
        # Policy 
        self.pi = None

        
        # Setting the Dynamics
        self.pr = {}
        self.pr[('high','search','high',self.r_search)]= self.alpha
        self.pr[('high','search','low',self.r_search)]= 1-self.alpha
        self.pr[('high','wait','high',self.r_wait)]= 1
        self.pr[('low','search','low',self.r_search)]= self.beta
        self.pr[('low','search','high',-3)]= 1-self.beta
        self.pr[('low','wait','low',self.r_wait)]= 1
        self.pr[('low','recharge','high',0)]= 1
        
        
    def search(self,state):
        if state=='high':
            p = random.random()
            if p<self.alpha:
                next_state, reward = 'high',self.r_search
            else:
                next_state, reward = 'low',self.r_search
        
        if state=='low':
            p = random.random()
            if p<self.beta:
                next_state, reward = 'low',self.r_search
            else:
                next_state, reward = 'high',-3
        
        return next_state, reward
    
    def wait(self,state):
        next_state = state
        reward = self.r_wait
        
        return next_state, reward
    
    def recharge(self,state):
        assert state=='low',ValueError("Invalid Action")
        next_state , reward = 'high' , 0
        return next_state, reward
            
    def generate_random_experiment(self,l):
        assert self.pi , AttributeError('GridWorld has no policy yet, use set_policy method first')
        state_action_list =f'{self.state }'
        
        for i in range(l):
            action = random.choices(list(self.pi[self.state].keys()),
                                    weights=list(self.pi[self.state].values()))[0] 
            new_state , new_reward = self.actions[action](self.state)
            state_action_list += f'-> {action} -> {new_reward} -> {new_state} '
            self.state = new_state
            
            
        return state_action_list
            
            
    
                
                
    
    def set_policy(self, gamma = 1, pi= None):
        self.pi = pi
        self.gamma = gamma
        
        # initialization
        self.v = {state : 0 for state in self.statesplus}
        
    
    def select_action(self,state):
        pass
    
    def evaluate_policy(self,theta):
        '''
        This function gets a delta (tolerance) and runs policy evaluation algorithm
        until it converges. 
        
        Arguments:
        ----------
        delta (int) : the tolernce for convergence. if the difference between old value
        and new value for all states is less than delta, then the algorithm stops.
        
        Returns:
        --------
        None.
        '''
         
        converged = False
        while not converged:
            diff = []
            # Calculate new value for each state
            for state in self.states:         
                new_v = 0
                for action in self.action_labels:
                    for pair,prob in self.pr.items():
                        old_s,a,new_s,reward =pair
                        if action== a and state == old_s:
                            new_v += self.pi[state][action] * prob *(reward+ self.gamma*self.v[new_s])
                
                old_v = self.v[state]
                self.v[state] = new_v
                diff.append(abs(new_v-old_v))                
            
            if max(diff) <= theta:
                converged = True
                
                
    def improve_policy(self):
        
        # create new policy
        new_policy = {state : {action:0 for action in self.action_labels}  for state in self.states}
        
        # For each state we have to update our policy for that state
        for state in self.states:
            old_v = self.v[state]
            old_policy = self.pi[state]
            
            q = {action:0 for action in self.action_labels}
            for action in self.action_labels:
                for pair,prob in self.pr.items():
                    old_s,a,new_s,reward =pair
                    if a==action and old_s == state:
                        q[action] += prob * (reward + self.gamma * self.v[new_s])
                    
            
            greedy_actions = [action for action,value in q.items() if value == max(q.values())]
            
            for greedy_action in greedy_actions:
                new_policy[state][greedy_action] = 1/len(greedy_actions)
                
        converged = True if self.pi == new_policy else False
        
        self.pi = new_policy
        return converged
                    
    
    def policy_iteration(self, isshow = False , theta = 0):
        converged = False
        v_history = []
        policy_history = []
        
        while not converged:
            if isshow==True:
                self.show('values')
                self.show('policy')
                
            self.evaluate_policy(theta=theta)
            converged = self.improve_policy()
    
    
    def value_iteration(self,theta,isshow=False):
        
        while True:
            if isshow==True:
                self.show('values')
                self.show('policy')
                
            delta = 0
            for state in self.states:
                old_v = self.v[state]
                
                actions_v = []
                for action in self.action_labels:
                    total = 0 
                    for pair,prob in self.pr.items():
                        old_s,a,new_s,reward =pair
                        if a==action and old_s == state:
                            total += prob * (reward + self.gamma * self.v[new_s])
                    actions_v.append(total)

                new_v = max(actions_v)
                self.v[state] = new_v              
                delta = max(abs(new_v-old_v),delta)
                best_action = [a for v,a in  zip(actions_v,self.action_labels) if v == max(actions_v)]
                self.pi[state] = {a:1/len(best_action) if a in best_action else 0 for a in self.action_labels}  
            
            
            
                
            if delta <= theta:
                break
    
    def show(self, mode = 'states'):
        
                   
        if mode == 'values':
            center = [f'{v:^20.4f}' for s,v in self.v.items()]
            
        elif mode == 'policy':        
            center = []
            for state,pi in self.pi.items():
                center.append(f"{'|'.join([act.upper() for act,pr in pi.items() if pr>0]):^20}")
            
        else:
            raise ValueError('mode should be either states|cs|values')
            
        print('┌────────────────────┬────────────────────┐')
        print('│        High        │         Low        │')
        print('├────────────────────┼────────────────────┤')
        print(f'│{center[0]}│{center[1]}│')
        print('└────────────────────┴────────────────────┘')
        
  

In [258]:
gw = RecylingRobot(0.25,0.25,5,1)
gw.set_policy(gamma =0.8, pi = {'high':{'search':0.5,'wait':0.5,'recharge':0},
                           'low':{'search':0.5,'wait':0.25,'recharge':0.25}})

In [259]:
gw.policy_iteration(isshow=True,theta = 0)

┌────────────────────┬────────────────────┐
│        High        │         Low        │
├────────────────────┼────────────────────┤
│       0.0000       │       0.0000       │
└────────────────────┴────────────────────┘
┌────────────────────┬────────────────────┐
│        High        │         Low        │
├────────────────────┼────────────────────┤
│    SEARCH|WAIT     │SEARCH|WAIT|RECHARGE│
└────────────────────┴────────────────────┘
┌────────────────────┬────────────────────┐
│        High        │         Low        │
├────────────────────┼────────────────────┤
│      10.1250       │       6.8750       │
└────────────────────┴────────────────────┘
┌────────────────────┬────────────────────┐
│        High        │         Low        │
├────────────────────┼────────────────────┤
│       SEARCH       │      RECHARGE      │
└────────────────────┴────────────────────┘


In [260]:
gw = RecylingRobot(0.25,0.25,5,1)
gw.set_policy(gamma =0.8, pi = {'high':{'search':0.5,'wait':0.5,'recharge':0},
                           'low':{'search':0.5,'wait':0.25,'recharge':0.25}})

In [261]:
gw.value_iteration(theta = 0.001,isshow=True)

┌────────────────────┬────────────────────┐
│        High        │         Low        │
├────────────────────┼────────────────────┤
│       0.0000       │       0.0000       │
└────────────────────┴────────────────────┘
┌────────────────────┬────────────────────┐
│        High        │         Low        │
├────────────────────┼────────────────────┤
│    SEARCH|WAIT     │SEARCH|WAIT|RECHARGE│
└────────────────────┴────────────────────┘
┌────────────────────┬────────────────────┐
│        High        │         Low        │
├────────────────────┼────────────────────┤
│       5.0000       │       4.0000       │
└────────────────────┴────────────────────┘
┌────────────────────┬────────────────────┐
│        High        │         Low        │
├────────────────────┼────────────────────┤
│       SEARCH       │      RECHARGE      │
└────────────────────┴────────────────────┘
┌────────────────────┬────────────────────┐
│        High        │         Low        │
├────────────────────┼──────────