In [1]:
import pandas as pd
import numpy as np
from blackjack import SimplifiedBlackjackMDP
from snake import SnakeMDP, hashable_state, array_state
import pickle
import random
import matplotlib.pyplot as plt
import time

from tqdm.notebook import tqdm

# Construct MDP Examples

In [2]:
problems = {
    'Snake': SnakeMDP(board_length=3, max_snake_length=3),
    'Blackjack': SimplifiedBlackjackMDP()
}

Trying max snake length 1 with 9 base combos
100 states found...
200 states found...
Trying max snake length 2 with 36 base combos
300 states found...
400 states found...


In [3]:
for key in problems:
    print(f'{key} MDP has {len(problems[key].states)} states.')

Snake MDP has 464 states.
Blackjack MDP has 1271 states.


In [4]:
params = {
    #'PI': [
    #    {'gamma': 0.99, 'epsilon': 0.001},
    #    {'gamma': 0.99, 'epsilon': 0.25},
    #    {'gamma': 0.75, 'epsilon': 0.001},
    #    {'gamma': 0.75, 'epsilon': 0.25},
    #    {'gamma': 0.50, 'epsilon': 0.001},
    #    {'gamma': 0.50, 'epsilon': 0.25},
    #    {'gamma': 0.25, 'epsilon': 0.001},
    #    {'gamma': 0.25, 'epsilon': 0.25},
    #],
    #'VI': [
    #    {'gamma': 0.99, 'epsilon': 0.001},
    #    {'gamma': 0.99, 'epsilon': 0.25},
    #    {'gamma': 0.75, 'epsilon': 0.001},
    #    {'gamma': 0.75, 'epsilon': 0.25},
    #    {'gamma': 0.50, 'epsilon': 0.001},
    #    {'gamma': 0.50, 'epsilon': 0.25},
    #    {'gamma': 0.25, 'epsilon': 0.001},
    #    {'gamma': 0.25, 'epsilon': 0.25},
    #],
    'Q-learning': [
        {'decay_pattern': 'mitchell',        'initialization': 'zeros',        'exploration': 'introduce-randomness'  },
        {'decay_pattern': 'mitchell',        'initialization': 'first_reward', 'exploration': 'introduce-randomness'  },
        {'decay_pattern': 'mitchell',        'initialization': 'zeros',        'exploration': 'q-optimal'},
        {'decay_pattern': 'mitchell',        'initialization': 'first_reward', 'exploration': 'q-optimal'},
        #{'decay_pattern': 'iteration_based', 'initialization': 'zeros',        'exploration': 'introduce-randomness'  },
        {'decay_pattern': 'iteration_based', 'initialization': 'first_reward', 'exploration': 'introduce-randomness'  },
        #{'decay_pattern': 'iteration_based', 'initialization': 'zeros',        'exploration': 'q-optimal'},
        {'decay_pattern': 'iteration_based', 'initialization': 'first_reward', 'exploration': 'q-optimal'},
    ]
}

In [5]:
def params_to_text(d):
    text = ''
    for i, k in enumerate(d):
        text += f'{k}={d[k]}'
        if i < len(d) - 1:
            text += '; '
    return text

