In [227]:
# Imports
import random
import numpy as np

In [228]:
# Data Definitions
CHECK = 0
# For now, fixed bet size.
BET = 1
NUM_ACTIONS = 2 # Check, Bet
BETTING_ROUNDS = 4 # Pre, Flop, Turn, River

In [229]:
class PokerNode:
    def __init__(self):
        self.infoset = ""
        self.regretSum = [0.0] * NUM_ACTIONS
        self.strategy = [0.0] * NUM_ACTIONS
        self.strategySum = [0.0] * NUM_ACTIONS

    def getStrategy(self, realizationWeight):
        normalizingSum = 0
        for action in range(NUM_ACTIONS):
            if self.regretSum[action] > 0:
                self.strategy[action] = self.regretSum[action]
            else:
                self.strategy[action] = 0
            normalizingSum += self.strategy[action]
        for action in range(NUM_ACTIONS):
            if normalizingSum > 0:
                self.strategy[action] /= normalizingSum
            else:
                self.strategy[action] = 1.0 / NUM_ACTIONS
            self.strategySum[action] += realizationWeight * self.strategy[action]
        return self.strategy

    def getAverageStrategy(self):
        averageStrategy = [0.0] * NUM_ACTIONS
        normalizingSum = 0.0
        for action in range(NUM_ACTIONS):
            normalizingSum += self.strategySum[action]
        for action in range(NUM_ACTIONS):
            if normalizingSum > 0:
                averageStrategy[action] = self.strategySum[action] / normalizingSum
            else:
                averageStrategy[action] = 1.0 / NUM_ACTIONS
        return averageStrategy

    def __str__(self):
        avgStratsNice = [round(num, 4) for num in self.getAverageStrategy()]
        return f"{self.infoSet:<40}: Pass: {avgStratsNice[0]} Bet: {avgStratsNice[1]}"

In [230]:
# A history node contains: previous node, pot size, previous action
# A history node has methods:
# historyLength - length of node chain
# prevAction is the previous bet, zero if checked

class historyNode:
    def __init__(self, last=None, prevAction=0, bettingRound=0):
        self.last = last
        if self.last:
            self.pot = last.pot + prevAction
        else:
            self.pot = 0
        self.prevAction = prevAction
        self.bettingRound = bettingRound

    def historyLength(self):
        if self.prevAction == -1:
            return 0
        if roundTerminalNode(self.last) != 0:
            return 1
        else:
            return 1 + self.last.historyLength()

    
    
    def niceString(self):
        if self.last == None:
            return ''
        if roundTerminalNode(self.last):
            return self.last.niceString() + '| ' + formatAction(self.prevAction)
        else:
            return self.last.niceString() + formatAction(self.prevAction)
        
    
    def __str__(self):
        if self.last == None:
            return ''
        return str(self.last) + str(self.prevAction)

def formatAction(action):
    if action > 0:
        return 'B' + str(action) + ' '
    elif action == 0:
        return 'X '
    

In [231]:
# Returns boolean if player is winning
# Current inputs: cards for player, cards for opponent
# Cards is currently just 1 card though
def playerWins(player, opponent):
    return player > opponent

# Some example history sequences:
# 0 0 (Check Check)
# 1 1 (Bet Call)
# 1 2 2 (Bet Raise Call)
# 1 2 3 4 4 (Bet Raise Raise Raise Call) # Due to cases like this, should add real time pot calculations
# 1 2 0 (Bet Raise Fold)

# Is this historyNode a terminal node?
# Terminal Cases:
# Call a bet on River (round=3)
# Fold any time
def terminalNode(historyNode): # Returns 0 for no, 1 for fold, 2 for showdown
    if historyNode.historyLength() < 2:
        return 0
    # If check check or bet/raise call (On River only!) - Showdown
    elif historyNode.prevAction - historyNode.last.prevAction == 0 and historyNode.bettingRound == 3:
        return 2
    # If fold
    elif historyNode.prevAction == 0 and historyNode.last.prevAction not in [0, -1]:
        return 1
    return 0

# 0 is no, 1 is yes
def roundTerminalNode(historyNode):
    if historyNode.historyLength() < 2:
        return 0
    elif historyNode.prevAction - historyNode.last.prevAction == 0 and historyNode.bettingRound < 3:
        return 1
    return 0

