* how do you deal with illegal actions? 
    * exclude them from the policy, then rescale probs
* how does a policy that assigns equal chance to every action work with illegal actions? 
    * see rescaling 
* how does a random reward affect convergence?
    * below our random reward meant we bounced around rather than converged. 
    * if we use expected value of a r.v. instead of values from the r.v. itself, we can converge. 

In [1]:
import numpy as np
import copy
from IPython.core.debugger import set_trace

In [2]:
max_cars = 5  # max number of cars at any place 
gamma = 0.9  # discount rate 

# Number of cars requested and returned are Poisson r.v. 
# These are the parameters. 
lambda_req = [3,4]
lambda_rtn = [3,2]

In [3]:
## Initialisation 
# Value function 
v = np.zeros((max_cars+1, max_cars+1)) 
# Actions 
actions = [(i, -i) for i in range(-5,6)]
# Policy - the chance of action in each state. 
# Initalised as uniform
p = np.zeros((len(actions), max_cars+1, max_cars+1))
p[:] = 1/len(actions) 

In [4]:
## END OF DAY  
# for each location, count how many cars i have 
    # car_count = current count + all the ones returned today 
    # take min(20, car_count)
# decide how many cars to move 
    
## REWARD STEP (NEW DAY)
# get reward of moving cars (-2 per car)
# cars_borrowed = min(cars_req, car_count)
# get reward for people borrowing car
    # +10 * cars_borrowed
# car_count = car_count - cars_borrow

## s is state at end of day
# each day 
    # calc_cars_req 
    # car_count = car_count + cars_rtn_prev_day 
    # cars_borrowed = min(cars_req, car_count)
    # get reward for people borrowing car
        # +10 * cars_borrowed
    # car_count = car_count - cars_borrowed
    # calc cars_rtn_prev_day 
    
# end of day 
# for x
    # for y 
        # for each action 
            # calc value 
                # -2 per car moved
                # update car 
                # simulate next day(car_count)


In [5]:
def simulate_day(car_count, cars_rtn_prev_day, cars_req, max_cars): 
    """all inputs are vectors!"""
    #set_trace()
    car_count = [min(o + cars_rtn_prev_day[i], max_cars) for i,o in enumerate(car_count)]
    cars_borrowed = [min(o, car_count[i]) for i,o in enumerate(cars_req)]
    reward_borrow = sum(cars_borrowed) * 10
    #car_count = [o - cars_borrowed[i] for i,o in enumerate(car_count)]
    return(reward_borrow)

In [6]:
def get_legal_actions(x,y,actions, max_cars):
    """Some actions are illegal. This filters only legal actions."""
    locs = [np.add((x,y), action) for action in actions]
    bounds = [False if np.min(o) < 0 or np.max(o) > max_cars else True for o in locs ]
    possible_actions = [o for i,o in enumerate(actions) if bounds[i]]
    return possible_actions

In [7]:
def rescale_probs(legal_actions, probs, actions):
    """Rescales probability vector to only consider legal actions"""
    def softmax(x):
        """Compute softmax values for each sets of scores in x."""
        return np.exp(x) / np.sum(np.exp(x), axis=0)
    
    mask = [o in legal_actions for o in actions]
    probs_unscaled = [o for i,o in enumerate(probs) if mask[i]]
    return(softmax(probs_unscaled))

In [12]:
# Find value function for policy 
n_iter = 150
# each time step is at the end of day 
for k in range(n_iter):  # replace later with while 
    # It is now end of day. 
    old_v = v.copy()  # copy old value 
    # Cars that will be requested for the coming up day 
    #cars_req = [np.random.poisson(o) for o in lambda_req]
    cars_req = lambda_req
    # Cars that returned the day just finished
    #cars_rtn = [np.random.poisson(o) for o in lambda_rtn]
    cars_rtn = lambda_rtn
    # x,y are number of cars at loc1, loc2 x
    for x in range(max_cars+1): 
        for y in range(max_cars+1): 
            v_temp = 0 
            probs = p[:,x,y]
            # Check for illegal action: would lead to more than 20 cars  
            # or less than 0 cars at a location
            legal_actions = get_legal_actions(x,y,actions, max_cars)
            legal_probs = rescale_probs(legal_actions, probs, actions)
            for i, action in enumerate(legal_actions): 
                s_prime = tuple(np.add((x,y), action))
                reward_borrow = simulate_day(list(s_prime), cars_rtn, cars_req, max_cars)
                reward_moving = max(action) * -2
                reward = reward_moving + reward_borrow
                v_temp += legal_probs[i] * (reward + gamma * v[s_prime])
            v[x,y] = v_temp             
    tol = np.sum(abs(v - old_v))
    print(np.round(v))

[[ 460.  524.  570.  598.  613.  620.]
 [ 526.  573.  602.  617.  624.  664.]
 [ 577.  607.  621.  628.  669.  694.]
 [ 611.  625.  632.  675.  701.  711.]
 [ 630.  636.  680.  709.  721.  720.]
 [ 640.  687.  717.  732.  733.  730.]]
[[ 454.  512.  554.  581.  595.  602.]
 [ 506.  550.  578.  593.  601.  644.]
 [ 542.  573.  589.  598.  641.  672.]
 [ 564.  583.  593.  636.  668.  687.]
 [ 574.  587.  629.  660.  680.  693.]
 [ 578.  619.  648.  668.  681.  697.]]
