In [None]:
import random
import itertools
from collections import Counter
from functools import lru_cache
import multiprocessing

# -----------------------------
# Scorecard & Categories
# -----------------------------
class Scorecard:
    UPPER_CATS = ['ones','twos','threes','fours','fives','sixes']
    LOWER_CATS = ['three_of_a_kind','four_of_a_kind','full_house',
                  'small_straight','large_straight','yahtzee','chance']
    BONUS_THRESHOLD = 63
    BONUS_SCORE = 35

    def __init__(self):
        self.scores = {cat: None for cat in self.UPPER_CATS + self.LOWER_CATS}

    def available(self):
        return [cat for cat,val in self.scores.items() if val is None]

    def set_score(self, category, dice_values):
        if category not in self.scores or self.scores[category] is not None:
            raise ValueError("Invalid or already used category")
        self.scores[category] = Score.category_score(category, dice_values)

    def subtotal_upper(self):
        return sum(self.scores[cat] or 0 for cat in self.UPPER_CATS)

    def bonus(self):
        return self.BONUS_SCORE if self.subtotal_upper() >= self.BONUS_THRESHOLD else 0

    def total(self):
        return sum(v or 0 for v in self.scores.values()) + self.bonus()

# -----------------------------
# Static Scoring Functions
# -----------------------------
class Score:
    @staticmethod
    def category_score(cat, dice_values):
        dice = Counter(dice_values)
        total = sum(dice_values)
        if cat == 'ones': return dice[1]
        if cat == 'twos': return dice[2]*2
        if cat == 'threes': return dice[3]*3
        if cat == 'fours': return dice[4]*4
        if cat == 'fives': return dice[5]*5
        if cat == 'sixes': return dice[6]*6
        if cat == 'three_of_a_kind': return total if any(v>=3 for v in dice.values()) else 0
        if cat == 'four_of_a_kind': return total if any(v>=4 for v in dice.values()) else 0
        if cat == 'full_house': return 25 if sorted(dice.values())==[2,3] else 0
        if cat == 'small_straight':
            sets = [{1,2,3,4},{2,3,4,5},{3,4,5,6}]
            return 30 if any(s.issubset(dice.keys()) for s in sets) else 0
        if cat == 'large_straight':
            sets = [{1,2,3,4,5},{2,3,4,5,6}]
            return 40 if any(s.issubset(dice.keys()) for s in sets) else 0
        if cat == 'yahtzee': return 50 if any(v==5 for v in dice.values()) else 0
        if cat == 'chance': return total
        raise ValueError(f"Unknown category: {cat}")

# -----------------------------
# Helper for Parallel Monte Carlo
# -----------------------------
def _simulate_once(args):
    hold_vals, rolls_left, available_cats = args
    dice = list(hold_vals)
    for _ in range(rolls_left):
        need = 5 - len(dice)
        dice += [random.randint(1,6) for _ in range(need)]
    return max(Score.category_score(cat, dice) for cat in available_cats)

# -----------------------------
# Strategy Engine
# -----------------------------
class Strategy:
    def __init__(self, simulations=500, exact_threshold=2, processes=None):
        self.simulations = simulations
        self.exact_threshold = exact_threshold
        self.processes = processes

    @staticmethod
    def all_hold_indices():
        idx_combs = []
        for r in range(6):
            for comb in itertools.combinations(range(5), r):
                idx_combs.append(comb)
        return idx_combs

    @lru_cache(maxsize=None)
    def enumerate_ev(self, hold_vals, rolls_left, available_cats):
        num_reroll = 5 - len(hold_vals)
        if rolls_left == 0:
            return max(Score.category_score(cat, hold_vals) for cat in available_cats)
        total_ev = 0
        for faces in itertools.product(range(1,7), repeat=num_reroll):
            new = tuple(sorted(hold_vals + faces))
            total_ev += self.enumerate_ev(new, rolls_left - 1, available_cats)
        return total_ev / (6 ** num_reroll)

    def monte_carlo_ev(self, hold_vals, rolls_left, available_cats):
        args = [(tuple(hold_vals), rolls_left, tuple(available_cats)) for _ in range(self.simulations)]
        ctx = multiprocessing.get_context('fork')
        with ctx.Pool(processes=self.processes) as pool:
            results = pool.map(_simulate_once, args)
        return sum(results) / len(results)

    def hold_ev(self, hold_vals, rolls_left, available_cats):
        if 5 - len(hold_vals) <= self.exact_threshold:
            return self.enumerate_ev(tuple(sorted(hold_vals)), rolls_left, tuple(available_cats))
        return self.monte_carlo_ev(hold_vals, rolls_left, available_cats)

    def best_hold(self, dice_vals, rolls_left, available_cats):
        best_idx, best_ev = (), -1
        for idxs in self.all_hold_indices():
            hold_vals = tuple(dice_vals[i] for i in idxs)
            ev = self.hold_ev(hold_vals, rolls_left, available_cats)
            if ev > best_ev:
                best_idx, best_ev = idxs, ev
        return list(best_idx), best_ev

