# Yatzy RL environment

Yatzy program to play with and get policy and probabilities

**Features**

* single player setting to start with
* encoding state
* transition function for rules
* checking terminal state and computing reward
* actions to roll dice and select categories
* value iteration to compute value function / policy
* scorecard - which categories are used
* check if category can be used

**Getting probabilities**

Value iteration framework naturally gives us expected values (average scores), but if we want the full probability distribution of outcomes, we have a couple of options:

* Extend our existing transition function - Since we're already building transitions (state → action → next state), we can track probabilities through these transitions. This is essentially using the Markov chain structure we're already creating.
* Monte Carlo simulation - Once we have our optimal policy, we can run many simulations and collect the distribution of final scores empirically.

The first approach is more exact but computationally intensive. The second is simpler to implement and often sufficient.

I've decided on following **state representation**

- **Current dice values** - a sorted tuple (1, 2, 3, 3, 5)
- **Rolls left** - int - determines whether we can reroll or must score
- **Filled categories** - frozenset of strings - determines which scoring options are still available

**Value function** could be a dict:
```python
# V = {state: value}
state = (dice, rolls, categories)
value = V[state]
```

**Plan to Start:**

1. Define the 15 category names (as strings for frozenset)
2. Write scoring functions for each category
3. Implement dice rolling and state transitions
4. Define valid actions (which dice to keep, which category to score)
5. Build value iteration algorithm
6. Test and iterate

In [None]:
#| default_exp yatzy

In [None]:
#| export
import random
from dataclasses import dataclass
from typing import Literal
import itertools
from collections import Counter
from functools import partial

In [None]:
#| export
# categories

categories = frozenset([
    "ones",
    "twos",
    "threes",
    "fours",
    "fives",
    "sixes",
    "one_pair",
    "two_pairs",
    "three_same",
    "four_same",
    "small_straight",
    "large_straight",
    "full_house",
    "chance",
    "yatzy",
])

In [None]:
categories

frozenset({'chance',
           'fives',
           'four_same',
           'fours',
           'full_house',
           'large_straight',
           'one_pair',
           'ones',
           'sixes',
           'small_straight',
           'three_same',
           'threes',
           'two_pairs',
           'twos',
           'yatzy'})

In [None]:
# dice values
dice_vals = [5, 4, 3, 2, 1]
sorted_vals = tuple(sorted(dice_vals))
sorted_vals

(1, 2, 3, 4, 5)

In [None]:
score = sum(sorted_vals)
score

15

In [None]:
random.choices([1,2,3,4,5,6], k=5)

[4, 3, 5, 6, 6]

In [None]:
#| export
def roll_n_dice(n):
    return random.choices([1,2,3,4,5,6], k=n)

In [None]:
roll_n_dice(5)

[2, 5, 1, 6, 4]

## Planning more about how to represent state and actions

Okay, thinking about actions.

I would have 2 rolls left.

I can choose what dices to hold and which to re-roll. So basically 5 long binary word.

I can also choose small straight category, but no chance as its taken already.

I can also choose to set zero to any category that I've not taken already.

so action could be like
```
action = [00000, "small_straight", None]
# or
action = [11000, None, None]
# or
action = [00000, None, "ones"]
```

I need to have a function to compute possible actions.
* allow re-roll if "rolls left" is > 0
* check which categories can be selected
* which categories could be set to zero

## Back to implementing

In [None]:
#| export
@dataclass
class State:
    dice_values: tuple[int, ...]
    rolls_left: int
    categories_picked: frozenset[str]

In [None]:
state = State((1, 2, 3, 4, 5), 2, frozenset({'chance', 'large_straight', 'yatzy'}))
state

State(dice_values=(1, 2, 3, 4, 5), rolls_left=2, categories_picked=frozenset({'chance', 'yatzy', 'large_straight'}))

In [None]:
#| export
@dataclass
class Action:
    type: Literal["reroll", "score", "zero"]
    value: tuple[int, ...] | str

In [None]:
Action("reroll", (0,0,1,0,0))

Action(type='reroll', value=(0, 0, 1, 0, 0))

In [None]:
#| export
reroll_masks = list(itertools.product((0,1), repeat=5))

In [None]:
reroll_masks

