In [1]:
import numpy as np
from collections import OrderedDict
import statistics as stat
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import matplotlib as mpl
import matplotlib.image as mpimg
from IPython.display import Image
import import_ipynb
from collections import defaultdict
import seaborn as sns
from matplotlib.patches import Patch
import pickle

In [2]:
def create_grids(pi, Q_s_a, usable_ace=True):
    """Create value and policy grid given an agent."""
    # convert our state-action values to state values
    # and build a policy dictionary that maps observations to actions
    optimalStateActionValue = np.zeros((11, 10), dtype=np.float32)  
    optimalPolicyActions = np.zeros((11, 10), dtype=np.int32)
    for i in range(10, 21):
        for j in range(1, 11):
            optimalStateActionValue[i-11, j-1] = max(Q_s_a.Q_s_a[((i, j, usable_ace), 0)], Q_s_a.Q_s_a[((i, j, usable_ace), 1)]) 
            optimalPolicyActions[i-11, j-1] = 0 if ( pi.pi[((i, j, usable_ace), 0)] > pi.pi[((i, j, usable_ace), 1)] ) else 1
    playerSum, dealerShows = np.meshgrid(
        # players count, dealers face-up card
        np.arange(10, 21),
        np.arange(1, 11),
    )
    # create the value grid for plotting
    value = np.apply_along_axis(
        lambda obs: optimalStateActionValue[obs[0]-11, obs[1]-1],
        axis=2,
        arr=np.dstack([playerSum, dealerShows]),
    )
    value_grid = playerSum, dealerShows, value
    # create the policy grid for plotting
    policy_grid = np.apply_along_axis(
        lambda obs: optimalPolicyActions[obs[0]-11, obs[1]-1],
        axis=2,
        arr=np.dstack([playerSum, dealerShows]),
    )
    return value_grid, policy_grid

In [1]:
def create_plots(value_grid, policy_grid, title, fileName, numEpisodes, runNum, discountFactor, firstVisit, usableAce):
    # create a new figure with 2 subplots (left: state values, right: policy)
    player_count, dealer_count, value = value_grid
    fig = plt.figure(figsize=plt.figaspect(0.4))
    fig.suptitle(title, fontsize=16)
    # plot the state values
    ax1 = fig.add_subplot(1, 2, 1, projection="3d")
    ax1.plot_surface(
        player_count,
        dealer_count,
        value,
        rstride=1,
        cstride=1,
        cmap="viridis",
        edgecolor="none",
    )
    plt.xticks(range(10, 21), range(10, 21), fontsize=10)#, range(10, 22))
    plt.yticks(range(1, 11), ["A"] + list(range(2, 11)), fontsize=12)
    ax1.set_title(f"")
    ax1.set_xlabel("Player sum", fontsize=12, rotation=45)
    ax1.set_ylabel("Dealer showing")
    ax1.zaxis.set_rotate_label(False)
    ax1.set_zlabel("Q(S, A) (following \u03C0*) ", fontsize=14, rotation=90)
    ax1.view_init(20, 220)
    # plot the policy
    fig.add_subplot(1, 2, 2)
    ax2 = sns.heatmap(policy_grid, linewidth=0, annot=True, cmap="Accent_r", cbar=False)
    ax2.set_title(f"Optimal Policy \u03C0*")
    ax2.set_xlabel("Player sum")
    ax2.set_ylabel("Dealer showing")
    ax2.set_xticklabels(range(10, 21), fontsize=10)
    ax2.set_yticklabels(["A"] + list(range(2,11)))# + range(2, 11), fontsize=12)
    # add a legend
    legend_elements = [
        Patch(facecolor="lightgreen", edgecolor="black", label="Hit"),
        Patch(facecolor="grey", edgecolor="black", label="Stick"),
    ]
    ax2.legend(handles=legend_elements, bbox_to_anchor=(1.3, 1))
    # Save the action value and policy plot to file
    plt.savefig(fileName)
    # Close the plot
    plt.close()