# -----------------------------
# Game Flow with Auto-Detect & Custom Holds
# -----------------------------
class YahtzeeGame:
    def __init__(self, simulations=200, exact_threshold=2, processes=None):
        multiprocessing.freeze_support()
        self.scorecard = Scorecard()
        self.strategy = Strategy(simulations=simulations,
                                  exact_threshold=exact_threshold,
                                  processes=processes)

    def input_full_roll(self):
        while True:
            raw = input("Enter your 5 dice values (1-6) separated by spaces: ")
            parts = raw.strip().split()
            if len(parts) != 5:
                print("Please enter exactly 5 values.")
                continue
            try:
                vals = [int(p) for p in parts]
            except ValueError:
                print("Please enter integers only.")
                continue
            if all(1 <= v <= 6 for v in vals):
                return vals
            print("Values must be between 1 and 6.")

    def input_reroll(self, count):
        while True:
            raw = input(f"Enter your {count} rerolled dice values (1-6): ")
            parts = raw.strip().split()
            if len(parts) != count:
                print(f"Enter exactly {count} values.")
                continue
            try:
                vals = [int(p) for p in parts]
            except ValueError:
                print("Please enter integers only.")
                continue
            if all(1 <= v <= 6 for v in vals):
                return vals
            print("Values must be between 1 and 6.")

    def play_turn(self):
        print("\n--- New Turn ---")
        dice_vals = self.input_full_roll()
        rolls_left = 2
        available = self.scorecard.available()

        while True:
            suggested_idxs, ev = self.strategy.best_hold(dice_vals, rolls_left, available)
            suggested_vals = [dice_vals[i] for i in suggested_idxs]
            print(f"\nCurrent roll: {dice_vals}")
            print(f"Suggested hold positions: {suggested_idxs} (values {suggested_vals}) (EV = {ev:.1f})")

            if rolls_left == 0:
                break

            cmd = input("Enter 'R' to reroll suggested hold, 'C' for custom hold, or 'S' to score now: ").strip().upper()
            if cmd == 'S':
                break

            # Determine hold indices
            if cmd == 'C':
                while True:
                    raw = input("Enter indices (0-4) of dice to hold, separated by spaces: ")
                    parts = raw.strip().split()
                    try:
                        custom_idxs = [int(p) for p in parts]
                        if all(0 <= i < 5 for i in custom_idxs):
                            hold_idxs = custom_idxs
                            break
                    except ValueError:
                        pass
                    print("Invalid indices. Try again.")
            else:  # 'R' or any other
                hold_idxs = suggested_idxs

            reroll_count = 5 - len(hold_idxs)
            new_vals = self.input_reroll(reroll_count)
            # Reconstruct full hand
            new_hand = []
            it = iter(new_vals)
            for i in range(5):
                if i in hold_idxs:
                    new_hand.append(dice_vals[i])
                else:
                    new_hand.append(next(it))
            dice_vals = new_hand
            rolls_left -= 1

        print(f"\nAvailable categories: {available}")
        while True:
            cat = input("Choose category to score: ").strip()
            if cat in available:
                break
            print("Invalid category. Choose from:", available)

        self.scorecard.set_score(cat, dice_vals)
        print(f"Scored '{cat}' = {self.scorecard.scores[cat]}")
        print(f"Total so far = {self.scorecard.total()}\n")

    def play_game(self):
        while self.scorecard.available():
            self.play_turn()
        print("\n--- GAME OVER ---")
        print(f"Final scores: {self.scorecard.scores}")
        print(f"Upper subtotal: {self.scorecard.subtotal_upper()} + bonus {self.scorecard.bonus()} = {self.scorecard.subtotal_upper() + self.scorecard.bonus()}")
        print(f"Grand Total = {self.scorecard.total()}")

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description="Yahtzee optimal play coach")
    parser.add_argument('--simulations', type=int, default=200, help='Number of Monte Carlo simulations')
    parser.add_argument('--exact', type=int, default=2, help='Max dice to reroll for exact enumeration')
    parser.add_argument('--processes', type=int, default=None, help='Number of parallel processes')
    args, _ = parser.parse_known_args()
    game = YahtzeeGame(simulations=args.simulations,
                       exact_threshold=args.exact,
                       processes=args.processes)
    game.play_game()



--- New Turn ---
Enter your 5 dice values (1-6) separated by spaces: 1 3 3 4 5

Current roll: [1, 3, 3, 4, 5]
Suggested hold positions: [1, 3, 4] (values [3, 4, 5]) (EV = 26.0)
Enter 'R' to reroll suggested hold, 'C' for custom hold, or 'S' to score now: C
Enter indices (0-4) of dice to hold, separated by spaces: 1 2
Enter your 3 rerolled dice values (1-6): 3 1 2

Current roll: [3, 3, 3, 1, 2]
Suggested hold positions: [2] (values [3]) (EV = 21.7)
