## Leduc Hold'em Game Code

### Import

In [1]:
from copy import deepcopy, copy
import numpy as np
import csv
import os

### Card

**Card stores the suit and rank of a single card.**
    
    Note:
        The suit variable in a standard card game should be one of [S, H, D, C, BJ, RJ] meaning [Spades, Hearts, Diamonds, Clubs, Black Joker, Red Joker]
        Similarly the rank variable should be one of [A, 2, 3, 4, 5, 6, 7, 8, 9, T, J, Q, K]

In [2]:
class Card:
    
    suit = None
    rank = None
    valid_suit = ['S', 'H', 'D', 'C', 'BJ', 'RJ']
    valid_rank = ['A', '2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K']

    """Initialize the suit and rank of a card"""
    def __init__(self, suit, rank):        
        '''
        Args:
            suit: string, suit of the card, should be one of valid_suit
            rank: string, rank of the card, should be one of valid_rank
        '''
        self.suit = suit
        self.rank = rank

    def __eq__(self, other):
        if isinstance(other, Card):
            return self.rank == other.rank and self.suit == other.suit
        else:
            # don't attempt to compare against unrelated types
            return NotImplemented

    def __hash__(self):
        suit_index = Card.valid_suit.index(self.suit)
        rank_index = Card.valid_rank.index(self.rank)
        return rank_index + 100 * suit_index

    """Get string representation of a card."""
    def __str__(self):
        ''' 
        Returns:
            string: the combination of rank and suit of a card. Eg: AS, 5H, JD, 3C, ...
        '''    
        return self.rank + self.suit
  
    """Get index of a card.""" 
    def get_index(self):
        '''    
        Returns:
            string: the combination of suit and rank of a card. Eg: 1S, 2H, AD, BJ, RJ...
        '''
        return self.suit + self.rank

### Util

In [4]:
def set_seed(seed):
    if seed is not None:
        import subprocess
        import sys

        reqs = subprocess.check_output([sys.executable, '-m', 'pip', 'freeze'])
        installed_packages = [r.decode().split('==')[0] for r in reqs.split()]
        if 'torch' in installed_packages:
            import torch
            torch.backends.cudnn.deterministic = True
            torch.manual_seed(seed)
        np.random.seed(seed)
        import random
        random.seed(seed)

def get_device():
    import torch
    if torch.backends.mps.is_available():
        device = torch.device("mps:0")
        print("--> Running on the GPU")
    elif torch.cuda.is_available():
        device = torch.device("cuda:0")
        print("--> Running on the GPU")
    else:
        device = torch.device("cpu")
        print("--> Running on the CPU")

    return device    

"""Initialize a standard deck of 52 cards"""
def init_standard_deck():
    '''
    Returns:
        (list): A list of Card object
    '''
    suit_list = ['S', 'H', 'D', 'C']
    rank_list = ['A', '2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K']
    res = [Card(suit, rank) for suit in suit_list for rank in rank_list]
    return res

"""Initialize a standard deck of 52 cards, BJ and RJ"""
def init_54_deck():
    ''' 
    Returns:
        (list): Alist of Card object
    '''
    suit_list = ['S', 'H', 'D', 'C']
    rank_list = ['A', '2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K']
    res = [Card(suit, rank) for suit in suit_list for rank in rank_list]
    res.append(Card('BJ', ''))
    res.append(Card('RJ', ''))
    return res

"""Get the coresponding number of a rank."""
def rank2int(rank):
    '''
    Args:
        rank(str): rank stored in Card objec
    Returns:
        (int): the number corresponding to the rank
    Note:
        1. If the input rank is an empty string, the function will return -1.
        2. If the input rank is not valid, the function will return None.
    '''
    if rank == '':
        return -1
    elif rank.isdigit():
        if int(rank) >= 2 and int(rank) <= 10:
            return int(rank)
        else:
            return None
    elif rank == 'A':
        return 14
    elif rank == 'T':
        return 10
    elif rank == 'J':
        return 11
    elif rank == 'Q':
        return 12
    elif rank == 'K':
        return 13
    return None

"""pls ignore"""
def reorganize(trajectories, payoffs):
    num_players = len(trajectories)
    new_trajectories = [[] for _ in range(num_players)]

    for player in range(num_players):
        for i in range(0, len(trajectories[player])-2, 2):
            if i ==len(trajectories[player])-3:
                reward = payoffs[player]
                done =True
            else:
                reward, done = 0, False
            transition = trajectories[player][i:i+3].copy()
            transition.insert(2, reward)
            transition.append(done)

            new_trajectories[player].append(transition)
    return new_trajectories

"""pls ignore"""
def tournament(env, num):
    payoffs = [0 for _ in range(env.num_players)]
    counter = 0
    while counter < num:
        _, _payoffs = env.run(is_training=False)
        if isinstance(_payoffs, list):
            for _p in _payoffs:
                for i, _ in enumerate(payoffs):
                    payoffs[i] += _p[i]
                counter += 1
        else:
            for i, _ in enumerate(payoffs):
                payoffs[i] += _payoffs[i]
            counter += 1
    for i, _ in enumerate(payoffs):
        payoffs[i] /= counter
    return payoffs

"""pls ignore"""
def plot_curve(csv_path, save_path, algorithm):
    import os
    import csv
    import matplotlib.pyplot as plt
    with open(csv_path) as csvfile:
        reader = csv.DictReader(csvfile)
        xs = []
        ys = []
        for row in reader:
            xs.append(int(row['episode']))
            ys.append(float(row['reward']))
        fig, ax = plt.subplots()
        ax.plot(xs, ys, label=algorithm)
        ax.set(xlabel='episode', ylabel='reward')
        ax.legend()
        ax.grid()

        save_dir = os.path.dirname(save_path)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        fig.savefig(save_path)

### Dealer

In [5]:
class LeducholdemDealer:
    
    """Initialize a leducholdem dealer class"""
    def __init__(self, np_random):

        self.np_random = np_random
        self.deck = [Card('S', 'J'), Card('H', 'J'), Card('S', 'Q'), Card('H', 'Q'), Card('S', 'K'), Card('H', 'K')]
        self.shuffle()
        self.pot = 0
    
    def shuffle(self):
        self.np_random.shuffle(self.deck)

    """Deal one card from the deck"""
    def deal_card(self): 
        '''
        Returns:
            (Card): The drawn card from the deck
    '''   
        return self.deck.pop()

