In [None]:
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from dataclasses import dataclass
from enum import Enum
import itertools
import matplotlib.pyplot as plt
import math
import numpy as np
import pandas as pd
import random
import torch
from typing import Dict, List, Optional, Set, Tuple, Union

In [None]:
from basic_utils import (
    ALL_ROLL_TUPLES,
    Box,
    BoxCategories,
    RollAction,
    RollValues,
    ScoreAction,
    ScoreCard,
    roll_first,
    remove_dice,
    roll_again,
    GameState,
    ROLL_TUPLES_BY_BOX,
)
from agents import Agent, EpsilonGreedyAgent, GreedyAgent, RandomAgent
from expected_score_utils import (
    all_expected_scores_table_by_box,
    best_expected_scores_table_by_box,
    best_roll_action_for_box_with_score,
    best_action_by_box_with_score,
    greedy_best_action,
    hit_probability_from_action,
    create_expected_scores_table_two_rolls,
)

A state in the game consists of a `ScoreCard` (score card state), `RollValues` (values of dice showing on the table), and `rolls_completed` from 1 to 3 within the turn. The `ScoreCard` contains all needed information from the previous turns. The `RollValues` just contains the values of the five dice that have been rolled at a given point.

`GameState` contains all three of these objects. It provides `possible_score_actions`, which gives the scores possible with a given set of dice values and score card state. I frame these as actions because at any time the player can choose to end their turn and score with one of these values. For convenience, they are sorted in descending order by score. `GameState` also provides `possible_actions`, which includes roll actions in addition to the `possible_score_actions`. The `re_roll` method takes dice that are specified by value and rolls again. Finally, `GameState` provides an `update_score` method, which updates the scorecard given a choice of box.

The current score can be accessed at any time by calling the `score` method of the `GameState`'s `scorecard`.

In [None]:
game_state = GameState()
game_state.start_turn()

In [None]:
action = greedy_best_action(game_state.roll_values, game_state.scorecard.unused_boxes, 3 - game_state.rolls_completed)
action

In [None]:
game_state.take_action(action)

In [None]:
action = greedy_best_action(game_state.roll_values, game_state.scorecard.unused_boxes, 3 - game_state.rolls_completed)
action

In [None]:
game_state.take_action(action)

In [None]:
action = greedy_best_action(game_state.roll_values, game_state.scorecard.unused_boxes, 3 - game_state.rolls_completed)
action

In [None]:
class YahtzeeDeepDoubleQ(Agent):
    """
    Using the MarioAgent here as a template:
    https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html
    """
    def __init__(self, state_dim, action_dim, save_dir, narrate: bool = False):
        super().__init__(narrate=narrate)
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.save_dir = save_dir

        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Mario's DNN to predict the most optimal action - we implement this in the Learn section
        self.net = YahtzeeNet(self.state_dim, self.action_dim).float()
        self.net = self.net.to(device=self.device)

        self.exploration_rate = 1
        self.exploration_rate_decay = 0.99999975
        self.exploration_rate_min = 0.1
        self.curr_step = 0

        self.save_every = 5e5  # no. of experiences between saving Mario Net

    def choose_action(self, game_state: GameState):
        """
        Given a state, choose an epsilon-greedy action and update value of step.

        Inputs:
        state(``LazyFrame``): A single observation of the current state, dimension is (state_dim)
        Outputs:
        ``action_idx`` (``int``): An integer representing which action Mario will perform
        """
        # EXPLORE
        if np.random.rand() < self.exploration_rate:
            action_idx = np.random.randint(self.action_dim)

        # EXPLOIT
        else:
            state = game_state.to_array()
            state = torch.tensor(state, device=self.device).unsqueeze(0)
            action_values = self.net(state, model="online")
            action_idx = torch.argmax(action_values, axis=1).item()

        # decrease exploration_rate
        self.exploration_rate *= self.exploration_rate_decay
        self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)

        # increment step
        self.curr_step += 1
        return action_idx

    def cache(self, experience):
        """Add the experience to memory"""
        pass

    def recall(self):
        """Sample experiences from memory"""
        pass

    def learn(self):
        """Update online action value (Q) function with a batch of experiences"""
        pass