In [2]:
from dataclasses import dataclass
from typing import Tuple, Dict, Mapping, List
from rl.distribution import Categorical
from rl.markov_decision_process import FiniteMarkovDecisionProcess
from rl.policy import FinitePolicy
from itertools import product, chain, combinations
from rl.dynamic_programming import policy_iteration_result, value_iteration_result

def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(1,len(s)+1))

In [3]:
@dataclass(frozen=True)
class DiceState:
    handstate: Tuple[List[int]]
    tablestate: Tuple[List[int]]
    condition: int

    def get_ones(self) -> int:
        return list(self.handstate).count(1)

    def calc_score(self) -> int:
        if(self.get_ones() >= self.condition):
            return sum(list(self.handstate))
        else:
            return 0    


DiceActionMapping = Mapping[
    DiceState,
    Mapping[Tuple[List[int]], Categorical[Tuple[DiceState, float]]]
]
        



In [4]:

class DiceMDP(FiniteMarkovDecisionProcess[int,int]):
    def __init__(self, n, k, c):
        self.num_dice = n
        self.faces = k
        self.condition = c

        super().__init__(self.get_action_transition_reward_map())

    def get_action_transition_reward_map(self) -> DiceActionMapping:
        d: Dict[DiceState, Dict[Tuple[List[int]], Categorical[Tuple[DiceState, float]]]] = {}


        allcombos = list(product(range(1,self.faces + 1), repeat=self.num_dice))
        for handamt in range(self.num_dice):
            print(handamt)
            print(len(allcombos))
            for i in range(len(allcombos)):
                if i > 2000:
                    if i % 100 == 0:
                        print(i)
                combo = allcombos[i]
                hand = []
                if handamt != 0:
                    hand = list(combo)[:handamt]

                table = list(combo)[-(self.num_dice - handamt):]
                state: DiceState = DiceState(tuple(hand), tuple(table), self.condition)

                d1: Dict[Tuple[List[int]], Categorical[Tuple[DiceState, float]]] = {}

                allactions = list(powerset(table))

                handones = state.get_ones()

                handscore = state.calc_score()

                for act in allactions:

                    reward = 0

                    actones = list(act).count(1)

                    if(handones + actones >= self.condition):
                        reward = sum(hand) + sum(list(act)) - handscore
                        
                    new_hand = hand + list(act)
                    if len(new_hand) == self.num_dice:
                        new_tables = [()]

                        sr_probs_dict: Dict[Tuple[DiceState, float], float] =\
                            {(DiceState(tuple(new_hand), (), self.condition), reward):
                            1
                        }
                        d1[act] = Categorical(sr_probs_dict)
                    else:

                        new_tables = list(product(range(self.faces + 1), repeat = (self.num_dice - len(new_hand))))
                        prob = (self.num_dice - len(new_hand))/self.faces

                        sr_probs_dict: Dict[Tuple[DiceState, float], float] =\
                            {(DiceState(tuple(new_hand), new_table, self.condition), reward):
                            prob for new_table in new_tables
                            }
                        d1[act] = Categorical(sr_probs_dict)
                d[state] = d1
        print("done")
        return d






In [5]:
game: FiniteMarkovDecisionProcess[DiceState, int] =\
    DiceMDP(
        n = 6,
        k = 4,
        c = 1
    )

gamma = 1.0

0
4096
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
1
4096
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
2
4096
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
3
4096
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4
4096
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
5
4096
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
done


MemoryError: 

In [6]:
result = value_iteration_result(game, gamma)