### Player

In [6]:
class LeducholdemPlayer:
    
    """Initilize a player."""
    def __init__(self, player_id, np_random):
        ''' 
        Args:
            player_id (int): The id of the player
        '''
        self.np_random = np_random
        self.player_id = player_id
        self.status = "alive"
        self.hand = None 
        self.in_chips = 0 # The chips that this player has put in until now

    """Encode the state for the player"""
    def get_state(self, public_card, all_chips, legal_actions): 
        ''' 
        Args:
            public_card (object): The public card that seen by all the players
            all_chips (int): The chips that all players have put in
        Returns:
            (dict): The state of the player
        '''   
        state = {}
        state['hand'] = self.hand.get_index()
        state['public_card'] = public_card.get_index() if public_card else None
        state['all_chips'] = all_chips
        state['my_chips'] = self.in_chips
        state['legal_actions'] = legal_actions
        return state

    """Return the id of the player"""
    def get_player_id(self):
        return self.player_id

### Game Winner Judger

In [9]:
class LeducholdemJudger:
    
    """Initialize a judger class"""
    def __init__(self, np_random):    
        self.np_random = np_random

    """Judge the winner of the game."""
    @staticmethod
    def judge_game(players, public_card):
        '''
        Args:
            players (list): The list of players who play the game
            public_card (object): The public card that seen by all the players
        Returns:
            (list): Each entry of the list corresponds to one entry of the
        '''

        # Judge who are the winners
        winners = [0] * len(players) # in leduc: a 2 elements list
        fold_count = 0
        ranks = []
        
        # If every player folds except one, the alive player is the winner
        for idx, player in enumerate(players):
            ranks.append(rank2int(player.hand.rank))
            if player.status == 'folded':
               fold_count += 1
            elif player.status == 'alive':
                alive_idx = idx
        if fold_count == (len(players) - 1):
            winners[alive_idx] = 1
        
        # If any of the players matches the public card wins
        if sum(winners) < 1:
            for idx, player in enumerate(players):
                if player.hand.rank == public_card.rank:
                    winners[idx] = 1
                    break
        
        # If non of the above conditions, the winner player is the one with the highest card rank
        if sum(winners) < 1:
            max_rank = max(ranks)
            max_index = [i for i, j in enumerate(ranks) if j == max_rank]
            # for idx in max_index:
            #     winners[idx] = 1
            # If more than one player has the max rank, it's a tie
            if len(max_index) > 1:
                for idx in max_index:
                    winners[idx] = 1

        # Compute the total chips
        total = 0
        for p in players:
            total += p.in_chips

        # each_win = float(total) / sum(winners)
        if sum(winners) > 0:
            each_win = float(total) / sum(winners)
        else:
            each_win = 0.0

        payoffs = []
        for i, _ in enumerate(players):
            if winners[i] == 1:
                payoffs.append(each_win - players[i].in_chips)
            else:
                payoffs.append(float(-players[i].in_chips))

        return payoffs

### Round

**Round can call other Classes' functions to keep the game running.**

In [11]:
class LeducholdemRound:
    
    """Initialize the round class"""
    def __init__(self, raise_amount, allowed_raise_num, num_players, np_random):
        '''
        Args:
            raise_amount (int): the raise amount for each raise
            allowed_raise_num (int): The number of allowed raise num
            num_players (int): The number of players
        '''
        self.np_random = np_random
        self.game_pointer = None
        self.raise_amount = raise_amount
        self.allowed_raise_num = allowed_raise_num
        self.num_players = num_players

        # Count the number of raise
        self.have_raised = 0

        # Count the number of player without raise
        # If every player agree to not raise, the round is over
        self.not_raise_num = 0

        # Raised amount for each player
        self.raised = [0 for _ in range(self.num_players)] # in leduc, [0,0]
        self.player_folded = None

    """Start a new bidding round"""
    def start_new_round(self, game_pointer, raised=None): 
        '''
        Args:
            game_pointer (int): The game_pointer that indicates the next player
            raised (list): Initialize the chips for each player
        Note: For the first round of the game, we need to setup the big/small blind
        '''       
        self.game_pointer = game_pointer
        self.have_raised = 0
        self.not_raise_num = 0
        if raised:
            self.raised = raised
        else:
            self.raised = [0 for _ in range(self.num_players)]

    """Call Player Class functions to keep one round running"""
    def proceed_round(self, players, action):
        '''
        Args:
            players (list): The list of players that play the game
            action (str): An legal action taken by the player
        Returns:
            (int): The game_pointer that indicates the next player
        '''
        
        if action not in self.get_legal_actions():
            raise Exception('{} is not legal action. Legal actions: {}'.format(action, self.get_legal_actions()))

        if action == 'call':
            diff = max(self.raised) - self.raised[self.game_pointer] # the chips that should be put by the game pointer if he calls
            self.raised[self.game_pointer] = max(self.raised) # update the present maximum chip amount
            players[self.game_pointer].in_chips += diff # update the total chips put by the game pointer
            self.not_raise_num += 1 # number of players without raise

        elif action == 'raise':
            diff = max(self.raised) - self.raised[self.game_pointer] + self.raise_amount # the present maximum betting chips + raise chips - the game pointer's betting chips
            self.raised[self.game_pointer] = max(self.raised) + self.raise_amount # update the present maximum chip amount
            players[self.game_pointer].in_chips += diff # update the total chips put by the game pointer
            self.have_raised += 1 # number of players that raise
            self.not_raise_num = 1 # reset

        elif action == 'fold':
            players[self.game_pointer].status = 'folded'
            self.player_folded = True

        elif action == 'check':
            self.not_raise_num += 1

        self.game_pointer = (self.game_pointer + 1) % self.num_players

        # Skip the folded players
        while players[self.game_pointer].status == 'folded':
            self.game_pointer = (self.game_pointer + 1) % self.num_players

        return self.game_pointer

    """Obtain the legal actions for the current player""" 
    def get_legal_actions(self):
        '''
        Returns:
           (list):  A list of legal actions
        '''
        full_actions = ['call', 'raise', 'fold', 'check']

        # If the the number of raises already reaches the maximum number raises, we can not raise any more
        if self.have_raised >= self.allowed_raise_num:
            full_actions.remove('raise')

        # If the current chips are less than that of the highest one in the round, we can not check
        if self.raised[self.game_pointer] < max(self.raised):
            full_actions.remove('check')

        # If the current player has put in the chips that are more than others, we can not call
        if self.raised[self.game_pointer] == max(self.raised):
            full_actions.remove('call')

        return full_actions
    
    """Check whether the round is over"""
    def is_over(self):
        '''
        Returns:
            (boolean): True if the current round is over
        '''
        # call, check -> not_raise_num + 1    
        if self.not_raise_num >= self.num_players:
            return True
        return False

