In [None]:
####################################################
# Notebook for Kaggle runs (with 30hr of free GPU) #
####################################################

# All important files of the repository jammed into one big file


# EVALUATION FILES FIRST

# evaluation/eval_card.py

# From https://github.com/keithlee96/pluribus-poker-AI/blob/develop/poker_ai/poker/evaluation/eval_card.py
# Binary Representation of Cards

import numpy as np
from matplotlib import pyplot as plt
import os

class EvaluationCard:
    """
    Static class that handles cards. We represent cards as 32-bit integers, so
    there is no object instantiation - they are just ints. Most of the bits are
    used, and have a specific meaning. See below:

                                    EvaluationCard:

                          bitrank     suit rank   prime
                    +--------+--------+--------+--------+
                    |xxxbbbbb|bbbbbbbb|cdhsrrrr|xxpppppp|
                    +--------+--------+--------+--------+

        1) p = prime number of rank (deuce=2,trey=3,four=5,...,ace=41)
        2) r = rank of card (deuce=0,trey=1,four=2,five=3,...,ace=12)
        3) cdhs = suit of card (bit turned on based on suit of card)
        4) b = bit turned on depending on rank of card
        5) x = unused

    This representation will allow us to do very important things like:
    - Make a unique prime prodcut for each hand
    - Detect flushes
    - Detect straights

    and is also quite performant.
    """

    # the basics
    STR_RANKS = "23456789TJQKA"
    INT_RANKS = range(13)
    PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41]

    # conversion from string => int
    CHAR_RANK_TO_INT_RANK = dict(zip(list(STR_RANKS), INT_RANKS))
    INT_SUIT_TO_BINARY_SUIT = {
        1: 1,  # spades
        2: 2,  # hearts
        3: 4,  # diamonds
        4: 8,  # clubs
    }
    INT_SUIT_TO_CHAR_SUIT = "xshxdxxxc"

    # for pretty printing
    PRETTY_SUITS = {
        1: chr(9824),  # spades
        2: chr(9829),  # hearts
        4: chr(9830),  # diamonds
        8: chr(9827),  # clubs
    }

    # hearts and diamonds
    PRETTY_REDS = [2, 4]

    @staticmethod
    def new(card: np.array):
        """
        Converts EvaluationCard np.array to binary integer representation of card, inspired by:

        http://www.suffecool.net/poker/evaluator.html
        """

        rank_int = card[0]
        suit_int = card[1]
        # rank_int = EvaluationCard.INT_SUIT_TO_BINARY_SUIT[rank_int]
        suit_int = EvaluationCard.INT_SUIT_TO_BINARY_SUIT[suit_int]
        rank_prime = EvaluationCard.PRIMES[rank_int]

        bitrank = 1 << rank_int << 16
        suit = suit_int << 12
        rank = rank_int << 8

        return bitrank | suit | rank | rank_prime

    @staticmethod
    def int_to_str(card_int):
        rank_int = EvaluationCard.get_rank_int(card_int)
        suit_int = EvaluationCard.get_suit_int(card_int)
        return EvaluationCard.STR_RANKS[rank_int] + EvaluationCard.INT_SUIT_TO_CHAR_SUIT[suit_int]

    @staticmethod
    def get_rank_int(card_int):
        return (card_int >> 8) & 0xF

    @staticmethod
    def get_suit_int(card_int):
        return (card_int >> 12) & 0xF

    @staticmethod
    def get_bitrank_int(card_int):
        return (card_int >> 16) & 0x1FFF

    @staticmethod
    def get_prime(card_int):
        return card_int & 0x3F

    @staticmethod
    def hand_to_binary(card_strs):
        """
        Expects a list of cards as strings and returns a list
        of integers of same length corresponding to those strings.
        """
        bhand = []
        for c in card_strs:
            bhand.append(EvaluationCard.new(c))
        return bhand

    @staticmethod
    def prime_product_from_hand(card_ints):
        """
        Expects a list of cards in integer form.
        """
        product = 1
        for c in card_ints:
            product *= c & 0xFF
        return product

    @staticmethod
    def prime_product_from_rankbits(rankbits):
        """
        Returns the prime product using the bitrank (b)
        bits of the hand. Each 1 in the sequence is converted
        to the correct prime and multiplied in.

        Params:
            rankbits = a single 32-bit (only 13-bits set) integer representing
                    the ranks of 5 _different_ ranked cards
                    (5 of 13 bits are set)

        Primarily used for evaulating flushes and straights,
        two occasions where we know the ranks are *ALL* different.

        Assumes that the input is in form (set bits):

                              rankbits
                        +--------+--------+
                        |xxxbbbbb|bbbbbbbb|
                        +--------+--------+

        """
        product = 1
        for i in EvaluationCard.INT_RANKS:
            # if the ith bit is set
            if rankbits & (1 << i):
                product *= EvaluationCard.PRIMES[i]
        return product

    @staticmethod
    def int_to_binary(card_int):
        """
        For debugging purposes. Displays the binary number as a
        human readable string in groups of four digits.
        """
        bstr = bin(card_int)[2:][::-1]  # chop off the 0b and THEN reverse string
        output = list("".join(["0000" + "\t"] * 7) + "0000")

        for i in range(len(bstr)):
            output[i + int(i / 4)] = bstr[i]

        # output the string to console
        output.reverse()
        return "".join(output)

    @staticmethod
    def int_to_pretty_str(card_int):
        """
        Prints a single card
        """

        color = False
        try:
            from termcolor import colored
            # for mac, linux: http://pypi.python.org/pypi/termcolor
            # can use for windows: http://pypi.python.org/pypi/colorama
            color = True
        except ImportError:
            pass

        # suit and rank
        suit_int = EvaluationCard.get_suit_int(card_int)
        rank_int = EvaluationCard.get_rank_int(card_int)

        # if we need to color red
        s = EvaluationCard.PRETTY_SUITS[suit_int]
        if color and suit_int in EvaluationCard.PRETTY_REDS:
            s = colored(s, "red")

        r = EvaluationCard.STR_RANKS[rank_int]

        return f"[{r}{s}]"

    @staticmethod
    def print_pretty_card(card_int):
        """
        Expects a single integer as input
        """
        print(EvaluationCard.int_to_pretty_str(card_int))

    @staticmethod
    def print_pretty_cards(card_ints):
        """
        Expects a list of cards in integer form.
        """
        output = " "
        for i in range(len(card_ints)):
            c = card_ints[i]
            if i != len(card_ints) - 1:
                output += str(EvaluationCard.int_to_pretty_str(c)) + ","
            else:
                output += str(EvaluationCard.int_to_pretty_str(c)) + " "

        print(output)

In [None]:
# deck.py
class Deck:
    _ranks = range(13)
    _suits = [1, 2, 3, 4]

    def __init__(self):
        self._deck = self._create_deck()
        self.pos = 0

    def _create_deck(self):
        ranks = np.tile(self._ranks, len(self._suits))
        suits = np.repeat(self._suits, len(self._ranks))

        return np.array([ranks, suits]).T

    def _shuffle(self):
        np.random.shuffle(self._deck)

    def deal(self, n=1):
        cards = self._deck[self.pos: self.pos + n].tolist()
        self.pos += n
        return cards

    def simulate_deal(self, n=3):
        cards = self._deck[self.pos: self.pos + n].tolist()
        return cards
    
    def deal_one(self):
        card = self._deck[self.pos]
        self.pos += 1
        return card
    
    def burn(self):
        self.pos += 1
    
    def reset(self):
        self.pos = 0
        self._shuffle()
    
    def __len__(self):
        return len(self._deck)
    
    def __str__(self):
        return ''.join([f"{rank} of {suit}" for rank, suit in self._deck], '\n')


In [None]:
# bots/action.py

from enum import IntEnum

