# Poker Agent V6: Adaptive Opponent Modeling

This notebook implements the V6 architecture designed to fix the lack of adaptability in previous versions. 

### Key Improvements:
1. **Dual-Branch Network**: 
   - **MLP Branch**: Analyses cards and board (static state).
   - **LSTM Branch**: Analyses the **sequence of actions** in the hand. This is the "memory" that allows the agent to distinguish between a Maniac (constant raises) and a Nit (passive play).
2. **Explicit History Tracking**: The environment now returns a sequence of the last 20 actions in the current hand.
3. **Adaptive Training**: The agent trains against a mixed pool of opponents (Maniac, Nit, Random) to learn context-specific strategies.

### Architecture Diagram
```
State (Cards, Pot) --> MLP --> Feat1 --\
                                        (+)--> FC -> Q-Values
Action History     --> LSTM -> Feat2 --/
```

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from collections import deque
import random
import matplotlib.pyplot as plt
from typing import Optional, Tuple, List, Dict, Any

from pokerkit import Automation, NoLimitTexasHoldem, Card

# Constants
SEED = 42
MAX_HISTORY_LEN = 20
ACTION_EMBED_DIM = 8
HIDDEN_DIM_LSTM = 64
HIDDEN_DIM_MLP = 128

# Action Constants
ENV_FOLD = 0
ENV_CHECK_CALL = 1
ENV_BET_RAISE = 2
NUM_ACTIONS = 3

