# Poker Agent V10: Context-Aware Memory & Targeted Training

This notebook refines the V9 Equity model by introducing **Context Switching** and **Dynamic Parameters** to solve the "Learning Plateau" problem.

### The Problem with V9
V9 reached break-even (-0.75 BB) but plateaued. Why?
1. **Memory Contamination**: When switching from Maniac to Nit, the LSTM still had "Maniac Actions" in its history buffer for the first ~15 hands. This confused the agent.
2. **Noise**: Training against extreme bots (Maniac/Nit) distracted the network from learning nuanced strategy against realistic bots.

### V10 Solution: Opponent-Specific History
We now simulate a "Known Player Database". The environment maintains **separate history buffers** for each opponent ID.
- When we switch to `ValueBot`, we load the strict `ValueBot` history.
- When we switch to `BluffBot`, we load the `BluffBot` history.

This gives the LSTM a **clean, continuous signal** for every opponent, maximizing adaptation speed.


In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from collections import deque
import random
import matplotlib.pyplot as plt
from typing import Optional, Tuple, List, Dict, Any
import copy
from itertools import combinations

from pokerkit import Automation, NoLimitTexasHoldem, Card, StandardHighHand, Deck

# Constants
SEED = 42
MAX_HISTORY_LEN = 100
ACTION_EMBED_DIM = 16
HIDDEN_DIM_LSTM = 128

# Actions
ENV_FOLD = 0
ENV_CHECK_CALL = 1
ENV_BET_RAISE = 2
NUM_ACTIONS = 3

# History Tokens
ACT_PAD = 0
ACT_V_FOLD = 1
ACT_V_CHECK_CALL = 2
ACT_V_BET_RAISE = 3
OPP_FOLD = 4
OPP_CHECK_CALL = 5
OPP_BET_RAISE = 6
OUT_AGENT_WIN = 7
OUT_AGENT_LOSS = 8
OUT_TIE = 9
OUT_NEW_HAND = 10
HISTORY_VOCAB_SIZE = 11

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [9]:
# Universal Helpers 
def flatten_cards_list(items):
    out = []
    if isinstance(items, Card): return [items]
    for x in items:
        if isinstance(x, (list, tuple)): out.extend(flatten_cards_list(x))
        else: out.append(x)
    return out

def monte_carlo_equity(hole_cards: List[Card], board_cards: List[Card], iterations=30) -> float:
    if not hole_cards: return 0.5 
    wins = 0
    hole_cards = flatten_cards_list(hole_cards)
    board_cards = flatten_cards_list(board_cards)
    known_cards = set(hole_cards + board_cards)
    
    for _ in range(iterations):
        deck_cards = [c for c in Deck.STANDARD if c not in known_cards]
        random.shuffle(deck_cards)
        opp_hole = deck_cards[:2]
        needed_board = 5 - len(board_cards)
        sim_board = board_cards + deck_cards[2:2+needed_board]
        my_total = hole_cards + sim_board
        opp_total = opp_hole + sim_board
        my_hand = max(StandardHighHand(c) for c in combinations(my_total, 5))
        opp_hand = max(StandardHighHand(c) for c in combinations(opp_total, 5))
        if my_hand > opp_hand: wins += 1
        elif my_hand == opp_hand: wins += 0.5
    return wins / iterations

