In [None]:
!pip install "pettingzoo[classic]" torch numpy tqdm

Collecting pettingzoo[classic]
  Downloading pettingzoo-1.25.0-py3-none-any.whl.metadata (8.9 kB)
Collecting chess>=1.9.4 (from pettingzoo[classic])
  Downloading chess-1.11.2.tar.gz (6.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m33.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting rlcard>=1.0.5 (from pettingzoo[classic])
  Downloading rlcard-1.2.0.tar.gz (269 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m269.0/269.0 kB[0m [31m23.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting shimmy>=1.2.0 (from shimmy[openspiel]>=1.2.0; extra == "classic"->pettingzoo[classic])
  Downloading Shimmy-2.0.0-py3-none-any.whl.metadata (3.5 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.

First we will test out the Texas Holden No Limit v6 library to set up our play environment for both the user and model to play on.


Actionn ID and Action
0.   Fold
1.   Check
2.   Call
3.   Raise Half Pot
4.   Raise Full Pot
5.   All In



In [None]:
action_map = {i: action for i, action in enumerate(['Fold', 'Check & Call', 'Raise Half Pot', 'Raise Full Pot', 'All In'])}

In [None]:
card_map = {}

ranks = ['A', '2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K']
suits = {
    '♠': range(0, 13),
    '♥': range(13, 26),
    '◆': range(26, 39),
    '♣': range(39, 52)
}

for suit, indices in suits.items():
    for i, idx in enumerate(indices):
        card_map[idx] = f"{ranks[i]}{suit}"

In [None]:
class GlobalLimiter():
  def __init__(self, loop_limit: int=5):
    self.loop_limit=loop_limit
    self.counter=0
  def inc(self):
    self.counter += 1
    if self.counter > self.loop_limit:
        raise Exception('Exceeded loop limit for testing')

In order for the users to properly visualize what is going on, we create a helper method to visualize the user's hands and the community cards

In [None]:
from typing import List
def visualize_cards(hand: List[int]):
  for card, num_cards in enumerate(hand):
    if num_cards:
      for i in range(int(num_cards)):
        print(f'{card_map[card]} ', end='')
  print('\n')

# visualize_cards([2, 0])

We can simulate one hand per player for two players using the loop

In [None]:
from pettingzoo.classic import texas_holdem_no_limit_v6
import numpy as np
import random

random.seed(0)

env = texas_holdem_no_limit_v6.env()
env.reset()

gl = GlobalLimiter(loop_limit=5)
turn = 1

player_hands = {}

for agent in env.agent_iter():
  #gl.inc()
  obs, reward, terminated, truncated, info = env.last()
  done = terminated or truncated
  player_cards = obs['observation'][:52]
  players_commitment = obs['observation'][52:]


  if done:
    print(f'{agent} is done.')
    if agent in player_hands:
      visualize_cards(player_hands[agent])
    env.step(None)
    print(f'Cumulative Reward: {reward}')
  else:
    if int(np.sum(player_cards)) == 2:
      player_hands[agent] = player_cards

    legal_actions_indices = np.where(obs['action_mask'])[0]

    print(f'\n[{turn}] {agent} Turn:')
    visualize_cards(player_cards)
    print(f'Pot: {players_commitment}')
    print(f'Available actions: {legal_actions_indices}')
    action = random.choice(legal_actions_indices)
    print(f'{agent} taking action: {action_map[action]} ({action})')
    env.step(action)
  #Returns observation, cumulative reward, terminated, truncated, info for the current agent (specified by self.agent_selection).


[1] player_1 Turn:
3♥ Q◆ 

Pot: [1. 2.]
Available actions: [0 1 3 4]
player_1 taking action: All In (4)

[1] player_0 Turn:
T♥ T◆ 

Pot: [  2. 100.]
Available actions: [0 1]
player_0 taking action: Check & Call (1)
player_0 is done.
T♥ T◆ 

Cumulative Reward: -100
player_1 is done.
3♥ Q◆ 

Cumulative Reward: 100


  gym.logger.warn(
  gym.logger.warn(


In [None]:
from typing import List
from collections import deque, namedtuple
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done', 'history'])

class ReplayBuffer:
  def __init__(self,
               capacity:int):
    # FIFO buffer
    self.buffer = deque(maxlen=capacity)

  def push(self,
           state:List[int],
           action:int,
           reward:int,
           next_state:List[int],
           done:int,
           history
           ):
    self.buffer.append(Experience(state, action, reward, next_state, done, history))

  def sample(self, batch_size:int):
    return random.sample(self.buffer, batch_size)

  def __len__(self):
    return len(self.buffer)

In [None]:
import torch
import torch.nn as nn
class FCNRewardFunction(nn.Module):
  def __init__(self,
               input_dim:int,
               hidden_dims:List[int] = [256, 128, 64]):
    super().__init__()
    layers = []
    prev_dim = input_dim

    for hidden_dim in hidden_dims:
      layers.extend([
          nn.Linear(in_features=prev_dim,
                    out_features=hidden_dim),
          nn.ReLU(),
          nn.Dropout(p=0.1)
      ])
      prev_dim=hidden_dim

    # Add final layer to output single reward value
    layers.append(nn.Linear(in_features=prev_dim, out_features=1))

    self.network = nn.Sequential(*layers)

  def forward(self, x):
      return self.network(x)

In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
from typing import List
class HistoryRNN(nn.Module):
  def __init__(self,
               input_dim:int,
               hidden_dim:int=128,
               num_layers:int=2,
               dropout:float=0.1):
    super().__init__()

    # store dims for init of hidden/cell states
    self.hidden_dim = hidden_dim
    self.num_layers = num_layers

    self.rnn = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim,
                       num_layers=num_layers, batch_first=True, dropout=dropout)

    self.output_projection = nn.Linear(hidden_dim, hidden_dim)

  def forward(self,
              history_sequence:torch.Tensor,
              sequence_lengths:List[int]):
    # History sequence contains the actual sequence of prior game states
    # Sequence lengths shows the true length of each game state prior to padding
    # BTD Batch Size, Max History Length, Input feature size per timestep

    # D is the size of every game snapshot that is to be consumed by RNN
    batch_size = history_sequence.size(0)
    seq_len = history_sequence.size(1)

    if history_sequence.dim() == 2: # in case we only have one history
      history_sequence = history_sequence.unsqueeze(0)
      batch_size = 1

    # hidden and cell state
    h0 = torch.zeros(
        self.num_layers, batch_size, self.hidden_dim).to(history_sequence.device)
    c0 = torch.zeros(
        self.num_layers, batch_size, self.hidden_dim).to(history_sequence.device)

    if sequence_lengths is not None:
      # accounts for padding since padding will pollute input with meaningless data
      # normalize lengths to a plain python list of ints on CPU
      if isinstance(sequence_lengths, torch.Tensor):
        sequence_lengths = sequence_lengths.detach().cpu().tolist()
      else:
        sequence_lengths = [int(x) for x in sequence_lengths]

      packed = rnn.pack_padded_sequence(
          input=history_sequence, lengths=sequence_lengths, batch_first=True, enforce_sorted=False
      )
      # pass (packed, (h0, c0)) to the LSTM; don't call packed as a function
      packed_output, (hn, cn) = self.rnn(packed, (h0, c0))
      output, _ = rnn.pad_packed_sequence(packed_output, batch_first=True)
    else:
      # no need to pack since we only have one example
      output, (hn, cn) = self.rnn(history_sequence, (h0, c0))

    history_encoding = self.output_projection(hn[-1])

    return history_encoding


In [None]:
import torch
import torch.nn as nn  # alias nn so we can use nn.Sequential, nn.Linear, etc.

class DQN(nn.Module):
  def __init__(self,
               state_dim:int,
               action_dim:int,
               history_dim:int,
               hidden_dim:int=256,
               history_rnn_dim:int=128):
    super().__init__()

    # history rnn, history_dim is the size of each snapshot
    self.history_rnn = HistoryRNN(history_dim, history_rnn_dim)

    # state processor, input for "what do i see right now?"
    self.state_processor = nn.Sequential(
        nn.Linear(in_features=state_dim, out_features=hidden_dim),
        nn.ReLU(),
        nn.Dropout()
    )

    # combined dimension, combining state and history
    combined_dim = hidden_dim + history_rnn_dim
    combined_output_dim = hidden_dim // 2  # fix typo: combiend_output_dim -> combined_output_dim

    # combined processor used for q head
    self.combined_processor = nn.Sequential(
        nn.Linear(in_features=combined_dim, out_features=hidden_dim),
        nn.ReLU(),
        nn.Dropout(p=0.1),
        nn.Linear(in_features=hidden_dim, out_features=combined_output_dim),
        nn.ReLU()
    )

    # expected future return (now + later)
    self.q_head = nn.Linear(
        in_features=combined_output_dim, out_features=action_dim) # 5

    # reward for current action, state
    self.reward_function = FCNRewardFunction(input_dim=combined_dim) # 64

  # forward will have the current state
  def forward(self,
              history,
              state,
              sequence_lengths,  # NEW: always pass sequence lengths so RNN can ignore padding
              return_reward:bool=False):
    # history: [batch, seq_length, history_dim]
    history_vector = self.history_rnn(history, sequence_lengths)

    state_vector = self.state_processor(state)

    combined_vector = torch.cat((history_vector, state_vector), dim=1)

    combined_embedding = self.combined_processor(combined_vector)

    q_values = self.q_head(combined_embedding)

    if return_reward:
      fcn_reward = self.reward_function(combined_vector)
      return q_values, fcn_reward

    return q_values


In [None]:
import numpy as np, random
from collections import deque
import torch
from torch.optim import Adam
import torch.nn.functional as F
class DQLAgent:
  def __init__(self, state_dim:int,
               action_dim:int,
               history_dim:int,
               lr:float=1e-3,
               gamma:float=0.99,
               epsilon:float=1.0,
               epsilon_decay:float=0.995,
               epsilon_min:float=0.01,
               buffer_size:int=10000,
               batch_size:int=32,
               maxlen:int=50):
    # input dims
    self.state_dim = state_dim
    self.action_dim = action_dim
    self.history_dim = history_dim

    # hyper params
    self.gamma = gamma
    self.epsilon = epsilon
    self.epsilon_decay = epsilon_decay
    self.epsilon_min = epsilon_min
    self.buffer_size = buffer_size
    self.batch_size = batch_size

    # neural networks
    self.q_network = DQN(state_dim=state_dim, action_dim=action_dim,
                              history_dim=history_dim)
    self.target_network = DQN(state_dim=state_dim, action_dim=action_dim,
                              history_dim=history_dim)

    self.target_network.load_state_dict(self.q_network.state_dict())

    self.optimizer = Adam(self.q_network.parameters(), lr=lr)  # not "parameters="

    self.replay_buffer = ReplayBuffer(buffer_size)

    self.history = deque(maxlen=maxlen)

    self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    self.q_network.to(self.device)
    self.target_network.to(self.device)

  def preprocess_state(self, obs):
    player_cards = obs['observation'][:52]
    players_commitment = obs['observation'][52:]

    # redundant? check later

    state = np.concatenate([player_cards, players_commitment])
    return state

  def add_to_history(self, obs):
    state = self.preprocess_state(obs)
    # redundant? check later
    self.history.append(state)

  def get_history_tensor(self):
    if len(self.history) == 0:
      return torch.zeros(1, 1, self.history_dim).to(self.device)
    history_array = np.array(list(self.history))
    # add batch dimension by unsqueezing
    history_tensor = torch.FloatTensor(history_array).unsqueeze(0).to(self.device)
    return history_tensor

   # depending on the observation and legal actions, choose an action
  def select_action(self, obs, legal_actions):
    if random.random() < self.epsilon:
      return random.choice(legal_actions)

    self.add_to_history(obs)

    state = torch.FloatTensor(self.preprocess_state(obs)).unsqueeze(0).to(self.device)
    history = self.get_history_tensor()
    # sequence lengths shows the true length of each game state prior to padding
    sequence_lengths = [len(self.history) if len(self.history) > 0 else 1]

    with torch.no_grad():
      q_values = self.q_network(history, state, sequence_lengths)

    masked_q_values = q_values.clone()
    mask = torch.ones(self.action_dim, dtype=torch.bool, device=masked_q_values.device)
    mask[legal_actions] = False
    masked_q_values[0][mask] = float('-inf')

    return masked_q_values.argmax().item()

  def store_experience(self, state, action, reward, next_state, done, history):
    self.replay_buffer.push(state, action, reward, next_state, done, history)

  def update_target_network(self):
    self.target_network.load_state_dict(self.q_network.state_dict())

  def train(self):
    if len(self.replay_buffer) < self.batch_size:
      return
    ###### Digesting Data For Training #####
    batch = self.replay_buffer.sample(self.batch_size)

    states = torch.FloatTensor(np.array([s.state for s in batch])).to(self.device)
    actions = torch.LongTensor(np.array([s.action for s in batch])).to(self.device)
    rewards = torch.FloatTensor(np.array([s.reward for s in batch])).to(self.device)
    next_states = torch.FloatTensor(np.array([s.next_state for s in batch])).to(self.device)
    dones = torch.BoolTensor(np.array([s.done for s in batch])).to(self.device)

    # prepping (padding) histories
    max_history_len = max(max(len(h) + 1, 1) for h in [s.history for s in batch])

    histories, next_histories = [], []
    sequence_lengths, next_sequence_lengths = [], []  # NEW: lengths for pack_padded_sequence

    for s in batch:
        hist = list(s.history)
        # true (unpadded) length before padding
        sequence_lengths.append(len(hist) if len(hist) > 0 else 1)

        while len(hist) < max_history_len:
            hist.insert(0, np.zeros(self.history_dim, dtype=np.float32))
        histories.append(hist)

        nhist = list(s.history)
        nhist.append(s.next_state)
        # true (unpadded) length for next history
        next_len = len(nhist) if len(nhist) > 0 else 1
        next_sequence_lengths.append(next_len)

        while len(nhist) < max_history_len:
            nhist.insert(0, np.zeros(self.history_dim, dtype=np.float32))
        next_histories.append(nhist)

    histories = torch.FloatTensor(np.array(histories)).to(self.device)
    next_histories = torch.FloatTensor(np.array(next_histories)).to(self.device)

    ##### Getting Q Values and rewards ######
    current_q_values, predicted_rewards = self.q_network.forward(
        histories, states, sequence_lengths=sequence_lengths, return_reward=True
    )
    current_q_values = current_q_values.gather(1, actions.unsqueeze(1))

    with torch.no_grad():
      next_q_values = self.target_network(
          next_histories, next_states, sequence_lengths=next_sequence_lengths
      )
      target_q_values = rewards + (self.gamma * next_q_values.max(1)[0] * ~dones)

    # calculate loss
    q_loss = F.mse_loss(current_q_values.squeeze(), target_q_values)
    reward_loss = F.mse_loss(predicted_rewards.squeeze(), rewards)

    # combined loss, weighed heavy towards q
    total_loss = q_loss + 0.1 * reward_loss  # Weight reward loss

    # back prop
    self.optimizer.zero_grad()
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), max_norm=1.0)
    self.optimizer.step()

    # epsilon decay
    self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)


In [None]:

def train_dql_agent(num_episodes:int=1000,
                    target_update_freq:int=10):
    env = texas_holdem_no_limit_v6.env()
    env.reset()

    # observation dimensions
    first_agent = env.agent_selection
    sample_obs, _, _, _, _ = env.last()

    state_dim = len(sample_obs['observation'])
    action_dim = len(sample_obs['action_mask'])
    history_dim = state_dim

    # initialize each agent
    agents = {}
    for agent_name in env.possible_agents:
        agents[agent_name] = DQLAgent(state_dim, action_dim, history_dim)

    for episode in range(num_episodes):
        env.reset()
        episode_rewards = {agent: 0 for agent in env.possible_agents}

        # reset all histories, starting new game
        for agent in agents.values():
            agent.history.clear()

        for agent_name in env.agent_iter():
            obs, reward, terminated, truncated, info = env.last()

            done = terminated or truncated

            if agent_name in agents:
                episode_rewards[agent_name] += reward

                if done:
                    env.step(None)
                else:
                    # legal actions
                    legal_actions = np.where(obs['action_mask'])[0].tolist()

                    # previous experience
                    agent = agents[agent_name]
                    if hasattr(agent, 'prev_state'):
                        next_state = agent.preprocess_state(obs)
                        agent.store_experience(
                            agent.prev_state, agent.prev_action, reward,
                            next_state, done, list(agent.history)
                        )

                    # select action
                    action = agent.select_action(obs, legal_actions)

                    agent.prev_state = agent.preprocess_state(obs)
                    agent.prev_action = action

                    env.step(action)
                    agent.train()
            else:
                env.step(None)

        # after a certain number of episodes, update target network
        if episode % target_update_freq == 0:
            for agent in agents.values():
                agent.update_target_network()

        if episode % 100 == 0:
            avg_rewards = {name: episode_rewards[name] for name in episode_rewards}
            print(f"Episode {episode}, Average Rewards: {avg_rewards}")

    return agents

In [35]:
trained_agents = None
if __name__ == "__main__":
    # Train the DQL agents
    trained_agents = train_dql_agent()

    # Test with trained agents (modify your existing code)
    env = texas_holdem_no_limit_v6.env()
    env.reset()

    for agent_name in env.agent_iter():
        obs, reward, terminated, truncated, info = env.last()
        done = terminated or truncated

        if done:
            env.step(None)
        else:
            if agent_name in trained_agents:
                legal_actions = np.where(obs['action_mask'])[0].tolist()
                action = trained_agents[agent_name].select_action(obs, legal_actions) if len(legal_actions) > 0 else None
            else:
                # Fallback to random action
                legal_actions = np.where(obs['action_mask'])[0].tolist()
                action = random.choice(legal_actions) if len(legal_actions) > 0 else None

            # Pretty-print action using action_map if available; otherwise just show the id
            if action is None:
                printable = "NO_ACTION"
            else:
                printable = action_map[action] if 'action_map' in globals() and isinstance(action_map, dict) and action in action_map else str(action)

            print(f'{agent_name} taking action: {printable} ({action})')
            env.step(action)


  gym.logger.warn(
  gym.logger.warn(


Episode 0, Average Rewards: {'player_0': np.int64(2), 'player_1': np.int64(-2)}
Episode 100, Average Rewards: {'player_0': np.int64(-40), 'player_1': np.int64(40)}
Episode 200, Average Rewards: {'player_0': np.int64(2), 'player_1': np.int64(-2)}
Episode 300, Average Rewards: {'player_0': np.int64(-100), 'player_1': np.int64(100)}
Episode 400, Average Rewards: {'player_0': np.int64(-100), 'player_1': np.int64(100)}
Episode 500, Average Rewards: {'player_0': np.int64(-80), 'player_1': np.int64(80)}
Episode 600, Average Rewards: {'player_0': np.int64(80), 'player_1': np.int64(-80)}
Episode 700, Average Rewards: {'player_0': np.int64(80), 'player_1': np.int64(-80)}
Episode 800, Average Rewards: {'player_0': np.int64(-100), 'player_1': np.int64(100)}
Episode 900, Average Rewards: {'player_0': np.int64(-100), 'player_1': np.int64(100)}
player_1 taking action: Raise Full Pot (3)
player_0 taking action: Raise Half Pot (2)
player_1 taking action: Raise Half Pot (2)
player_0 taking action: Raise

In [41]:
import gradio as gr
import numpy as np
import random
from pettingzoo.classic import texas_holdem_no_limit_v6

# Action and card mappings
action_map = {
    0: 'Fold',
    1: 'Check/Call',
    2: 'Raise Half Pot',
    3: 'Raise Full Pot',
    4: 'All In'
}

# Fixed card mapping
suits = ['♠', '♥', '♦', '♣']
ranks = ['2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K', 'A']

def card_from_index(index):
    """Convert card index to readable card format"""
    return ranks[index % 13] + suits[index // 13]

# ---------- ASCII CARD RENDERER ----------

def _card_ascii(rank: str, suit: str):
    """Return a single card as a list of 5 lines (ASCII box)."""
    # Fit ranks to two chars (pad '10' look with 'T')
    rank_left  = f"{rank:<2}"
    rank_right = f"{rank:>2}"
    center = f"{suit:^9}"
    return [
        "┌─────────┐",
        f"│{rank_left}       │",
        f"│{center}│",
        f"│       {rank_right}│",
        "└─────────┘",
    ]

def _card_back_ascii():
    """Facedown card as a list of 5 lines."""
    return [
        "┌─────────┐",
        "│░░░░░░░░░│",
        "│░░░░░░░░░│",
        "│░░░░░░░░░│",
        "└─────────┘",
    ]

def render_cards_ascii(indices, pad_to=0):
    """
    Render a horizontal row of cards from 0-51 indices.
    Pads with facedown backs (?) up to pad_to cards if provided.
    """
    cards = []
    for i in indices:
        r = ranks[i % 13]
        s = suits[i // 13]
        cards.append(_card_ascii(r, s))

    # pad with backs
    if pad_to and len(cards) < pad_to:
        for _ in range(pad_to - len(cards)):
            cards.append(_card_back_ascii())

    if not cards:
        # If nothing, show pad_to backs if requested, else one back
        n = pad_to if pad_to else 1
        cards = [_card_back_ascii() for _ in range(n)]

    # stitch line-by-line
    lines = []
    for row in zip(*cards):
        lines.append(" ".join(row))
    return "\n".join(lines)

def render_unknown_ascii(n):
    """Render n facedown cards."""
    return "\n".join([" ".join(row) for row in zip(*([_card_back_ascii()] * n))])

# ---------- GAME LOGIC ----------

class PokerGame:
    def __init__(self):
        self.env = None
        self.game_log = []
        self.current_agent = None
        self.game_over = False
        self.pot_size = 0
        self.player_chips = {}
        self.round_count = 0
        self.final_rewards = {}
        # Track the latest 52-bit card vectors per player to infer community
        self.last_cards_52 = {'player_0': None, 'player_1': None}
        # Persist the final revealed board so it shows after the game ends
        self.revealed_board_indices = None
        self.reset_game()

    def reset_game(self):
        """Reset the poker game"""
        self.env = texas_holdem_no_limit_v6.env()
        self.env.reset()
        self.round_count += 1
        self.game_log = [f"🎰 **Round {self.round_count}** started!"]
        self.current_agent = None
        self.game_over = False
        self.pot_size = 0
        self.player_chips = {'player_0': 200, 'player_1': 200}  # Starting chips
        self.final_rewards = {}
        self.last_cards_52 = {'player_0': None, 'player_1': None}
        self.revealed_board_indices = None
        return True

    def compute_board_indices(self):
        """Compute community cards as intersection of both players' 52-bit card vectors."""
        a = self.last_cards_52['player_0']
        b = self.last_cards_52['player_1']
        if a is None or b is None:
            return []
        # Intersection where both see the card as present -> community cards
        board_mask = np.logical_and(a, b)
        board = [i for i, v in enumerate(board_mask) if v]
        board.sort()
        return board

    def compute_hole_indices(self, player_key):
        """Compute player's hole cards as their cards minus the community."""
        vec = self.last_cards_52[player_key]
        if vec is None:
            return []
        board = set(self.compute_board_indices())
        mine = [i for i, v in enumerate(vec) if v and i not in board]
        return mine[:2]

    def get_player_cards(self, obs):
        """Extract and format player's hole cards (ASCII)."""
        holes = self.compute_hole_indices('player_0')
        if len(holes) == 2:
            return render_cards_ascii(holes, pad_to=2)
        # Fallback to current obs if cache not ready
        if obs is not None:
            cards = obs['observation'][:52]
            player_cards = [i for i, val in enumerate(cards) if val == 1]
            if len(player_cards) >= 2:
                return render_cards_ascii(player_cards[:2], pad_to=2)
        return render_cards_ascii([], pad_to=2)

    def get_community_cards(self, obs=None):
        """Get community cards (ASCII) progressively; show 5 backs if unknown."""
        if self.revealed_board_indices is not None:
            return render_cards_ascii(self.revealed_board_indices, pad_to=5)
        board = self.compute_board_indices()
        if len(board) == 0:
            return render_cards_ascii([], pad_to=5)
        # Show currently available community cards; pad to 5 with backs
        return render_cards_ascii(board, pad_to=5)

    def get_pot_info(self, obs):
        """Extract pot and betting information (approx via commitment slice)."""
        if obs is None:
            return "💰 Pot: $0"
        commitment = obs['observation'][52:]
        total_pot = float(sum(commitment)) if len(commitment) else 0.0
        self.pot_size = total_pot
        return f"💰 Pot: ${total_pot:.0f}"

    def format_action(self, agent_name, action_idx):
        """Format action for display"""
        action_name = action_map.get(action_idx, f"Action {action_idx}")
        return f"👤 **You**: {action_name}" if agent_name == "player_0" else f"🤖 **AI**: {action_name}"

    def determine_winner(self):
        """Determine the winner based on final rewards"""
        if not self.final_rewards:
            return "🤷 No winner determined"
        player_reward = self.final_rewards.get('player_0', 0)
        ai_reward = self.final_rewards.get('player_1', 0)
        if player_reward > ai_reward:
            return f"🎉 **YOU WON!** (+${player_reward:.0f})"
        elif ai_reward > player_reward:
            return f"💔 **AI WON!** (You lost ${abs(player_reward):.0f})"
        else:
            return "🤝 **TIE GAME!**"

    def step_game(self, human_action=None):
        """Execute one step of the game"""
        if self.game_over or self.env is None:
            return self.get_game_state()

        try:
            current_obs = None
            human_played = False
            actions_this_step = []

            for agent_name in self.env.agent_iter():
                obs, reward, terminated, truncated, info = self.env.last()
                done = terminated or truncated

                # Cache latest 52-bit card vector for this agent
                if obs is not None and 'observation' in obs and len(obs['observation']) >= 52:
                    self.last_cards_52[agent_name] = np.array(obs['observation'][:52], dtype=bool)

                if done:
                    # Store final rewards
                    self.final_rewards[agent_name] = reward
                    self.env.step(None)

                    # If both players are done, finalize game
                    if all(self.env.terminations.values()) or all(self.env.truncations.values()):
                        self.game_over = True
                        # Persist final community cards
                        final_board = self.compute_board_indices()
                        self.revealed_board_indices = final_board

                        # Log reveal
                        board_str = render_cards_ascii(final_board, pad_to=5)
                        self.game_log.append("🃏 **Final Community Cards (Flop • Turn • River)**")
                        self.game_log.append(board_str)

                        winner_msg = self.determine_winner()
                        self.game_log.append("=" * 40)
                        self.game_log.append("🏁 **GAME OVER**")
                        self.game_log.append(winner_msg)
                        self.game_log.append("=" * 40)
                        break
                    continue

                # Store current observation for display
                if agent_name == 'player_0':
                    current_obs = obs

                legal_actions = np.where(obs['action_mask'])[0].tolist()

                # Human player (player_0)
                if agent_name == 'player_0' and human_action is not None and not human_played:
                    if human_action in legal_actions:
                        action = human_action
                        human_played = True
                    else:
                        action = legal_actions[0] if legal_actions else 0
                        self.game_log.append("⚠️ **Invalid action!** Using default instead.")
                else:
                    action = self.simple_ai_strategy(obs, legal_actions)

                # Log the action
                actions_this_step.append(self.format_action(agent_name, action))

                # Execute action
                self.env.step(action)

                if human_played:
                    break

            # Add all actions from this step to the log
            if actions_this_step:
                self.game_log.extend(actions_this_step)
                self.game_log.append("─" * 30)

            return self.get_game_state(current_obs)

        except Exception as e:
            self.game_log.append(f"❌ **Error**: {str(e)}")
            return self.get_game_state()

    def simple_ai_strategy(self, obs, legal_actions):
        """Simple AI strategy for demo purposes"""
        if not legal_actions:
            return 0
        if 0 in legal_actions and random.random() < 0.1:
            return 0
        if 1 in legal_actions and random.random() < 0.6:
            return 1
        remaining = [a for a in legal_actions if a != 0]
        return random.choice(remaining) if remaining else legal_actions[0]

    def get_game_state(self, obs=None):
        """Get current game state for display"""
        display_log = ["..."] + self.game_log[-14:] if len(self.game_log) > 15 else self.game_log
        log_display = "\n".join(display_log)

        player_cards_ascii = self.get_player_cards(obs)
        community_cards_ascii = self.get_community_cards(obs)

        pot_info = self.get_pot_info(obs) if obs else "💰 Pot: $0"

        # Chip counts (roughly track commitments)
        player_current = self.player_chips['player_0']
        ai_current = self.player_chips['player_1']
        if obs is not None:
            commitment = obs['observation'][52:]
            if len(commitment) >= 2:
                player_current -= float(commitment[0])
                ai_current -= float(commitment[1])
        chip_info = f"💰 **Your Chips**: ${player_current:.0f} | **AI Chips**: ${ai_current:.0f}"

        if obs and not self.game_over:
            legal = np.where(obs['action_mask'])[0].tolist()
            available_actions = [action_map[i] for i in legal]
        else:
            available_actions = list(action_map.values())

        status = "🏁 Game Over - Click 'New Game' to play again" if self.game_over else "🎯 Your turn - Choose an action"

        return (log_display, player_cards_ascii, community_cards_ascii,
                f"{pot_info}\n{chip_info}", available_actions, status)

# Initialize game
game = PokerGame()

def play_action(action_choice):
    """Handle human player action"""
    if action_choice is None or game.game_over:
        return game.get_game_state()
    # Convert action name to index
    action_idx = None
    for idx, name in action_map.items():
        if name == action_choice:
            action_idx = idx
            break
    if action_idx is not None:
        return game.step_game(action_idx)
    return game.get_game_state()

def auto_play():
    """Auto-play one round (AI vs AI)"""
    if game.game_over:
        return game.get_game_state()
    return game.step_game()

def reset_game():
    """Reset the game"""
    game.reset_game()
    return game.get_game_state()

# Create Gradio interface
with gr.Blocks(
    theme=gr.themes.Soft(),
    title="🎰 Poker AI Game",
    css="""
    .game-container {
        background: linear-gradient(135deg, #0f4c3a, #2d5a27);
        border-radius: 15px;
        padding: 20px;
        margin: 10px;
    }
    .card-display {
        background: #0b0e12;
        color: #e7ecef;
        border-radius: 10px;
        padding: 12px;
        font-family: 'Courier New', monospace;
        font-size: 14px;
        text-align: left;
        border: 2px solid #4CAF50;
        white-space: pre;           /* preserve ASCII layout */
        line-height: 1.05;
    }
    .pot-display {
        background: linear-gradient(135deg, #8b4513, #a0522d);
        color: #ffd700;
        border-radius: 10px;
        padding: 15px;
        font-weight: bold;
        text-align: center;
        border: 2px solid #ffd700;
    }
    .log-display {
        background: #f8f9fa;
        border-radius: 10px;
        padding: 15px;
        font-family: 'Arial', sans-serif;
        line-height: 1.6;
        border: 1px solid #dee2e6;
        max-height: 300px;
        overflow-y: auto;
        white-space: pre-wrap;
    }
    .status-display {
        background: linear-gradient(135deg, #007bff, #0056b3);
        color: white;
        border-radius: 10px;
        padding: 10px;
        font-weight: bold;
        text-align: center;
        border: 2px solid #0056b3;
    }
    .action-btn {
        background: #4CAF50;
        color: white;
        border-radius: 8px;
        padding: 12px 24px;
        font-weight: bold;
        margin: 5px;
    }
    """) as demo:

    gr.Markdown("""
    # 🎰 Texas Hold'em Poker AI

    **Play against an AI opponent in Texas Hold'em poker!**

    Choose your action and watch the cards unfold. The AI will respond automatically.
    Winners and losers are clearly displayed when the game ends!
    """)

    # Game status
    status_display = gr.Textbox(
        label="🎮 Game Status",
        value="🎯 Your turn - Choose an action",
        interactive=False,
        elem_classes=["status-display"]
    )

    with gr.Row():
        # Left column - Game controls
        with gr.Column(scale=1):
            gr.Markdown("### 🎮 Game Controls")

            action_dropdown = gr.Dropdown(
                choices=list(action_map.values()),
                label="Choose Your Action",
                value="Check/Call"
            )

            with gr.Row():
                play_btn = gr.Button("🎯 Play Action", variant="primary", elem_classes=["action-btn"])
                auto_btn = gr.Button("🤖 Auto Play", variant="secondary")

            reset_btn = gr.Button("🔄 New Game", variant="stop", size="lg")

        # Right column - Game display
        with gr.Column(scale=2):
            # Community cards
            community_display = gr.Textbox(
                label="🛠️ Community Cards (Flop, Turn, River)",
                value=render_cards_ascii([], pad_to=5),
                interactive=False,
                lines=7,                   # taller to fit ASCII cards
                elem_classes=["card-display"]
            )

            # Pot and chip info
            pot_display = gr.Textbox(
                label="💰 Pot & Chip Status",
                value="💰 Pot: $0\n💰 Your Chips: $200 | AI Chips: $200",
                interactive=False,
                lines=2,
                elem_classes=["pot-display"]
            )

            # Player cards
            player_cards = gr.Textbox(
                label="🃏 Your Hole Cards",
                value=render_cards_ascii([], pad_to=2),
                interactive=False,
                lines=7,                   # taller to fit ASCII cards
                elem_classes=["card-display"]
            )

    # Game log with better styling
    game_log = gr.Textbox(
        label="📋 Game Action Log",
        value="🎰 New poker game started!",
        lines=10,
        max_lines=12,
        interactive=False,
        elem_classes=["log-display"]
    )

    # Available actions display
    available_actions = gr.Textbox(
        label="⚡ Available Actions",
        value=", ".join(action_map.values()),
        interactive=False
    )

    # Event handlers
    play_btn.click(
        fn=play_action,
        inputs=[action_dropdown],
        outputs=[game_log, player_cards, community_display, pot_display, available_actions, status_display]
    )

    auto_btn.click(
        fn=auto_play,
        outputs=[game_log, player_cards, community_display, pot_display, available_actions, status_display]
    )

    reset_btn.click(
        fn=reset_game,
        outputs=[game_log, player_cards, community_display, pot_display, available_actions, status_display]
    )

    # Initialize display
    demo.load(
        fn=lambda: game.get_game_state(),
        outputs=[game_log, player_cards, community_display, pot_display, available_actions, status_display]
    )

# Launch the interface
if __name__ == "__main__":
    demo.launch(share=True, debug=True)


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://020639a51116b1a5e7.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7862 <> https://020639a51116b1a5e7.gradio.live
