# Poker Bot Evaluation

Evaluate `YiPz/Qwen3-4B-pokerbench-sft` vs base `unsloth/Qwen3-4B-Thinking-2507`

**Features:**
- Auto-detects GPU (T4/L4/A100) and selects optimal quantization
- Configurable: override quantization, hands, sessions
- Tracks win rate, BB/100, VPIP, PFR, and more

**Quick Start:**
1. Run all cells in order
2. Results saved to `/content/eval_results/`

## 1. Install Dependencies

In [None]:
!pip install -q transformers accelerate bitsandbytes torch pokerkit
!pip install -q tqdm pandas matplotlib

## 2. Hardware Detection & Quantization Config

In [None]:
import subprocess
from dataclasses import dataclass
from enum import Enum
from typing import Optional

import torch
from transformers import BitsAndBytesConfig


class Quantization(Enum):
    INT4 = "4bit"
    INT8 = "8bit"
    FP16 = "fp16"


@dataclass
class HardwareConfig:
    gpu_name: str
    vram_gb: float
    quantization: Quantization

    @classmethod
    def detect(cls, override_quant: Optional[Quantization] = None) -> "HardwareConfig":
        try:
            result = subprocess.run(
                ["nvidia-smi", "--query-gpu=name,memory.total", "--format=csv,noheader,nounits"],
                capture_output=True, text=True, check=True
            )
            gpu_name, vram_mb = result.stdout.strip().split(", ")
            vram_gb = float(vram_mb) / 1024
        except:
            gpu_name, vram_gb = "Unknown", 16.0

        if override_quant:
            quant = override_quant
        elif "A100" in gpu_name:
            quant = Quantization.FP16
        elif "L4" in gpu_name:
            quant = Quantization.INT8
        else:
            quant = Quantization.INT4

        return cls(gpu_name=gpu_name, vram_gb=vram_gb, quantization=quant)

    def get_bnb_config(self) -> Optional[BitsAndBytesConfig]:
        if self.quantization == Quantization.INT4:
            return BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
            )
        elif self.quantization == Quantization.INT8:
            return BitsAndBytesConfig(load_in_8bit=True)
        return None


# Detect hardware
# Override: hw = HardwareConfig.detect(override_quant=Quantization.INT4)
hw = HardwareConfig.detect()
print(f"GPU: {hw.gpu_name} ({hw.vram_gb:.0f}GB)")
print(f"Quantization: {hw.quantization.value}")

## 3. Load Models

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

MODELS = {
    "SFT": "YiPz/Qwen3-4B-pokerbench-sft",
    "Base": "unsloth/Qwen3-4B-Thinking-2507",
}

loaded_models = {}
tokenizers = {}
bnb_config = hw.get_bnb_config()

for name, model_id in MODELS.items():
    print(f"Loading {name}: {model_id} ({hw.quantization.value})...")

    tokenizers[name] = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

    load_kwargs = {
        "device_map": "auto",
        "trust_remote_code": True,
        "torch_dtype": torch.float16,
    }
    if bnb_config:
        load_kwargs["quantization_config"] = bnb_config

    loaded_models[name] = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs)

    allocated = torch.cuda.memory_allocated() / 1024**3
    print(f"  Loaded. VRAM: {allocated:.1f}GB")

print(f"\nTotal VRAM: {torch.cuda.memory_allocated() / 1024**3:.1f}GB / {hw.vram_gb:.0f}GB")

## 4. Evaluation Config

In [None]:
# Configuration
NUM_HANDS = 100        # Hands per session
NUM_SESSIONS = 3       # Sessions to run
STARTING_STACK = 10000
SMALL_BLIND = 50
BIG_BLIND = 100
VERBOSE = False        # Print each action
SEED = 42

print(f"Config:")
print(f"  Hands/session: {NUM_HANDS}")
print(f"  Sessions: {NUM_SESSIONS}")
print(f"  Total hands: {NUM_HANDS * NUM_SESSIONS}")
print(f"  Stack: {STARTING_STACK}")
print(f"  Blinds: {SMALL_BLIND}/{BIG_BLIND}")

## 5. Evaluation Classes

In [None]:
import time
import random
import json
import re
import os
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Any, Callable, Optional
from datetime import datetime
from pathlib import Path

from pokerkit import NoLimitTexasHoldem, Automation


# ============= Hand Logger =============

SUIT_SYMBOLS = {"c": "‚ô£", "d": "‚ô¶", "h": "‚ô•", "s": "‚ô†"}