In [None]:
start = time.time()
results = {}
for problem_name in tqdm(problems):
    results[problem_name] = {}
    for algo in tqdm(params):
        results[problem_name][algo] = {}
        for parameterization in tqdm(params[algo]):
            print(f'{problem_name}-{algo}-{parameterization} @ {(time.time() - start) / 60:.1f}m')
            
            if algo == 'PI':
                output = problems[problem_name].policy_iteration(**parameterization)
            elif algo == 'VI':
                output = problems[problem_name].value_iteration(**parameterization)
            elif algo == 'Q-learning':
                output = problems[problem_name].Q_learning(**parameterization)
            else:
                raise Exception('Unexpected...')
                
            results[problem_name][algo][params_to_text(parameterization)] = output

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Snake-Q-learning-{'decay_pattern': 'mitchell', 'initialization': 'zeros', 'exploration': 'introduce-randomness'} @ 0.0s
At iteration 1, max change in value: 0.98702
At iteration 2, max change in value: 0.98708
At iteration 3, max change in value: 0.86257
At iteration 4, max change in value: 0.92568
At iteration 5, max change in value: 0.85353
At iteration 6, max change in value: 0.94322
At iteration 7, max change in value: 0.93739
At iteration 8, max change in value: 0.89461
At iteration 9, max change in value: 0.88176
At iteration 10, max change in value: 0.89135
At iteration 11, max change in value: 0.64059
At iteration 12, max change in value: 0.65139
At iteration 13, max change in value: 0.93807
At iteration 14, max change in value: 0.88037
At iteration 15, max change in value: 0.68616
At iteration 16, max change in value: 0.72699
At iteration 17, max change in value: 0.54730
At iteration 18, max change in value: 0.67042
At iteration 19, max change in value: 0.65475
At iteration 20

At iteration 176, max change in value: 0.11618
At iteration 177, max change in value: 0.11635
At iteration 178, max change in value: 0.11603
At iteration 179, max change in value: 0.11623
At iteration 180, max change in value: 0.11625
At iteration 181, max change in value: 0.11621
At iteration 182, max change in value: 0.11611
At iteration 183, max change in value: 0.11630
At iteration 184, max change in value: 0.11617
At iteration 185, max change in value: 0.11630
At iteration 186, max change in value: 0.11644
At iteration 187, max change in value: 0.11613
At iteration 188, max change in value: 0.11634
At iteration 189, max change in value: 0.11558
At iteration 190, max change in value: 0.11637
At iteration 191, max change in value: 0.11620
At iteration 192, max change in value: 0.11631
At iteration 193, max change in value: 0.11606
At iteration 194, max change in value: 0.11620
At iteration 195, max change in value: 0.11633
At iteration 196, max change in value: 0.11626
At iteration 

At iteration 95, max change in value: 0.11627
At iteration 96, max change in value: 0.11634
At iteration 97, max change in value: 0.11231
At iteration 98, max change in value: 0.11578
At iteration 99, max change in value: 0.11619
At iteration 100, max change in value: 0.11608
At iteration 101, max change in value: 0.11583
At iteration 102, max change in value: 0.11609
At iteration 103, max change in value: 0.11586
At iteration 104, max change in value: 0.11626
At iteration 105, max change in value: 0.11457
At iteration 106, max change in value: 0.11607
At iteration 107, max change in value: 0.11605
At iteration 108, max change in value: 0.11627
At iteration 109, max change in value: 0.11601
At iteration 110, max change in value: 0.11603
At iteration 111, max change in value: 0.11619
At iteration 112, max change in value: 0.11633
At iteration 113, max change in value: 0.11602
At iteration 114, max change in value: 0.11616
At iteration 115, max change in value: 0.11616
At iteration 116, 

At iteration 24, max change in value: 0.26139
At iteration 25, max change in value: 0.26135
At iteration 26, max change in value: 0.24379
At iteration 27, max change in value: 0.26158
At iteration 28, max change in value: 0.24229
At iteration 29, max change in value: 0.19041
At iteration 30, max change in value: 0.26144
At iteration 31, max change in value: 0.24420
At iteration 32, max change in value: 0.24897
At iteration 33, max change in value: 0.22300
At iteration 34, max change in value: 0.35783
At iteration 35, max change in value: 0.43553
At iteration 36, max change in value: 0.26273
At iteration 37, max change in value: 0.21217
At iteration 38, max change in value: 0.24228
At iteration 39, max change in value: 0.22276
At iteration 40, max change in value: 0.26258
At iteration 41, max change in value: 0.22583
At iteration 42, max change in value: 0.25566
At iteration 43, max change in value: 0.24362
At iteration 44, max change in value: 0.24458
At iteration 45, max change in val