# History Tokens
ACT_PAD = 0
ACT_V_FOLD = 1
ACT_V_CHECK_CALL = 2
ACT_V_BET_RAISE = 3
OPP_FOLD = 4
OPP_CHECK_CALL = 5
OPP_BET_RAISE = 6
HISTORY_VOCAB_SIZE = 7

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
class PokerKitGymEnvV6(gym.Env):
    """
    Gymnasium wrapper for PokerKit's No-Limit Texas Hold'em.
    Returns 'action_history' in observation to allow opponent modeling.
    """
    
    def __init__(self, num_players: int = 2, starting_stack: int = 1000, 
                 small_blind: int = 5, big_blind: int = 10):
        super().__init__()
        self.num_players = num_players
        self.starting_stack = starting_stack
        self.small_blind = small_blind
        self.big_blind = big_blind
        
        self.game_state_dim = 52*2 + 52*5 + num_players + 1 + 1 + 4 + 1 
        
        self.observation_space = spaces.Dict({
            'game_state': spaces.Box(low=0, high=1, shape=(self.game_state_dim,), dtype=np.float32),
            'history': spaces.Box(low=0, high=HISTORY_VOCAB_SIZE-1, shape=(MAX_HISTORY_LEN,), dtype=np.int64)
        })
        self.action_space = spaces.Discrete(NUM_ACTIONS)
        self.state = None
        self.agent_player_index = 0
        self.action_history = deque(maxlen=MAX_HISTORY_LEN)
        
    def _card_to_index(self, card: Card) -> int:
        ranks = '23456789TJQKA'
        suits = 'cdhs'
        rank_idx = ranks.index(card.rank)
        suit_idx = suits.index(card.suit)
        return rank_idx * 4 + suit_idx
    
    def _encode_card(self, card: Optional[Card]) -> np.ndarray:
        encoding = np.zeros(52, dtype=np.float32)
        if card is not None:
            encoding[self._card_to_index(card)] = 1.0
        return encoding
    
    def _flatten_cards(self, cards) -> List:
        flat = []
        for item in cards:
            if hasattr(item, 'rank'):
                flat.append(item)
            else:
                flat.extend(self._flatten_cards(item))
        return flat
    
    def _get_observation(self) -> Dict[str, Any]:
        state_vector = []
        hole_cards = self._flatten_cards(self.state.hole_cards[self.agent_player_index])
        for i in range(2):
            if i < len(hole_cards):
                state_vector.extend(self._encode_card(hole_cards[i]))
            else:
                state_vector.extend(np.zeros(52, dtype=np.float32))
        
        board_cards = self._flatten_cards(self.state.board_cards)
        for i in range(5):
            if i < len(board_cards):
                state_vector.extend(self._encode_card(board_cards[i]))
            else:
                state_vector.extend(np.zeros(52, dtype=np.float32))
        
        for i in range(self.num_players):
            stack = self.state.stacks[i] / self.starting_stack
            state_vector.append(min(stack, 2.0))
        
        total_pot = sum(self.state.bets)
        state_vector.append(total_pot / (self.starting_stack * self.num_players))
        
        if self.state.actor_index is not None:
            state_vector.append(self.state.actor_index / max(1, self.num_players - 1))
        else:
            state_vector.append(0.0)
        
        street = [0.0, 0.0, 0.0, 0.0]
        num_board = len(board_cards)
        if num_board == 0: street[0] = 1.0
        elif num_board == 3: street[1] = 1.0
        elif num_board == 4: street[2] = 1.0
        else: street[3] = 1.0
        state_vector.extend(street)
        state_vector.append(float(self.agent_player_index))

        history_seq = list(self.action_history)
        pad_len = MAX_HISTORY_LEN - len(history_seq)
        history_padded = [ACT_PAD] * pad_len + history_seq
        
        return {
            'game_state': np.array(state_vector, dtype=np.float32),
            'history': np.array(history_padded, dtype=np.int64)
        }
    
    def _update_history(self, player_idx: int, action: int):
        if player_idx == self.agent_player_index:
            if action == ENV_FOLD: token = ACT_V_FOLD
            elif action == ENV_CHECK_CALL: token = ACT_V_CHECK_CALL
            else: token = ACT_V_BET_RAISE
        else:
            if action == ENV_FOLD: token = OPP_FOLD
            elif action == ENV_CHECK_CALL: token = OPP_CHECK_CALL
            else: token = OPP_BET_RAISE
        self.action_history.append(token)

    def _get_legal_actions(self) -> List[int]:
        legal = []
        if self.state.can_fold(): legal.append(ENV_FOLD)
        if self.state.can_check_or_call(): legal.append(ENV_CHECK_CALL)
        if self.state.can_complete_bet_or_raise_to(): legal.append(ENV_BET_RAISE)
        return legal if legal else [ENV_CHECK_CALL]
    
    def _execute_action(self, action: int) -> None:
        if action == ENV_FOLD:
            if self.state.can_fold(): self.state.fold()
            elif self.state.can_check_or_call(): self.state.check_or_call()
        elif action == ENV_CHECK_CALL:
            if self.state.can_check_or_call(): self.state.check_or_call()
            elif self.state.can_fold(): self.state.fold()
        elif action == ENV_BET_RAISE:
            if self.state.can_complete_bet_or_raise_to():
                min_r = self.state.min_completion_betting_or_raising_to_amount
                max_r = self.state.max_completion_betting_or_raising_to_amount
                self.state.complete_bet_or_raise_to(min(min_r * 2, max_r))
            elif self.state.can_check_or_call():
                self.state.check_or_call()
    
    def _run_automations(self) -> None:
        while self.state.can_burn_card(): self.state.burn_card('??')
        while self.state.can_deal_board(): self.state.deal_board()
        while self.state.can_push_chips(): self.state.push_chips()
        while self.state.can_pull_chips(): self.state.pull_chips()
    
    def reset(self, seed=None, options=None) -> Tuple[Dict, Dict]:
        super().reset(seed=seed)
        self.action_history.clear()
        self.state = NoLimitTexasHoldem.create_state(
            automations=(Automation.ANTE_POSTING, Automation.BET_COLLECTION, Automation.BLIND_OR_STRADDLE_POSTING, Automation.HOLE_CARDS_SHOWING_OR_MUCKING, Automation.HAND_KILLING, Automation.CHIPS_PUSHING, Automation.CHIPS_PULLING),
            ante_trimming_status=True,
            raw_antes={-1: 0},
            raw_blinds_or_straddles=(self.small_blind, self.big_blind),
            min_bet=self.big_blind,
            raw_starting_stacks=[self.starting_stack] * self.num_players,
            player_count=self.num_players,
        )
        while self.state.can_deal_hole(): self.state.deal_hole()
        self._run_automations()
        return self._get_observation(), {'legal_actions': self._get_legal_actions()}
    
    def step(self, action: int) -> Tuple[Dict, float, bool, bool, Dict]:
        if self.state.actor_index is not None:
             self._update_history(self.state.actor_index, action)
        self._execute_action(action)
        self._run_automations()
        done = self.state.status is False
        reward = 0.0
        if done:
            reward = (self.state.stacks[self.agent_player_index] - self.starting_stack) / self.big_blind
        obs = self._get_observation()
        info = {'legal_actions': self._get_legal_actions() if not done else []}
        return obs, reward, done, False, info

    def get_final_reward(self) -> float:
        return (self.state.stacks[self.agent_player_index] - self.starting_stack) / self.big_blind
    
    def update_opponent_history(self, action: int):
        opp_idx = 1 - self.agent_player_index
        self._update_history(opp_idx, action)