class HandLogger:
    """Logs sampled poker hands to a file in a pretty format."""

    def __init__(self, log_dir: str = "logs", sample_rate: int = 100):
        self.log_dir = log_dir
        self.sample_rate = sample_rate
        self.session_file: Optional[str] = None
        self._current_hand: Optional[Dict] = None

        os.makedirs(log_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.session_file = os.path.join(log_dir, f"poker_session_{timestamp}.log")

    def should_log(self, hand_num: int) -> bool:
        return hand_num % self.sample_rate == 0

    def start_hand(self, hand_num, player_names, stacks, hole_cards, button_pos, sb_pos, bb_pos, blinds):
        if not self.should_log(hand_num):
            self._current_hand = None
            return
        self._current_hand = {
            "hand_num": hand_num, "timestamp": datetime.now().isoformat(),
            "player_names": player_names, "stacks": stacks.copy(),
            "hole_cards": hole_cards, "button_pos": button_pos,
            "sb_pos": sb_pos, "bb_pos": bb_pos, "blinds": blinds,
            "streets": [], "current_street": None, "board": [],
            "final_stacks": None, "winners": [],
        }

    def start_street(self, street_name: str, board: List[str]):
        if self._current_hand is None:
            return
        self._current_hand["current_street"] = {"name": street_name, "board": [str(c) for c in board], "actions": []}
        self._current_hand["board"] = [str(c) for c in board]

    def log_action(self, player_idx: int, player_name: str, action_str: str):
        if self._current_hand is None or self._current_hand["current_street"] is None:
            return
        self._current_hand["current_street"]["actions"].append({
            "player_idx": player_idx, "player_name": player_name, "action": action_str,
        })

    def end_street(self):
        if self._current_hand is None or self._current_hand["current_street"] is None:
            return
        self._current_hand["streets"].append(self._current_hand["current_street"])
        self._current_hand["current_street"] = None

    def end_hand(self, final_stacks: List[int], winners: List[int], chips_won: int):
        if self._current_hand is None:
            return
        if self._current_hand["current_street"] is not None:
            self.end_street()
        self._current_hand["final_stacks"] = final_stacks
        self._current_hand["winners"] = winners
        self._current_hand["chips_won"] = chips_won
        self._write_hand()
        self._current_hand = None

    def _format_card(self, card: str) -> str:
        card_str = str(card)
        if "(" in card_str and ")" in card_str:
            start = card_str.rfind("(") + 1
            end = card_str.rfind(")")
            card_str = card_str[start:end]
        if len(card_str) >= 2:
            rank = card_str[:-1].upper()
            suit = card_str[-1].lower()
            return f"{rank}{SUIT_SYMBOLS.get(suit, suit)}"
        return card_str

    def _format_cards(self, cards: List) -> str:
        if not cards:
            return "[ ]"
        return "[" + " ".join(self._format_card(c) for c in cards) + "]"

    def _pad_line(self, content: str, width: int = 58) -> str:
        padding = max(0, width - len(content))
        return f"‚ïë{content}" + " " * padding + "‚ïë"

    def _write_hand(self):
        if self._current_hand is None:
            return
        h = self._current_hand
        lines = ["", "‚ïî" + "‚ïê" * 58 + "‚ïó"]
        lines.append(f"‚ïë  üé¥ HAND #{h['hand_num']:>4}  ‚îÇ  {h['timestamp'][:19]}  ‚ïë")
        lines.append("‚ï†" + "‚ïê" * 58 + "‚ï£")
        lines.append("‚ïë  PLAYERS" + " " * 49 + "‚ïë")
        lines.append("‚ïü" + "‚îÄ" * 58 + "‚ï¢")
        
        for i, name in enumerate(h["player_names"]):
            pos_tag = " [BTN]" if i == h["button_pos"] else " [SB]" if i == h["sb_pos"] else " [BB]" if i == h["bb_pos"] else ""
            hole = self._format_cards(h["hole_cards"][i]) if h["hole_cards"][i] else "[?? ??]"
            stack_str = f"${h['stacks'][i]:,}"
            line = f"  {name[:12]:<12} {hole:<12} {stack_str:>10}{pos_tag:<8}"
            lines.append(self._pad_line(line))

        lines.append("‚ï†" + "‚ïê" * 58 + "‚ï£")
        sb, bb = h["blinds"]
        lines.append(self._pad_line(f"  Blinds: ${sb}/${bb}"))
        lines.append("‚ï†" + "‚ïê" * 58 + "‚ï£")

        for street in h["streets"]:
            board_str = self._format_cards(street["board"]) if street["board"] else ""
            lines.append(self._pad_line(f"  ‚ñ∂ {street['name'].upper()} {board_str}"))
            lines.append("‚ïü" + "‚îÄ" * 58 + "‚ï¢")
            for action in street["actions"]:
                lines.append(self._pad_line(f"    {action['player_name'][:12]:<12}: {action['action']}"))
            if not street["actions"]:
                lines.append(self._pad_line("    (no actions)"))
            lines.append("‚ïü" + "‚îÄ" * 58 + "‚ï¢")

        if h["board"]:
            lines.append(self._pad_line(f"  Final Board: {self._format_cards(h['board'])}"))
            lines.append("‚ï†" + "‚ïê" * 58 + "‚ï£")

        lines.append("‚ïë  üèÜ RESULTS" + " " * 46 + "‚ïë")
        lines.append("‚ïü" + "‚îÄ" * 58 + "‚ï¢")
        if h["winners"]:
            winner_names = [h["player_names"][w] for w in h["winners"]]
            lines.append(self._pad_line(f"  Winner: {', '.join(winner_names)} (+${h['chips_won']:,})"))
        lines.append("‚ïü" + "‚îÄ" * 58 + "‚ï¢")
        lines.append(self._pad_line("  Final Stacks:"))
        for i, name in enumerate(h["player_names"]):
            if h["final_stacks"]:
                diff = h["final_stacks"][i] - h["stacks"][i]
                diff_str = f"+{diff}" if diff > 0 else str(diff)
                lines.append(self._pad_line(f"    {name[:12]:<12}: ${h['final_stacks'][i]:,} ({diff_str})"))
        lines.append("‚ïö" + "‚ïê" * 58 + "‚ïù")
        lines.append("")
        
        with open(self.session_file, "a", encoding="utf-8") as f:
            f.write("\n".join(lines) + "\n")

    def log_session_start(self, num_players, starting_stack, blinds, num_hands):
        lines = []
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        lines.append("‚îå" + "‚îÄ" * 58 + "‚îê")
        lines.append("‚îÇ" + " " * 20 + "üÉè POKER SESSION üÉè" + " " * 19 + "‚îÇ")
        lines.append("‚îú" + "‚îÄ" * 58 + "‚î§")
        lines.append(f"‚îÇ  Started: {timestamp}" + " " * 27 + "‚îÇ")
        lines.append(f"‚îÇ  Players: {num_players}" + " " * (47 - len(str(num_players))) + "‚îÇ")
        lines.append(f"‚îÇ  Starting Stack: ${starting_stack:,}" + " " * max(0, 40 - len(str(starting_stack))) + "‚îÇ")
        lines.append(f"‚îÇ  Blinds: ${blinds[0]}/${blinds[1]}" + " " * max(0, 45 - len(str(blinds[0])) - len(str(blinds[1]))) + "‚îÇ")
        lines.append(f"‚îÇ  Planned Hands: {num_hands}" + " " * max(0, 41 - len(str(num_hands))) + "‚îÇ")
        lines.append(f"‚îÇ  Sample Rate: every {self.sample_rate} hands" + " " * max(0, 36 - len(str(self.sample_rate))) + "‚îÇ")
        lines.append("‚îî" + "‚îÄ" * 58 + "‚îò")
        with open(self.session_file, "w", encoding="utf-8") as f:
            f.write("\n".join(lines) + "\n")

    def log_session_end(self, hands_played, final_stacks, player_names, starting_stack):
        lines = [""]
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        lines.append("‚îå" + "‚îÄ" * 58 + "‚îê")
        lines.append("‚îÇ" + " " * 18 + "üìä SESSION SUMMARY üìä" + " " * 18 + "‚îÇ")
        lines.append("‚îú" + "‚îÄ" * 58 + "‚î§")
        lines.append(f"‚îÇ  Ended: {timestamp}" + " " * 29 + "‚îÇ")
        lines.append(f"‚îÇ  Hands Played: {hands_played}" + " " * max(0, 42 - len(str(hands_played))) + "‚îÇ")
        lines.append("‚îú" + "‚îÄ" * 58 + "‚î§")
        lines.append("‚îÇ  Final Results:" + " " * 42 + "‚îÇ")
        for i, name in enumerate(player_names):
            diff = final_stacks[i] - starting_stack
            diff_str = f"+{diff}" if diff > 0 else str(diff)
            emoji = "üèÜ" if diff > 0 else "üìâ" if diff < 0 else "‚ûñ"
            line = f"‚îÇ    {emoji} {name[:12]:<12}: ${final_stacks[i]:,} ({diff_str})"
            lines.append(line + " " * max(0, 58 - len(line) + 1) + "‚îÇ")
        lines.append("‚îî" + "‚îÄ" * 58 + "‚îò")
        with open(self.session_file, "a", encoding="utf-8") as f:
            f.write("\n".join(lines) + "\n")
        print(f"\nüìù Hand log saved to: {self.session_file}")


# ============= Action parsing (from existing codebase) =============

@dataclass
class ParsedAction:
    action_type: str
    amount: int = None

    def __str__(self):
        if self.amount:
            return f"{self.action_type.title()} {self.amount}"
        return self.action_type.title()


class ActionParser:
    RE_ACTION_TAG = re.compile(r"<action>\s*([^<]+?)\s*</action>", re.IGNORECASE)
    RE_FOLD = re.compile(r"\b(f|fold)\b", re.IGNORECASE)
    RE_CC = re.compile(r"\b(cc|call|check)\b", re.IGNORECASE)
    RE_CBR = re.compile(r"\b(?:cbr|bet|raise)(?:\s+(?:to\s+)?(\d+))?\b", re.IGNORECASE)
    RE_ALL_IN = re.compile(r"\b(all[\-\s]?in|shove)\b", re.IGNORECASE)

    def parse(self, text: str, can_check: bool = True, stack: int = 0) -> ParsedAction:
        match = self.RE_ACTION_TAG.search(text)
        content = match.group(1).strip() if match else text

        if self.RE_ALL_IN.search(content):
            return ParsedAction("all_in", stack)
        if self.RE_FOLD.search(content):
            return ParsedAction("fold")
        if self.RE_CC.search(content):
            return ParsedAction("check" if can_check else "call")

        cbr = self.RE_CBR.search(content)
        if cbr:
            amt = int(cbr.group(1)) if cbr.group(1) else stack
            return ParsedAction("raise", amt)

        return ParsedAction("check" if can_check else "fold")


# Action record
@dataclass
class ActionRecord:
    hand_id: int
    street: str
    hole_cards: Tuple[str, str]
    board: List[str]
    pot: int
    to_call: int
    stack: int
    position: str
    action: ParsedAction
    thinking: str
    response: str
    latency_ms: float
    tokens_generated: int


# Player
class TransformersPlayer:
    SYSTEM_PROMPT = """You are an expert poker player. Analyze the game state and decide your action.

Output format: <action>ACTION</action>
- <action>f</action> = fold
- <action>cc</action> = check or call
- <action>cbr AMOUNT</action> = bet or raise to AMOUNT

Think step by step, then output exactly ONE action tag."""

    THINK_END_TOKEN_ID = 151668

    def __init__(self, name: str, model: Any, tokenizer: Any, temperature: float = 0.6, max_new_tokens: int = 512):
        self.name = name
        self.model = model
        self.tokenizer = tokenizer
        self.temperature = temperature
        self.max_new_tokens = max_new_tokens
        self.parser = ActionParser()
        self.action_history: List[ActionRecord] = []
        self._hand_id = 0
        self._street = "preflop"

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def set_hand_context(self, hand_id: int, street: str):
        self._hand_id = hand_id
        self._street = street

    def get_action(self, hole_cards, board, pot, to_call, stack, position, num_players) -> ParsedAction:
        start = time.perf_counter()
        prompt = self._build_prompt(hole_cards, board, pot, to_call, stack, position, num_players)

        try:
            thinking, response, tokens_gen = self._generate(prompt)
            can_check = to_call == 0
            action = self.parser.parse(response, can_check, stack)
        except Exception as e:
            thinking, response, tokens_gen = "", f"ERROR: {e}", 0
            action = ParsedAction("fold")

        latency = (time.perf_counter() - start) * 1000

        self.action_history.append(ActionRecord(
            hand_id=self._hand_id, street=self._street, hole_cards=hole_cards,
            board=list(board), pot=pot, to_call=to_call, stack=stack,
            position=position, action=action, thinking=thinking[:1000],
            response=response[:500], latency_ms=latency, tokens_generated=tokens_gen,
        ))
        return action

    def _build_prompt(self, hole_cards, board, pot, to_call, stack, position, num_players) -> str:
        board_str = " ".join(board) if board else "None"
        user_msg = f"""Game: {num_players}-handed No-Limit Hold'em
Position: {position}
Stack: {stack}
Hole Cards: {hole_cards[0]} {hole_cards[1]}
Board: {board_str}
Pot: {pot}
To Call: {to_call}

What is your action?"""
        messages = [{"role": "system", "content": self.SYSTEM_PROMPT}, {"role": "user", "content": user_msg}]
        return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    def _generate(self, prompt: str) -> Tuple[str, str, int]:
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        input_len = inputs.input_ids.shape[1]

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs, max_new_tokens=self.max_new_tokens, temperature=self.temperature,
                top_p=0.95, top_k=20, do_sample=True, pad_token_id=self.tokenizer.pad_token_id,
            )

        new_tokens = outputs[0][input_len:]
        num_tokens = len(new_tokens)

        try:
            think_end_idx = (new_tokens == self.THINK_END_TOKEN_ID).nonzero(as_tuple=True)[0][-1].item()
            thinking_tokens = new_tokens[:think_end_idx]
            response_tokens = new_tokens[think_end_idx + 1:]
        except:
            thinking_tokens = torch.tensor([], dtype=new_tokens.dtype)
            response_tokens = new_tokens

        thinking = self.tokenizer.decode(thinking_tokens, skip_special_tokens=True).strip()
        response = self.tokenizer.decode(response_tokens, skip_special_tokens=True).strip()
        return thinking, response, num_tokens

    def get_stats(self) -> dict:
        if not self.action_history:
            return {}
        total = len(self.action_history)
        preflop = [a for a in self.action_history if a.street == "preflop"]
        vpip = len([a for a in preflop if a.action.action_type in ("call", "raise", "all_in")]) / len(preflop) if preflop else 0
        pfr = len([a for a in preflop if a.action.action_type in ("raise", "all_in")]) / len(preflop) if preflop else 0
        bets_raises = sum(1 for a in self.action_history if a.action.action_type in ("raise", "all_in"))
        calls = sum(1 for a in self.action_history if a.action.action_type == "call")
        return {
            "total_actions": total, "vpip": vpip, "pfr": pfr,
            "aggression_factor": bets_raises / calls if calls > 0 else float('inf'),
            "avg_latency_ms": sum(a.latency_ms for a in self.action_history) / total,
            "fold_pct": sum(1 for a in self.action_history if a.action.action_type == "fold") / total,
        }

    def reset_history(self):
        self.action_history = []


