In [1]:
%matplotlib widget

In [2]:
import copy
import enum
import multiprocessing
import random
from collections import defaultdict
from dataclasses import dataclass
from itertools import product
from typing import Callable
from typing import Iterable
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TypeVar

import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import torch
import tqdm

In [3]:
# # doctest to catch regressions, copy below relevant code section
# import doctest
# doctest.testmod()

# Monte Carlo Methods

## <p style="color:red">Text content copied near verbatim from: <a href="http://incompleteideas.net/book/the-book-2nd.html">Sutton and Barto</a>. Code is my own unless otherwise stated.</p>

## Key Concepts

TODO

## Definitions

**Monte Carlo (MC) Methods**: 
- Generic definition: Repeated random sampling to obtain numerical results. 
- RL definition: A way of solving the RL problem based on averaging sample returns.

**Control Problem** (Reminder?):

**First-visit MC Method**:

**Every-visit MC Method**:

## Notation

TODO

## Introduction

Monte Carlo methods only require experience - sample sequences of states, actions, and rewards from actual or simulated interaction with an environment. They do _not_ require complete knowledge of the environment dynamics. 

For simulated learning a model is required but it only needs to generate sample transitions and not the complete probability distribution of all possible transitions. There are a surprising number of places where it is easy to generate experience sampled according to the desired probability distribution but infeasible to obtain the distributions in explicit form.

Monte Carlo methods solve the RL problem by averaging sample returns. Here we only consider Monte Carlo methods for episodic tasks. Only on completion of an episode are value estimates and policies changed. This means they can be used incrementally in an episode-by-episode nature but not in an online step-by-step way. 

In the previous notebook we looked at dynamic programming which _computed_ value functions from knowledge of the MDP. Monte Carlo methods here _learn_ value functions from sample returns with the MDP. The value functions and corresponding policies still interact to obtain optimality via general policy iteration. 

As per the previous chapter, we start by looking at the prediction problem, then the control problem and its solution by GPI. 

## 5.1: Monte Carlo Prediction

The "obvious" way to learn the state-value function - expected cumulative sum of future discounted reward - for a given policy is to average the returns observed after visits to a state. As more returns are observed, the average should converge to the expected value.

To estimate $v_\pi(s)$ we define:

- _visit_ to $s$: each occurrence of state $s$ in an episode
- _first visit_: the first time s is visited in an episode

This leads to the following MC methods:

- _first visit MC method_: estimates $v_\pi(s)$ as the average returns following first visits to $s$
- _every visit MC method_: averages returns following _all_ visits to $s$

These are similar but have different theoretical properties.

This chapter looks at first-visit MC. Every-visit MC extends more naturally to function approximation and eligibility tracing (ch 9/12).

In [4]:
StateT = TypeVar("StateT")
ActionT = TypeVar("ActionT")
RewardT = TypeVar("RewardT", int, float)

TraceT = Sequence[Tuple[StateT, ActionT, RewardT]]

In [5]:
def first_visit_mc(
    states: Sequence[StateT],
    episodes: Iterable[TraceT],
    gamma: float = 0.9
) -> Mapping[StateT, RewardT]:
    """
    
    Example:
        >>> def gen_ep(n_episodes):
        ...     for _ in range(n_episodes):
        ...         yield [(0, 0, 0.5), (0, 0, 0.75), (1, 0, 1.0), (0, 0, 0.75)]
        >>> value_map = first_visit_mc(states=[0, 1], episodes=gen_ep(1), gamma=0.5)
        >>> value_map[0] == 1.21875
        True
        >>> value_map[1] == 1.375
        True
    """
    values = dict((s, 0.0) for s in states)                    
    returns = defaultdict(list)
    
    for trace in tqdm.tqdm(episodes):
        ret = 0.0
        fv = first_visits(trace)
        for first_visit, (state, _, reward) in reversed(list(zip(fv, trace))):
            ret = reward + gamma*ret
            
            if first_visit:
                returns[state].append(ret)
                
    for s in states:
        rets = returns[s]
        if len(rets) == 0:
            continue
        values[s] = sum(rets) / len(rets)
                
    return values


