# Jack's Car Rental

In [1]:
import numpy as np
from scipy.stats import poisson
from tqdm import tqdm_notebook

In [2]:
max_capacity = 20
rent_1_lambda = 3
rent_2_lambda = 4
return_1_lambda = 3
return_2_lambda = 2

max_move = 5
move_cost = -2
rent_credit = 10
discount = 0.9

def possible_rewards(s,a):
    for i in range(s[0]+s[1]+1):
        yield i*rent_credit+a*move_cost

def possible_states():
    for i in range(max_capacity+1):
        for j in range(max_capacity+1):
            yield [i,j]  

def possible_actions(s):
    for a in range(-min(max_move, s[1]), min(max_move, s[0])+1):
        yield a
           

In [3]:
# environment dynamic function
def p(s_new,r,s,a)->float:
    # read no. cars from state
    car1 = s[0]
    car2 = s[1]
    
    # move cars overnight
    car1 -= a
    car2 += a
    cost = abs(a)*move_cost
    num_rent = int((r-cost)/rent_credit)

    # rent
    rent_1 = poisson.pmf(np.arange(car1+1), rent_1_lambda)
    rent_1[-1] += 1-rent_1.sum()
    rent_2 = poisson.pmf(np.arange(car2+1), rent_2_lambda)
    rent_2[-1] += 1-rent_2.sum()    
    
    # joint rent
    joint_rent = np.outer(rent_1,rent_2)

    
    # initialize probability
    prob = 0.0
    
    # interate the joint probs
    for rnt1 in range(car1+1):
        rnt2 = num_rent-rnt1
        
        if not (rnt2 >= 0 and rnt2 <= car2): continue
        # substract rent cars
        c1 = car1 - rnt1
        c2 = car2 - rnt2
        # return
        retn1 = s_new[0]-c1
        retn2 = s_new[1]-c2
        if not (retn1 >= 0 and retn2 >= 0): continue
            
        return_1 = poisson.pmf(retn1, return_1_lambda) if s_new[0]!=max_capacity else 1-poisson.cdf(retn1-1, return_1_lambda)
        return_2 = poisson.pmf(retn2, return_2_lambda) if s_new[1]!=max_capacity else 1-poisson.cdf(retn2-1, return_2_lambda)

        
        # joint return
        joint_return = return_1*return_2
        
        prob += joint_rent[rnt1,rnt2]*joint_return
                
    return prob


s_new=[1,1]
r = 48
s = [3,2]
a = 1

p(s_new,r,s,a)

0.02466760846032992

In [4]:
# The probability returned by function p conditioned on (s,a) is guaranteed to sum up to 1
# run the test below

# change settings here
s = [3,2]
a = 0
##########################
total=0.0
for i in range(max_capacity+1):
    for j in range(max_capacity+1):
        for r in possible_rewards(s,a):
            try:
                total+=p([i,j],r,s,a)
            except:
                print(i,j,r)
total

0.9999999999999969

In [84]:
def bellman(s,a):
    V_update = 0.0
    for s_new in possible_states():
        for r in possible_rewards(s,a):
            V_update += p(s_new,r,s,a)*(r+discount*V[s_new[0],s_new[1]])
    return V_update

def evaluation(V, pi, theta = 5, max_iteration = 2):
    it = 0
    while it<max_iteration:
        it+=1
        delta = 0.0

        for s in tqdm_notebook(possible_states(),"eval it %d"%it,total=(max_capacity+1)**2):
            a = pi[s[0],s[1]]
            V_update = bellman(s,a)
            delta = max(delta, abs(V_update-V[s[0],s[1]]))
            V[s[0],s[1]] = V_update

        if delta < theta:
            break

def improvement(V, pi):
    policy_stable = True
    for s in tqdm_notebook(possible_states(),"imp ",total=(max_capacity+1)**2):
        best_v = -10e5
        new_action = pi[s[0],s[1]]
        for a in possible_actions(s):
            v = bellman(s,a)
            if v > best_v:
                best_v = v
                new_action = a
                    
        if new_action!=pi[s[0],s[1]]:
            policy_stable = False
        pi[s[0],s[1]] = new_action
    
    return policy_stable

In [None]:
V = np.zeros([max_capacity+1]*2)
pi = np.zeros([max_capacity+1]*2, int)

max_iteration = 3
it = 0
policy_stable = False
while it<max_iteration or not policy_stable:
    it+=1
    evaluation(V, pi)
    improvement(V, pi)


HBox(children=(IntProgress(value=0, description='eval it 1', max=441, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='eval it 2', max=441, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='imp ', max=441, style=ProgressStyle(description_width='initia…

In [71]:
V

array([[  0.        ,  18.65200286,  35.91201717,  50.38805436,
         61.15212207,  68.20422031,  72.28674295,  74.38954854,
         75.36108724,  75.7669925 ,  75.92150511,  75.97546067,
         75.99285002,  75.99805064,  75.99950088,  75.99987958,
         75.99997254,  75.99999407,  75.99999878,  75.99999976,
         75.99999996],
       [ 18.0540457 ,  36.70604856,  53.96606287,  68.44210006,
         79.20616778,  86.25826601,  90.34078865,  92.44359424,
         93.41513294,  93.8210382 ,  93.97555081,  94.02950637,
         94.04689572,  94.05209634,  94.05354658,  94.05392528,
         94.05401824,  94.05403977,  94.05404448,  94.05404546,
         94.05404566],
       [ 33.27022851,  51.92223137,  69.18224567,  83.65828287,
         94.42235058, 101.47444881, 105.55697146, 107.65977705,
        108.63131575, 109.037221  , 109.19173362, 109.24568917,
        109.26307853, 109.26827915, 109.26972938, 109.27010809,
        109.27020105, 109.27022257, 109.27022729, 109.2702