# Hand result
@dataclass
class HandResult:
    hand_id: int
    player_names: List[str]
    starting_stacks: List[int]
    ending_stacks: List[int]
    chip_deltas: List[int]
    hole_cards: Dict[str, Tuple[str, str]]
    board: List[str]
    winner_names: List[str]
    pot_size: int
    timestamp: float = field(default_factory=time.time)


# Metrics collector
class MetricsCollector:
    def __init__(self, session_id: str = None):
        self.session_id = session_id or f"session_{int(time.time())}"
        self.hand_results: List[HandResult] = []
        self.session_start = time.time()
        self.player_summaries = {}

    def log_hand(self, result: HandResult):
        self.hand_results.append(result)

    def finalize_session(self, player_stats: Dict[str, dict]):
        duration = time.time() - self.session_start
        total_hands = len(self.hand_results)

        player_names = set()
        for hr in self.hand_results:
            player_names.update(hr.player_names)

        for name in player_names:
            hands_played = hands_won = total_chip_delta = 0
            for hr in self.hand_results:
                if name in hr.player_names:
                    idx = hr.player_names.index(name)
                    hands_played += 1
                    total_chip_delta += hr.chip_deltas[idx]
                    if name in hr.winner_names:
                        hands_won += 1

            self.player_summaries[name] = {
                "hands_played": hands_played, "hands_won": hands_won,
                "win_rate": hands_won / hands_played if hands_played > 0 else 0,
                "total_chip_delta": total_chip_delta,
                "bb_per_100": (total_chip_delta / hands_played * 100 / BIG_BLIND) if hands_played > 0 else 0,
                **player_stats.get(name, {}),
            }

        self.duration = duration
        self.total_hands = total_hands
        self.hands_per_hour = (total_hands / duration * 3600) if duration > 0 else 0