### Game

In [12]:
class LeducholdemGame():
    
    """Initialize the class leducholdem Game, Set game rules"""
    def __init__(self, config):
        ''' 
        Configs:
            num_players (int): the number of game players
            small_blind (int): The amount of small blind
            allowed_raise_num (int): the maximum raise chances
            allow_step_back (Boolean)
            seed (int)
        '''
        self.allow_step_back = config.get('allow_step_back', True)
        self.np_random = np.random.RandomState(config.get('seed', None))
        
        self.num_players = config.get('game_num_players', 2)

        # Small blind and big blind
        self.small_blind = config.get('small_blind', 1)
        self.big_blind = 2 * self.small_blind

        # Raise amount and allowed times
        self.raise_amount = self.big_blind
        self.allowed_raise_num = config.get('allowed_raise_num', 2)

        # initial chips
        self.initial_chips = [config.get('initial_chips', 1000),config.get('initial_chips', 1000)]
        self.reset_match()
    
    ''' Specifiy some game specific parameters'''
    def configure(self, game_config):
        self.num_players = game_config.get('game_num_players', self.num_players)
        self.small_blind = game_config.get('small_blind', self.small_blind)
        self.big_blind = 2 * self.small_blind
        self.raise_amount = self.big_blind
        self.allowed_raise_num = game_config.get('allowed_raise_num', self.allowed_raise_num)
        self.initial_chips = game_config.get('initial_chips', self.initial_chips)

    """Initialilze the game of Limit Texas Hold'em"""
    def init_game(self):
        ''' 
        This version supports two-player limit texas hold'em
        Returns:
            (tuple): Tuple containing:
                (dict): The first state of the game
                (int): Current player's id
        '''
        # Initilize a dealer that can deal cards
        self.dealer = LeducholdemDealer(self.np_random)

        # Initilize two players to play the game
        self.players = [LeducholdemPlayer(i, self.np_random) for i in range(self.num_players)]

        # Initialize a judger class which will decide who wins in the end
        self.judger = LeducholdemJudger(self.np_random)

        # Set initial player chips
        for i in range(self.num_players):
            self.players[i].chips = self.player_chips[i]
            
        # Prepare for the first round
        for i in range(self.num_players):
            self.players[i].hand = self.dealer.deal_card()
        
        # Randomly choose a small blind(player1) and a big blind(player2)
        s = self.np_random.randint(0, self.num_players) # in leduc, either 0 or 1
        b = (s + 1) % self.num_players
        self.players[b].in_chips = self.big_blind
        self.players[s].in_chips = self.small_blind
        self.public_card = None
        
        # The player with small blind plays the first
        self.game_pointer = s

        # Initilize a betting round, in the first round, the big blind and the small blind needs to
        # be passed to the round for processing.
        self.round = LeducholdemRound(raise_amount=self.raise_amount,
                           allowed_raise_num=self.allowed_raise_num,
                           num_players=self.num_players,
                           np_random=self.np_random)

        self.round.start_new_round(game_pointer=self.game_pointer, raised=[p.in_chips for p in self.players])

        # Count the round. There are 2 rounds in each game.
        self.round_counter = 0

        # Save the hisory for stepping back to the last state.
        self.history = []

        state = self.get_state(self.game_pointer)

        return state, self.game_pointer

    """reset match when it is over (one runs out his chips)"""
    def reset_match(self):
        self.player_chips = [self.initial_chips, self.initial_chips]
        self.match_over = False

    """Return the legal actions for current player"""
    def get_legal_actions(self):
        '''
        legal actions vary in different rounds (call Round Class to get), follow different actions of opponents
        Returns:
            (list): A list of legal actions
        '''
        return self.round.get_legal_actions()
    
    """Return the current player's id"""
    def get_player_id(self):
        '''
        Returns:
            (int): current player's id
        '''
        return self.game_pointer
    
    """Return the number of applicable actions"""
    def get_num_actions(self):
        '''
        Returns:
            (int): The number of actions. There are 4 actions (call, raise, check and fold)
        '''
        return 4
    
    """Return the number of players in limit texas holdem"""
    def get_num_players(self):
        '''
        Returns:
            (int): The number of players in the game
        '''
        return self.num_players
    
    """Get the next state"""
    def step(self, action):
        ''' 
        Args:
            action (str): a specific action. (call, raise, fold, or check)
        Returns:
            (tuple): Tuple containing:
                (dict): next player's state
                (int): next plater's id
        '''
        if self.allow_step_back:
            # First snapshot the current state
            r = copy(self.round)
            r_raised = copy(self.round.raised)
            gp = self.game_pointer
            r_c = self.round_counter
            d_deck = copy(self.dealer.deck)
            p = copy(self.public_card)
            ps = [copy(self.players[i]) for i in range(self.num_players)]
            ps_hand = [copy(self.players[i].hand) for i in range(self.num_players)]
            self.history.append((r, r_raised, gp, r_c, d_deck, p, ps, ps_hand))

        # Then we proceed to the next round
        self.game_pointer = self.round.proceed_round(self.players, action)

        # If a round is over, we deal more public cards
        if self.round.is_over():
            # For the first round, we deal 1 card as public card. Double the raise amount for the second round
            if self.round_counter == 0:
                self.public_card = self.dealer.deal_card()
                self.round.raise_amount = 2 * self.raise_amount

            self.round_counter += 1
            self.round.start_new_round(self.game_pointer)

        state = self.get_state(self.game_pointer)

        return state, self.game_pointer

    """Return to the previous state of the game"""
    def step_back(self):
        ''' 
        Returns:
            (bool): True if the game steps back successfully
        '''    
        if len(self.history) > 0:
            self.round, r_raised, self.game_pointer, self.round_counter, d_deck, self.public_card, self.players, ps_hand = self.history.pop()
            self.round.raised = r_raised
            self.dealer.deck = d_deck
            for i, hand in enumerate(ps_hand):
                self.players[i].hand = hand
            return True
        return False
    
    """Return player's state"""
    def get_state(self, player): 
        ''' 
        Args:
            player_id (int): player id
        Returns:
            (dict): The state of the player
        '''   
        chips = [self.players[i].in_chips for i in range(self.num_players)]
        legal_actions = self.get_legal_actions()
        state = self.players[player].get_state(self.public_card, chips, legal_actions) # class Player.get_state
        state['current_player'] = self.game_pointer
        return state

    """Check if the game is over"""
    def is_over(self):
        ''' 
        Returns:
            (boolean): True if the game is over
        '''    
        alive_players = [1 if p.status=='alive' else 0 for p in self.players]
        # If only one player is alive, the game is over.
        if sum(alive_players) == 1:
            return True

        # If all rounds are finshed
        if self.round_counter >= 2:
            return True
        return False

    """Return the payoffs of the game"""
    def get_payoffs(self):    
        '''
        Returns:
            (list): Each entry corresponds to the payoff of one player
        '''
        chips_payoffs = self.judger.judge_game(self.players, self.public_card)
        payoffs = np.array(chips_payoffs) # payoffs = np.array(chips_payoffs) / (self.big_blind)
        return payoffs

    """Return the chip stacks of player"""
    def get_chipstack(self):
        

