# Classic Taxi Problem (MDP)

In [1]:
import numpy as np

In [2]:
class TaxiEnv:
    def __init__(self, states, actions, probabilities, rewards):
        self.possible_states = states
        self._possible_actions = {st: ac for st, ac in zip(states, actions)}
        self._ride_probabilities = {st: pr for st, pr in zip(states, probabilities)}
        self._ride_rewards = {st: rw for st, rw in zip(states, rewards)}
        self._verify()

    def _check_state(self, state):
        assert state in self.possible_states, "State %s is not a valid state" % state

    def _verify(self):
        ns = len(self.possible_states)
        for state in self.possible_states:
            ac = self._possible_actions[state]
            na = len(ac)

            rp = self._ride_probabilities[state]
            assert np.all(rp.shape == (na, ns)), "invalid Probabilities shape"
        
            rr = self._ride_rewards[state]
            assert np.all(rr.shape == (na, ns)), "invalid Rewards shape"

            assert np.allclose(rp.sum(axis=1), 1), "Probabilities doesn't add up to 1"

    def possible_actions(self, state):
        self._check_state(state)
        return self._possible_actions[state]

    def ride_probabilities(self, state, action):
        actions = self.possible_actions(state)
        ac_idx = actions.index(action)
        return self._ride_probabilities[state][ac_idx]

    def ride_rewards(self, state, action):
        actions = self.possible_actions(state)
        ac_idx = actions.index(action)
        return self._ride_rewards[state][ac_idx]

In [3]:
def make_taxienv():
    states = ['A', 'B', 'C']
    actions = [['1','2','3'], ['1','2'], ['1','2','3']]
    probs = [np.array([[1/2,  1/4,  1/4],
                    [1/16, 3/4,  3/16],
                    [1/4,  1/8,  5/8]]),

            np.array([[1/2,   0,     1/2],
                    [1/16,  7/8,  1/16]]),

            np.array([[1/4,  1/4,  1/2],
                    [1/8,  3/4,  1/8],
                    [3/4,  1/16, 3/16]]),]
    rewards = [np.array([[10,  4,  8],
                        [ 8,  2,  4],
                        [ 4,  6,  4]]),   
            np.array([[14,  0, 18],
                        [ 8, 16,  8]]),    
            np.array([[10,  2,  8],
                        [6,   4,  2],
                        [4,   0,  8]]),]
    env = TaxiEnv(states, actions, probs, rewards)
    return env
env1=make_taxienv()

# DP Algorithm implementation


In [4]:
def dp_solve(taxienv):
    states = taxienv.possible_states
    values = {s: 0 for s in states}
    policy = {s: '0' for s in states}
    all_values = [] 
    all_policies = []
    N=10
    S=len(states)
    J=np.array([[0.0]*S for _ in range(N+1)]) #J[i][a]
    
    for i in range(N-1,-1,-1):
        values = {state: 0 for state in states}
        policy = {state: '0' for state in states}
        for s in range(S):
            acts=taxienv.possible_actions(states[s]) 
            maxVal,bestAct = -float('inf'),''
            for act in acts:
                probs=taxienv.ride_probabilities(states[s], act)
                rewards=taxienv.ride_rewards(states[s], act)
                val=0
                for j in range(S):
                    val+=probs[j]*(rewards[j]+J[i+1][j])
                if val>=maxVal:
                    maxVal=val
                    bestAct=act
            J[i][s]=maxVal
            values[states[s]]=maxVal
            policy[states[s]]=bestAct
        all_values.append(values)
        all_policies.append(policy)

    results = {"Expected Reward": all_values, "Polcies": all_policies}
    return results

In [5]:
results=dp_solve(env1)

In [6]:
print('Expected Reward:')
for i in range(10):
    print(10-i,results['Expected Reward'][i])

print('\nPolicies:')
for i in range(10):
    print(10-i,results['Polcies'][i])

Expected Reward:
10 {'A': 8.0, 'B': 16.0, 'C': 7.0}
9 {'A': 17.75, 'B': 29.9375, 'C': 17.875}
8 {'A': 29.6640625, 'B': 43.421875, 'C': 30.90625}
7 {'A': 42.96533203125, 'B': 56.77978515625, 'C': 44.1376953125}
6 {'A': 56.295989990234375, 'B': 70.12625122070312, 'C': 57.47271728515625}
5 {'A': 69.63932228088379, 'B': 83.47101402282715, 'C': 70.81577682495117}
4 {'A': 82.98367631435394, 'B': 96.81558096408844, 'C': 84.16014790534973}
3 {'A': 96.32819322496653, 'B': 110.16012235730886, 'C': 97.50466375052929}
2 {'A': 109.6727282977663, 'B': 123.50466062361374, 'C': 110.84919888991863}
1 {'A': 123.01726577818044, 'B': 136.84919849489233, 'C': 124.19373636617092}

Policies:
10 {'A': '1', 'B': '1', 'C': '1'}
9 {'A': '1', 'B': '2', 'C': '2'}
8 {'A': '2', 'B': '2', 'C': '2'}
7 {'A': '2', 'B': '2', 'C': '2'}
6 {'A': '2', 'B': '2', 'C': '2'}
5 {'A': '2', 'B': '2', 'C': '2'}
4 {'A': '2', 'B': '2', 'C': '2'}
3 {'A': '2', 'B': '2', 'C': '2'}
2 {'A': '2', 'B': '2', 'C': '2'}
1 {'A': '2', 'B': '2', '