# Eval game with Hand Logging
class EvalPokerGame:
    def __init__(self, players: List[TransformersPlayer], starting_stack=10000, small_blind=50, big_blind=100,
                 metrics: MetricsCollector = None, verbose=False, progress_callback=None, log_sample_rate=100):
        self.players = players
        self.num_players = len(players)
        self.starting_stack = starting_stack
        self.small_blind = small_blind
        self.big_blind = big_blind
        self.stacks = [starting_stack] * self.num_players
        self.button = 0
        self.hand_num = 0
        self.metrics = metrics or MetricsCollector()
        self.verbose = verbose
        self.progress_callback = progress_callback
        
        # Hand logger (samples every Nth hand)
        self.logger = HandLogger(log_dir="/content/eval_results/logs", sample_rate=log_sample_rate)

    def play_session(self, num_hands: int) -> MetricsCollector:
        # Log session start
        self.logger.log_session_start(
            num_players=self.num_players,
            starting_stack=self.starting_stack,
            blinds=(self.small_blind, self.big_blind),
            num_hands=num_hands,
        )
        
        for hand_idx in range(num_hands):
            self._play_hand()
            if self.progress_callback:
                self.progress_callback(hand_idx + 1, num_hands)
            if sum(1 for s in self.stacks if s > 0) < 2:
                break
        
        self.metrics.finalize_session({p.name: p.get_stats() for p in self.players})
        
        # Log session end
        player_names = [p.name for p in self.players]
        self.logger.log_session_end(
            hands_played=self.hand_num,
            final_stacks=self.stacks,
            player_names=player_names,
            starting_stack=self.starting_stack,
        )
        
        return self.metrics

    def _play_hand(self):
        self.hand_num += 1
        self.button = (self.button + 1) % self.num_players
        for p in self.players:
            p.set_hand_context(self.hand_num, "preflop")

        sb_pos = (self.button + 1) % self.num_players
        bb_pos = (self.button + 2) % self.num_players
        if self.stacks[sb_pos] <= 0 or self.stacks[bb_pos] <= 0:
            return

        starting_stacks = self.stacks.copy()

        try:
            state = NoLimitTexasHoldem.create_state(
                automations=(Automation.ANTE_POSTING, Automation.BET_COLLECTION, Automation.BLIND_OR_STRADDLE_POSTING,
                             Automation.CARD_BURNING, Automation.HOLE_DEALING, Automation.HOLE_CARDS_SHOWING_OR_MUCKING,
                             Automation.HAND_KILLING, Automation.CHIPS_PUSHING, Automation.CHIPS_PULLING),
                ante_trimming_status=True, raw_antes={-1: 0},
                raw_blinds_or_straddles=(self.small_blind, self.big_blind),
                min_bet=self.big_blind, raw_starting_stacks=self.stacks.copy(), player_count=self.num_players,
            )
        except Exception as e:
            if self.verbose:
                print(f"Error: {e}")
            return

        hole_cards = [(str(state.hole_cards[i][0]), str(state.hole_cards[i][1])) if state.hole_cards[i] and len(state.hole_cards[i]) >= 2 else ("??", "??") for i in range(self.num_players)]
        dealable = list(state.get_dealable_cards())
        random.shuffle(dealable)
        deck = dealable
        board = []
        
        # Log hand start
        player_names = [p.name for p in self.players]
        self.logger.start_hand(
            hand_num=self.hand_num,
            player_names=player_names,
            stacks=self.stacks,
            hole_cards=hole_cards,
            button_pos=self.button,
            sb_pos=sb_pos,
            bb_pos=bb_pos,
            blinds=(self.small_blind, self.big_blind),
        )

        for street_idx, street in enumerate(["preflop", "flop", "turn", "river"]):
            if state.status is False:
                break
            for p in self.players:
                p.set_hand_context(self.hand_num, street)
            
            # End previous street in logger
            if street_idx > 0:
                self.logger.end_street()

            if street == "flop":
                board = [deck.pop(), deck.pop(), deck.pop()]
                for c in board:
                    try: state.deal_board(c)
                    except: pass
            elif street in ("turn", "river"):
                board.append(deck.pop())
                try: state.deal_board(board[-1])
                except: pass
            
            # Log street start
            self.logger.start_street(street, board)

            board_strs = [str(c) for c in board]
            while state.actor_index is not None:
                actor = state.actor_index
                player = self.players[actor]
                pot = state.total_pot_amount if hasattr(state, 'total_pot_amount') else 0
                current_bet = max(state.bets) if state.bets else 0
                player_bet = state.bets[actor] if state.bets else 0
                to_call = current_bet - player_bet
                stack = state.stacks[actor]
                position = self._get_position_name(actor)

                action = player.get_action(hole_cards[actor], board_strs, pot, to_call, stack, position, self.num_players)
                if self.verbose:
                    print(f"  H{self.hand_num} {street} {player.name}: {action}")
                
                # Log action
                self.logger.log_action(actor, player.name, str(action))
                
                self._execute_action(state, action)

        if hasattr(state, 'stacks'):
            for i in range(self.num_players):
                self.stacks[i] = state.stacks[i]

        chip_deltas = [self.stacks[i] - starting_stacks[i] for i in range(self.num_players)]
        winners = [i for i, d in enumerate(chip_deltas) if d > 0]
        max_gain = max(chip_deltas) if chip_deltas else 0
        winner_names = [self.players[i].name for i in winners]
        
        # Log hand end
        self.logger.end_street()
        self.logger.end_hand(final_stacks=self.stacks, winners=winners, chips_won=max_gain)

        self.metrics.log_hand(HandResult(
            hand_id=self.hand_num, player_names=[p.name for p in self.players],
            starting_stacks=starting_stacks, ending_stacks=self.stacks.copy(),
            chip_deltas=chip_deltas, hole_cards={p.name: hole_cards[i] for i, p in enumerate(self.players)},
            board=[str(c) for c in board], winner_names=winner_names, pot_size=sum(abs(d) for d in chip_deltas if d < 0),
        ))

    def _execute_action(self, state, action: ParsedAction):
        try:
            if action.action_type == "fold": state.fold()
            elif action.action_type in ("check", "call"): state.check_or_call()
            elif action.action_type in ("raise", "bet"): state.complete_bet_or_raise_to(action.amount)
            elif action.action_type == "all_in":
                actor = state.actor_index
                state.complete_bet_or_raise_to(state.stacks[actor] + state.bets[actor])
        except:
            try: state.check_or_call()
            except:
                try: state.fold()
                except: pass

    def _get_position_name(self, idx: int) -> str:
        pos_map = {2: ["SB", "BB"], 3: ["BTN", "SB", "BB"], 4: ["BTN", "CO", "SB", "BB"]}
        positions = pos_map.get(self.num_players, ["BTN", "CO", "HJ", "LJ", "SB", "BB"][:self.num_players])
        rel_pos = (idx - self.button) % self.num_players
        return positions[rel_pos] if rel_pos < len(positions) else f"P{idx}"