[[ 458.  517.  562.  589.  602.  609.]
 [ 519.  565.  591.  605.  611.  652.]
 [ 569.  594.  607.  614.  655.  683.]
 [ 598.  610.  616.  658.  686.  699.]
 [ 614.  619.  662.  690.  703.  707.]
 [ 622.  666.  696.  709.  714.  717.]]
[[ 463.  520.  567.  596.  611.  619.]
 [ 522.  569.  598.  614.  621.  663.]
 [ 569.  600.  616.  623.  666.  694.]
 [ 600.  617.  625.  668.  698.  711.]
 [ 617.  626.  670.  701.  715.  718.]
 [ 625.  670.  702.  718.  723.  726.]]
[[ 466.  523.  569.  598.  613.  621.]
 [ 524.  571.  600.  

[[ 407.  445.  480.  507.  521.  529.]
 [ 439.  476.  504.  519.  527.  567.]
 [ 468.  498.  515.  524.  564.  596.]
 [ 489.  508.  518.  558.  591.  613.]
 [ 498.  511.  550.  582.  605.  617.]
 [ 502.  539.  569.  591.  603.  618.]]
[[ 396.  432.  465.  489.  501.  507.]
 [ 426.  461.  486.  498.  505.  544.]
 [ 456.  482.  495.  502.  541.  573.]
 [ 477.  491.  499.  537.  569.  591.]
 [ 487.  495.  533.  564.  585.  598.]
 [ 491.  529.  559.  578.  589.  606.]]
[[ 406.  440.  473.  495.  505.  510.]
 [ 444.  476.  497.  507.  511.  547.]
 [ 479.  499.  509.  513.  549.  576.]
 [ 502.  511.  515.  551.  578.  594.]
 [ 513.  517.  553.  580.  596.  603.]
 [ 519.  555.  583.  598.  606.  616.]]
[[ 396.  432.  466.  488.  498.  503.]
 [ 428.  465.  488.  498.  503.  540.]
 [ 461.  485.  497.  502.  540.  569.]
 [ 481.  494.  501.  539.  568.  584.]
 [ 490.  498.  536.  565.  582.  593.]
 [ 494.  532.  561.  577.  588.  604.]]
[[ 396.  426.  456.  474.  482.  485.]
 [ 423.  453.  472.  

[[ 418.  477.  524.  554.  569.  577.]
 [ 484.  530.  559.  574.  582.  621.]
 [ 535.  564.  578.  586.  626.  649.]
 [ 568.  583.  590.  631.  656.  663.]
 [ 587.  593.  636.  662.  671.  665.]
 [ 597.  641.  669.  680.  677.  663.]]
[[ 416.  476.  521.  550.  564.  572.]
 [ 476.  521.  550.  565.  573.  614.]
 [ 518.  548.  563.  572.  614.  640.]
 [ 543.  560.  570.  612.  639.  652.]
 [ 555.  566.  608.  635.  650.  653.]
 [ 561.  602.  628.  643.  648.  647.]]
[[ 424.  488.  536.  567.  583.  592.]
 [ 493.  541.  572.  588.  596.  637.]
 [ 546.  577.  593.  600.  642.  667.]
 [ 583.  597.  605.  648.  674.  681.]
 [ 603.  609.  654.  681.  691.  684.]
 [ 615.  661.  691.  703.  698.  682.]]
[[ 432.  500.  555.  589.  608.  618.]
 [ 506.  561.  595.  613.  623.  666.]
 [ 567.  601.  618.  627.  672.  697.]
 [ 605.  622.  631.  678.  705.  710.]
 [ 626.  635.  683.  712.  720.  711.]
 [ 637.  687.  718.  728.  723.  704.]]
[[ 429.  497.  553.  590.  610.  621.]
 [ 495.  553.  591.  

[[ 400.  427.  449.  463.  468.  469.]
 [ 429.  450.  463.  468.  470.  501.]
 [ 454.  465.  470.  470.  501.  525.]
 [ 469.  472.  472.  502.  526.  542.]
 [ 475.  474.  504.  527.  543.  555.]
 [ 477.  508.  532.  546.  556.  574.]]
[[ 410.  439.  460.  473.  478.  478.]
 [ 445.  464.  476.  480.  481.  509.]
 [ 468.  479.  483.  483.  512.  532.]
 [ 481.  485.  485.  514.  534.  547.]
 [ 486.  486.  516.  536.  549.  559.]
 [ 487.  517.  537.  551.  561.  577.]]
[[ 409.  442.  466.  482.  488.  490.]
 [ 443.  468.  485.  491.  493.  522.]
 [ 469.  487.  494.  496.  526.  546.]
 [ 487.  495.  498.  529.  550.  562.]
 [ 496.  499.  531.  553.  567.  573.]
 [ 499.  532.  556.  572.  579.  589.]]
[[ 398.  432.  459.  479.  488.  491.]
 [ 428.  457.  479.  489.  493.  525.]
 [ 453.  478.  489.  493.  527.  551.]
 [ 475.  487.  493.  527.  553.  569.]
 [ 485.  492.  527.  554.  571.  577.]
 [ 489.  525.  553.  572.  579.  590.]]
[[ 398.  431.  459.  479.  489.  493.]
 [ 430.  460.  480.  

In [13]:
print (np.round(v))
print(tol)

[[ 432.  490.  536.  568.  585.  595.]
 [ 497.  542.  575.  591.  601.  638.]
 [ 547.  580.  597.  606.  646.  668.]
 [ 584.  601.  610.  652.  677.  684.]
 [ 604.  613.  657.  684.  695.  685.]
 [ 616.  661.  690.  704.  699.  681.]]
932.57255692
