In [105]:
import numpy as np
from blackjack_env import BlackJackEnv, CARD_INDX

policy = np.zeros(shape=(10,10,2))

for player_sum_idx in range(10): # simulating dealers policy
    actual_player_sum = player_sum_idx + 12
    for dealer_card_idx in range(10):
        for usable_ace_idx in range(2):
            if actual_player_sum >= 17:
                policy[player_sum_idx][dealer_card_idx][usable_ace_idx] = 1 # Stick
            else:
                policy[player_sum_idx][dealer_card_idx][usable_ace_idx] = 0 # Hit
    
def firstvisit_mc_eval(policy, runs=10000, gamma=1):
    state_values = np.zeros(shape=(10,10,2))
    returns = [[[[], []] for _ in range(10)] for _ in range(10)]
    
    bj = BlackJackEnv()

    for _ in range(runs):
        episode = bj.run_episode(policy)
        states_seen = set()
        g = 0

        for t in reversed(range(len(episode))):
            state_dict, action, reward = episode[t]
            
            g = reward + gamma * g

            h_i = state_dict['hand_sum'] - 12
            d_i = CARD_INDX[state_dict['dealer_card']]
            ua_i = 1 if state_dict['usable_aces'] > 0 else 0

            state_key = (h_i, d_i, ua_i)

            if state_key not in states_seen:
                states_seen.add(state_key)
                returns[h_i][d_i][ua_i].append(g)
                state_values[h_i][d_i][ua_i] = np.mean(returns[h_i][d_i][ua_i])

    return state_values

firstvisit_mc_eval(policy, 100000)

array([[[-0.4552737 , -0.45762712],
        [-0.27306733,  0.11290323],
        [-0.21134021,  0.03703704],
        [-0.2621232 ,  0.08      ],
        [-0.24903226,  0.02222222],
        [-0.21447028,  0.33870968],
        [-0.16031537,  0.36111111],
        [-0.28552279, -0.20454545],
        [-0.34909597, -0.2826087 ],
        [-0.40601255, -0.21556886]],

       [[-0.51732991, -0.21551724],
        [-0.33510638,  0.04301075],
        [-0.38385093, -0.13043478],
        [-0.34255599, -0.03773585],
        [-0.32082794,  0.02702703],
        [-0.30647292,  0.20430108],
        [-0.21496599,  0.11702128],
        [-0.34183673,  0.05617978],
        [-0.39389736, -0.17241379],
        [-0.46002579, -0.03846154]],

       [[-0.60772834, -0.18796992],
        [-0.4025974 , -0.17808219],
        [-0.37780401,  0.05376344],
        [-0.35231317, -0.05555556],
        [-0.33174224,  0.02247191],
        [-0.28588098,  0.        ],
        [-0.29021372,  0.04081633],
        [-0.38596491, -0