## Algorithm Code (Monte Carlo CFR Chance Sampling)

### Import

In [90]:
import hashlib
import numpy as np
import collections
import os
import argparse
import pickle
import struct
from collections import OrderedDict

### Seeding

In [91]:
def colorize(string, color, bold=False, highlight = False):
    attr = []
    num = color2num[color]
    if highlight: num += 10
    attr.append(str(num))
    if bold: attr.append('1')
    attrs = ';'.join(attr)
    return '\x1b[%sm%s\x1b[0m' % (attrs, string)

def error(msg, *args):
    print(colorize('%s: %s'%('ERROR', msg % args), 'red'))

def np_random(seed=None):
    if seed is not None and not (isinstance(seed, int) and 0 <= seed):
        raise error.Error('Seed must be a non-negative integer or omitted, not {}'.format(seed))

    seed = create_seed(seed)

    rng = np.random.RandomState()
    rng.seed(_int_list_from_bigint(hash_seed(seed)))
    return rng, seed

def hash_seed(seed=None, max_bytes=8):
    if seed is None:
        seed = create_seed(max_bytes=max_bytes)
    _hash = hashlib.sha512(str(seed).encode('utf8')).digest()
    return _bigint_from_bytes(_hash[:max_bytes])

def create_seed(a=None, max_bytes=8):
    if a is None:
        a = _bigint_from_bytes(os.urandom(max_bytes))
    elif isinstance(a, str):
        a = a.encode('utf8')
        a += hashlib.sha512(a).digest()
        a = _bigint_from_bytes(a[:max_bytes])
    elif isinstance(a, int):
        a = a % 2**(8 * max_bytes)
    else:
        raise error.Error('Invalid type for seed: {} ({})'.format(type(a), a))

    return a

def _bigint_from_bytes(_bytes):
    sizeof_int = 4
    padding = sizeof_int - len(_bytes) % sizeof_int
    _bytes += b'\0' * padding
    int_count = int(len(_bytes) / sizeof_int)
    unpacked = struct.unpack("{}I".format(int_count), _bytes)
    accum = 0
    for i, val in enumerate(unpacked):
        accum += 2 ** (sizeof_int * 8 * i) * val
    return accum

def _int_list_from_bigint(bigint):
    # Special case 0
    if bigint < 0:
        raise error.Error('Seed must be non-negative, not {}'.format(bigint))
    elif bigint == 0:
        return [0]

    ints = []
    while bigint > 0:
        bigint, mod = divmod(bigint, 2 ** 32)
        ints.append(mod)
    return ints

### General Game Env

**The base Env class.**