def first_visits(trace: TraceT) -> Sequence[bool]:
    """
    
    Example:
    
        >>> first_visits([(0, 0, 0.5)])
        [True]
        >>> first_visits([(0, 0, 0.5), (1, 0, 0.1), (0, 0, 0.5)])
        [True, True, False]
    """
    visited = set()
    fv = []
    for s, _, _ in trace:
        fv.append(s not in visited)
        visited.add(s)
    return fv

Both first-visit MC and every-visit MC converge to $v_\pi(s)$ as the number of visits (or first visits) to $s$ goes to infinity.

In [6]:
# Example 5.1: Blackjack 

In [7]:
class BJAction(enum.Enum):
    STICK = 0
    HIT = 1
    
    
class BJCard(enum.IntEnum):
    ACE   = 11
    TWO   = 2
    THREE = 3
    FOUR  = 4
    FIVE  = 5
    SIX   = 6
    SEVEN = 7
    EIGHT = 8
    NINE  = 9
    TEN   = 10
    
    @staticmethod
    def _prob(card) -> float:
        if card == BJCard.TEN:
            return 4.0 / 13
        return 1.0 / 13

    @staticmethod
    def hit(n_cards=None):
        return np.random.choice(BJCard, size=n_cards, p=[BJCard._prob(c) for c in BJCard]).tolist()

    
@dataclass
class Hand:
    value: int
    usable_ace: bool
        
    def __init__(self, cards):
        self.value = 0
        self.usable_ace = False
        for card in cards:
            if card != BJCard.ACE:
                self.value += card
            else:
                if self.usable_ace:
                    # cannot have more than one usable ace (11 + 11 > 21)
                    self.value += 1
                else:
                    # attempt to use max value, correct later if > 21
                    self.value += 11
                    self.usable_ace = True

            if self.value > 21 and self.usable_ace:
                self.value -= 10
                self.usable_ace = False
        
    def add(self, card):
        self.value += card
        if self.usable_ace and self.value > 21:
            self.value -= 10
            self.usable_ace = False        
        return self
    
    def __hash__(self):
        return hash(self.value) + hash(self.usable_ace)

    
@dataclass
class BJState:
    player: Hand
    dealer: Hand
        
    @staticmethod
    def new():
        cards = BJCard.hit(3)
        player = Hand(cards[:2])
        dealer = Hand(cards[2:])
        return BJState(player, dealer)
    
    def copy(self):
        return copy.deepcopy(self)
    
    def __hash__(self):
        return hash(self.player) + hash(self.dealer)

In [8]:
def bj_states() -> Sequence[BJState]:
    states = []
    for player_value in range(12, 21 + 1):
        for usable_ace in [False, True]:
            for dealer_value in range(2, 11 + 1):
                player = Hand([])
                player.value = player_value
                player.usable_ace = usable_ace
                
                dealer = Hand([])
                dealer.value = dealer_value
                dealer.usable_ace = dealer_value == 11
                
                states.append(BJState(player, dealer))
    return states

In [9]:
def stick_20_21(state: BJState, action: BJAction) -> float:
    if state.player.value in {20, 21}:
        if action == BJAction.STICK:
            return 1.0
        return 0.0
    if action == BJAction.HIT:
        return 1.0
    return 0.0

In [10]:
def bj_simulate(
    policy: Callable[[BJState, BJAction], float]
) -> Sequence[Tuple[BJState, BJAction, int]]:
    """..."""
    states = []
    actions = []
    rewards = []
    
    state = BJState.new()
    
    # simulate player until sticks or goes bust
    while True:
        states.append(state.copy())
        
        action = np.random.choice(BJAction, p=[policy(state, a) for a in BJAction])
        actions.append(action)
        
        if action == BJAction.HIT:
            state.player.add(BJCard.hit())
        
        if action == BJAction.STICK or state.player.value > 21:
            break
            
        rewards.append(0)   # 0 whilst player still playing
    
    # compute final reward
    if state.player.value > 21:
        # player bust
        reward = -1
    else:
        # player not bust, compute dealer value using fixed policy
        while state.dealer.value < 17:
            state.dealer.add(BJCard.hit())
        if state.dealer.value > 21:
            # dealer bust
            reward = 1
        else:
            if state.player.value == state.dealer.value:
                reward = 0
            elif state.player.value > state.dealer.value:
                reward = 1
            else:
                reward = -1
           
    states.append(state.copy())
    actions.append(None)
    rewards.append(reward)

    return list(zip(states, actions, rewards))