print("Evaluation classes loaded with hand logging!")

## 6. Run Evaluation

In [None]:
from tqdm.notebook import tqdm
import os

random.seed(SEED)

# Create output directory
os.makedirs("/content/eval_results", exist_ok=True)

# Aggregate results across sessions
all_results = []

for session_idx in range(NUM_SESSIONS):
    print(f"\n{'='*50}")
    print(f"Session {session_idx + 1}/{NUM_SESSIONS}")
    print(f"{'='*50}")

    # Create players
    players = [
        TransformersPlayer("SFT", loaded_models["SFT"], tokenizers["SFT"]),
        TransformersPlayer("Base", loaded_models["Base"], tokenizers["Base"]),
    ]

    # Create game
    metrics = MetricsCollector(f"session_{session_idx}")
    pbar = tqdm(total=NUM_HANDS, desc=f"Session {session_idx+1}")

    def update_progress(current, total):
        pbar.n = current
        pbar.refresh()

    game = EvalPokerGame(
        players=players,
        starting_stack=STARTING_STACK,
        small_blind=SMALL_BLIND,
        big_blind=BIG_BLIND,
        metrics=metrics,
        verbose=VERBOSE,
        progress_callback=update_progress,
    )

    # Run session
    result = game.play_session(NUM_HANDS)
    pbar.close()

    all_results.append(result)

    # Print session summary
    print(f"\nSession {session_idx+1} Results:")
    print(f"  Hands: {result.total_hands}")
    print(f"  Rate: {result.hands_per_hour:.0f} hands/hour")
    for name, stats in result.player_summaries.items():
        print(f"  {name}: {stats['hands_won']}/{stats['hands_played']} wins, BB/100: {stats['bb_per_100']:+.2f}")

