In [1]:
import pandas as pd
import numpy as np
from blackjack import SimplifiedBlackjackMDP
from snake import SnakeMDP, hashable_state, array_state
import pickle
import random
import matplotlib.pyplot as plt

## Construct MDP Examples

In [8]:
problems = {
    'Snake': SnakeMDP(board_length=4, max_snake_length=5),
    'Blackjack': SimplifiedBlackjackMDP()
}

Trying max snake length 1 with 16 base combos
100 states found...
200 states found...
300 states found...
400 states found...
500 states found...
600 states found...
700 states found...
800 states found...
900 states found...
Trying max snake length 2 with 120 base combos
1,000 states found...
1,100 states found...
1,200 states found...
1,300 states found...
1,400 states found...
1,500 states found...
1,600 states found...
Trying max snake length 3 with 560 base combos
1,700 states found...
1,800 states found...
1,900 states found...
2,000 states found...
2,100 states found...
2,200 states found...
2,300 states found...
2,400 states found...
2,500 states found...
2,600 states found...
2,700 states found...
2,800 states found...
2,900 states found...
Trying max snake length 4 with 1,820 base combos
3,000 states found...
3,100 states found...
3,200 states found...
3,300 states found...
3,400 states found...
3,500 states found...
3,600 states found...
3,700 states found...
3,800 states fo

In [9]:
for key in problems:
    print(f'{key} MDP has {len(problems[key].states)} states.')

Snake MDP has 5776 states.
Blackjack MDP has 1271 states.


## Learn on these examples

## Test accuracy with handpicked examples

## Develop graphs for these bad boys

In [None]:
with open('snake_bl5_msl7_20221124.pkl', 'rb') as f:
    snake_mdp = pickle.load(f)

In [3]:
snake_mdp = SnakeMDP(board_length=3, max_snake_length=4)

Trying max snake length 1 with 9 base combos
100 states found...
200 states found...
Trying max snake length 2 with 36 base combos
300 states found...
400 states found...
Trying max snake length 3 with 84 base combos
500 states found...
600 states found...
700 states found...


In [None]:
snake_policy, snake_q = snake_mdp.Q_learning(gamma=0.9, epsilon=0.001)

In [4]:
snake_policy, snake_q, stats = snake_mdp.value_iteration(gamma=0.9, epsilon=0.01)

At iteration 1, max change in value: 1.04341; avg. change: 0.21854
At iteration 2, max change in value: 0.93907; avg. change: 0.30082
At iteration 3, max change in value: 0.84516; avg. change: 0.23794
At iteration 4, max change in value: 0.72900; avg. change: 0.23923
At iteration 5, max change in value: 0.52715; avg. change: 0.19515
At iteration 6, max change in value: 0.40357; avg. change: 0.14404
At iteration 7, max change in value: 0.28788; avg. change: 0.09297
At iteration 8, max change in value: 0.21717; avg. change: 0.06434
At iteration 9, max change in value: 0.10865; avg. change: 0.04583
At iteration 10, max change in value: 0.09713; avg. change: 0.03332
At iteration 11, max change in value: 0.05735; avg. change: 0.02417
At iteration 12, max change in value: 0.05162; avg. change: 0.01714
At iteration 13, max change in value: 0.04646; avg. change: 0.01276
At iteration 14, max change in value: 0.04181; avg. change: 0.00947
At iteration 15, max change in value: 0.02469; avg. chang

In [None]:
snake_policy, snake_value = mdp.policy_iteration(gamma=0.5, epsilon=0.0001, max_allowed_time=720)

# s = ((1.0, 1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), 'down')
# array_state(s)
# array_state(mdp.accessible_states(s, 'down')[0])

for s in random.sample(mdp.states, k=5):
    print(f'\nFor state s = \n{array_state(s)}')
    print(f'the recommended action is: {snake_policy[s]}')

In [None]:
# take a state and visualize it

def plot(s):
    board, last_direction = array_state(s)
    plt.imshow(board, interpolation='none')
    plt.title(f"Snake with last move='{s[1]}'")
    plt.show()

In [None]:
plot(example_4)

In [None]:
snake_policy[example_4]

In [None]:

mdp = SimplifiedBlackjackMDP()


mdp._sample_state_layout()

for s in mdp.accessible_states((11, 3, 'hitting'), 'hit'):
    print(s, mdp.reward(s))
mdp.transition_model((19, 6, 'hitting'), 'hold', (19, 6, 'stand'))

blackjack_policy, blackjack_value = mdp.policy_iteration(gamma=0.99999, epsilon=1e-5)
blackjack_policy, blackjack_value = mdp.value_iteration(gamma=0.99999, epsilon=1e-5)

policy_visualization = pd.DataFrame(
    index=pd.Index(range(2, 12), name='Dealer Value'),
    columns=pd.Index(range(2, 22), name='Player Value'),
    dtype='string'
)
for i in range(2, 12):
    for j in range(2, 22):
        policy_visualization.loc[i, j] = blackjack_policy[(j, i, 'hitting')]

policy_visualization.iloc[:, 5:15]

value_visualization = pd.DataFrame(
    index=pd.Index(range(2, 12), name='Dealer Value'),
    columns=pd.Index(range(2, 22), name='Player Value'),
    dtype='float'
)
for i in range(2, 12):
    for j in range(2, 22):
        value_visualization.loc[i, j] = np.round(blackjack_value[(j, i, 'hitting')], 1)

value_visualization.iloc[:, 5:15]


In [2]:
example_1 = hashable_state(
    (np.array([[0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0.],x
               [3., 1., 1., 1., 2.]]),
     'right')
)

example_2 = hashable_state(
    (np.array([[0., 0., 0., 0., 0.],
               [0., 0., 0., 3., 0.],
               [0., 0., 0., 0., 0.],
               [0., 0., 0., 2., 0.],
               [0., 1., 1., 1., 0.]]),
     'up')
)

example_3 = hashable_state(
    (np.array([[0., 0., 0., 0.],
               [0., 0., 3., 0.],
               [0., 0., 0., 0.],
               [0., 1., 2., 0.]]),
     'right')
)

example_4 = hashable_state(
    (np.array([[0., 0., 0.],
               [0., 1., 1.],
               [0., 2., 3.]]),
     'down')
)
s = example_2