In [232]:
class PokerTrainer:
    def __init__(self):
        self.nodeMap = dict()
    
    def train(self, iterations):
        cards = list(range(1, 4))
        util = 0.0
        for iteration in range(iterations):
            random.shuffle(cards)
            util += self.cfr(cards, historyNode(None, -1, 0), 1, 1)
        print("Average game value: ", util / iterations)
        for node in sorted(self.nodeMap.values(), key=(lambda node: node.infoSet)):
            print(node)

    def cfr(self, cards, history, p0, p1):
        plays = history.historyLength()
        player = plays % 2
        opponent = 1 - player

        # Getting return payoffs for terminal states
        if (terminalState := terminalNode(history)) != 0: # Walrus 😎
            # If folded
            if terminalState == 1:
                return -history.pot if player == 1 else history.pot
            # Showdown
            elif terminalState == 2:
                return history.pot if playerWins(cards[0], cards[1]) else -history.pot

        infoSet = str(cards[player]) + ': ' + history.niceString()
        
        # Get info set node, or create it
        node = self.nodeMap.get(infoSet)
        if node == None:
            node = PokerNode()
            node.infoSet = infoSet
            self.nodeMap[infoSet] = node
    
        # Recursively call cfr with additional history and probability for each action
        strategy = node.getStrategy(p0 if player == 0 else p1)
        util = [0.0] * NUM_ACTIONS
        nodeUtil = 0
        for action in range(NUM_ACTIONS):
            nextHistory = historyNode(history, action, history.bettingRound + 1 if roundTerminalNode(history) == 1 else history.bettingRound)
            if player == 0:
                util[action] = - self.cfr(cards, nextHistory, p0 * strategy[action], p1)
            else:
                util[action] = - self.cfr(cards, nextHistory, p0, p1 * strategy[action])
            nodeUtil += strategy[action] * util[action]
    
        # Compute and accumulate cfr for each action
        for action in range(NUM_ACTIONS):
            regret = util[action] - nodeUtil
            node.regretSum[action] += (p1 if player == 0 else p0) * regret
    
        return nodeUtil

In [233]:
trainer = PokerTrainer()
trainer.train(10000)

Average game value:  0.33596134255700916
1:                                      : Pass: 0.9965 Bet: 0.0035
1: B1                                   : Pass: 0.9999 Bet: 0.0001
1: B1 B1                                : Pass: 0.8804 Bet: 0.1196
1: B1 B1 | B1                           : Pass: 0.5 Bet: 0.5
1: B1 B1 | B1 B1                        : Pass: 0.5 Bet: 0.5
1: B1 B1 | B1 B1 | B1                   : Pass: 0.5 Bet: 0.5
1: B1 B1 | B1 B1 | B1 B1                : Pass: 0.5 Bet: 0.5
1: B1 B1 | B1 B1 | B1 B1 | B1           : Pass: 0.5 Bet: 0.5
1: B1 B1 | B1 B1 | B1 B1 | X            : Pass: 0.5 Bet: 0.5
1: B1 B1 | B1 B1 | B1 B1 | X B1         : Pass: 0.5 Bet: 0.5
1: B1 B1 | B1 B1 | X                    : Pass: 0.5 Bet: 0.5
1: B1 B1 | B1 B1 | X B1                 : Pass: 0.5 Bet: 0.5
1: B1 B1 | B1 B1 | X B1 B1              : Pass: 0.5 Bet: 0.5
1: B1 B1 | B1 B1 | X B1 B1 | B1         : Pass: 0.5 Bet: 0.5
1: B1 B1 | B1 B1 | X B1 B1 | X          : Pass: 0.5 Bet: 0.5
1: B1 B1 | B1 B1 | X B1 B1

TODO:
Change terminal payoff state calculations
Need limited betting rounds, as well as sizings. Now just choose pot.

CURRENT ASSUMPTIONS:

- Stacks: INF

- Bets: INF

- Betting Rounds: 3

In [148]:
#Testing
initialNode = historyNode(None, -1, 0)
midNode = historyNode(historyNode(historyNode(initialNode, 1, 0), 1, 0), 1, 1)
terminalFoldNode = historyNode(midNode, 0, 1)
turnNode = historyNode(historyNode(historyNode(midNode, 1, 1), 1, 2), 1, 2)
riverNode = historyNode(historyNode(turnNode, 1, 3), 1, 3)
checksFuckNode = historyNode(historyNode(historyNode(initialNode, 0, 2), 1, 2), 0, 2)
retardNode = historyNode(historyNode(initialNode, 0, 2), 1, 2)
print(terminalNode(checksFuckNode))
print(checksFuckNode.historyLength())
print(roundTerminalNode(retardNode))

0
1
None