class ActionType(IntEnum):
    FOLD = 0
    CALL = 1
    RAISE = 2

class Action:
    def __init__(self, action_type: ActionType = ActionType.FOLD, bet: int = 0):
        self.type = action_type
        self.bet = bet

    def type_str(self):
        if self.type == ActionType.FOLD:
            return "FOLD"
        elif self.type == ActionType.CALL:
            return "CALL"
        elif self.type == ActionType.RAISE:
            return "RAISE"

    def __str__(self):
        return f"Action({self.type_str()}, {self.bet})"

In [None]:
# evaluation/lookup.py

# Copied from https://github.com/keithlee96/pluribus-poker-AI/blob/develop/poker_ai/poker/evaluation/lookup.py
# Lookup Table for quick hand evaluation

import itertools

class LookupTable(object):
    """
    Number of Distinct Hand Values:

    Straight Flush   10
    Four of a Kind   156      [(13 choose 2) * (2 choose 1)]
    Full Houses      156      [(13 choose 2) * (2 choose 1)]
    Flush            1277     [(13 choose 5) - 10 straight flushes]
    Straight         10
    Three of a Kind  858      [(13 choose 3) * (3 choose 1)]
    Two Pair         858      [(13 choose 3) * (3 choose 2)]
    One Pair         2860     [(13 choose 4) * (4 choose 1)]
    High Card      + 1277     [(13 choose 5) - 10 straights]
    -------------------------
    TOTAL            7462

    Here we create a lookup table which maps:
        5 card hand's unique prime product => rank in range [1, 7462]

    Examples:
    * Royal flush (best hand possible)          => 1
    * 7-5-4-3-2 unsuited (worst hand possible)  => 7462
    """

    MAX_STRAIGHT_FLUSH = 10
    MAX_FOUR_OF_A_KIND = 166
    MAX_FULL_HOUSE = 322
    MAX_FLUSH = 1599
    MAX_STRAIGHT = 1609
    MAX_THREE_OF_A_KIND = 2467
    MAX_TWO_PAIR = 3325
    MAX_PAIR = 6185
    MAX_HIGH_CARD = 7462

    MAX_TO_RANK_CLASS = {
        MAX_STRAIGHT_FLUSH: 1,
        MAX_FOUR_OF_A_KIND: 2,
        MAX_FULL_HOUSE: 3,
        MAX_FLUSH: 4,
        MAX_STRAIGHT: 5,
        MAX_THREE_OF_A_KIND: 6,
        MAX_TWO_PAIR: 7,
        MAX_PAIR: 8,
        MAX_HIGH_CARD: 9,
    }

    RANK_CLASS_TO_STRING = {
        1: "Straight Flush",
        2: "Four of a Kind",
        3: "Full House",
        4: "Flush",
        5: "Straight",
        6: "Three of a Kind",
        7: "Two Pair",
        8: "Pair",
        9: "High Card",
    }

    def __init__(self):
        """
        Calculates lookup tables
        """
        # create dictionaries
        self.flush_lookup = {}
        self.unsuited_lookup = {}

        # create the lookup table in piecewise fashion
        # this will call straights and high cards method,
        # we reuse some of the bit sequences
        self.flushes()
        self.multiples()

    def flushes(self):
        """
        Straight flushes and flushes.

        Lookup is done on 13 bit integer (2^13 > 7462):
        xxxbbbbb bbbbbbbb => integer hand index
        """

        # straight flushes in rank order
        straight_flushes = [
            7936,  # int('0b1111100000000', 2), # royal flush
            3968,  # int('0b111110000000', 2),
            1984,  # int('0b11111000000', 2),
            992,  # int('0b1111100000', 2),
            496,  # int('0b111110000', 2),
            248,  # int('0b11111000', 2),
            124,  # int('0b1111100', 2),
            62,  # int('0b111110', 2),
            31,  # int('0b11111', 2),
            4111,  # int('0b1000000001111', 2) # 5 high
        ]

        # now we'll dynamically generate all the other
        # flushes (including straight flushes)
        flushes = []
        gen = self.get_lexographically_next_bit_sequence(int("0b11111", 2))

        # 1277 = number of high cards
        # 1277 + len(str_flushes) is number of hands with all cards unique rank
        for i in range(1277 + len(straight_flushes) - 1):
            # we also iterate over SFs
            # pull the next flush pattern from our generator
            f = next(gen)

            # if this flush matches perfectly any
            # straight flush, do not add it
            notSF = True
            for sf in straight_flushes:
                # if f XOR sf == 0, then bit pattern
                # is same, and we should not add
                if not f ^ sf:
                    notSF = False

            if notSF:
                flushes.append(f)

        # we started from the lowest straight pattern, now we want to start
        # ranking from the most powerful hands, so we reverse
        flushes.reverse()
        # now add to the lookup map:
        # start with straight flushes and the rank of 1
        # since it is the best hand in poker
        # rank 1 = Royal Flush!
        self._fill_in_lookup_table(
            rank_init=1,
            rankbits_list=straight_flushes,
            lookup_table=self.flush_lookup)
        # we start the counting for flushes on max full house, which
        # is the worst rank that a full house can have (2,2,2,3,3)
        self._fill_in_lookup_table(
            rank_init=LookupTable.MAX_FULL_HOUSE + 1,
            rankbits_list=flushes,
            lookup_table=self.flush_lookup)
        # we can reuse these bit sequences for straights
        # and high cards since they are inherently related
        # and differ only by context
        self.straight_and_highcards(straight_flushes, flushes)

    def _fill_in_lookup_table(self, rank_init, rankbits_list, lookup_table):
        """Iterate over rankbits and fill in lookup_table"""
        rank = rank_init
        for rb in rankbits_list:
            prime_product = EvaluationCard.prime_product_from_rankbits(rb)
            lookup_table[prime_product] = rank
            rank += 1

    def straight_and_highcards(self, straights, highcards):
        """
        Unique five card sets. Straights and highcards.

        Reuses bit sequences from flush calculations.
        """
        self._fill_in_lookup_table(
            rank_init=LookupTable.MAX_FLUSH + 1,
            rankbits_list=straights,
            lookup_table=self.unsuited_lookup)
        self._fill_in_lookup_table(
            rank_init=LookupTable.MAX_PAIR + 1,
            rankbits_list=highcards,
            lookup_table=self.unsuited_lookup)

    def multiples(self):
        """
        Pair, Two Pair, Three of a Kind, Full House, and 4 of a Kind.
        """
        backwards_ranks = list(range(len(EvaluationCard.INT_RANKS) - 1, -1, -1))

        # 1) Four of a Kind
        rank = LookupTable.MAX_STRAIGHT_FLUSH + 1

        # for each choice of a set of four rank
        for i in backwards_ranks:

            # and for each possible kicker rank
            kickers = backwards_ranks[:]
            kickers.remove(i)
            for k in kickers:
                product = EvaluationCard.PRIMES[i] ** 4 * EvaluationCard.PRIMES[k]
                self.unsuited_lookup[product] = rank
                rank += 1

        # 2) Full House
        rank = LookupTable.MAX_FOUR_OF_A_KIND + 1

        # for each three of a kind
        for i in backwards_ranks:

            # and for each choice of pair rank
            pairranks = backwards_ranks[:]
            pairranks.remove(i)
            for pr in pairranks:
                product = EvaluationCard.PRIMES[i] ** 3 * EvaluationCard.PRIMES[pr] ** 2
                self.unsuited_lookup[product] = rank
                rank += 1

        # 3) Three of a Kind
        rank = LookupTable.MAX_STRAIGHT + 1

        # pick three of one rank
        for r in backwards_ranks:

            kickers = backwards_ranks[:]
            kickers.remove(r)
            gen = itertools.combinations(kickers, 2)

            for kickers in gen:

                c1, c2 = kickers
                product = EvaluationCard.PRIMES[r] ** 3 * EvaluationCard.PRIMES[c1] * EvaluationCard.PRIMES[c2]
                self.unsuited_lookup[product] = rank
                rank += 1

        # 4) Two Pair
        rank = LookupTable.MAX_THREE_OF_A_KIND + 1

        tpgen = itertools.combinations(backwards_ranks, 2)
        for tp in tpgen:

            pair1, pair2 = tp
            kickers = backwards_ranks[:]
            kickers.remove(pair1)
            kickers.remove(pair2)
            for kicker in kickers:

                product = (
                    EvaluationCard.PRIMES[pair1] ** 2
                    * EvaluationCard.PRIMES[pair2] ** 2
                    * EvaluationCard.PRIMES[kicker]
                )
                self.unsuited_lookup[product] = rank
                rank += 1

        # 5) Pair
        rank = LookupTable.MAX_TWO_PAIR + 1

        # choose a pair
        for pairrank in backwards_ranks:

            kickers = backwards_ranks[:]
            kickers.remove(pairrank)
            kgen = itertools.combinations(kickers, 3)

            for kickers in kgen:

                k1, k2, k3 = kickers
                product = (
                    EvaluationCard.PRIMES[pairrank] ** 2
                    * EvaluationCard.PRIMES[k1]
                    * EvaluationCard.PRIMES[k2]
                    * EvaluationCard.PRIMES[k3]
                )
                self.unsuited_lookup[product] = rank
                rank += 1

    def write_table_to_disk(self, table, filepath):
        """
        Writes lookup table to disk
        """
        with open(filepath, "w") as f:
            for prime_prod, rank in table.iteritems():
                f.write(str(prime_prod) + "," + str(rank) + "\n")

    def get_lexographically_next_bit_sequence(self, bits):
        """
        Bit hack from here:
        http://www-graphics.stanford.edu/~seander/bithacks.html#NextBitPermutation

        Generator even does this in poker order rank
        so no need to sort when done! Perfect.
        """
        t = int((bits | (bits - 1))) + 1
        next = t | ((int(((t & -t) / (bits & -bits))) >> 1) - 1)
        yield next
        while True:
            t = (next | (next - 1)) + 1
            next = t | ((((t & -t) // (next & -next)) >> 1) - 1)
            yield next

In [None]:
# evaluation/card.py

import numpy as np

class Card:
    _INT_RANK_TO_STR = {
        0: "Two",
        1: "Three",
        2: "Four",
        3: "Five",
        4: "Six",
        5: "Seven",
        6: "Eight",
        7: "Nine",
        8: "Ten",
        9: "Jack",
        10: "Queen",
        11: "King",
        12: "Ace"
    }

    _INT_SUIT_TO_STR = {
        1: "Spades",
        2: "Hearts",
        3: "Diamonds",
        4: "Clubs"
    }
    
    def __init__(self, card, from_encode=False):
        if from_encode:
            self._card = np.array([card % 13, card // 13 + 1])
        else:
            self._card = card

    def __str__(self):
        rank = self._INT_RANK_TO_STR[self._card[0]]
        suit = self._INT_SUIT_TO_STR[self._card[1]]
        return f"{rank} of {suit}"

    def __repr__(self):
        rank = self._INT_RANK_TO_STR[self._card[0]]
        suit = self._INT_SUIT_TO_STR[self._card[1]]
        return f"{rank} of {suit}"

    def rank(self):
        return self._card[0]

    def suit(self):
        return self._INT_SUIT_TO_STR[self._card[1]]
    
    def encode(self):
        return int(self._card[0] + 13 * (self._card[1] - 1))
    
    def __eq__(self, other):
        return self._card == other._card
    
    def to_numpy(self):
        return np.array(self._card).astype(int)

    def get_eval_card(self):
        return EvaluationCard.new(self.to_numpy())

In [None]:
# evaluation/evaluate.py

# Copied from https://github.com/keithlee96/pluribus-poker-AI/blob/develop/poker_ai/poker/evaluation/evaluator.py
import itertools


class Evaluator(object):
    """
    Evaluates hand strengths using a variant of Cactus Kev's algorithm:
    http://suffe.cool/poker/evaluator.html

    I make considerable optimizations in terms of speed and memory usage,
    in fact the lookup table generation can be done in under a second and
    consequent evaluations are very fast. Won't beat C, but very fast as
    all calculations are done with bit arithmetic and table lookups.
    """

    def __init__(self):

        self.table = LookupTable()

        self.hand_size_map = {5: self._five, 6: self._six, 7: self._seven}

    def evaluate(self, cards, board):
        """
        This is the function that the user calls to get a hand rank.

        Supports empty board, etc very flexible. No input validation
        because that's cycles!
        """
        all_cards = [int(c) for c in cards + board]
        return self.hand_size_map[len(all_cards)](all_cards)

    def _five(self, cards):
        """
        Performs an evalution given cards in integer form, mapping them to
        a rank in the range [1, 7462], with lower ranks being more powerful.

        Variant of Cactus Kev's 5 card evaluator, though I saved a lot of memory
        space using a hash table and condensing some of the calculations.
        """
        # if flush
        if cards[0] & cards[1] & cards[2] & cards[3] & cards[4] & 0xF000:
            handOR = (cards[0] | cards[1] | cards[2] | cards[3] | cards[4]) >> 16
            prime = EvaluationCard.prime_product_from_rankbits(handOR)
            return self.table.flush_lookup[prime]

        # otherwise
        else:
            prime = EvaluationCard.prime_product_from_hand(cards)
            return self.table.unsuited_lookup[prime]

    def _six(self, cards):
        """
        Performs five_card_eval() on all (6 choose 5) = 6 subsets
        of 5 cards in the set of 6 to determine the best ranking,
        and returns this ranking.
        """
        minimum = LookupTable.MAX_HIGH_CARD

        all5cardcombobs = itertools.combinations(cards, 5)
        for combo in all5cardcombobs:

            score = self._five(combo)
            if score < minimum:
                minimum = score

        return minimum

    def _seven(self, cards):
        """
        Performs five_card_eval() on all (7 choose 5) = 21 subsets
        of 5 cards in the set of 7 to determine the best ranking,
        and returns this ranking.
        """
        minimum = LookupTable.MAX_HIGH_CARD

        all5cardcombobs = itertools.combinations(cards, 5)
        for combo in all5cardcombobs:

            score = self._five(combo)
            if score < minimum:
                minimum = score

        return minimum

    def get_rank_class(self, hr):
        """Returns the class of hand from the hand hand_rank from evaluate."""
        if hr >= 0 and hr <= LookupTable.MAX_STRAIGHT_FLUSH:
            c = LookupTable.MAX_TO_RANK_CLASS[LookupTable.MAX_STRAIGHT_FLUSH]
        elif hr <= LookupTable.MAX_FOUR_OF_A_KIND:
            c = LookupTable.MAX_TO_RANK_CLASS[LookupTable.MAX_FOUR_OF_A_KIND]
        elif hr <= LookupTable.MAX_FULL_HOUSE:
            c = LookupTable.MAX_TO_RANK_CLASS[LookupTable.MAX_FULL_HOUSE]
        elif hr <= LookupTable.MAX_FLUSH:
            c = LookupTable.MAX_TO_RANK_CLASS[LookupTable.MAX_FLUSH]
        elif hr <= LookupTable.MAX_STRAIGHT:
            c = LookupTable.MAX_TO_RANK_CLASS[LookupTable.MAX_STRAIGHT]
        elif hr <= LookupTable.MAX_THREE_OF_A_KIND:
            c = LookupTable.MAX_TO_RANK_CLASS[LookupTable.MAX_THREE_OF_A_KIND]
        elif hr <= LookupTable.MAX_TWO_PAIR:
            c = LookupTable.MAX_TO_RANK_CLASS[LookupTable.MAX_TWO_PAIR]
        elif hr <= LookupTable.MAX_PAIR:
            c = LookupTable.MAX_TO_RANK_CLASS[LookupTable.MAX_PAIR]
        elif hr <= LookupTable.MAX_HIGH_CARD:
            c = LookupTable.MAX_TO_RANK_CLASS[LookupTable.MAX_HIGH_CARD]
        else:
            raise Exception("Inavlid hand rank, cannot return rank class")
        return c

    def class_to_string(self, class_int):
        """
        Converts the integer class hand score into a human-readable string.
        """
        return LookupTable.RANK_CLASS_TO_STRING[class_int]

    def get_five_card_rank_percentage(self, hand_rank):
        """
        Scales the hand rank score to the [0.0, 1.0] range.
        """
        return float(hand_rank) / float(LookupTable.MAX_HIGH_CARD)

    def hand_summary(self, board, hands):
        """
        Gives a sumamry of the hand with ranks as time proceeds.

        Requires that the board is in chronological order for the
        analysis to make sense.
        """

        assert len(board) == 5, "Invalid board length"
        for hand in hands:
            assert len(hand) == 2, "Inavlid hand length"

        line_length = 10
        stages = ["FLOP", "TURN", "RIVER"]

        for i in range(len(stages)):
            line = "=" * line_length
            print(f"{line} {stages[i]} {line}")

            best_rank = 7463  # rank one worse than worst hand
            winners = []
            for player, hand in enumerate(hands):

                # evaluate current board position
                rank = self.evaluate(hand, board[: (i + 3)])
                rank_class = self.get_rank_class(rank)
                class_string = self.class_to_string(rank_class)
                percentage = 1.0 - self.get_five_card_rank_percentage(
                    rank
                )  # higher better here
                print(
                    f"Player {player + 1} hand = {class_string}, percentage rank among all hands = {percentage}"
                )

                # detect winner
                if rank == best_rank:
                    winners.append(player)
                    best_rank = rank
                elif rank < best_rank:
                    winners = [player]
                    best_rank = rank

            # if we're not on the river
            if i != stages.index("RIVER"):
                if len(winners) == 1:
                    print(f"Player {winners[0] + 1} hand is currently winning.\n")
                else:
                    print(
                        f"Players {[x + 1 for x in winners]} are tied for the lead.\n"
                    )

            # otherwise on all other streets
            else:
                hand_result = self.class_to_string(
                    self.get_rank_class(self.evaluate(hands[winners[0]], board))
                )
                print()
                print(f"{line} HAND OVER {line}")
                if len(winners) == 1:
                    print(
                        f"Player {winners[0] + 1} is the winner with a {hand_result}\n"
                    )
                else:
                    print(f"Players {winners} tied for the win with a {hand_result}\n")

In [None]:
# state/state.py

import torch
import copy


START_MONEY = 10000

class Round(IntEnum):
    PREFLOP = 0
    FLOP = 1
    TURN = 2
    RIVER = 3

class BotState:
    def __init__(self) -> None:
        self._play = True
        self._money = START_MONEY
        self._total_bet = 0
        self._current_bet = 0

    @property
    def round_money(self) -> int:
        return self._money + self._current_bet
    @property
    def play(self) -> bool:
        return self._play
    
    @play.setter
    def play(self, value: bool) -> None:
        self._play = value

    @property
    def money(self) -> int:
        return self._money
    
    @money.setter
    def money(self, value: int) -> None:
        self._money = value
    
    @property
    def total_bet(self) -> int:
        return self._total_bet

    @total_bet.setter
    def total_bet(self, value: int) -> None:
        self._total_bet = value
    
    @property
    def current_bet(self) -> int:
        return self._current_bet
    
    @current_bet.setter
    def current_bet(self, value: int) -> None:
        self._current_bet = value 
    
    def build_stack(self) -> torch.Tensor:
        return torch.Tensor([self.money, self.total_bet, self.current_bet]) / START_MONEY
    
    def __str__(self) -> str:
        return f"BotState(play={self.play}, money={self.money}, total_bet={self.total_bet}, current_bet={self.current_bet})"

class MiniState:
    """
    Mini state of the game.
    Tracks:
        - table
        - history of bets
        - round
        - current player
    """
    def __init__(self, table, players_left, n_players, max_round_size) -> None:
        self.table = table # cards on table

        self.action_history = [] # List of Actions

        self.n_players = n_players
        
        # We know the round is over when folded + called = players_left
        self.players_left = players_left
        self.folded = 0
        self.other = 0
        self.num_raises = 0

        self.pot = 0
        self.top_bet = 0
        self.max_round_size = max_round_size

    def encode_single_action(self, action) -> np.array:
        """Encodes the action history (formatted for BrownNet)"""
        if action.type == ActionType.FOLD: return np.array([0])
        if action.type == ActionType.CALL: return np.array([0])
        if action.type == ActionType.RAISE: return np.array([action.bet])
            
    def encode_action(self) -> np.array:
        """Encodes the action history"""
        return np.array([self.encode_single_action(action) for action in self.action_history] + [np.zeros(1) for _ in range(self.max_round_size - len(self.action_history))])

    def update(self, action: Action, bots, curr_player, global_pot, blind=False) -> None:
        """Update the state of the game with the given action"""

        if action.type == ActionType.CALL: 
            difference = action.bet - bots[curr_player].current_bet
            bots[curr_player].current_bet = action.bet
            bots[curr_player].total_bet += difference
            bots[curr_player].money -= difference
            action.bet = difference / global_pot
            self.pot += difference
        elif action.type == ActionType.RAISE:
            bots[curr_player].current_bet += action.bet
            bots[curr_player].total_bet += action.bet
            bots[curr_player].money -= action.bet
            self.top_bet = max(self.top_bet, bots[curr_player].current_bet)
            self.pot += action.bet
            action.bet /= global_pot  
        
        self.action_history.append(action)

        # blinds do not count as part of the round
        if blind:
            return

        if action.type == ActionType.FOLD: 
            bots[curr_player].play = False
            self.folded += 1
        elif action.type == ActionType.CALL: self.other += 1
        elif action.type == ActionType.RAISE: self.update_backlog()
    
    def update_backlog(self) -> None:
        """Reset folded and called players"""
        self.players_left -= self.folded
        self.folded = 0
        self.other = 1
        self.num_raises += 1
    
    def end_round(self) -> bool:
        """Called every time a player makes an action to check if the round is over"""
        return self.other + self.folded == self.players_left 
    
    def __str__(self) -> str:
        build_str = f"MiniState: pot={self.pot}, top_bet={self.top_bet}\n"
        build_str += f"Table: {self.table}\n"
        build_str += f"Action history: \n"
        for i, action in enumerate(self.action_history):
            build_str += f"Player {(self.start_player + i) % self.total_players}: {bot.__str__()}\n"
        
        return build_str

class State:
    """
    State of the game
        - could have mini state for each round
            - field for invidiual mini state (preflop)
            - consider each player's bets
        - pot
        - could have history of actions
    """
    def __init__(self, 
        n_players=6,
        num_rounds=2, # 2 for FHP
        max_round_size=7 # 7 for FHP
    ) -> None:
        self.pot = 0
        self.start_player = np.random.randint(0, n_players)
        self.curr_player = self.start_player
        self.num_rounds = num_rounds
        self.max_round_size = max_round_size

        self.total_players = n_players
        self.active = n_players

        self.bots = [BotState() for _ in range(n_players)]

        self.table = []
        self.mini_states = [MiniState(self.table, self.active, self.total_players, self.max_round_size)]
        self.round = Round.PREFLOP
        self.deck = Deck()
        self.deck.reset()

        # little
        self.mini_states[-1].update(Action(ActionType.RAISE, 50), self.bots, self.curr_player, 150, blind=True)
        self.curr_player = (self.curr_player + 1) % self.total_players
        self.pot += 50

        # big
        self.mini_states[-1].update(Action(ActionType.RAISE, 100), self.bots, self.curr_player, 150, blind=True)
        self.curr_player = (self.curr_player + 1) % self.total_players
        self.pot += 100

        self.depth = 0

    def board_size(self):
        return self.num_rounds + 1

    def finish_round(self):
        """Called at end of each round. Updates the pot and creates a new mini state"""
        self.reset_money()

        if self.round == self.num_rounds - 1:
            return

        if self.round == Round.PREFLOP:
            self.table = self.deck.deal(3)
            self.round += 1
        elif self.round == Round.FLOP:
            self.table += self.deck.deal(1)
            self.round += 1
        elif self.round == Round.TURN:
            self.table += self.deck.deal(1)
            self.round += 1

        self.mini_states.append(MiniState(self.table, self.active, self.total_players, self.max_round_size))
            
        self.curr_player = self.start_player

    def deal_player(self):
        return self.deck.deal_one()
    
    def get_top_bet(self):
        return self.mini_states[-1].top_bet

    def is_terminal(self) -> bool:
        return (self.round == (self.num_rounds - 1)  and self.end_round()) or self.one_player_left() or len(self.mini_states) > self.num_rounds

    def one_player_left(self) -> bool:
        return sum([bot.play for bot in self.bots]) == 1

    def end_round(self) -> bool:
        return self.mini_states[-1].end_round()

    def to_dict(self):
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        encoded_actions = [torch.Tensor(mini_state.encode_action()).to(device) for mini_state in self.mini_states]
        empty_actions = [torch.zeros(self.max_round_size, 1).to(device) for _ in range(self.num_rounds - len(self.mini_states))]
        return {
            "h_action": torch.stack(encoded_actions + empty_actions).view(-1),
            # encode cards between 0-51, -1 for unknown
            "cards": [Card(c).encode() for c in self.table] + [-1 for _ in range(self.board_size() - len(self.table))],
        }

    def reset_money(self):
        for player in self.bots:
            player.current_bet = 0
    
    def next_player(self):
        self.curr_player = (self.curr_player + 1) % self.total_players
        while not self.bots[self.curr_player].play:
            self.curr_player = (self.curr_player + 1) % self.total_players
            
    def update(self, action: Action):
        if action.type == 0: self.active -= 1 # for next mini state
        # update state pot with ministate pot difference after update
        self.pot -= self.mini_states[-1].pot
        self.mini_states[-1].update(action, self.bots, self.curr_player, self.pot + self.mini_states[-1].pot)
        self.pot += self.mini_states[-1].pot
        self.curr_player = (self.curr_player + 1) % self.total_players

        if self.end_round(): self.finish_round()

        self.depth += 1

    def round_to_str(self):
        return ["PREFLOP", "FLOP", "TURN", "RIVER"][int(self.round)]

    def print_history(self):
        print("State history:")
        for i, mini_state in enumerate(self.mini_states):
            print(f"Round {i}: {mini_state.__str__()}")
            
    def __str__(self) -> str:
        build_str = f"Sate: pot={self.pot}, round={self.round_to_str()}\n"
        for i, bot in enumerate(self.bots):
            build_str += f"Player {i}: {bot.__str__()}\n"
        
        build_str += f"Table: {self.table}\n"
        build_str += f"Top bet: {self.get_top_bet()}\n"
        return build_str

In [None]:
# cfr/dataset.py
# cfr/wprollout.py
# cfr/brown_net.py
# cfr/deep_cfr.py
# + extras for experimentation

# Apply deep mccfr minimization
# Allow neural network to learn abstractions of the game through embeddings

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Any
import copy
import math
    
def aggregate_bets_fhp(state, action_dist):
    """
    Default aggregation of bets for Flop Hold'em Poker
    """
    if state.mini_states[-1].num_raises >= 3:
        new_dist = torch.zeros(action_dist.shape)#.to(device)
        new_dist[0] = action_dist[0]
        new_dist[1] = action_dist[1:].sum()
        return new_dist

    return action_dist
        
class ValueDataset(torch.utils.data.Dataset):
    def __init__(self, states = [], hands = [], values = [], T = []):
        self.states = states
        self.hands = hands
        self.values = values
        self.T = T

    def append(self, x, value, t):
        self.states.append(x)
        self.values.append(value)
        self.T.append(t)
    
    def setup(self):
        self.values = [torch.Tensor(x).cpu() for x in self.values]
        self.values = torch.stack(self.values)

    def reset(self):
        self.values = self.values.cpu().tolist()

    def save(self, save_path="/kaggle/working/value_dataset.pt"):
        data = {
            'states': self.states,
            'hands': self.hands,
            'values': self.values,
            'T': self.T
        }
        torch.save(data, save_path)

    def load(self, load_path="/kaggle/input/M_Vp/value_dataset.pt"):
        data = torch.load(load_path, weights_only=False)
        self.states = data['states']
        self.hands = data['hands']
        self.values = data['values']
        self.T = data['T']

    def __len__(self):
        return len(self.states)

    def __getitem__(self, idx):
        return self.states[idx], self.values[idx], self.T[idx]

class PolicyDataset(torch.utils.data.Dataset):
    def __init__(self, states = [], hands = [], policies = [], T = []):
        self.states = states
        self.hands = hands
        self.policies = policies
        self.T = T

    def append(self, x, policy, t):
        self.states.append(x)
        self.policies.append(policy)
        self.T.append(t)
    
    def setup(self):
        self.policies = [torch.Tensor(x).cpu() for x in self.policies]
        self.policies = torch.stack(self.policies)

    def reset(self):
        self.policies = self.policies.cpu().tolist()

    def save(self, save_path="/kaggle/working/policy_dataset.pt"):
        data = {
            'states': self.states,
            'hands': self.hands,
            'policies': self.policies,
            'T': self.T
        }
        torch.save(data, save_path)

    def load(self, load_path="/kaggle/input/M_Vp/policy_dataset.pt"):
        data = torch.load(load_path, weights_only=False)
        self.states = data['states']
        self.hands = data['hands']
        self.policies = data['policies']
        self.T = data['T']

    def __len__(self):
        return len(self.states)

    def __getitem__(self, idx):
        return self.states[idx], self.policies[idx], self.T[idx]

class WpRollout:
    def __init__(self, preflop_win_rates, preflop_tie_rates):
        self.preflop_win_rates = torch.Tensor(preflop_win_rates)
        self.preflop_tie_rates = torch.Tensor(preflop_tie_rates)
        self.flop_win_rates = None
        self.flop_tie_rates = None
        self.player = None
        self.index = None

    def fix(self, p_cards, i):
        self.player = i
        self.index = self.encode_hand([p_cards[0].encode(), p_cards[1].encode()])
        self.p_cards = [c.encode() for c in p_cards]
        self.eval_cards = [c.get_eval_card() for c in p_cards]
        self.flop_win_rates = None
        self.flop_tie_rates = None

    def win_rates(self, board, remaining_cards):
        win_rates = np.zeros(26 * 51)
        tie_rates = np.zeros(26 * 51)
        evaluator = Evaluator()
        
        hand_val = evaluator.evaluate(self.eval_cards, [Card(c).get_eval_card() for c in board])
        eval_board = [Card(c).get_eval_card() for c in board]
        
        remaining_hands = list(itertools.combinations(remaining_cards, 2))
        # remaining_hands = [h.sort() for h in remaining_hands] # enforce order

        eval_hands = [[Card(c, from_encode=True).get_eval_card() for c in h] for h in remaining_hands]
        remaining_vals = np.array([evaluator.evaluate(h, eval_board) for h in eval_hands])

        wins = (remaining_vals > hand_val).astype(float)
        ties = (remaining_vals == hand_val).astype(float)

        encoded_hands = [self.encode_hand(hand) for hand in remaining_hands]

        win_rates[encoded_hands] = wins
        tie_rates[encoded_hands] = ties

        return win_rates, tie_rates 
            
    def compute_flop(self, state: State, player_idx: int):
        board = state.deck.simulate_deal(3) if len(state.table) == 0 else state.table
        encoded_board = [Card(c).encode() for c in board]
        
        remaining_cards = [c for c in range(52) if c not in self.p_cards + encoded_board]
        self.flop_win_rates, self.flop_tie_rates = self.win_rates(board, remaining_cards)
        self.flop_win_rates = torch.Tensor(self.flop_win_rates)
        self.flop_tie_rates = torch.Tensor(self.flop_tie_rates)

    def encode_hand(self, hand):
        return (52 - hand[0]) * (52 - hand[0] - 1) // 2 - hand[1] + hand[0]

    def get_wins(self, state, pi):
        assert self.index is not None, "Fix index of WpRollout before getting win rates"
        self.preflop_win_rates = self.preflop_win_rates.to(pi.device)
        if state.round == Round.PREFLOP:
            return pi * self.preflop_win_rates[self.index]
        else:
            if self.flop_win_rates is None: self.compute_flop(state, self.player)
            return pi * self.flop_win_rates.to(pi.device)
    
    def get_ties(self, state, pi):
        assert self.index is not None, "Fix index of WpRollout before getting tie rates"
        self.preflop_tie_rates = self.preflop_tie_rates.to(pi.device)

        if state.round == Round.PREFLOP:
            return pi * self.preflop_tie_rates[self.index]
        else:
            if self.flop_tie_rates is None: self.compute_flop(state, self.player)
            return pi * self.flop_tie_rates.to(pi.device)


class FC(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(FC, self).__init__()

        self.fc = nn.Linear(input_dim, output_dim)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.relu(self.fc(x))

class CardEmbedding(nn.Module):
    def __init__(self, dim):
        super(CardEmbedding, self).__init__()
        self.rank = nn.Embedding(13, dim)
        self.suit = nn.Embedding(4, dim)
        self.card = nn.Embedding(52, dim)

    def forward(self, input):
        if input.dim() > 2: 
            input = input.squeeze()
        B, num_cards = input.shape
        x = input.view(-1)
        valid = x.ge(0).float()
        x = x.clamp(min=0)
        embs = self.card(x) + self.rank(x // 4) + self.suit(x % 4)
        embs = embs * valid.unsqueeze(1) # ‘zero out’ no card ’ embeddings
        
        return embs.view(B, num_cards, -1).sum(1)

class BrownNet(nn.Module):
    def __init__(self, n_card_types, n_bets, n_actions, dim=64):
        super(BrownNet, self).__init__()
        self.card_embeddings = nn.ModuleList(
            [CardEmbedding(dim) for _ in range(n_card_types)]
        )
        
        self.card1 = nn.Linear(dim * n_card_types, dim)
        self.card2 = nn.Linear(dim, dim)
        self.card3 = nn.Linear(dim, dim)
        
        self.bet1 = nn.Linear(n_bets * 2, dim)
        self.bet2 = nn.Linear(dim, dim)
        
        self.comb1 = nn.Linear(2 * dim, dim)
        self.comb2 = nn.Linear(dim, dim)
        self.comb3 = nn.Linear(dim, dim)

        self.layer_norm = nn.LayerNorm((dim))
        self.action_head = nn.Linear(dim, n_actions)

        nn.init.constant_(self.action_head.weight, 0)
        nn.init.constant_(self.action_head.bias, 0)

    def forward(self, x):
        """
        cards : ( (N x 2), (N x 3) [, (N x 1), (N x 1)] ) # (hole, board, [turn, river])
        bets : N x n_bet_feats
        """
        cards = x["cards"]
        # bets = torch.cat((x["h_action"][0].squeeze(), x["h_action"][1].squeeze()), dim=-1).squeeze().unsqueeze(0).to("cuda")
        bets = x["h_action"].to("cuda")
        if bets.dim() < 2: bets = bets.unsqueeze(0)
        card_embs = []
        for embedding, card_group in zip(self.card_embeddings, cards):
            card_embs.append(embedding(card_group))
        card_embs = torch.cat(card_embs, dim=-1)

        x = F.relu(self.card1(card_embs))
        x = F.relu(self.card2(x))
        x = F.relu(self.card3(x))


        bet_size = bets.clamp(min=0)
        bets_occured = bets.ge(0).float()
        bet_feats = torch.cat([bet_size, bets_occured], dim=-1)

        y = F.relu(self.bet1(bets))
        y = F.relu(self.bet2(y) + y)

        if y.dim() > 2: y = y.squeeze()

        z = torch.cat([x, y], dim=-1)
        z = F.relu(self.comb1(z))
        z = F.relu(self.comb2(z) + z)
        z = F.relu(self.comb3(z) + z)

        z = self.layer_norm(z)
        return self.action_head(z)

class MCCFR():
    def __init__(self, 
        n_players: int = 2, # 6 for standard poker
        start_money: int = 10000, 
        load_ckpt = False, 
        aggregation_func: Any = aggregate_bets_fhp,
        raise_map: np.array = np.array([100]), # action map for raise amounts
        exact_map: bool = True, # using exact map or fraction of pot
        wp_path: str = "/kaggle/input/flop-rollouts"
    ):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.bot_hands = []

        self.value_net = BrownNet(n_card_types=2, n_bets=7, n_actions=3, dim=64)
        self.value_net.to(self.device)

        # self.policy_net = BrownNet(n_card_types=2, n_bets=7, n_actions=3, dim=64)
        # self.policy_net.to(self.device)
        
        count = 0
        for param in self.value_net.parameters():
            count += param.numel()
        print(f"Network with {count} parameters")

        self.optimizer = torch.optim.Adam(self.value_net.parameters(), lr=1e-3)
        # self.policy_optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=1e-3)
        
        self.bet_money = 0
        self.num_players = n_players
        
        self.decisions = 0
        self.questionable = 0

        self.load_ckpt = load_ckpt

        win_rates = np.load(os.path.join(wp_path, "win_rates.npy"))
        tie_rates = np.load(os.path.join(wp_path, "tie_rates.npy"))

        self.wprollout = WpRollout(win_rates, tie_rates)

        # Function for aggregating bets using round-based constraints
        self.aggregation_func = aggregation_func
        self.raise_map = raise_map
        self.exact_map = exact_map


    def get_eval_cards(self, idx):
        card_a = Card(self.bot_hands[idx][0]._card)
        card_b = Card(self.bot_hands[idx][1]._card)
        return [card_a.get_eval_card(), card_b.get_eval_card()]

    def begin_round(self, state):
        """Deals cards to players and initializes round"""
        
        self.bot_hands.clear()
        for i in range(self.num_players):
            self.bot_hands.append([Card(state.deal_player())])
        
        for i in range(self.num_players):
            self.bot_hands[i].append(Card(state.deal_player()))
        
        self.decisions = 0
        self.questionable = 0

    def save_checkpoint(self, sim):
        ckpt = {
            'sim': sim,
            'state': self.value_net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        torch.save(ckpt, '/kaggle/working/value_net.pth')

    def save_policy_net(self, sim):
        ckpt = {
            'sim': sim,
            'state': self.policy_net.state_dict(),
            'optimizer': self.policy_optimizer.state_dict(),
        }
        torch.save(ckpt, '/kaggle/working/policy_net.pth')

    def choice_to_action(self, state, index):
        if index == 0: return Action(0, 0)
        elif index == 1: return Action(1, state.mini_states[-1].top_bet)
        else: 
            if self.exact_map:
                raise_amount = self.raise_map[index - 2]
            else:
                raise_amount = self.raise_map[index - 2] * state.pot
                raise_amount = math.floor(raise_amount)
            return Action(2, raise_amount)

    def target_policy(self, regret):
        """
        Compute target policy from regret
        """
        target = torch.zeros(regret.shape).to(self.device)
        pos_regret = F.relu(regret)
        sum_regret = pos_regret.sum()
        
        if sum_regret < 0:
            argmax = torch.argmax(regret)
            target[argmax] = 1.
            return target
        elif sum_regret == 0:
            return torch.ones(3).to(regret.device) / 3
            
        for i in range(len(pos_regret)):
            target[i] = pos_regret[i] / sum_regret

        return target
    
    def winner(self, state: State):
        """
        Given a terminal state, return the index of the winning bot
        """
        self.r_bots = [i for i, bot in enumerate(state.bots) if bot.play]

        if len(self.r_bots) == 1: return self.r_bots[0]

        eval_cards = [self.get_eval_cards(i) for i in self.r_bots]
        eval_table = [Card(c).get_eval_card() for c in state.table]
        
        evals = [Evaluator().evaluate(eval_pair, eval_table) for eval_pair in eval_cards]

        win_index = evals.index(min(evals))
        return self.r_bots[win_index] # Return index of winning bot

    def utility(self, state, bot_idx, win_index):
        player = state.bots[bot_idx]
        if not player.play: return -player.total_bet
        
        if bot_idx == win_index: return state.pot - player.total_bet
        else: return -player.total_bet

    @torch.no_grad()
    def traverse(self, state: State, player_idx: int, M_Vp: ValueDataset, t: int):
        """Rough Training algorithm from Deep MCCFR paper: https://arxiv.org/pdf/1811.00164
        if terminal(s) then return u(s)
        if s is chance node then 
            sample a successor s' of s
            return traverse(s', player)
        if s is a player node then
            compute strategy pi(s) for player i
            for each action a in pi(s) do
                s' = apply(a, s)
                u = traverse(s', player)
            regret = u - pi(s) * u
            M_Vp.append(s, regret, t) # train value net on this
            return pi(s) * u
        """
        if not state.bots[state.curr_player].play: # skip folded players
            state.next_player()
            return self.traverse(state, player_idx, M_Vp, t)

        if state.is_terminal(): 
            return self.utility(state, player_idx, self.winner(state))
        elif not state.bots[player_idx].play: # utility of folded player is -total_bet
            return -state.bots[player_idx].total_bet
        elif state.curr_player != player_idx: # opponent's turn
            # With external sampling, we just sample an action using our trained value network 
            x = state.to_dict()
            x["cards"] = [torch.IntTensor(x["cards"]).unsqueeze(0).to(self.device), torch.IntTensor([card.encode() for card in self.bot_hands[(player_idx + 1) % 2]]).unsqueeze(0).to(self.device)]
            values = self.value_net(x).squeeze()
            
            policy = self.target_policy(values)
            policy = policy.squeeze()
            policy = policy / policy.sum(dim=-1)

            agg_policy = self.aggregation_func(state, policy).to(self.device)
            
            action = torch.distributions.Categorical(agg_policy).sample()
            action = self.choice_to_action(state, action)
            state.update(action)

            return self.traverse(state, player_idx, M_Vp, t)
        else: # decision point
            self.decisions += 1
            x = state.to_dict()
            x["cards"] = [torch.IntTensor(x["cards"]).unsqueeze(0).to(self.device), torch.IntTensor([card.encode() for card in self.bot_hands[player_idx]]).unsqueeze(0).to(self.device)]
            values = self.value_net(x)
            values = values.squeeze()

            policy = self.target_policy(values) # calculate policy from values

            agg_policy = self.aggregation_func(state, policy).to(self.device)

            u = torch.zeros(policy.shape).to(self.device)
            for a in range(agg_policy.shape[0]):
                if agg_policy[a] < 0.001:
                    continue
                s_prime = copy.deepcopy(state)
                if a > 0:
                    act = self.choice_to_action(state, a)
                    s_prime.update(act)
                    u_raise = self.traverse(s_prime, player_idx, M_Vp, t)
                    u[a] = u_raise
                else:
                    u[a] = -state.bots[player_idx].total_bet

            # regret = E[u | a] - E[u | pi]
            regret = u - (u * agg_policy).sum()
            M_Vp.append(x, regret, t)

            return (u * policy).sum()# , total_loss
        
    def optimize(self, M_Vp: ValueDataset, T: int, steps: int = 4000, batch_size: int = 10000):
        self.value_net = BrownNet(n_card_types=2, n_bets=7, n_actions=3, dim=64).to(self.device)
        M_Vp.setup()
        self.value_net.train()
        step = 0
        losses = []

        loader = torch.utils.data.DataLoader(M_Vp, batch_size=batch_size, shuffle=True)

        # Reinitialize optimizer
        self.optimizer = torch.optim.Adam(self.value_net.parameters(), lr=1e-3)
        while True:
            for x, value, t in loader:
                t = t.to(self.device)
                t = t.unsqueeze(1).repeat(1, 3) # I might be stupid
                self.optimizer.zero_grad()
                values = self.value_net(x).squeeze()
                loss = torch.mean(t * 2 / T * ((values.squeeze() - value.to(values.device))**2))
                losses.append(loss.item())
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.value_net.parameters(), 1.0)
                self.optimizer.step()
                step += 1
                # x["h_money"] = [h.cpu() for h in x["h_money"]]
                x["h_action"] = [h.cpu() for h in x["h_action"]]
                value = value.cpu()
                t = t.cpu()
                if step >= steps: 
                    M_Vp.reset()
                    return losses

    def optimize_policy(self, M_Pi: PolicyDataset, steps: int, T: int):
        self.policy_net = BrownNet(n_card_types=2, n_bets=7, n_actions=3, dim=64).to(self.device)
        M_Pi.setup()
        self.policy_net.train()
        step = 0
        losses = []

        loader = torch.utils.data.DataLoader(M_Pi, batch_size=10000, shuffle=True) # TODO: return to 10000
        
        self.policy_optimizer = torch.optim.Adam(self.value_net.parameters(), lr=1e-3)
        while True:
            for x, target, t in loader:
                t = t.to(self.device)
                t = t.unsqueeze(1).repeat(1, 3) # I might be stupid
                self.policy_optimizer.zero_grad()
                policy = self.policy_net(x)
                loss = torch.mean(t * 2 / T * ((policy.squeeze() - target.to(policy.device))**2))
                losses.append(loss.item())
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
                self.policy_optimizer.step()
                step += 1
                x["h_action"] = [h.cpu() for h in x["h_action"]]
                target = target.cpu()
                t = t.cpu()
                if step >= steps: 
                    M_Pi.reset()
                    return losses


    def remaining_boards(self, board, remaining_cards):
        """
        Returns all possible remaining boards given current board and remaining cards
        """
        return list(itertools.combinations(remaining_cards, 5 - len(board)))

    def encode_hand(self, hand):
        return (52 - hand[0]) * (52 - hand[0] - 1) // 2 - hand[1] + hand[0]

    @torch.no_grad()
    def local_best_response(self, pi, state: State, h_i: int):
        """
        Approximate 2-player best response to given policy using local BR algorithm [Lisý et al. 2016: https://arxiv.org/pdf/1612.07547]
        
        Args
            pi: torch.Tensor (52^2, 1) -> probability distribution over all possible hands
            state: State -> current state of game
            h_i: int -> hand index, represented as (card a) * 52 + (card b)
        """

        U = torch.zeros(3) # utility of each action
        board = [Card(c).encode() for c in state.table]
        hand = [card.encode() for card in self.bot_hands[h_i]]

        remaining_cards = set([card for card in range(52) if card not in board + hand])
        remaining_hands = list(itertools.combinations(remaining_cards, 2))

        encoded_hands = [self.encode_hand(hand) for hand in remaining_hands]
        encoded_hands = np.array(encoded_hands)

        # renormalize pi
        for card in board + hand:
            for i in range(52):
                if i == card: continue
                h = [card, i]
                h.sort()
                pi[self.encode_hand(h)] = 0
        pi = pi / pi.sum()
        

        wp = self.wprollout.get_wins(state, pi).sum(axis=0)
        tp = self.wprollout.get_ties(state, pi).sum(axis=0)
        asked = state.bots[(h_i + 1) % 2].total_bet - state.bots[h_i].total_bet
        
        U[1] = wp * state.pot - (1 - wp) * asked + (tp * state.pot) / 2

        action_map = []
        for a in range(2, 3):
            raise_amount = 100

            if raise_amount > state.bots[h_i].money: continue
            elif raise_amount + state.bots[h_i].current_bet < state.bots[(h_i + 1) % 2].current_bet * 2: continue

            fp = 0
            for enc_hand, hand in zip(encoded_hands, remaining_hands):
                x = state.to_dict()
                x["cards"] = [torch.IntTensor(x["cards"]).unsqueeze(0).to(self.device), torch.IntTensor(hand).unsqueeze(0).to(self.device)]       
                values = self.value_net(x).squeeze()
                fold_policy = self.target_policy(values)[0]
                fp += pi[enc_hand] * fold_policy
                pi[enc_hand] *= (1 - fold_policy)
            sum_pi = pi.sum()
            pi = pi / pi.sum() # renormalize
            wp = self.wprollout.get_wins(state, pi).sum(axis=0) # self.wprollout(h_i, state, pi)
            tp = self.wprollout.get_ties(state, pi).sum(axis=0) # self.wprollout(h_i, state, pi)
            U[a] = fp * state.pot \
                    + (1 - fp) * (wp * (state.pot + raise_amount) - (1 - wp) * (asked + raise_amount)) \
                    + (1 - fp) * (tp * (state.pot + raise_amount)) / 2

        if U.max() > 0: return U.argmax(), pi
        return 0, pi

    @torch.no_grad()
    def vs_br(self, state, pi, player_idx):
        assert state.total_players == 2, "Can only approximate local best response for 2-player games"
        
        if state.is_terminal(): return -1 * self.utility(state, player_idx, self.winner(state))
        elif state.curr_player != player_idx:
            action, pi = self.local_best_response(pi, state, (player_idx + 1) % 2)
            state.update(self.choice_to_action(state, action))
            return self.vs_br(state, pi, player_idx)
        else:
            with torch.no_grad():
                self.value_net.eval()
                x = state.to_dict()
                x["cards"] = [torch.IntTensor(x["cards"]).unsqueeze(0).to(self.device), torch.IntTensor([card.encode() for card in self.bot_hands[player_idx]]).unsqueeze(0).to(self.device)] 
                values = self.value_net(x).squeeze()

                policy = self.target_policy(values)
                assert (policy >= 0).all()
                assert torch.abs(policy.sum() - 1) < 0.0001
                
                agg_policy = self.aggregation_func(state, policy).to(self.device)
                action = torch.distributions.Categorical(agg_policy).sample()
                action = self.choice_to_action(state, action)
                state.update(action)
            return self.vs_br(state, pi, player_idx)


    def exploitability(self, runs=100):
        total = 0
        for run in range(runs):
            if run % 100 == 0: print(f"Run {run}")

            state = State(n_players=2)
            self.begin_round(state)

            # milli-bb (bb = 100 => mbb = money * 1000/100)
            self.wprollout.fix(self.bot_hands[1], 1)
            pi = torch.ones(26 * 51).to(self.device) / (26 * 51)
            total += self.vs_br(copy.deepcopy(state), pi, 0) * 10 

            self.wprollout.fix(self.bot_hands[0], 0)
            pi = torch.ones(26 * 51).to(self.device) / (26 * 51)
            total += self.vs_br(copy.deepcopy(state), pi, 1) * 10

        return total / runs

    def load_value_net(self):
        ckpt = torch.load('/kaggle/input/pokerv4mini-ckpt/pytorch/default/1/value_net.pth', weights_only=True)
        self.value_net.load_state_dict(ckpt["state"])
        self.optimizer.load_state_dict(ckpt["optimizer"])
        print(f"Loaded checkpoint at [sim {ckpt['sim'] + 1}]")
        return ckpt["sim"] + 1

    def load_policy_net(self):
        ckpt = torch.load('/kaggle/input/pokerv4mini-ckpt/pytorch/default/1/policy_net.pth', weights_only=True)
        self.policy_net.load_state_dict(ckpt["state"])
        self.policy_optimizer.load_state_dict(ckpt["optimizer"])
        print("Loaded checkpoint")

In [None]:
# train_cfr.py

class MCGame:
    def __init__(self, load_ckpt=False, load_dataset=False):
        self.losses = []
        self.exploit = []
        self.load_ckpt = load_ckpt
        self.load_dataset = load_dataset
        
    def train(self, 
        n_iters: int, 
        K: int = 3000,
        optim_steps: int = 4000,
        batch_size: int = 10000
    ):
        sim_iter = 0
        mccfr = MCCFR(n_players=2)
        
        M_Vp = ValueDataset()

        if self.load_dataset:
            # Change to path to value dataset
            M_Vp.load("/kaggle/input/cfr-memory/value_dataset.pt")
            print("Loaded value dataset from /kaggle/input/cfr-memory/value_dataset.pt")
            print(torch.unique(M_Vp.T))

        if self.load_ckpt:
            sim_iter = mccfr.load_value_net()

            # In Kaggle runs, kernel typically dies during optimization
            # Change this as needed
            print("Optimizing...")
            losses = mccfr.optimize(M_Vp, sim_iter, optim_steps, batch_size)
            print(f"[Sim {sim_iter}] mean loss:", np.mean(losses))
            self.losses.append(np.mean(losses))
            sim_iter += 1
               
            
        print("Starting game simulations")
        for _ in range(n_iters):
            if sim_iter != 0 and sim_iter % 10 == 0: # exploitability check
                expl = mccfr.exploitability(runs=500)
                print("Exploitability:", expl)
                self.exploit.append(expl)

            decisions = 0
            for k in range(K):   
                state = State(n_players=2)
                mccfr.begin_round(state)
                _ = mccfr.traverse(state, 0, M_Vp, sim_iter + 1)
                decisions += mccfr.decisions
            print(f"[Sim {sim_iter}] average decision points:", decisions / K)
            
            M_Vp.save()
            print(f"[Sim {sim_iter}] value dataset saved")
            
            mccfr.optimizer.zero_grad() # flush gradients
            torch.autograd.set_detect_anomaly(True)

            print(f"[Sim {sim_iter}] Optimizing...")
            losses = mccfr.optimize(M_Vp, sim_iter + 1, optim_steps, batch_size)

            print(f"[Sim {sim_iter}] mean loss:", np.mean(losses))
            self.losses.append(np.mean(losses))

            mccfr.save_checkpoint(sim_iter)
            print(f"[Sim {sim_iter}] value net saved")

            sim_iter += 1

In [None]:
game = MCGame(load_ckpt=True, load_dataset=True)
game.train(100) # Train for 1000 iterations

In [None]:
def plot_figs():
    fig, axs = plt.subplots(ncols=2, figsize=(15, 5))
    axs[0].plot(range(len(game.losses)), game.losses)
    axs[0].set_xlabel("simulation iter")
    axs[0].set_ylabel("avg decision point loss")
    
    axs[1].plot(np.array(range(len(game.exploit))) * 1000, game.exploit)
    axs[1].set_xlabel("simulation iter")
    axs[1].set_ylabel("exploitabilitiy (mbb/g)")
    
    
    fig.savefig('loss.png')