# Temporal-Difference Learning

This notebook summarizes the Temporal-Difference Learning Reinforcement Learning algorithm, as described in David Silvers' Reinforcement Learning course. For each part the relevant theory will be introduced, followed by an implementation in Python.

We will test our implementation by following the assignment, Easy21, which can be found at: [https://www.davidsilver.uk/wp-content/uploads/2020/03/Easy21-Johannes.pdf](https://www.davidsilver.uk/wp-content/uploads/2020/03/Easy21-Johannes.pdf)

It covers:
* Temporal-Difference Policy Evaluation, both state and state-action (Lecture 3)
* Temporal-Difference Policy Iteration (Lecture 4)
* Model-Free control using Temporal-Difference & $\epsilon$-greedy policy improvement

Previous knowledge:
* Lectures 1 and 2

# Glossary

* MDP: Markov Decision Processes

# Implementing our environment

The assignment Easy21 describes a game similar to Blackjack. We will implement the environment of the game, which allow us to use our Monte Carlo Agent to play the game!

From the assignment: "You should write an environment that implements the game Easy21. Specifically, write a function, named step, which takes as input a state s (dealer’s first card 1–10 and the player’s sum 1–21), and an action a (hit or stick), and returns a sample of the next state $s_0$ (which may be terminal if the game is finished) and reward r. We will be using this environment for model-free reinforcement learning, and you should not explicitly represent the transition matrix for the MDP. There is no discounting (γ = 1). You should treat the dealer’s moves as part of the environment, i.e. calling step with a stick action will play out the dealer’s cards and return the final reward and terminal state"

In [1]:
from collections import namedtuple
from typing import Callable, List, Tuple

import numpy as np

STICK = 0
HIT = 1
ACTION_SPACE = [
    STICK,
    HIT,
]
DEALER_HIT_MAX = 17

# min in inclusive, max is exclusive!
Range = namedtuple('Range', ['min', 'max'])
PLAYER_RANGE_FOR_ACTION = Range(1, 22)
DEALER_RANGE_FOR_ACTION = Range(1, 11)


class Easy21:
    def __init__(self):
        self.action_space = [
            STICK,
            HIT,
        ]
        self.state_max_bound = (
            DEALER_RANGE_FOR_ACTION.max,
            PLAYER_RANGE_FOR_ACTION.max,
        )

        self.state_min_bound = (
            DEALER_RANGE_FOR_ACTION.min,
            PLAYER_RANGE_FOR_ACTION.min,
        )

    def reset(self):
        """
        Reset the environment to the initial state:
        "At the start of the game both the player and the dealer draw one black
        card (fully observed)"

        :return: Observation consisting of (dealer card, player card)
        """
        return tuple(np.random.randint(low=1, high=11, size=(2,)))

    def _draw(self) -> Tuple[int, str]:
        """
        Draw a card according to easy21 rules:
        - Each draw from the deck results in a value between 1 and 10 (uniformly distributed)
          with a colour of red (probability 1/3) or black (probability 2/3).
        - There are no aces or picture (face) cards in this game
        :return: Tuple of card value and card color
        """
        card = np.random.randint(1, 11)
        color = np.random.choice(['b', 'r'], p=[2/3, 1/3])

        return card, color

    def _draw_and_update(self, prev_sum):
        """
        Draw a card from the deck and add or subtract from the current sum of the given cards
        The values of the cards are added (black cards) or subtracted (red cards)

        :param prev_sum: Previous sum of cards
        :return: New sum of cards after drawing a card from the deck
        """
        value, color = self._draw()
        value = -1*value if color == 'r' else value
        return prev_sum + value

    def step(self, s: Tuple[int, int], a: int) -> Tuple[Tuple[int, int], int, bool]:
        """
        Takes as input a state s (dealer’s first card 1–10 and the player’s sum 1–21),
        and an action a (hit or stick), and returns a sample of the next state s0
        (which may be terminal if the game is finished) and reward r

        :param s: Input state s of format (dealer’s first card 1-10, player’s sum 1–21)
        :param a: Action to perform: 0-stick or 1-hit
        :return: The next state, the reward and True/False if terminal state
        """

        dealer_card, player_sum = s

        # Player sticks
        if STICK == a:
            s1, reward, is_terminal = self._stick(dealer_card, player_sum)
        # Player hits
        elif HIT == a:
            s1, reward, is_terminal = self._hit(dealer_card, player_sum)
        else:
            raise ValueError('Unknown action value:', a)

        # Clip state to boundaries
        s1 = tuple(np.clip(s1, a_min=self.state_min_bound, a_max=self.state_max_bound))

        return s1, reward, is_terminal

    def _stick(self, dealer_card: int, player_sum: int) -> Tuple[Tuple[int, int], int, bool]:
        """
        If the player sticks then the dealer starts taking turns. The dealer always
        sticks on any sum of 17 or greater, and hits otherwise.
        If the dealer goes bust, then the player wins; otherwise, the outcome – win (reward +1),
        lose (reward -1), or draw (reward 0) – is the player with the largest sum.

        :param dealer_card: The dealers' card, between 1-10
        :param player_sum: The sum of the players' cards, between 1-21
        :return: The next state, the reward and True/False if terminal state
        """

        # Dealer hits until reaching sum of 17 or greater
        dealer_sum = self._draw_and_update(dealer_card)
        # while 0 < dealer_sum < DEALER_HIT_MAX:
        while dealer_sum < DEALER_HIT_MAX:
            dealer_sum = self._draw_and_update(dealer_sum)

        # Dealer didn't go bust, winner has higher sum
        if 1 <= dealer_sum <= 21:
            if player_sum > dealer_sum:  # Player wins
                return (dealer_sum, player_sum), 1, True
            elif player_sum == dealer_sum:  # Tie
                return (dealer_sum, player_sum), 0, True
            else:  # Dealer wins!
                return (dealer_sum, player_sum), -1, True
        else:  # Dealer goes bust
            return (dealer_sum, player_sum), 1, True

    def _hit(self, dealer_card: int, player_sum: int) -> Tuple[Tuple[int, int], int, bool]:
        """
        If the player’s sum exceeds 21, or becomes less than 1, then she “goes bust” and loses the game (reward -1)

        :param dealer_card: The dealers' card, between 1-10
        :param player_sum: The sum of the players' cards, between 1-21
        :return: The next state, the reward and True/False if terminal state
        """
        new_player_sum = self._draw_and_update(player_sum)

        # Player still not bust yet
        if 1 <= new_player_sum <= 21:
            return (dealer_card, new_player_sum), 0, False
        else:
            return (dealer_card, new_player_sum), -1, True

In addition, we will also implement a method for running a full episode in our environment. For Monte-Carlo we will need to run many full episodes, so we'll implement this once and use it for all of our Monte-Carlo implementations.

In [2]:
def episode(env, policy: Callable) -> Tuple[List[Tuple[int, int]], List[int], List[int]]:
    """
    Run an episode in the given environment using the given policy

    :param env: Environment to run in, should have step and reset methods
    :param policy: Policy function, which given a state returns the action to perform
    :return: List of the states, list of the rewards and list of the actions
    """
    s0 = env.reset()

    episode_states, rewards, actions = [], [], []

    while True:
        a = policy(s0)
        s1, reward, is_terminal = env.step(s0, a)

        episode_states.append(s0)
        rewards.append(reward)
        actions.append(a)

        if is_terminal is True:
            break

        s0 = s1

    return episode_states, rewards, actions

Let's also define a variable for how many episodes we want to run each algorithm for:

In [14]:
N_EPISODES = 100_000
rng = np.random.default_rng()
PLOT_SIZE = (7, 7)

# What is Temporal-Difference Reinforcement Learning?

- TD methods learn directly from episodes of experience
- TD is model-free: no knowledge of MDP transitions / rewards
- TD learns from incomplete episodes, by bootstrapping 
	- Bootstrapping: Taking a current guess, moving partially through an episode and updating the previous guess with the new guess
- TD updates a guess towards a guess

# Temporal-Difference vs. Monte-Carlo Pros and Cons

TD can learn *before* knowing the final outcome
	- TD can learn online after every step
	- MC must wait until end of episode before return is known 
- TD can learn *without* the final outcome
	- TD can learn from incomplete sequences
	- MC can only learn from complete sequences
	- TD works in continuing (non-terminating) environments
	- MC only works for episodic (terminating) environments
- MC has high variance, zero bias
	- Good convergence properties
	- (even with function approximation)
	- Not very sensitive to initial value
	- Very simple to understand and use
- TD has low variance, some bias
	- Usually more efficient than MC
	- TD(0) converges to $v_π(s)$
	- (but not always with function approximation)
	- More sensitive to initial value

# Policy Evaluation

* Goal: learn $v_π$ from episodes of experience under policy $π$
$$S_1, A_1, R_2, ..., S_k ∼ π$$
* Recall that the return is the total discounted reward ($\gamma$ is discount factor):
$$G_t = R_{t+1} + γR_{t+2} + ... + γ^{T−1}R_T$$
- Recall that the value function is the expected return:
$$v_π(s) = E_π[G_t | S_t = s]$$
- Monte-Carlo policy evaluation uses empirical mean return instead of expected return
* Robot example:
	* Let it talk a walk, and calculate the reward observed from the walk
	* Each walk is an episode
	* $v_\pi=\mathbb{E}[All walks]$


In [None]:
from src.helpers import plot_value_function
from src.td_agent import SARSAAgent

env = Easy21()
sarsa = SARSAAgent(env = env, lambda_=1.0)
sarsa.learn(n_episodes=N_EPISODES)
plot_value_function(sarsa.Q, state_action_value_function=True, figsize=PLOT_SIZE)