In [92]:
class Env(object):
    
    """Initialize the environment"""
    def __init__(self, config):    

        self.action_recorder = []
        self.allow_step_back = self.game.allow_step_back = config['allow_step_back']

        # Get the number of players/actions in this game
        self.num_players = self.game.get_num_players()
        self.num_actions = self.game.get_num_actions()

        # A counter for the timesteps
        self.timestep = 0

        # Set random seed, default is None
        self.seed(config['seed'])
    
    """Start a new game"""
    def reset(self):
        ''' 
        Returns:
            (tuple): Tuple containing:
                (numpy.array): The begining state of the game
                (int): The begining player
        '''     
        state, player_id = self.game.init_game()
        self.action_recorder = []
        return self._extract_state(state), player_id

    """Step forward"""
    def step(self, action, raw_action=False): 
        ''' 
        Args:
            action (int): The action taken by the current player
            raw_action (boolean): True if the action is a raw action
        Returns:
            (tuple): Tuple containing:
                (dict): The next state
                (int): The ID of the next player
        '''   
        if not raw_action:
            action = self._decode_action(action)
        self.timestep += 1
        # Record the action for human interface
        self.action_recorder.append((self.get_player_id(), action))
        next_state, player_id = self.game.step(action)

        return self._extract_state(next_state), player_id

    """Take one step backward."""
    def step_back(self):
        ''' 
        Returns:
            (tuple): Tuple containing:
                (dict): The previous state
                (int): The ID of the previous player
        Note: Error will be raised if step back from the root node.
        '''
        if not self.allow_step_back:
            raise Exception
        
        if not self.game.step_back():
            return False

        player_id = self.get_player_id()
        state = self.get_state(player_id)

        return state, player_id
    
    """Set the agents that will interact with the environment. Must be called before `run`."""
    def set_agents(self, agents):
        '''
        Args:
            agents (list): List of Agent classes
        '''        
        self.agents = agents

    """Run a complete game, either for evaluation or training RL agent."""
    def run(self, is_training=True):
        '''
        Args:
            is_training (boolean): True if for training purpose.
        Returns:
            (tuple) Tuple containing:
                (list): A list of trajectories generated from the environment.
                (list): A list payoffs. Each entry corresponds to one player.
        Note: The trajectories are 3-dimension list.
                The first dimension is for different players.
                The second dimension is for different transitions. 
                The third dimension is for the contents of each transiton.
        '''    
        trajectories = [[] for _ in range(self.num_players)]
        state, player_id = self.reset()

        # Loop to play the game
        trajectories[player_id].append(state)
        while not self.is_over():
            # Agent plays
            if not is_training:
                action, _ = self.agents[player_id].eval_step(state)
            else:
                action = self.agents[player_id].step(state)

            # Environment steps
            next_state, next_player_id = self.step(action, self.agents[player_id].use_raw) #self.agents[player_id]使得player0是cfr agent, player1是random agent
            # Save action
            trajectories[player_id].append(action)

            # Set the state and player
            state = next_state
            player_id = next_player_id

            # Save state.
            if not self.game.is_over():
                trajectories[player_id].append(state)

        # Add a final state to all the players
        for player_id in range(self.num_players):
            state = self.get_state(player_id)
            trajectories[player_id].append(state)

        # Payoffs
        payoffs = self.get_payoffs()

        return trajectories, payoffs

    """Check whether the curent game is over"""
    def is_over(self):
        ''' 
        Returns:
            (boolean): True if current game is over
        '''
        return self.game.is_over()

    """Get the current player id"""
    def get_player_id(self):
        ''' 
        Returns:
            (int): The id of the current player
        '''
        return self.game.get_player_id()

    """Get the state given player id"""
    def get_state(self, player_id):
        '''     
        Args:
            player_id (int): The player id
        Returns:
            (numpy.array): The observed state of the player
        '''
        return self._extract_state(self.game.get_state(player_id))

    """Get the payoffs of players. Must be implemented in the child class."""
    def get_payoffs(self):
        ''' 
        Returns:
            (list): A list of payoffs for each player.
        '''
        raise NotImplementedError

    """Get the perfect information of the current state"""
    def get_perfect_information(self):
        ''' 
        Returns:
            (dict): A dictionary of all the perfect information of the current state
        ''' 
        raise NotImplementedError

    """Extract useful information from state for RL. Must be implemented in the child class."""
    def _extract_state(self, state): 
        ''' 
        Args:
            state (dict): The raw state
        Returns:
            (numpy.array): The extracted state
        '''   
        raise NotImplementedError

    """Decode Action id to the action in the game. Must be implemented in the child class."""
    def _decode_action(self, action_id):
        ''' 
        Args:
            action_id (int): The id of the action
        Returns:
            (string): The action that will be passed to the game engine.      
        '''
        raise NotImplementedError

    """Get all legal actions for current state. Must be implemented in the child class."""
    def _get_legal_actions(self):
        ''' 
        Returns:
            (list): A list of legal actions' id.    
        '''
        raise NotImplementedError

    '''Seeding'''
    def seed(self, seed=None):
        self.np_random, seed = np_random(seed)
        self.game.np_random = self.np_random
        return seed

### Baseline Leduc Holdem Env

**For specific setting, more should be create to include:**

    1) tournament setting, 

    2) time average ensemble, etc


In [93]:
class LeducholdemEnv(Env):
    
    """Initialize the Limitholdem environment"""
    def __init__(self, config):
         
        self.game = LeducholdemGame(config)
        super().__init__(config)
        self.actions = ['call', 'raise', 'fold', 'check']
        self.card2index = {"SJ": 0, "SQ": 1, "SK": 2, "HJ": 0, "HQ": 1, "HK": 2}
        
        # self.state_shape = [[36] for _ in range(self.num_players)]
        # self.action_shape = [None for _ in range(self.num_players)]   
        # self.max_chips = 1000  # Maximum chips each player can have
    
        # def reset(self):
        #     self.game.init_game()
        #     # for player in self.game.players:
        #     #     player.chips = self.max_chips  # Reset each player's chips to max_chips
        #     self.action_recorder = []
        #     return self._extract_state(self.game.get_state(0)), 0
    
    """Get all leagal actions"""
    def _get_legal_actions(self):
        ''' 
        Returns:
            encoded_action_list (list): return encoded legal action list (from str to int)
        '''
        return self.game.get_legal_actions()

    """Extract the state representation from state dictionary for agent"""
    def _extract_state(self, state, readable=True):
        ''' 
        Args:
            state (dict): Original state from the game
        Returns:
            observation (list or str): combine the player's score and dealer's observable score for observation
    '''    
        extracted_state = {}

        legal_actions = OrderedDict({self.actions.index(a): None for a in state['legal_actions']})
        extracted_state['legal_actions'] = legal_actions # actions index (int, not str)

        public_card = state['public_card']
        hand = state['hand']
        
        if readable==False:
            obs = np.zeros(36)
            obs[self.card2index[hand]] = 1
            if public_card:
                obs[self.card2index[public_card]+3] = 1
            obs[state['my_chips']+6] = 1
            obs[sum(state['all_chips'])-state['my_chips']+21] = 1
        else:
            current_player_id = state['current_player']
            in_chips = state['my_chips']
            all_chips = sum(state['all_chips'])
            obs = f"{current_player_id}_{hand}_{public_card}_{self.action_recorder}_{in_chips}_{all_chips}_{state['legal_actions']}_{list(extracted_state['legal_actions'].keys())}"
            
        extracted_state['obs'] = obs
        extracted_state['raw_obs'] = state
        extracted_state['raw_legal_actions'] = [a for a in state['legal_actions']]
        extracted_state['action_record'] = self.action_recorder

        return extracted_state

    """Get the payoff of a game"""
    def get_payoffs(self):
        ''' 
        Returns:
           payoffs (list): list of payoffs
        '''   
        return self.game.get_payoffs()

    """Decode the action for applying to the game"""
    def _decode_action(self, action_id):
        ''' 
        Args:
            action id (int): action id
        Returns:
            action (str): action for the game
    '''
        legal_actions = self.game.get_legal_actions()
        if self.actions[action_id] not in legal_actions:
            if 'check' in legal_actions:
                return 'check'
            else:
                return 'fold'
        return self.actions[action_id]

    """Get the perfect information of the current state"""
    def get_perfect_information(self):
        ''' 
        Returns:
            (dict): A dictionary of all the perfect information of the current state
        '''
        state = {}
        state['chips'] = [self.game.players[i].in_chips for i in range(self.num_players)]
        state['public_card'] = self.game.public_card.get_index() if self.game.public_card else None
        state['hand_cards'] = [self.game.players[i].hand.get_index() for i in range(self.num_players)]
        state['current_round'] = self.game.round_counter
        state['current_player'] = self.game.game_pointer
        state['legal_actions'] = self.game.get_legal_actions()
        return state