[(0, 0, 0, 0, 0),
 (0, 0, 0, 0, 1),
 (0, 0, 0, 1, 0),
 (0, 0, 0, 1, 1),
 (0, 0, 1, 0, 0),
 (0, 0, 1, 0, 1),
 (0, 0, 1, 1, 0),
 (0, 0, 1, 1, 1),
 (0, 1, 0, 0, 0),
 (0, 1, 0, 0, 1),
 (0, 1, 0, 1, 0),
 (0, 1, 0, 1, 1),
 (0, 1, 1, 0, 0),
 (0, 1, 1, 0, 1),
 (0, 1, 1, 1, 0),
 (0, 1, 1, 1, 1),
 (1, 0, 0, 0, 0),
 (1, 0, 0, 0, 1),
 (1, 0, 0, 1, 0),
 (1, 0, 0, 1, 1),
 (1, 0, 1, 0, 0),
 (1, 0, 1, 0, 1),
 (1, 0, 1, 1, 0),
 (1, 0, 1, 1, 1),
 (1, 1, 0, 0, 0),
 (1, 1, 0, 0, 1),
 (1, 1, 0, 1, 0),
 (1, 1, 0, 1, 1),
 (1, 1, 1, 0, 0),
 (1, 1, 1, 0, 1),
 (1, 1, 1, 1, 0),
 (1, 1, 1, 1, 1)]

In [None]:
available_actions = []
if state.rolls_left > 0:
    rerolls = [Action("reroll", mask) for mask in reroll_masks]
    available_actions.extend(rerolls)
available_actions