In [11]:
def generate(f, q_in, q_out, seed):
    # force different seed per process to avoid 
    # processes simulating the same episodes
    np.random.seed(seed)
    random.seed(seed)
    to_gen = q_in.get()
    for _ in range(to_gen):
        q_out.put(f())
        

def bj_simulate_n(policy, n_episodes, n_procs=10):
    assert n_procs > 0
    if n_procs == 1:
        for _ in range(n_episodes):
            yield bj_simulate(policy)
        return
        
    q_in = multiprocessing.Queue()
    q_out = multiprocessing.Queue()

    proc = [
        multiprocessing.Process(
            target=generate, 
            args=(lambda: bj_simulate(policy), q_in, q_out, seed)
        )
        for seed in range(n_procs)
    ]
    
    for p in proc:
        p.daemon = True
        p.start()
        
    for pid in range(n_procs):
        to_proc = n_episodes // n_procs
        if pid == n_procs - 1:
            to_proc += n_episodes % n_procs
        q_in.put(to_proc)

    for _ in range(n_episodes):
        yield q_out.get()

    [p.join() for p in proc]

In [12]:
def run(n_procs, n_episodes):
    state_values = first_visit_mc(
        states=bj_states(),
        episodes=bj_simulate_n(stick_20_21, n_episodes, n_procs=n_procs),
        gamma=1.0
    )
    return state_values

In [13]:
state_values = {}
for n_episodes in [10000, 500000]:
    state_values[n_episodes] = run(n_procs=6, n_episodes=n_episodes)

10000it [00:01, 7550.00it/s]
500000it [00:58, 8574.73it/s] 


In [43]:
plt.close("all")
fig = plt.figure(figsize=(10, 8))

for i, n_episodes in enumerate(sorted(state_values.keys())):
    for j, usable_ace in enumerate([True, False]):
        values = np.empty((10, 10))

        for state, value in state_values[n_episodes].items():
            if state.player.usable_ace != usable_ace:
                continue
            values[ state.player.value - 12, state.dealer.value - 2] = value

        X, Y = np.meshgrid(range(2, 11 + 1), range(12, 21 + 1))
        xs = X.flatten()
        ys = Y.flatten()
        zs = values.flatten()

        ax = fig.add_subplot(2, 2, i + j*2 + 1, projection="3d")
        surf = ax.plot_trisurf(xs, ys, zs)
        ax.set_zlim3d(-1, 1)
        
        ax.set_xlabel("Dealer showing")
        ax.set_ylabel("Player sum")

fig.tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

---

Exercise 5.1: The value function jumps up for the last two rows as the policy is to stick at a value of 20 or 21 and the probability of the player winning when in either of these states is high. It drops off for the far right row (left in the book's figures) as the dealer has an ace (valued 11). An ace combined with a 10 means the player cannot and the probability of drawing a card with a value of 10 is relatively high (16/52). The frontmost values are higher when there is a usable ace as the player has less of a chance of going bust - if the player draws a high card to go bust in the lower diagrams they lose but in the upper diagrams the ace takes the value of 1 rather than 10 allowing the player to continue.

Exercise 5.2: The results will be identical as no state is visited twice in an episode of black jack (assuming "usable ace" is part of the state).

---

#### Backup Diagrams for MC Methods

The general idea of backup diagrams is to show at the top the root node to be updated and to show below all the transitions and leaf nodes whose rewards and estimated values contribute to the update. 

For MC estimation of $v_\pi$ the root node is a state node and below is the trace for a single episode. 

This is in contrast to DP algorithms that include only one-step transitions.

Another difference is that for MC methods the estimate of each state is independent - one estimate does not build upon another estimate. MC methods do not bootstrap.

Estimation of a state's value using MC methods is therefore independent of all other states. This allows only the state space of interest to be explored (which may be significantly smaller than the entire space). 

### Monte Carlo Estimation of Action Values