### CFR Algorithm - CFR agent

**Implement CFR (chance sampling) algorithm.**

In [94]:
class CFRAgent():
    
    """Initilize Agent"""
    def __init__(self, env, model_path='./cfr_model'):
        ''' 
        Args:
            env (Env): Env class
            model_path: where to store the model and results
        '''  
        self.use_raw = False
        self.env = env
        self.model_path = model_path

        # A policy is a dict state_str -> action probabilities
        self.policy = collections.defaultdict(list)
        self.average_policy = collections.defaultdict(np.array)
        self.state_utilities = {}  # Dictionary to store state utilities

        # Regret is a dict state_str -> action regrets
        self.regrets = collections.defaultdict(np.array)

        self.iteration = 0

    """Do one iteration of CFR"""
    def train(self):
        self.iteration += 1
        # Firstly, traverse tree to compute counterfactual regret for each player
        # The regrets are recorded in traversal
        for player_id in range(self.env.num_players):
            self.env.reset()
            probs = np.ones(self.env.num_players)
            self.traverse_tree(probs, player_id)

        # Update policy
        self.update_policy()
        print(self.iteration)

    """Traverse the game tree, update the regrets"""
    def traverse_tree(self, probs, player_id):
        ''' 
        Args:
            probs: The reach probability of the current node
            player_id: The player to update the value
        Returns:
            state_utilities (list): The expected utilities for all the players
        '''
        if self.env.is_over():
            return self.env.get_payoffs()

        current_player = self.env.get_player_id()

        action_utilities = {}
        state_utility = np.zeros(self.env.num_players)
        obs, legal_actions = self.get_state(current_player)
        action_probs = self.action_probs(obs, legal_actions, self.policy)

        for action in legal_actions:
            action_prob = action_probs[action]
            new_probs = probs.copy()
            new_probs[current_player] *= action_prob

            # Keep traversing the child state
            self.env.step(action)
            utility = self.traverse_tree(new_probs, player_id)
            self.env.step_back()

            state_utility += action_prob * utility
            action_utilities[action] = utility

        if not current_player == player_id:
            return state_utility
        
        # need to update!
        if obs not in self.state_utilities:
            self.state_utilities[obs] = []  # Initialize state utility list for the current state
        self.state_utilities[obs].append(state_utility)

        # If it is current player, we record the policy and compute regret
        player_prob = probs[current_player]
        counterfactual_prob = (np.prod(probs[:current_player]) *
                                np.prod(probs[current_player + 1:]))
        player_state_utility = state_utility[current_player]

        if obs not in self.regrets:
            self.regrets[obs] = np.zeros(self.env.num_actions)
        if obs not in self.average_policy:
            self.average_policy[obs] = np.zeros(self.env.num_actions)
        
        # iterate each legal action and calculate its prob
        for action in legal_actions:
            action_prob = action_probs[action]
            # regret = action counterfactual prob * action utility - current state utility
            regret = counterfactual_prob * (action_utilities[action][current_player]
                    - player_state_utility)
            # update current state regret and average policy
            self.regrets[obs][action] += regret
            self.average_policy[obs][action] += self.iteration * player_prob * action_prob
        
        return state_utility

    """Update policy/strategy based on the current regrets"""
    def update_policy(self):
        for obs in self.regrets:
            self.policy[obs] = self.regret_matching(obs)
            print(f"Updated policy: {self.policy[obs]}")
    
    """Apply regret matching"""
    '''update action probabilities only when regret > 0, that is, not taking certain actions will result in regret'''
    def regret_matching(self, obs):
        ''' 
        Args:
            obs (string): The state_str
        '''
        regret = self.regrets[obs]
        positive_regret_sum = sum([r for r in regret if r > 0]) # negative regret values ​​indicate no regret or no need for adjustment

        action_probs = np.zeros(self.env.num_actions)
        if positive_regret_sum > 0:
            for action in range(self.env.num_actions):
                action_probs[action] = max(0.0, regret[action] / positive_regret_sum)
        else:
            for action in range(self.env.num_actions):
                action_probs[action] = 1.0 / self.env.num_actions
        
        # print(f"Regret matching for obs: {obs}")
        # print(f"Regrets: {regret}")
        # print(f"Action probabilities: {action_probs}")
        return action_probs
    
    """Obtain the action probabilities of the current state"""
    def action_probs(self, obs, legal_actions, policy):
        ''' 
        Args:
            obs (str): state_str
            legal_actions (list): List of leagel actions
            player_id (int): The current player
            policy (dict): The used policy
        Returns:
            (tuple) that contains:
                action_probs(numpy.array): The action probabilities
                legal_actions (list): Indices of legal actions
        '''    
        if obs not in policy:
            action_probs = np.zeros(self.env.num_actions)
            action_probs[legal_actions] = 1.0 / len(legal_actions)
            policy[obs] = action_probs
        else:
            action_probs = policy[obs]       
        action_probs = self.remove_illegal(action_probs, legal_actions)
        return action_probs
    
    """Remove illegal actions and normalize theprobability vector"""
    '''Only legal actions should be allocated probabilities, we won't take actions that we can not take'''
    '''Is called in action_probs'''
    def remove_illegal(self, action_probs, legal_actions):
        ''' 
        Args:
            action_probs (numpy.array): A 1 dimention numpy array.
            legal_actions (list): A list of indices of legal actions.
        Returns:
            probd (numpy.array): A normalized vector without legal actions.
        '''
        probs = np.zeros(action_probs.shape[0])
        probs[legal_actions] = action_probs[legal_actions]
        if np.sum(probs) == 0:
            probs[legal_actions] = 1 / len(legal_actions)
        else:
            probs /= sum(probs)
        return probs

    """Given a state, predict action based on average policy"""
    '''This is used when we play game with other agents/human and would like to compare the performance'''
    def eval_step(self, state):
        ''' 
        Args:
            state (numpy.array): State representation
        Returns:
            action (int): Predicted action
            info (dict): A dictionary containing information
        '''
        probs = self.action_probs(state['obs'].tostring(), list(state['legal_actions'].keys()), self.average_policy)
        # probs = self.action_probs(state['obs'], list(state['legal_actions'].keys()), self.average_policy)
        action = np.random.choice(len(probs), p=probs)
        info = {}
        info['probs'] = {state['raw_legal_actions'][i]: float(probs[list(state['legal_actions'].keys())[i]]) for i in range(len(state['legal_actions']))}
        return action, info

    """Get state_str of the player"""
    def get_state(self, player_id, readable=True):
        ''' 
        Args:
            player_id (int): The player id
        Returns:
            (tuple) that contains:
                state (str): The state str
                legal_actions (list): Indices of legal actions
        '''
        state = self.env.get_state(player_id)
        if readable==True:
            return state['obs'], list(state['legal_actions'].keys())
        else:
            return np.array_str(state['obs']), list(state['legal_actions'].keys())

    """Save model"""        
    def save(self):
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)

        policy_file = open(os.path.join(self.model_path, 'policy.pkl'),'wb')
        pickle.dump(self.policy, policy_file)
        policy_file.close()

        average_policy_file = open(os.path.join(self.model_path, 'average_policy.pkl'),'wb')
        pickle.dump(self.average_policy, average_policy_file)
        average_policy_file.close()

        regrets_file = open(os.path.join(self.model_path, 'regrets.pkl'),'wb')
        pickle.dump(self.regrets, regrets_file)
        regrets_file.close()

        iteration_file = open(os.path.join(self.model_path, 'iteration.pkl'),'wb')
        pickle.dump(self.iteration, iteration_file)
        iteration_file.close()

    """Load model"""   
    def load(self):
        if not os.path.exists(self.model_path):
            return

        policy_file = open(os.path.join(self.model_path, 'policy.pkl'),'rb')
        self.policy = pickle.load(policy_file)
        policy_file.close()

        average_policy_file = open(os.path.join(self.model_path, 'average_policy.pkl'),'rb')
        self.average_policy = pickle.load(average_policy_file)
        average_policy_file.close()

        regrets_file = open(os.path.join(self.model_path, 'regrets.pkl'),'rb')
        self.regrets = pickle.load(regrets_file)
        regrets_file.close()

        iteration_file = open(os.path.join(self.model_path, 'iteration.pkl'),'rb')
        self.iteration = pickle.load(iteration_file)
        iteration_file.close()