In [4]:
class DualBranchDRQN(nn.Module):
    def __init__(self, state_dim: int, action_dim: int):
        super().__init__()
        self.state_net = nn.Sequential(
            nn.Linear(state_dim, 256), nn.LayerNorm(256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU()
        )
        self.action_embedding = nn.Embedding(HISTORY_VOCAB_SIZE, ACTION_EMBED_DIM)
        self.lstm = nn.LSTM(input_size=ACTION_EMBED_DIM, hidden_size=HIDDEN_DIM_LSTM, batch_first=True)
        self.value_head = nn.Sequential(
            nn.Linear(128 + HIDDEN_DIM_LSTM, 128), nn.ReLU(),
            nn.Linear(128, action_dim)
        )
        
    def forward(self, state, history):
        s_feat = self.state_net(state)
        h_embed = self.action_embedding(history)
        lstm_out, (hn, cn) = self.lstm(h_embed)
        h_context = hn[-1]
        combined = torch.cat([s_feat, h_context], dim=1)
        return self.value_head(combined)

class ReplayBufferV6:
    def __init__(self, capacity=50000):
        self.buffer = deque(maxlen=capacity)
    def push(self, transition):
        self.buffer.append(transition)
    def sample(self, batch_size):
        return random.sample(self.buffer, min(len(self.buffer), batch_size))
    def __len__(self): return len(self.buffer)

class AdaptiveAgent:
    def __init__(self, state_dim, action_dim=NUM_ACTIONS, lr=1e-4):
        self.model = DualBranchDRQN(state_dim, action_dim).to(device)
        self.target_model = DualBranchDRQN(state_dim, action_dim).to(device)
        self.target_model.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.05
        self.epsilon_decay = 0.99995
        
    def select_action(self, obs, legal_actions, eval_mode=False):
        if not eval_mode and random.random() < self.epsilon:
            return random.choice(legal_actions)
        state_t = torch.FloatTensor(obs['game_state']).unsqueeze(0).to(device)
        h_t = torch.LongTensor(obs['history']).unsqueeze(0).to(device)
        with torch.no_grad():
            q_values = self.model(state_t, h_t)
        q_numpy = q_values.cpu().numpy().flatten()
        masked_q = np.full(NUM_ACTIONS, -np.inf)
        for a in legal_actions: masked_q[a] = q_numpy[a]
        return int(np.argmax(masked_q))

    def train(self, buffer, batch_size=64):
        if len(buffer) < batch_size: return None
        batch = buffer.sample(batch_size)
        states = torch.FloatTensor(np.array([t[0] for t in batch])).to(device)
        histories = torch.LongTensor(np.array([t[1] for t in batch])).to(device)
        actions = torch.LongTensor(np.array([t[2] for t in batch])).to(device)
        rewards = torch.FloatTensor(np.array([t[3] for t in batch])).to(device)
        next_states = torch.FloatTensor(np.array([t[4] for t in batch])).to(device)
        next_histories = torch.LongTensor(np.array([t[5] for t in batch])).to(device)
        dones = torch.FloatTensor(np.array([t[6] for t in batch])).to(device)
        
        current_q = self.model(states, histories).gather(1, actions.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            next_actions = self.model(next_states, next_histories).argmax(1).unsqueeze(1)
            target_q_next = self.target_model(next_states, next_histories).gather(1, next_actions).squeeze(1)
            target = rewards + (1 - dones) * self.gamma * target_q_next
        loss = F.mse_loss(current_q, target)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()
        return loss.item()
    
    def update_target(self): self.target_model.load_state_dict(self.model.state_dict())

In [5]:
class ManiacAgent:
    def select_action(self, legal_actions):
        if ENV_BET_RAISE in legal_actions: return ENV_BET_RAISE
        if ENV_CHECK_CALL in legal_actions: return ENV_CHECK_CALL
        return ENV_FOLD
class NitAgent:
    def select_action(self, legal_actions):
        # Fold to aggression 90% of time
        if ENV_FOLD in legal_actions and ENV_CHECK_CALL in legal_actions:
            if random.random() < 0.9: return ENV_FOLD
        if ENV_CHECK_CALL in legal_actions: return ENV_CHECK_CALL
        return ENV_FOLD
class RandomAgent:
    def select_action(self, legal_actions): return random.choice(legal_actions)

In [6]:
def train_v6(num_hands=25000):
    env = PokerKitGymEnvV6()
    agent = AdaptiveAgent(env.game_state_dim)
    buffer = ReplayBufferV6(capacity=50000)
    opponents = [ManiacAgent(), NitAgent(), RandomAgent()]
    opp_names = ['Maniac', 'Nit', 'Random']
    
    rewards_history = []
    
    print(f"Training V6 for {num_hands} hands against MIXED opponents...")
    
    for hand in range(num_hands):
        opp_idx = random.randint(0, 2)
        opponent = opponents[opp_idx]
        
        obs, info = env.reset()
        done = False
        episode_transitions = []
        pending_agent_obs = None
        pending_agent_action = None
        
        while not done:
            if env.state.actor_index == env.agent_player_index:
                if pending_agent_obs is not None:
                    episode_transitions.append((pending_agent_obs['game_state'], pending_agent_obs['history'], pending_agent_action, 0.0, obs['game_state'], obs['history'], False, info['legal_actions']))
                
                action = agent.select_action(obs, info['legal_actions'])
                pending_agent_obs = obs
                pending_agent_action = action
                obs, reward, done, _, info = env.step(action)
                
                if done:
                    episode_transitions.append((pending_agent_obs['game_state'], pending_agent_obs['history'], pending_agent_action, 0.0, obs['game_state'], obs['history'], True, []))
            else:
                action = opponent.select_action(info['legal_actions'])
                env.update_opponent_history(action)
                env._execute_action(action)
                env._run_automations()
                done = env.state.status is False
                if done and pending_agent_obs is not None:
                    term_obs = env._get_observation()
                    episode_transitions.append((pending_agent_obs['game_state'], pending_agent_obs['history'], pending_agent_action, 0.0, term_obs['game_state'], term_obs['history'], True, []))
                elif not done:
                    obs = env._get_observation()
                    info['legal_actions'] = env._get_legal_actions()
        
        final_reward = env.get_final_reward()
        rewards_history.append(final_reward)
        for i in range(len(episode_transitions)):
            s, h, a, r, ns, nh, d, l = episode_transitions[i]
            buffer.push((s, h, a, final_reward, ns, nh, d, l))
            
        if len(buffer) > 1000:
            agent.train(buffer)
        agent.epsilon = max(agent.epsilon_min, agent.epsilon * agent.epsilon_decay)
        if hand % 500 == 0: agent.update_target()
        if hand % 2500 == 0: print(f"Hand {hand} | Avg Reward: {np.mean(rewards_history[-100:]):.2f}")
            
    return agent

def evaluate(agent, num_hands=500):
    print("\n--- EVALUATION ---")
    env = PokerKitGymEnvV6()
    opps = {'Maniac': ManiacAgent(), 'Nit': NitAgent(), 'Random': RandomAgent()}
    for name, opponent in opps.items():
        total = 0
        for _ in range(num_hands):
            obs, info = env.reset()
            done = False
            while not done:
                if env.state.actor_index == env.agent_player_index:
                    action = agent.select_action(obs, info['legal_actions'], eval_mode=True)
                    obs, _, done, _, info = env.step(action)
                else:
                    action = opponent.select_action(info['legal_actions'])
                    env.update_opponent_history(action)
                    env._execute_action(action)
                    env._run_automations()
                    if env.state.status is False: done = True
                    else: 
                        obs = env._get_observation()
                        info['legal_actions'] = env._get_legal_actions()
            total += env.get_final_reward()
        print(f"Vs {name}: {total/num_hands:.2f} BB/hand")

agent = train_v6(30000)
evaluate(agent)

Training V6 for 30000 hands against MIXED opponents...
Hand 0 | Avg Reward: 14.00
Hand 2500 | Avg Reward: -1.47
Hand 5000 | Avg Reward: -2.10
Hand 7500 | Avg Reward: 0.56
Hand 10000 | Avg Reward: 1.65
Hand 12500 | Avg Reward: -2.98
Hand 15000 | Avg Reward: -1.67
Hand 17500 | Avg Reward: -1.30
Hand 20000 | Avg Reward: -2.08
Hand 22500 | Avg Reward: -3.95
Hand 25000 | Avg Reward: -3.31
Hand 27500 | Avg Reward: -2.42

--- EVALUATION ---
Vs Maniac: -1.27 BB/hand
Vs Nit: 0.53 BB/hand
Vs Random: -0.09 BB/hand