In [10]:
# Reusing V9 Env Logic but adding Memory Swapping
class ContextAwarePokerEnv(gym.Env):
    def __init__(self, num_players: int = 2, starting_stack: int = 1000, 
                 small_blind: int = 5, big_blind: int = 10):
        super().__init__()
        self.num_players = num_players
        self.starting_stack = starting_stack
        self.small_blind = small_blind
        self.big_blind = big_blind
        self.base_state_dim = 52*2 + 52*5 + num_players + 1 + 1 + 4 + 1
        self.game_state_dim = self.base_state_dim + 3 
        self.observation_space = spaces.Dict({
            'game_state': spaces.Box(low=0, high=1, shape=(self.game_state_dim,), dtype=np.float32),
            'history': spaces.Box(low=0, high=HISTORY_VOCAB_SIZE-1, shape=(MAX_HISTORY_LEN,), dtype=np.int64)
        })
        self.action_space = spaces.Discrete(NUM_ACTIONS)
        self.state = None
        self.agent_player_index = 0
        self.global_history = deque(maxlen=MAX_HISTORY_LEN)
        for _ in range(MAX_HISTORY_LEN): self.global_history.append(ACT_PAD)
        
        # --- V10: History Banks ---
        self.history_bank = {}
        self.current_opp_id = "default"
        
    def load_opponent_history(self, opp_id: str):
        # Save current history to previous opp bank
        if self.current_opp_id:
            self.history_bank[self.current_opp_id] = copy.deepcopy(self.global_history)
            
        # Load or Init new history
        self.current_opp_id = opp_id
        if opp_id in self.history_bank:
            self.global_history = copy.deepcopy(self.history_bank[opp_id])
        else:
            self.global_history = deque(maxlen=MAX_HISTORY_LEN)
            for _ in range(MAX_HISTORY_LEN): self.global_history.append(ACT_PAD)
            
    def _card_to_index(self, card: Card) -> int:
        ranks = '23456789TJQKA'; suits = 'cdhs'
        return ranks.index(card.rank) * 4 + suits.index(card.suit)
    
    def _encode_card(self, card: Optional[Card]) -> np.ndarray:
        encoding = np.zeros(52, dtype=np.float32)
        if card is not None: encoding[self._card_to_index(card)] = 1.0
        return encoding
    
    def _get_observation(self) -> Dict[str, Any]:
        hole = flatten_cards_list(self.state.hole_cards[self.agent_player_index])
        board = flatten_cards_list(self.state.board_cards)
        equity = monte_carlo_equity(hole, board, iterations=25)
        total_pot = sum(self.state.bets)
        current_bet = max(self.state.bets)
        my_bet = self.state.bets[self.agent_player_index]
        to_call = current_bet - my_bet
        pot_odds = 0.0
        if (total_pot + to_call) > 0: pot_odds = to_call / (total_pot + to_call)
        spr = 0.0
        if total_pot > 0: spr = min((self.state.stacks[self.agent_player_index] / total_pot) / 20.0, 1.0)
        
        state_vector = [equity, pot_odds, spr]
        for i in range(2):
            if i < len(hole): state_vector.extend(self._encode_card(hole[i]))
            else: state_vector.extend(np.zeros(52, dtype=np.float32))
        for i in range(5):
            if i < len(board): state_vector.extend(self._encode_card(board[i]))
            else: state_vector.extend(np.zeros(52, dtype=np.float32))
        for i in range(self.num_players):
            state_vector.append(min(self.state.stacks[i] / self.starting_stack, 2.0))
        state_vector.append(total_pot / (self.starting_stack * self.num_players))
        state_vector.append(self.state.actor_index / max(1, self.num_players - 1) if self.state.actor_index is not None else 0.0)
        street = [0.0]*4
        if len(board) == 0: street[0] = 1.0
        elif len(board) == 3: street[1] = 1.0
        elif len(board) == 4: street[2] = 1.0
        else: street[3] = 1.0
        state_vector.extend(street)
        state_vector.append(float(self.agent_player_index))
        return {'game_state': np.array(state_vector, dtype=np.float32), 'history': np.array(list(self.global_history), dtype=np.int64)}
    
    def _update_history(self, player_idx: int, action: int):
        if player_idx == self.agent_player_index:
            token = [ACT_V_FOLD, ACT_V_CHECK_CALL, ACT_V_BET_RAISE][action]
        else:
            token = [OPP_FOLD, OPP_CHECK_CALL, OPP_BET_RAISE][action]
        self.global_history.append(token)

    def append_outcome_token(self, final_reward: float):
        if final_reward > 0: self.global_history.append(OUT_AGENT_WIN)
        elif final_reward < 0: self.global_history.append(OUT_AGENT_LOSS)
        else: self.global_history.append(OUT_TIE)

    def _get_legal_actions(self) -> List[int]:
        legal = []
        if self.state.can_fold(): legal.append(ENV_FOLD)
        if self.state.can_check_or_call(): legal.append(ENV_CHECK_CALL)
        if self.state.can_complete_bet_or_raise_to(): legal.append(ENV_BET_RAISE)
        return legal if legal else [ENV_CHECK_CALL]
    
    def _execute_action(self, action: int) -> None:
        if action == ENV_FOLD:
            if self.state.can_fold(): self.state.fold()
            elif self.state.can_check_or_call(): self.state.check_or_call()
        elif action == ENV_CHECK_CALL:
            if self.state.can_check_or_call(): self.state.check_or_call()
            elif self.state.can_fold(): self.state.fold()
        elif action == ENV_BET_RAISE:
            if self.state.can_complete_bet_or_raise_to():
                min_r = self.state.min_completion_betting_or_raising_to_amount
                max_r = self.state.max_completion_betting_or_raising_to_amount
                self.state.complete_bet_or_raise_to(min(min_r * 2, max_r))
            elif self.state.can_check_or_call():
                self.state.check_or_call()
    
    def _run_automations(self) -> None:
        while self.state.can_burn_card(): self.state.burn_card('??')
        while self.state.can_deal_board(): self.state.deal_board()
        while self.state.can_push_chips(): self.state.push_chips()
        while self.state.can_pull_chips(): self.state.pull_chips()
    
    def reset(self, seed=None, options=None) -> Tuple[Dict, Dict]:
        self.global_history.append(OUT_NEW_HAND)
        self.state = NoLimitTexasHoldem.create_state(
            automations=(Automation.ANTE_POSTING, Automation.BET_COLLECTION, Automation.BLIND_OR_STRADDLE_POSTING, Automation.HOLE_CARDS_SHOWING_OR_MUCKING, Automation.HAND_KILLING, Automation.CHIPS_PUSHING, Automation.CHIPS_PULLING),
            ante_trimming_status=True, raw_antes={-1: 0}, raw_blinds_or_straddles=(self.small_blind, self.big_blind),
            min_bet=self.big_blind, raw_starting_stacks=[self.starting_stack] * self.num_players, player_count=self.num_players)
        while self.state.can_deal_hole(): self.state.deal_hole()
        self._run_automations()
        return self._get_observation(), {'legal_actions': self._get_legal_actions()}
    
    def step(self, action: int) -> Tuple[Dict, float, bool, bool, Dict]:
        if self.state.actor_index is not None: self._update_history(self.state.actor_index, action)
        self._execute_action(action)
        self._run_automations()
        done = self.state.status is False
        reward = (self.state.stacks[self.agent_player_index] - self.starting_stack) / self.big_blind if done else 0.0
        return self._get_observation(), reward, done, False, {'legal_actions': self._get_legal_actions() if not done else []}
    
    def get_final_reward(self) -> float:
        return (self.state.stacks[self.agent_player_index] - self.starting_stack) / self.big_blind
    
    def update_opponent_history(self, action: int):
        self._update_history(1 - self.agent_player_index, action)