print("\n" + "="*50)
print("ALL SESSIONS COMPLETE")
print("="*50)

## 7. Results Analysis

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Aggregate across sessions
aggregate = {}
for result in all_results:
    for name, stats in result.player_summaries.items():
        if name not in aggregate:
            aggregate[name] = {"hands_played": 0, "hands_won": 0, "total_chip_delta": 0, "vpip_sum": 0, "pfr_sum": 0, "sessions": 0}
        aggregate[name]["hands_played"] += stats["hands_played"]
        aggregate[name]["hands_won"] += stats["hands_won"]
        aggregate[name]["total_chip_delta"] += stats["total_chip_delta"]
        aggregate[name]["vpip_sum"] += stats.get("vpip", 0)
        aggregate[name]["pfr_sum"] += stats.get("pfr", 0)
        aggregate[name]["sessions"] += 1

# Build summary DataFrame
rows = []
for name, agg in aggregate.items():
    rows.append({
        "Model": name,
        "Hands": agg["hands_played"],
        "Wins": agg["hands_won"],
        "Win%": agg["hands_won"] / agg["hands_played"] * 100 if agg["hands_played"] > 0 else 0,
        "Chip Delta": agg["total_chip_delta"],
        "BB/100": agg["total_chip_delta"] / agg["hands_played"] * 100 / BIG_BLIND if agg["hands_played"] > 0 else 0,
        "VPIP%": agg["vpip_sum"] / agg["sessions"] * 100 if agg["sessions"] > 0 else 0,
        "PFR%": agg["pfr_sum"] / agg["sessions"] * 100 if agg["sessions"] > 0 else 0,
    })

