# MDP example: dieN problem

In [1]:
import mdptoolbox
import numpy as np
import pandas as pd

In [2]:
def goodSides(mask):
    """
    Get the list of values of good sides from input mask.

    Args:
        mask (list): mask of die, 0 for good side, 1 for bad side, values are 1 indexed.

    Returns:
        list: list of good side values
    """
    return [i+1 for (i, v) in enumerate(mask) if v==0]

In [3]:
def getStates(sides, max_rolls):
    """
    Get all possible states given "maximum rolls" (not really how many times agent can roll the dice, 
    but for getting the largest value that is possible from rolling that many times).

    Args:
        sides (list): values of good sides
        max_rolls (int): maximux rolls that is possible when rolling the max side value continuously

    Returns:
        list: all possible states in ascending order, appended by the 'E' for end state and 'B' for bankrupt state
    """
    
    states = [0] + [x for x in sides]
    r = 0
    max_num = (max_rolls)*sides[-1]
    counts = 0
    while (True):
        counts = len(states)
        for i in range(len(states)):
            for j in range(len(sides)):
                s = states[i] + sides[j]
                if s not in states and s <=max_num:
                    states.append(s)
        if counts == len(states): break
    states = sorted(states)
    states.append('E')
    states.append('B')
    return states

In [4]:
def getTransitions(states, sides, N, max_rolls):
    """
    Generate transition matrices given states, good side and total side number.

    Args:
        states (list): all possible states
        sides (list): good side values (rewards)
        N (int): total number of sides
        max_rolls (int): see doc for getStatus

    Returns:
        np.array: transition matrices of size (A, S, S), A is #actions, S is #states
    """
    trans0 = pd.DataFrame(0, index=states, columns=states)
    trans1 = pd.DataFrame(0, index=states, columns=states)
    n = len(sides)                                     # number of good side
    b_rate = 1 - n/N                                   # probability of rolling a bad side
    trans0.iloc[len(sides)+1:, -1] = 1                 # preset all transitions to bankrupt state as true for action 'roll'
    for i in range(len(states)-2):                     # loop through all numerical states
        if states[i] <= sides[-1]*(max_rolls-1):       # check if the current row (state) is from less than max_rolls rolls
            for j in range(len(states)-2):
                if states[j] - states[i] in sides:
                    trans0.iloc[i, j] = 1/N
            trans0.iloc[i, -1] = b_rate                # set probability of transition to bankrupt state 
    trans1.iloc[:-1, -2] = 1
    trans1.iloc[-1, -1] = 1
    return np.stack((trans0.to_numpy(), trans1.to_numpy()), axis=0)

In [5]:
def getRewards(states):
    """
    Generate reward function.

    Args:
        states (list): all possible states

    Returns:
        reward (np.array): reward matrices of size (S, A)
    """
    rewards = pd.DataFrame(0, index=states, columns=["roll", "end"])
    for s in states:
        if s not in ['B', 'E']:
            rewards.loc[s, "end"] = s
    return rewards.to_numpy()

## Plug in mask and solve the MDP

In [6]:
mask = [0,0,0,0,0,1,0,1,0,0,1,0,0,1,0,0,0,0,1,0,0,0,0]
sides = goodSides(mask)
sides

[1, 2, 3, 4, 5, 7, 9, 10, 12, 13, 15, 16, 17, 18, 20, 21, 22, 23]

In [7]:
max_rolls = 3

In [8]:
states = getStates(sides, max_rolls)
states

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 'E',
 'B']

In [9]:
trans = getTransitions(states, sides, len(mask), max_rolls)
trans

array([[[0.        , 0.04347826, 0.04347826, ..., 0.        ,
         0.        , 0.2173913 ],
        [0.        , 0.        , 0.04347826, ..., 0.        ,
         0.        , 0.2173913 ],
        [0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 0.2173913 ],
        ...,
        [0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 1.        ],
        [0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 1.        ],
        [0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 1.        ]],

       [[0.        , 0.        , 0.        , ..., 0.        ,
         1.        , 0.        ],
        [0.        , 0.        , 0.        , ..., 0.        ,
         1.        , 0.        ],
        [0.        , 0.        , 0.        , ..., 0.        ,
         1.        , 0.        ],
        ...,
        [0.        , 0.        , 0.        , ..., 0.        ,
         1.        , 0.        ],
        [0. 

In [10]:
rewards = getRewards(states)
rewards

array([[ 0,  0],
       [ 0,  1],
       [ 0,  2],
       [ 0,  3],
       [ 0,  4],
       [ 0,  5],
       [ 0,  6],
       [ 0,  7],
       [ 0,  8],
       [ 0,  9],
       [ 0, 10],
       [ 0, 11],
       [ 0, 12],
       [ 0, 13],
       [ 0, 14],
       [ 0, 15],
       [ 0, 16],
       [ 0, 17],
       [ 0, 18],
       [ 0, 19],
       [ 0, 20],
       [ 0, 21],
       [ 0, 22],
       [ 0, 23],
       [ 0, 24],
       [ 0, 25],
       [ 0, 26],
       [ 0, 27],
       [ 0, 28],
       [ 0, 29],
       [ 0, 30],
       [ 0, 31],
       [ 0, 32],
       [ 0, 33],
       [ 0, 34],
       [ 0, 35],
       [ 0, 36],
       [ 0, 37],
       [ 0, 38],
       [ 0, 39],
       [ 0, 40],
       [ 0, 41],
       [ 0, 42],
       [ 0, 43],
       [ 0, 44],
       [ 0, 45],
       [ 0, 46],
       [ 0, 47],
       [ 0, 48],
       [ 0, 49],
       [ 0, 50],
       [ 0, 51],
       [ 0, 52],
       [ 0, 53],
       [ 0, 54],
       [ 0, 55],
       [ 0, 56],
       [ 0, 57],
       [ 0, 58

In [11]:
mdp = mdptoolbox.mdp.ValueIteration(trans, rewards, 1)



In [12]:
mdp.run()
print("V(0) = " + str(mdp.V[0]))

V(0) = 18.788893349181098