In [11]:
class FeatureAwareV10(nn.Module):
    def __init__(self, state_dim: int, action_dim: int):
        super().__init__()
        self.state_net = nn.Sequential(
            nn.Linear(state_dim, 256), nn.LayerNorm(256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU()
        )
        self.action_embedding = nn.Embedding(HISTORY_VOCAB_SIZE, ACTION_EMBED_DIM)
        self.lstm = nn.LSTM(input_size=ACTION_EMBED_DIM, hidden_size=HIDDEN_DIM_LSTM, batch_first=True)
        self.value_head = nn.Sequential(
            nn.Linear(128 + HIDDEN_DIM_LSTM, 256), nn.ReLU(),
            nn.Linear(256, action_dim)
        )
    def forward(self, state, history):
        s_feat = self.state_net(state)
        h_embed = self.action_embedding(history)
        lstm_out, _ = self.lstm(h_embed)
        combined = torch.cat([s_feat, lstm_out[:, -1, :]], dim=1)
        return self.value_head(combined)

class ReplayBufferV10:
    def __init__(self, capacity=50000):
        self.buffer = deque(maxlen=capacity)
    def push(self, transition): self.buffer.append(transition)
    def sample(self, batch_size): return random.sample(self.buffer, min(len(self.buffer), batch_size))
    def __len__(self): return len(self.buffer)

class HybridAgentV10:
    def __init__(self, state_dim, action_dim=NUM_ACTIONS, lr=1e-4):
        self.model = FeatureAwareV10(state_dim, action_dim).to(device)
        self.target_model = FeatureAwareV10(state_dim, action_dim).to(device)
        self.target_model.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.gamma, self.epsilon, self.epsilon_min = 0.99, 1.0, 0.05
        
    def select_action(self, obs, legal_actions, eval_mode=False):
        if not eval_mode and random.random() < self.epsilon:
            return random.choice(legal_actions)
        state_t = torch.FloatTensor(obs['game_state']).unsqueeze(0).to(device)
        h_t = torch.LongTensor(obs['history']).unsqueeze(0).to(device)
        with torch.no_grad():
            q_values = self.model(state_t, h_t)
        q_numpy = q_values.cpu().numpy().flatten()
        masked_q = np.full(NUM_ACTIONS, -np.inf)
        for a in legal_actions: masked_q[a] = q_numpy[a]
        return int(np.argmax(masked_q))

    def train(self, buffer, batch_size=64):
        if len(buffer) < batch_size: return None
        batch = buffer.sample(batch_size)
        states = torch.FloatTensor(np.array([t[0] for t in batch])).to(device)
        histories = torch.LongTensor(np.array([t[1] for t in batch])).to(device)
        actions = torch.LongTensor(np.array([t[2] for t in batch])).to(device)
        rewards = torch.FloatTensor(np.array([t[3] for t in batch])).to(device)
        next_states = torch.FloatTensor(np.array([t[4] for t in batch])).to(device)
        next_histories = torch.LongTensor(np.array([t[5] for t in batch])).to(device)
        dones = torch.FloatTensor(np.array([t[6] for t in batch])).to(device)
        current_q = self.model(states, histories).gather(1, actions.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            next_actions = self.model(next_states, next_histories).argmax(1).unsqueeze(1)
            target_q_next = self.target_model(next_states, next_histories).gather(1, next_actions).squeeze(1)
            target = rewards + (1 - dones) * self.gamma * target_q_next
        loss = F.mse_loss(current_q, target)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()
        return loss.item()
    def update_target(self): self.target_model.load_state_dict(self.model.state_dict())

In [12]:
# Realistic Bots (Value, Bluff, Balanced) ONLY
class HeuristicBot:
    def __init__(self, player_idx=1): self.player_idx = player_idx
    def get_equity(self, state):
        return monte_carlo_equity(flatten_cards_list(state.hole_cards[self.player_idx]), flatten_cards_list(state.board_cards), iterations=40)

class ValueBot(HeuristicBot):
    def select_action(self, state, legal_actions):
        equity = self.get_equity(state)
        if equity > 0.75 and ENV_BET_RAISE in legal_actions: return ENV_BET_RAISE
        if equity > 0.50 and ENV_CHECK_CALL in legal_actions: return ENV_CHECK_CALL
        if ENV_FOLD in legal_actions: return ENV_FOLD
        return ENV_CHECK_CALL

class BluffBot(HeuristicBot):
    def select_action(self, state, legal_actions):
        equity = self.get_equity(state)
        if equity > 0.70 and ENV_BET_RAISE in legal_actions: return ENV_BET_RAISE
        if equity < 0.40 and random.random() < 0.35 and ENV_BET_RAISE in legal_actions: return ENV_BET_RAISE
        if equity > 0.45 and ENV_CHECK_CALL in legal_actions: return ENV_CHECK_CALL
        if ENV_FOLD in legal_actions: return ENV_FOLD
        return ENV_CHECK_CALL

class BalancedBot(HeuristicBot):
    def select_action(self, state, legal_actions):
        equity = self.get_equity(state)
        if equity > 0.8 and ENV_BET_RAISE in legal_actions: return ENV_BET_RAISE
        if ENV_CHECK_CALL in legal_actions:
            if equity > 0.6: return ENV_CHECK_CALL
            to_call = max(state.bets) - state.bets[self.player_idx]
            pot_odds = to_call / (sum(state.bets) + to_call + 1e-5)
            if equity > pot_odds + 0.05: return ENV_CHECK_CALL # +5% edge required
        if ENV_FOLD in legal_actions: return ENV_FOLD
        return ENV_CHECK_CALL

In [13]:
def train_v10(num_hands=25000):
    env = ContextAwarePokerEnv()
    agent = HybridAgentV10(env.game_state_dim)
    buffer = ReplayBufferV10(capacity=50000)
    
    # REFINED Opponent Pool (No Maniac/Nit)
    opps = {
        'ValueBot': ValueBot(),
        'BluffBot': BluffBot(),
        'Balanced': BalancedBot()
    }
    opp_names = list(opps.keys())
    stats = {name: {'rewards': [], 'wins': 0, 'hands': 0} for name in opp_names}
    
    print(f"Training V10 (Targeted) for {num_hands} hands...")
    current_opp_name = 'ValueBot'
    
    # Dynamic Epsilon Decay Calculation
    # Goal: Reach 0.05 at 80% of training (hand 20000)
    # Formula: 1.0 * (decay)^20000 = 0.05  => decay = 0.05^(1/20000)
    decay_steps = int(num_hands * 0.8)
    agent.epsilon_decay = (0.05) ** (1 / decay_steps)
    print(f"Dynamic Epsilon Decay Factor: {agent.epsilon_decay:.6f}")

    for hand in range(num_hands):
        if hand % 50 == 0: 
            current_opp_name = random.choice(opp_names)
            # CRITICAL: Switch History Context
            env.load_opponent_history(current_opp_name)
            
        opponent = opps[current_opp_name]
        
        obs, info = env.reset()
        done = False
        episode_transitions = []
        pending_agent_obs = None
        pending_agent_action = None
        
        while not done:
            if env.state.actor_index == env.agent_player_index:
                if pending_agent_obs is not None:
                    episode_transitions.append((pending_agent_obs['game_state'], pending_agent_obs['history'], pending_agent_action, 0.0, obs['game_state'], obs['history'], False, info['legal_actions']))
                action = agent.select_action(obs, info['legal_actions'])
                pending_agent_obs, pending_agent_action = obs, action
                obs, reward, done, _, info = env.step(action)
            else:
                action = opponent.select_action(env.state, info['legal_actions'])
                env.update_opponent_history(action)
                env._execute_action(action)
                env._run_automations()
                done = env.state.status is False
                if not done:
                    obs = env._get_observation()
                    info['legal_actions'] = env._get_legal_actions()
        
        final_reward = env.get_final_reward()
        env.append_outcome_token(final_reward)
        term_obs = env._get_observation()
        if pending_agent_obs is not None:
             episode_transitions.append((pending_agent_obs['game_state'], pending_agent_obs['history'], pending_agent_action, 0.0, term_obs['game_state'], term_obs['history'], True, []))
        
        stats[current_opp_name]['rewards'].append(final_reward)
        stats[current_opp_name]['hands'] += 1
        if final_reward > 0: stats[current_opp_name]['wins'] += 1
        
        for i in range(len(episode_transitions)):
            s, h, a, r, ns, nh, d, l = episode_transitions[i]
            buffer.push((s, h, a, final_reward, ns, nh, d, l))
            
        if len(buffer) > 1000: agent.train(buffer)
        agent.epsilon = max(agent.epsilon_min, agent.epsilon * agent.epsilon_decay)
        if hand % 500 == 0: agent.update_target()
        
        if hand % 2500 == 0 and hand > 0:
            print(f"\n=== Checkpoint Hand {hand} (Eps: {agent.epsilon:.2f}) ===")
            for name in opp_names:
                 if stats[name]['hands'] > 0:
                     print(f"{name}: Avg {np.mean(stats[name]['rewards'][-200:]):.2f} BB | Win {stats[name]['wins']/stats[name]['hands']:.1%}")
    return agent, stats

def evaluate_agent_v10(agent):
    print("\n=== FINAL V10 EVALUATION (Epsilon=0.0) ===")
    agent.epsilon = 0.0
    env = ContextAwarePokerEnv()
    opps = {'ValueBot': ValueBot(), 'BluffBot': BluffBot(), 'Balanced': BalancedBot()}
    
    for name, opponent in opps.items():
        env.load_opponent_history(name) # Start fresh for this opp
        rewards, wins = [], 0
        for _ in range(2000):
            obs, info = env.reset()
            done = False
            while not done:
                if env.state.actor_index == env.agent_player_index:
                    action = agent.select_action(obs, info['legal_actions'])
                    obs, _, done, _, info = env.step(action)
                else:
                    action = opponent.select_action(env.state, info['legal_actions'])
                    env.update_opponent_history(action)
                    env._execute_action(action)
                    env._run_automations()
                    done = env.state.status is False
                    if not done: obs = env._get_observation(); info['legal_actions'] = env._get_legal_actions()
            final_reward = env.get_final_reward()
            env.append_outcome_token(final_reward)
            rewards.append(final_reward)
            if final_reward > 0: wins += 1
        print(f"Vs {name}: Avg {np.mean(rewards):.2f} BB | Total: {sum(rewards):.1f} BB | Win: {wins/2000:.1%}")

agent, stats = train_v10(25000)
evaluate_agent_v10(agent)

Training V10 (Targeted) for 25000 hands...
Dynamic Epsilon Decay Factor: 0.999850

=== Checkpoint Hand 2500 (Eps: 0.69) ===
ValueBot: Avg -0.14 BB | Win 82.2%
BluffBot: Avg -4.81 BB | Win 61.9%
Balanced: Avg -2.22 BB | Win 70.5%

=== Checkpoint Hand 5000 (Eps: 0.47) ===
ValueBot: Avg 0.52 BB | Win 79.3%
BluffBot: Avg -2.32 BB | Win 61.1%
Balanced: Avg -1.99 BB | Win 68.5%

=== Checkpoint Hand 7500 (Eps: 0.33) ===
ValueBot: Avg -0.49 BB | Win 79.3%
BluffBot: Avg -1.89 BB | Win 60.3%
Balanced: Avg -0.53 BB | Win 65.0%

=== Checkpoint Hand 10000 (Eps: 0.22) ===
ValueBot: Avg -0.03 BB | Win 77.9%
BluffBot: Avg -1.00 BB | Win 59.4%
Balanced: Avg -1.14 BB | Win 62.7%

=== Checkpoint Hand 12500 (Eps: 0.15) ===
ValueBot: Avg -0.00 BB | Win 77.3%
BluffBot: Avg -0.60 BB | Win 58.0%
Balanced: Avg -0.30 BB | Win 61.3%

=== Checkpoint Hand 15000 (Eps: 0.11) ===
ValueBot: Avg -0.59 BB | Win 76.6%
BluffBot: Avg -2.47 BB | Win 57.4%
Balanced: Avg 0.11 BB | Win 60.4%

=== Checkpoint Hand 17500 (Eps: 0.