df = pd.DataFrame(rows).sort_values("BB/100", ascending=False)

print("\n" + "="*70)
print("AGGREGATE RESULTS")
print("="*70)
print(f"Total hands: {sum(r['Hands'] for r in rows)}")
print(f"Sessions: {NUM_SESSIONS}")
print()
print(df.to_string(index=False, float_format="%.2f"))
print("="*70)

# SFT improvement
if "SFT" in aggregate and "Base" in aggregate:
    sft_bb = aggregate["SFT"]["total_chip_delta"] / aggregate["SFT"]["hands_played"] * 100 / BIG_BLIND
    base_bb = aggregate["Base"]["total_chip_delta"] / aggregate["Base"]["hands_played"] * 100 / BIG_BLIND
    print(f"\nSFT IMPROVEMENT:")
    print(f"  BB/100 Delta: {sft_bb - base_bb:+.2f}")
    print(f"  (SFT: {sft_bb:+.2f}, Base: {base_bb:+.2f})")

In [None]:
# Visualization
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# BB/100 comparison
ax = axes[0]
colors = ["green" if x >= 0 else "red" for x in df["BB/100"]]
ax.bar(df["Model"], df["BB/100"], color=colors)
ax.axhline(y=0, color="black", linestyle="-", linewidth=0.5)
ax.set_title("Profitability (BB/100)")
ax.set_ylabel("BB/100")