## Training Code

In [95]:
import matplotlib.pyplot as plt

In [102]:
def train_cfr(agent, num_iterations):

    for i in range(num_iterations):
        agent.train()

# Creat LeducholdemEnvSimplified Env
env = LeducholdemEnv(
    config={'allow_step_back':True,
            'small_blind': 1,
            'allowed_raise_num': 2,
            'seed':42})

# Creat CFR Agent 
cfr_agent = CFRAgent(env)

# Train CFR Agent
num_iterations = 100
train_cfr(cfr_agent, num_iterations)

# Save
cfr_agent.save()


Updated policy: [0. 0. 1. 0.]
Updated policy: [0.  0.8 0.2 0. ]
Updated policy: [0. 0. 0. 1.]
Updated policy: [0. 1. 0. 0.]
Updated policy: [0. 0. 1. 0.]
Updated policy: [0. 1. 0. 0.]
Updated policy: [1. 0. 0. 0.]
Updated policy: [0.  0.5 0.5 0. ]
Updated policy: [0. 0. 1. 0.]
Updated policy: [0.  0.  0.5 0.5]
Updated policy: [0.         0.34693878 0.         0.65306122]
Updated policy: [0.  0.8 0.2 0. ]
Updated policy: [0. 0. 1. 0.]
Updated policy: [0.  0.  0.5 0.5]
Updated policy: [0. 0. 1. 0.]
Updated policy: [0. 1. 0. 0.]
Updated policy: [0. 0. 0. 1.]
Updated policy: [0. 1. 0. 0.]
Updated policy: [1. 0. 0. 0.]
Updated policy: [0. 1. 0. 0.]
Updated policy: [0.         0.28571429 0.         0.71428571]
Updated policy: [0.05882353 0.94117647 0.         0.        ]
Updated policy: [1. 0. 0. 0.]
Updated policy: [0.         0.54054054 0.         0.45945946]
Updated policy: [1. 0. 0. 0.]
Updated policy: [0. 1. 0. 0.]
Updated policy: [1. 0. 0. 0.]
Updated policy: [0.         0.36363636 0. 

## Results

In [97]:
import pickle
import pandas as pd

In [123]:
with open('/Users/sishanyang/Desktop/德国生活/Academic/master_thesis/code/testing/cfr_model/policy.pkl', 'rb') as f:
    policy_data = pickle.load(f)
policy_df = pd.DataFrame(list(policy_data.items()), columns=['Key', 'Value'])

split_columns1 = policy_df['Key'].str.split('_', expand=True)
split_columns1.columns = ['current_player_id','hand_card','public_card','action_record','in_chips','all_chips','legal_actions','legal_actions_idx']
policy_df = policy_df.join(split_columns1)

# split_columns2 = policy_df['Value'] = policy_df['Value'].apply(lambda x: ', '.join(map(str, x)))
# split_columns2 = split_columns2.astype(float)
# split_columns2.columns = ['call', 'raise', 'fold', 'check']
# policy_df = policy_df.join(split_columns2)

policy_df

Unnamed: 0,Key,Value,current_player_id,hand_card,public_card,action_record,in_chips,all_chips,legal_actions,legal_actions_idx
0,"1_HJ_None_[]_1_3_['call', 'raise', 'fold']_[0,...","[0.12188477456422417, 0.8781152254357758, 0.0,...",1,HJ,,[],1,3,"['call', 'raise', 'fold']","[0, 1, 2]"
1,"0_HQ_None_[(1, 'call')]_2_4_['raise', 'fold', ...","[0.0, 0.10343095159829879, 0.2872725362249501,...",0,HQ,,"[(1, 'call')]",2,4,"['raise', 'fold', 'check']","[1, 2, 3]"
2,"1_HJ_None_[(1, 'call'), (0, 'raise')]_2_6_['ca...","[0.7379799857037688, 0.059982762034574116, 0.2...",1,HJ,,"[(1, 'call'), (0, 'raise')]",2,6,"['call', 'raise', 'fold']","[0, 1, 2]"
3,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[0.0, 0.0, 0.0, 1.0]",0,HQ,SK,"[(1, 'call'), (0, 'raise'), (1, 'call')]",4,8,"['raise', 'fold', 'check']","[1, 2, 3]"
4,"1_HJ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[0.25, 0.25, 0.25, 0.25]",1,HJ,SK,"[(1, 'call'), (0, 'raise'), (1, 'call'), (0, '...",4,12,"['call', 'raise', 'fold']","[0, 1, 2]"
...,...,...,...,...,...,...,...,...,...,...
1807,"1_HQ_HK_[(0, 'call'), (1, 'raise'), (0, 'call'...","[0.25, 0.25, 0.25, 0.25]",1,HQ,HK,"[(0, 'call'), (1, 'raise'), (0, 'call'), (1, '...",4,8,"['raise', 'fold', 'check']","[1, 2, 3]"
1808,"1_HQ_HK_[(0, 'call'), (1, 'raise'), (0, 'call'...","[0.25, 0.25, 0.25, 0.25]",1,HQ,HK,"[(0, 'call'), (1, 'raise'), (0, 'call'), (1, '...",8,20,"['call', 'fold']","[0, 2]"
1809,"1_HQ_HK_[(0, 'call'), (1, 'raise'), (0, 'call'...","[0.25, 0.25, 0.25, 0.25]",1,HQ,HK,"[(0, 'call'), (1, 'raise'), (0, 'call'), (1, '...",6,12,"['raise', 'fold', 'check']","[1, 2, 3]"
1810,"1_HQ_HK_[(0, 'call'), (1, 'raise'), (0, 'call'...","[0.25, 0.25, 0.25, 0.25]",1,HQ,HK,"[(0, 'call'), (1, 'raise'), (0, 'call'), (1, '...",10,24,"['call', 'fold']","[0, 2]"


In [100]:
with open('/Users/sishanyang/Desktop/德国生活/Academic/master_thesis/code/testing/cfr_model/regrets.pkl', 'rb') as f:
    regrets_data = pickle.load(f)
regrets_df = pd.DataFrame(list(regrets_data.items()), columns=['Key', 'Value'])
regrets_df['positive_regret_sum'] = regrets_df['Value'].apply(lambda x: sum(v for v in x if v > 0))
regrets_df.head(10)

Unnamed: 0,Key,Value,positive_regret_sum
0,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[2.444444444444444, 0.0, 0.07407407407407407, ...",2.518519
1,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[-1.6641975308641974, 0.11358024691358021, -2....",0.11358
2,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[0.0, -4.495473251028806, -7.088065843621399, ...",0.205761
3,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[-0.4444444444444444, 0.14814814814814814, -0....",0.148148
4,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[-0.07407407407407407, 0.0, 0.0246913580246913...",0.024691
5,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[0.0, 0.016460905349794244, -0.057613168724279...",0.016461
6,"0_HQ_None_[(1, 'call'), (0, 'raise'), (1, 'cal...","[2.2261697663557056, 0.0, 3.1304201177474296, ...",5.35659
7,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[-0.9074074074074073, 0.9814814814814814, -0.6...",0.981481
8,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[0.4444444444444444, 0.0, 0.07407407407407407,...",0.518519
9,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[0.0, -0.4074074074074075, -0.2592592592592593...",0.407407


In [101]:
with open('/Users/sishanyang/Desktop/德国生活/Academic/master_thesis/code/testing/cfr_model/average_policy.pkl', 'rb') as f:
    average_policy_data = pickle.load(f)
average_policy_df = pd.DataFrame(list(average_policy_data.items()), columns=['Key', 'Value'])
average_policy_df.head(10)

Unnamed: 0,Key,Value
0,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[0.05555555555555555, 0.0, 0.05555555555555555..."
1,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[0.037037037037037035, 1.9798941798941796, 0.5..."
2,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[0.0, 0.1111111111111111, 0.1111111111111111, ..."
3,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[0.05555555555555555, 2.4841269841269833, 0.05..."
4,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[0.027777777777777776, 0.0, 2.4563492063492056..."
5,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[0.0, 2.4841269841269833, 0.05555555555555555,..."
6,"0_HQ_None_[(1, 'call'), (0, 'raise'), (1, 'cal...","[29.108049112865864, 0.0, 24.964745406048227, ..."
7,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[0.1111111111111111, 9.727839309706527, 9.7278..."
8,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[0.05555555555555555, 0.0, 0.05555555555555555..."
9,"0_HQ_SK_[(1, 'call'), (0, 'raise'), (1, 'call'...","[0.0, 0.1111111111111111, 9.727839309706527, 9..."