[Action(type='reroll', value=(0, 0, 0, 0, 0)),
 Action(type='reroll', value=(0, 0, 0, 0, 1)),
 Action(type='reroll', value=(0, 0, 0, 1, 0)),
 Action(type='reroll', value=(0, 0, 0, 1, 1)),
 Action(type='reroll', value=(0, 0, 1, 0, 0)),
 Action(type='reroll', value=(0, 0, 1, 0, 1)),
 Action(type='reroll', value=(0, 0, 1, 1, 0)),
 Action(type='reroll', value=(0, 0, 1, 1, 1)),
 Action(type='reroll', value=(0, 1, 0, 0, 0)),
 Action(type='reroll', value=(0, 1, 0, 0, 1)),
 Action(type='reroll', value=(0, 1, 0, 1, 0)),
 Action(type='reroll', value=(0, 1, 0, 1, 1)),
 Action(type='reroll', value=(0, 1, 1, 0, 0)),
 Action(type='reroll', value=(0, 1, 1, 0, 1)),
 Action(type='reroll', value=(0, 1, 1, 1, 0)),
 Action(type='reroll', value=(0, 1, 1, 1, 1)),
 Action(type='reroll', value=(1, 0, 0, 0, 0)),
 Action(type='reroll', value=(1, 0, 0, 0, 1)),
 Action(type='reroll', value=(1, 0, 0, 1, 0)),
 Action(type='reroll', value=(1, 0, 0, 1, 1)),
 Action(type='reroll', value=(1, 0, 1, 0, 0)),
 Action(type=

In [None]:
can_zero = categories - state.categories_picked
can_zero

frozenset({'fives',
           'four_same',
           'fours',
           'full_house',
           'one_pair',
           'ones',
           'sixes',
           'small_straight',
           'three_same',
           'threes',
           'two_pairs',
           'twos'})

## I need checker functions for
* has_ones, has_twos, has_threes, ... I think I could have has_num(n) or something
* has_pair
* has_two_pairs
* has_three_same
* has_four_same
* has_small_straight
* has_large_straight
* has_full_house
* has_yatzy

In [None]:
values = state.dice_values
# has ones
1 in values
# n in values

# from collections import Counter

Counter(values)

Counter({1: 1, 2: 1, 3: 1, 4: 1, 5: 1})

In [None]:
counts = Counter((1, 1, 2, 2, 3))
# Keys with count >= 2
pairs = [k for k, v in counts.items() if v >= 2]
pairs

[1, 2]

In [None]:
counts = Counter((1, 2, 2, 2, 3))
three_sames = [k for k, v in counts.items() if v >= 3]
three_sames

[2]

In [None]:
counts = Counter((1, 5, 5, 5, 5))
four_sames = [k for k, v in counts.items() if v >= 4]
four_sames

[5]

In [None]:
counts = Counter((6,6,6,6,6))
yatzy = [k for k, v in counts.items() if v >= 5]
yatzy

[6]

In [None]:
# Full house
counts = Counter((5,5,5,6,6))
pairs = [k for k, v in counts.items() if v >= 2]
three_sames = [k for k, v in counts.items() if v >= 3]

pairs.remove(three_sames[0])
len(pairs) > 0

True

In [None]:
# small straight
values = (1, 2, 3, 4, 5)
values == (1, 2, 3, 4, 5)

# large straight
values = (2, 3, 4, 5, 6)
values == (2, 3, 4, 5, 6)


True

## Putting putting code together

In [None]:
#| export
def has_ns(n, values):
    return n in values

assert has_ns(3, (1,2,3,4,5))
assert not has_ns(6, (1,2,3,4,5))

def n_same(n, values):
    counts = Counter(values)
    return [k for k, v in counts.items() if v >= n]

assert n_same(2, (1,2,2,3,3)) == [2, 3]

def has_n_same(n, values):
    return bool(n_same(n, values))

def has_two_pairs(values):
    pairs = n_same(2, values)
    if len(pairs) == 2:
        return True
    return False

assert has_two_pairs((1,1,2,2,3))
assert not has_two_pairs((1,1,2,3,4))
assert not has_two_pairs((1,1,1,1,2))

def has_full_house(values):
    pairs = n_same(2, values)
    three_sames = n_same(3, values)
    if pairs and three_sames:
        pairs.remove(three_sames[0])
        return len(pairs) > 0
    return False

assert has_full_house((2,2,3,3,3))
assert not has_full_house((1,2,3,3,3))

def has_small_straight(values):
    return values == (1,2,3,4,5)

def has_large_straight(values):
    return values == (2,3,4,5,6)

In [None]:
can_zero

frozenset({'fives',
           'four_same',
           'fours',
           'full_house',
           'one_pair',
           'ones',
           'sixes',
           'small_straight',
           'three_same',
           'threes',
           'two_pairs',
           'twos'})

In [None]:
#| export
checkers = {
    "ones": partial(has_ns, 1),
    "twos": partial(has_ns, 2),
    "threes": partial(has_ns, 3),
    "fours": partial(has_ns, 4),
    "fives": partial(has_ns, 5),
    "sixes": partial(has_ns, 6),
    "one_pair": partial(has_n_same, 2),
    "two_pairs": has_two_pairs,
    "three_same": partial(has_n_same, 3),
    "four_same": partial(has_n_same, 4),
    "small_straight": has_small_straight,
    "large_straight": has_large_straight,
    "full_house": has_full_house,
    "chance": lambda x: x, # ???
    "yatzy": partial(n_same, 5),
}

In [None]:
for category in can_zero:
    can_select = checkers[category](state.dice_values)
    print(category, can_select)

sixes False
two_pairs False
full_house False
fours True
one_pair False
twos True
ones True
four_same False
small_straight True
fives True
threes True
three_same False


In [None]:
available_actions

[Action(type='reroll', value=(0, 0, 0, 0, 0)),
 Action(type='reroll', value=(0, 0, 0, 0, 1)),
 Action(type='reroll', value=(0, 0, 0, 1, 0)),
 Action(type='reroll', value=(0, 0, 0, 1, 1)),
 Action(type='reroll', value=(0, 0, 1, 0, 0)),
 Action(type='reroll', value=(0, 0, 1, 0, 1)),
 Action(type='reroll', value=(0, 0, 1, 1, 0)),
 Action(type='reroll', value=(0, 0, 1, 1, 1)),
 Action(type='reroll', value=(0, 1, 0, 0, 0)),
 Action(type='reroll', value=(0, 1, 0, 0, 1)),
 Action(type='reroll', value=(0, 1, 0, 1, 0)),
 Action(type='reroll', value=(0, 1, 0, 1, 1)),
 Action(type='reroll', value=(0, 1, 1, 0, 0)),
 Action(type='reroll', value=(0, 1, 1, 0, 1)),
 Action(type='reroll', value=(0, 1, 1, 1, 0)),
 Action(type='reroll', value=(0, 1, 1, 1, 1)),
 Action(type='reroll', value=(1, 0, 0, 0, 0)),
 Action(type='reroll', value=(1, 0, 0, 0, 1)),
 Action(type='reroll', value=(1, 0, 0, 1, 0)),
 Action(type='reroll', value=(1, 0, 0, 1, 1)),
 Action(type='reroll', value=(1, 0, 1, 0, 0)),
 Action(type=

In [None]:
for cat in can_zero:
    available_actions.append(Action(type="zero", value=cat))
    can_select = checkers[cat](state.dice_values)
    if can_select:
        available_actions.append(Action(type="score", value=cat))
available_actions

[Action(type='reroll', value=(0, 0, 0, 0, 0)),
 Action(type='reroll', value=(0, 0, 0, 0, 1)),
 Action(type='reroll', value=(0, 0, 0, 1, 0)),
 Action(type='reroll', value=(0, 0, 0, 1, 1)),
 Action(type='reroll', value=(0, 0, 1, 0, 0)),
 Action(type='reroll', value=(0, 0, 1, 0, 1)),
 Action(type='reroll', value=(0, 0, 1, 1, 0)),
 Action(type='reroll', value=(0, 0, 1, 1, 1)),
 Action(type='reroll', value=(0, 1, 0, 0, 0)),
 Action(type='reroll', value=(0, 1, 0, 0, 1)),
 Action(type='reroll', value=(0, 1, 0, 1, 0)),
 Action(type='reroll', value=(0, 1, 0, 1, 1)),
 Action(type='reroll', value=(0, 1, 1, 0, 0)),
 Action(type='reroll', value=(0, 1, 1, 0, 1)),
 Action(type='reroll', value=(0, 1, 1, 1, 0)),
 Action(type='reroll', value=(0, 1, 1, 1, 1)),
 Action(type='reroll', value=(1, 0, 0, 0, 0)),
 Action(type='reroll', value=(1, 0, 0, 0, 1)),
 Action(type='reroll', value=(1, 0, 0, 1, 0)),
 Action(type='reroll', value=(1, 0, 0, 1, 1)),
 Action(type='reroll', value=(1, 0, 1, 0, 0)),
 Action(type=

In [None]:
#| export
def get_available_actions(state: State) -> list[Action]:
    available_actions = []
    if state.rolls_left > 0:
        rerolls = [Action("reroll", mask) for mask in reroll_masks]
        available_actions.extend(rerolls)
    
    can_zero = categories - state.categories_picked
    for cat in can_zero:
        available_actions.append(Action(type="zero", value=cat))
        can_select = checkers[cat](state.dice_values)
        if can_select:
            available_actions.append(Action(type="score", value=cat))
    
    return available_actions

In [None]:
get_available_actions(state)

[Action(type='reroll', value=(0, 0, 0, 0, 0)),
 Action(type='reroll', value=(0, 0, 0, 0, 1)),
 Action(type='reroll', value=(0, 0, 0, 1, 0)),
 Action(type='reroll', value=(0, 0, 0, 1, 1)),
 Action(type='reroll', value=(0, 0, 1, 0, 0)),
 Action(type='reroll', value=(0, 0, 1, 0, 1)),
 Action(type='reroll', value=(0, 0, 1, 1, 0)),
 Action(type='reroll', value=(0, 0, 1, 1, 1)),
 Action(type='reroll', value=(0, 1, 0, 0, 0)),
 Action(type='reroll', value=(0, 1, 0, 0, 1)),
 Action(type='reroll', value=(0, 1, 0, 1, 0)),
 Action(type='reroll', value=(0, 1, 0, 1, 1)),
 Action(type='reroll', value=(0, 1, 1, 0, 0)),
 Action(type='reroll', value=(0, 1, 1, 0, 1)),
 Action(type='reroll', value=(0, 1, 1, 1, 0)),
 Action(type='reroll', value=(0, 1, 1, 1, 1)),
 Action(type='reroll', value=(1, 0, 0, 0, 0)),
 Action(type='reroll', value=(1, 0, 0, 0, 1)),
 Action(type='reroll', value=(1, 0, 0, 1, 0)),
 Action(type='reroll', value=(1, 0, 0, 1, 1)),
 Action(type='reroll', value=(1, 0, 1, 0, 0)),
 Action(type=

**What I have**

* state representation
* action representation
* categories
* dice rolling
* functions to check if dice values fit category
* available actions function

**What is missing**
* transition function
    * logic for reroll
    * logic for selecting zeroing categories

**Transition function**

state + action --> next state

```python
next_state = transition_func(state, action)
```

**What transitions are there?**

Action types:
* reroll
* zero
* score

**reroll**
* reroll the dice according to mask
* set new dice to state
* decrement rolls left

**zero**
* set the category as selected
* check wheter categories left and set terminal state if not any left
* reroll whole dice
* reset rolls left to 2

**score**
* compute score
* set the category as selected
* add score to scoreboard or total sum
* check wheter categories left and set terminal state if not any left
* reroll whole dice
* reset rolls left to 2


Let's have a state and action to work with.

In [None]:
state

State(dice_values=(1, 2, 3, 4, 5), rolls_left=2, categories_picked=frozenset({'chance', 'yatzy', 'large_straight'}))

In [None]:
action = Action(type='reroll', value=(1, 1, 1, 0, 0))
action

Action(type='reroll', value=(1, 1, 1, 0, 0))

What it would mean to transition to next state given this state and action?

In [None]:
#| export
def reroll_masked(values, mask):
    next_values = list(values)
    mask = action.value
    reroll = roll_n_dice(5)
    for i, bit in enumerate(mask):
        if bit == 1:
            next_values[i] = reroll[i]
    return tuple(sorted(next_values))

In [None]:
reroll_masked((1,2,3,4,5), (1,1,1,0,0))

(1, 3, 3, 4, 5)

In [None]:
#| export
def transition_func(state: State, action: Action):
    match action.type:
        case "reroll":
            assert state.rolls_left > 0
            new_values = reroll_masked(state.dice_values, action.value)
            return State(new_values, state.rolls_left - 1, state.categories_picked)
        case "score": ...
        case "zero":
            assert action.value not in state.categories_picked
            new_categories = state.categories_picked | {action.value}
            categories_left = categories - state.categories_picked
            if not categories_left: return TERMINAL_STATE
            new_values = tuple(sorted(roll_n_dice(5)))
            return State(new_values, 2, frozenset(new_categories))

In [None]:
print(transition_func(state, action))
print(transition_func(state, Action("zero", "ones")))

State(dice_values=(4, 5, 5, 6, 6), rolls_left=1, categories_picked=frozenset({'chance', 'yatzy', 'large_straight'}))
State(dice_values=(1, 1, 2, 6, 6), rolls_left=2, categories_picked=frozenset({'chance', 'yatzy', 'large_straight', 'ones'}))


Getting pretty far

##### 🤖Reply🤖<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

You're making excellent progress! You've built a lot of the core components:

✓ State and Action representations  
✓ Dice rolling mechanics  
✓ Category checking functions  
✓ Available actions generation  
✓ Transition function (partially complete)

I noticed your transition function still has the `"score"` case to implement. That one should be similar to `"zero"` but also needs to:
- Calculate the actual score using your `scorers` dict
- Return that score somehow (maybe as a reward?)

A couple of questions to help you move forward:

1. How do you want to handle the score/reward? Will your transition function return just the next state, or a tuple like `(next_state, reward)`?

2. Have you thought about what `TERMINAL_STATE` should be? (I see you referenced it in the zero case)

What would you like to tackle next?

## Scoring stuff WIP

In [None]:
def score_n(value, values): 
    n = values.count(value)
    return value * n
assert score_n(1, (1,1,1,2,3)) == 3

def score_same(n, values): 
    # todo
    return value * n


def score_yatzy(value): return value * 5 + 50
def score_two_pairs(pair1, pair2): return 2 * pair1 + 2 * pair2
def score_full_house(pair_value, three_value): return 2 * pair_value + 3 * three_value
def score_small_straight(): return 15
def score_large_straight(): return 20


In [None]:
scorers = {
    "ones": partial(score_n, 1),
    "twos": partial(score_n, 2),
    "threes": partial(score_n, 3),
    "fours": partial(score_n, 4),
    "fives": partial(score_n, 5),
    "sixes": partial(score_n, 6),
    "one_pair": partial(score_same, 2),
    "two_pairs": score_two_pairs,
    "three_same": partial(score_same, 3),
    "four_same": partial(score_same, 4),
    "small_straight": score_small_straight,
    "large_straight": score_large_straight,
    "full_house": score_full_house,
    "chance": lambda x: sum(x),
    "yatzy": score_yatzy,
}