# Win rate
ax = axes[1]
ax.bar(df["Model"], df["Win%"], color="steelblue")
ax.axhline(y=50, color="black", linestyle="--", linewidth=0.5)
ax.set_title("Win Rate")
ax.set_ylabel("Win %")

# Playing style
ax = axes[2]
x = range(len(df))
width = 0.35
ax.bar([i - width/2 for i in x], df["VPIP%"], width, label="VPIP", color="orange")
ax.bar([i + width/2 for i in x], df["PFR%"], width, label="PFR", color="purple")
ax.set_xticks(x)
ax.set_xticklabels(df["Model"])
ax.set_title("Playing Style")
ax.set_ylabel("%")
ax.legend()

plt.tight_layout()
plt.savefig("/content/eval_results/comparison.png", dpi=150)
plt.show()

## 8. Export Results

In [None]:
# Save CSV
df.to_csv("/content/eval_results/summary.csv", index=False)

# Save detailed JSON
detailed_results = {
    "config": {
        "num_hands": NUM_HANDS,
        "num_sessions": NUM_SESSIONS,
        "starting_stack": STARTING_STACK,
        "blinds": f"{SMALL_BLIND}/{BIG_BLIND}",
        "gpu": hw.gpu_name,
        "quantization": hw.quantization.value,
    },
    "aggregate": df.to_dict(orient="records"),
    "sessions": [
        {
            "session_id": r.session_id,
            "hands": r.total_hands,
            "duration_s": r.duration,
            "player_summaries": r.player_summaries,
        }
        for r in all_results
    ],
}

with open("/content/eval_results/results.json", "w") as f:
    json.dump(detailed_results, f, indent=2)

print("Results saved to /content/eval_results/")
print("  - summary.csv")
print("  - results.json")
print("  - comparison.png")

## 9. (Optional) Export to Google Drive

In [None]:
# Uncomment to export to Drive
# from google.colab import drive
# import shutil
#
# drive.mount('/content/drive')
# shutil.copytree("/content/eval_results", "/content/drive/MyDrive/poker_eval_results", dirs_exist_ok=True)
# print("Exported to Google Drive: poker_eval_results/")