# AI Poker Simulation: GPT-4 vs Llama (Groq)

Run a poker simulation between OpenAI GPT-4 and Llama 3 (via Groq API) on Google Colab.

**Setup required:**
1. Add `OPENAI_API_KEY` to Colab Secrets
2. Add `GROQ_API_KEY` to Colab Secrets (free at https://console.groq.com)

In [None]:
# Mount Google Drive and install dependencies
from google.colab import drive
drive.mount('/content/drive')

!pip install -q requests pokerkit

In [None]:
# Get API keys from Colab Secrets
from google.colab import userdata
import os

os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')
os.environ['GROQ_API_KEY'] = userdata.get('GROQ_API_KEY')

print("API keys loaded!")

In [None]:
# Core imports and utilities
import json
import os
import random
import re
import requests
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import List, Tuple, Optional, Dict

from pokerkit import NoLimitTexasHoldem, Automation

# ANSI colors (disabled for notebook)
RESET = BOLD = RED = GREEN = YELLOW = BLUE = CYAN = ""

In [None]:
# Card utilities
SUIT_SYMBOLS = {'s': '\u2660', 'h': '\u2665', 'd': '\u2666', 'c': '\u2663'}
RANK_ORDER = '23456789TJQKA'

def pretty_card(card: str) -> str:
    """Format card for display."""
    if len(card) < 2:
        return card
    rank, suit = card[0].upper(), card[1].lower()
    return f"{rank}{SUIT_SYMBOLS.get(suit, suit)}"

def format_cards(cards) -> str:
    """Format list of cards."""
    return ' '.join(pretty_card(str(c)) for c in cards)

# Preflop hand rankings (simplified)
PREMIUM_HANDS = {'AA', 'KK', 'QQ', 'JJ', 'AKs', 'AKo'}
STRONG_HANDS = {'TT', '99', 'AQs', 'AQo', 'AJs', 'KQs'}

def score_hole_cards(c1: str, c2: str) -> int:
    """Score hole cards 1-169 (1=best)."""
    r1, r2 = c1[0].upper(), c2[0].upper()
    s1, s2 = c1[1].lower(), c2[1].lower()
    suited = s1 == s2
    
    i1, i2 = RANK_ORDER.index(r1), RANK_ORDER.index(r2)
    if i1 < i2:
        i1, i2 = i2, i1
        r1, r2 = r2, r1
    
    hand = f"{r1}{r2}{'s' if suited else 'o'}" if r1 != r2 else f"{r1}{r2}"
    
    if hand in PREMIUM_HANDS or hand.rstrip('so') in PREMIUM_HANDS:
        return random.randint(1, 10)
    elif hand in STRONG_HANDS or hand.rstrip('so') in STRONG_HANDS:
        return random.randint(11, 30)
    else:
        return random.randint(31, 169)

In [None]:
# Action parsing
@dataclass
class ParsedAction:
    action_type: str  # fold, check, call, raise, all_in, error
    amount: Optional[int] = None
    error_message: Optional[str] = None

    def __str__(self):
        if self.action_type == "raise" and self.amount:
            return f"raises to {self.amount}"
        return self.action_type


class ActionParser:
    """Parse LLM responses into poker actions."""
    
    ACTION_PATTERN = re.compile(r'<action>\s*([^<]+?)\s*</action>', re.IGNORECASE)
    
    def parse(self, response: str, can_check: bool, stack: int) -> ParsedAction:
        response_lower = response.lower()
        
        # Try XML tag first
        match = self.ACTION_PATTERN.search(response)
        if match:
            action_text = match.group(1).strip().lower()
            return self._parse_action_text(action_text, can_check, stack)
        
        # Fallback to keywords
        if 'fold' in response_lower:
            return ParsedAction('fold')
        if 'all-in' in response_lower or 'all in' in response_lower or 'shove' in response_lower:
            return ParsedAction('all_in', stack)
        if 'raise' in response_lower or 'bet' in response_lower:
            amounts = re.findall(r'\b(\d+)\b', response)
            if amounts:
                return ParsedAction('raise', int(amounts[-1]))
            return ParsedAction('raise', stack // 2)
        if 'call' in response_lower:
            return ParsedAction('call')
        if 'check' in response_lower:
            return ParsedAction('check')
        
        return ParsedAction('check' if can_check else 'call')
    
    def _parse_action_text(self, text: str, can_check: bool, stack: int) -> ParsedAction:
        if text.startswith('f'):
            return ParsedAction('fold')
        if text.startswith('cc'):
            return ParsedAction('check' if can_check else 'call')
        if text.startswith('cbr') or text.startswith('raise') or text.startswith('bet'):
            parts = text.split()
            if len(parts) >= 2:
                try:
                    return ParsedAction('raise', int(parts[1]))
                except ValueError:
                    pass
            return ParsedAction('raise', stack // 2)
        return ParsedAction('check' if can_check else 'call')

In [None]:
# System prompt for poker players
SYSTEM_PROMPT = """You are an expert poker player. Analyze and decide the optimal action.

Output format: <action>ACTION</action>
- <action>f</action> = fold
- <action>cc</action> = call/check
- <action>cbr AMOUNT</action> = bet/raise to AMOUNT

Think first, then output ONE action tag."""

In [None]:
# GPT-4 Player (OpenAI)
class GPT4Player:
    """OpenAI GPT-4 based poker player."""

    def __init__(
        self,
        name: str,
        model: str = "gpt-4o",
        api_key: Optional[str] = None,
        temperature: float = 0.6,
        max_tokens: int = 1024,
        trace_file: Optional[Path] = None,
    ):
        self.name = name
        self.model = model
        self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.parser = ActionParser()
        self.trace_file = trace_file

        if not self.api_key:
            raise ValueError("OpenAI API key required.")

    def check_connection(self) -> bool:
        try:
            r = requests.get(
                "https://api.openai.com/v1/models",
                headers={"Authorization": f"Bearer {self.api_key}"},
                timeout=10,
            )
            return r.status_code == 200
        except Exception:
            return False

    def shutdown(self) -> bool:
        return True

    def get_action(
        self,
        hole_cards: Tuple[str, str],
        board: List[str],
        pot: int,
        to_call: int,
        stack: int,
        position: str,
        num_players: int,
    ) -> ParsedAction:
        prompt = self._build_prompt(hole_cards, board, pot, to_call, stack, position, num_players)
        try:
            response = self._call_api(prompt)
            can_check = to_call == 0
            return self.parser.parse(response, can_check, stack)
        except Exception as e:
            return ParsedAction("error", error_message=str(e))

    def _build_prompt(self, hole_cards, board, pot, to_call, stack, position, num_players) -> str:
        c1, c2 = hole_cards
        lines = [
            f"Playing {num_players}-handed No-Limit Hold'em.",
            f"Position: {position}",
            f"Stack: {stack} chips",
            f"",
            f"Hole cards: {pretty_card(c1)} {pretty_card(c2)}",
        ]
        if not board:
            strength = score_hole_cards(c1, c2)
            lines.append(f"Preflop strength: {strength}/169")
        else:
            lines.append(f"Board: {' '.join(pretty_card(c) for c in board)}")
        lines.extend([f"", f"Pot: {pot} chips"])
        if to_call > 0:
            lines.append(f"To call: {to_call} chips")
            lines.append(f"Actions: Fold, Call {to_call}, Raise")
        else:
            lines.append(f"Actions: Check, Bet")
        return "\n".join(lines)

    def _call_api(self, prompt: str) -> str:
        payload = {
            "model": self.model,
            "messages": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": prompt},
            ],
            "temperature": self.temperature,
            "max_tokens": self.max_tokens,
        }
        r = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers={
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json",
            },
            json=payload,
            timeout=60,
        )
        r.raise_for_status()
        result = r.json()
        content = result["choices"][0]["message"]["content"]
        
        if self.trace_file:
            self._log_trace(prompt, content)
        return content

    def _log_trace(self, prompt: str, content: str):
        trace = {
            "timestamp": datetime.now().isoformat(),
            "player": self.name,
            "model": self.model,
            "prompt": prompt,
            "response": content,
        }
        with open(self.trace_file, "a") as f:
            f.write(json.dumps(trace) + "\n")

In [None]:
# Groq Player (Llama 3)
class GroqPlayer:
    """Groq API based poker player (Llama 3)."""

    def __init__(
        self,
        name: str,
        model: str = "llama-3.3-70b-versatile",
        api_key: Optional[str] = None,
        temperature: float = 0.6,
        max_tokens: int = 1024,
        trace_file: Optional[Path] = None,
    ):
        self.name = name
        self.model = model
        self.api_key = api_key or os.environ.get("GROQ_API_KEY")
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.parser = ActionParser()
        self.trace_file = trace_file

        if not self.api_key:
            raise ValueError("Groq API key required.")

    def check_connection(self) -> bool:
        try:
            r = requests.get(
                "https://api.groq.com/openai/v1/models",
                headers={"Authorization": f"Bearer {self.api_key}"},
                timeout=10,
            )
            return r.status_code == 200
        except Exception:
            return False

    def shutdown(self) -> bool:
        return True

    def get_action(
        self,
        hole_cards: Tuple[str, str],
        board: List[str],
        pot: int,
        to_call: int,
        stack: int,
        position: str,
        num_players: int,
    ) -> ParsedAction:
        prompt = self._build_prompt(hole_cards, board, pot, to_call, stack, position, num_players)
        try:
            response = self._call_api(prompt)
            can_check = to_call == 0
            return self.parser.parse(response, can_check, stack)
        except Exception as e:
            return ParsedAction("error", error_message=str(e))

    def _build_prompt(self, hole_cards, board, pot, to_call, stack, position, num_players) -> str:
        c1, c2 = hole_cards
        lines = [
            f"Playing {num_players}-handed No-Limit Hold'em.",
            f"Position: {position}",
            f"Stack: {stack} chips",
            f"",
            f"Hole cards: {pretty_card(c1)} {pretty_card(c2)}",
        ]
        if not board:
            strength = score_hole_cards(c1, c2)
            lines.append(f"Preflop strength: {strength}/169")
        else:
            lines.append(f"Board: {' '.join(pretty_card(c) for c in board)}")
        lines.extend([f"", f"Pot: {pot} chips"])
        if to_call > 0:
            lines.append(f"To call: {to_call} chips")
            lines.append(f"Actions: Fold, Call {to_call}, Raise")
        else:
            lines.append(f"Actions: Check, Bet")
        return "\n".join(lines)

    def _call_api(self, prompt: str) -> str:
        payload = {
            "model": self.model,
            "messages": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": prompt},
            ],
            "temperature": self.temperature,
            "max_tokens": self.max_tokens,
        }
        r = requests.post(
            "https://api.groq.com/openai/v1/chat/completions",
            headers={
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json",
            },
            json=payload,
            timeout=60,
        )
        r.raise_for_status()
        result = r.json()
        content = result["choices"][0]["message"]["content"]
        
        if self.trace_file:
            self._log_trace(prompt, content)
        return content

    def _log_trace(self, prompt: str, content: str):
        trace = {
            "timestamp": datetime.now().isoformat(),
            "player": self.name,
            "model": self.model,
            "prompt": prompt,
            "response": content,
        }
        with open(self.trace_file, "a") as f:
            f.write(json.dumps(trace) + "\n")

In [None]:
# AI vs AI Poker Game Engine
class AIPokerGame:
    """AI vs AI poker game with reporting."""

    def __init__(
        self,
        players: List,
        starting_stack: int = 10000,
        small_blind: int = 50,
        big_blind: int = 100,
        report_file: Optional[Path] = None,
    ):
        self.players = players
        self.num_players = len(players)
        self.starting_stack = starting_stack
        self.small_blind = small_blind
        self.big_blind = big_blind
        self.report_file = report_file

        self.stacks = [starting_stack] * self.num_players
        self.button = 0
        self.hand_num = 0

        self.stats = {
            "hands_played": 0,
            "player_stats": {
                p.name: {
                    "hands_won": 0,
                    "chips_won": 0,
                    "chips_lost": 0,
                    "folds": 0,
                    "calls": 0,
                    "raises": 0,
                    "all_ins": 0,
                }
                for p in players
            },
            "hand_history": [],
        }

    def play_session(self, num_hands: int = 10):
        print("=" * 60)
        print("  AI vs AI POKER SIMULATION")
        print("=" * 60)
        print(f"  Players: {', '.join(p.name for p in self.players)}")
        print(f"  Stack: {self.starting_stack}")
        print(f"  Blinds: {self.small_blind}/{self.big_blind}")
        print(f"  Hands: {num_hands}")
        print("=" * 60)

        for hand_idx in range(num_hands):
            print(f"\n  Playing hand {hand_idx + 1}/{num_hands}...", end="", flush=True)
            self._play_hand()
            print(" done")

            if any(s <= 0 for s in self.stacks):
                broke = [self.players[i].name for i, s in enumerate(self.stacks) if s <= 0]
                print(f"\n  Game over: {', '.join(broke)} eliminated")
                break

        self._generate_report()

    def _play_hand(self):
        self.hand_num += 1
        self.stats["hands_played"] = self.hand_num
        self.button = (self.button + 1) % self.num_players

        hand_record = {
            "hand_num": self.hand_num,
            "button": self.players[self.button].name,
            "stacks_before": {self.players[i].name: self.stacks[i] for i in range(self.num_players)},
            "actions": [],
            "winner": None,
            "pot": 0,
        }

        try:
            state = NoLimitTexasHoldem.create_state(
                automations=(
                    Automation.ANTE_POSTING,
                    Automation.BET_COLLECTION,
                    Automation.BLIND_OR_STRADDLE_POSTING,
                    Automation.CARD_BURNING,
                    Automation.HOLE_DEALING,
                    Automation.HOLE_CARDS_SHOWING_OR_MUCKING,
                    Automation.HAND_KILLING,
                    Automation.CHIPS_PUSHING,
                    Automation.CHIPS_PULLING,
                ),
                ante_trimming_status=True,
                raw_antes={-1: 0},
                raw_blinds_or_straddles=(self.small_blind, self.big_blind),
                min_bet=self.big_blind,
                raw_starting_stacks=self.stacks.copy(),
                player_count=self.num_players,
            )
        except Exception as e:
            print(f" Error: {e}")
            return

        hole_cards = []
        for i in range(self.num_players):
            cards = state.hole_cards[i]
            if cards and len(cards) >= 2:
                hole_cards.append((str(cards[0]), str(cards[1])))
            else:
                hole_cards.append(("??", "??"))

        dealable = list(state.get_dealable_cards())
        random.shuffle(dealable)
        deck = dealable

        board = []
        stacks_before = self.stacks.copy()

        streets = ["Preflop", "Flop", "Turn", "River"]
        for street in streets:
            if state.status is False:
                break

            if street == "Flop":
                board = [deck.pop(), deck.pop(), deck.pop()]
                for card in board:
                    state.deal_board(card)
            elif street == "Turn":
                board.append(deck.pop())
                state.deal_board(board[-1])
            elif street == "River":
                board.append(deck.pop())
                state.deal_board(board[-1])

            board_strs = [str(c) for c in board]
            while state.actor_index is not None:
                actor = state.actor_index
                player = self.players[actor]

                action = self._get_ai_action(player, state, hole_cards[actor], board_strs)

                hand_record["actions"].append({
                    "street": street,
                    "player": player.name,
                    "action": action.action_type,
                    "amount": action.amount,
                })
                self._track_action(player.name, action)

                if action.action_type == "error":
                    break

                self._execute_action(state, action)

        if hasattr(state, 'stacks'):
            for i in range(self.num_players):
                self.stacks[i] = state.stacks[i]

        max_gain = 0
        winners = []
        for i in range(self.num_players):
            gain = self.stacks[i] - stacks_before[i]
            if gain > max_gain:
                max_gain = gain
                winners = [i]
            elif gain == max_gain and gain > 0:
                winners.append(i)

        for i in winners:
            self.stats["player_stats"][self.players[i].name]["hands_won"] += 1
            self.stats["player_stats"][self.players[i].name]["chips_won"] += max_gain

        for i in range(self.num_players):
            if i not in winners:
                loss = stacks_before[i] - self.stacks[i]
                if loss > 0:
                    self.stats["player_stats"][self.players[i].name]["chips_lost"] += loss

        hand_record["winner"] = [self.players[w].name for w in winners]
        hand_record["pot"] = max_gain * len(winners) if winners else 0
        hand_record["stacks_after"] = {self.players[i].name: self.stacks[i] for i in range(self.num_players)}
        self.stats["hand_history"].append(hand_record)

    def _get_ai_action(self, player, state, hole_cards, board) -> ParsedAction:
        pot = state.total_pot_amount if hasattr(state, 'total_pot_amount') else 0
        current_bet = max(state.bets) if state.bets else 0
        actor = state.actor_index
        player_bet = state.bets[actor] if state.bets else 0
        to_call = current_bet - player_bet
        stack = state.stacks[actor]
        position = self._get_position_name(actor)
        return player.get_action(hole_cards, board, pot, to_call, stack, position, self.num_players)

    def _track_action(self, player_name: str, action: ParsedAction):
        stats = self.stats["player_stats"][player_name]
        if action.action_type == "fold":
            stats["folds"] += 1
        elif action.action_type in ("check", "call"):
            stats["calls"] += 1
        elif action.action_type in ("raise", "bet"):
            stats["raises"] += 1
        elif action.action_type == "all_in":
            stats["all_ins"] += 1

    def _execute_action(self, state, action: ParsedAction):
        try:
            if action.action_type == "fold":
                state.fold()
            elif action.action_type in ("check", "call"):
                state.check_or_call()
            elif action.action_type in ("raise", "bet"):
                state.complete_bet_or_raise_to(action.amount)
            elif action.action_type == "all_in":
                actor = state.actor_index
                stack = state.stacks[actor] + state.bets[actor]
                state.complete_bet_or_raise_to(stack)
        except Exception:
            try:
                state.check_or_call()
            except Exception:
                try:
                    state.fold()
                except Exception:
                    pass

    def _get_position_name(self, idx: int) -> str:
        positions = ["SB", "BB"] if self.num_players == 2 else ["BTN", "SB", "BB"]
        rel_pos = (idx - self.button) % self.num_players
        return positions[rel_pos] if rel_pos < len(positions) else f"P{idx}"

    def _generate_report(self):
        print()
        print("=" * 60)
        print("  SIMULATION REPORT")
        print("=" * 60)
        print(f"  Hands played: {self.stats['hands_played']}")
        print()

        print("  PLAYER STATISTICS:")
        print("  " + "-" * 56)
        print(f"  {'Player':<20} {'Won':>6} {'W%':>6} {'P/L':>10} {'Final':>10}")
        print("  " + "-" * 56)

        for i, player in enumerate(self.players):
            name = player.name
            stats = self.stats["player_stats"][name]
            hands_won = stats["hands_won"]
            win_pct = (hands_won / self.stats["hands_played"] * 100) if self.stats["hands_played"] > 0 else 0
            net = self.stacks[i] - self.starting_stack
            net_str = f"+{net}" if net > 0 else str(net)
            print(f"  {name:<20} {hands_won:>6} {win_pct:>5.1f}% {net_str:>10} {self.stacks[i]:>10}")

        print("  " + "-" * 56)
        print()

        print("  ACTION BREAKDOWN:")
        print("  " + "-" * 56)
        print(f"  {'Player':<20} {'Folds':>8} {'Calls':>8} {'Raises':>8} {'All-ins':>8}")
        print("  " + "-" * 56)

        for player in self.players:
            name = player.name
            stats = self.stats["player_stats"][name]
            print(f"  {name:<20} {stats['folds']:>8} {stats['calls']:>8} {stats['raises']:>8} {stats['all_ins']:>8}")

        print("  " + "-" * 56)
        print()

        if self.report_file:
            report = {
                "timestamp": datetime.now().isoformat(),
                "config": {
                    "players": [p.name for p in self.players],
                    "starting_stack": self.starting_stack,
                    "blinds": f"{self.small_blind}/{self.big_blind}",
                    "hands": self.stats["hands_played"],
                },
                "final_stacks": {self.players[i].name: self.stacks[i] for i in range(self.num_players)},
                "player_stats": self.stats["player_stats"],
                "hand_history": self.stats["hand_history"],
            }
            with open(self.report_file, "w") as f:
                json.dump(report, f, indent=2)
            print(f"  Report saved to: {self.report_file}")

        print("=" * 60)

In [None]:
# Configuration
DRY_RUN = True  # Set to False for full simulation
DRY_RUN_HANDS = 10  # Hands for quick testing
FULL_RUN_HANDS = 100  # Hands for full simulation

NUM_HANDS = DRY_RUN_HANDS if DRY_RUN else FULL_RUN_HANDS

STARTING_STACK = 10000
SMALL_BLIND = 50
BIG_BLIND = 100

# Output files (saved to Google Drive)
REPORT_FILE = Path("/content/drive/MyDrive/poker_results/simulation_report.json")
TRACE_FILE = Path("/content/drive/MyDrive/poker_results/traces.jsonl")

# Create output directory
REPORT_FILE.parent.mkdir(parents=True, exist_ok=True)

In [None]:
# Initialize players
gpt4_player = GPT4Player(
    name="GPT-4o",
    model="gpt-4o",
    temperature=0.6,
    trace_file=TRACE_FILE,
)

llama_player = GroqPlayer(
    name="Llama-3.3-70B",
    model="llama-3.3-70b-versatile",
    temperature=0.6,
    trace_file=TRACE_FILE,
)

# Verify connections
print("Checking API connections...")
assert gpt4_player.check_connection(), "OpenAI API connection failed!"
print("  GPT-4o: OK")
assert llama_player.check_connection(), "Groq API connection failed!"
print("  Llama-3.3-70B: OK")

In [None]:
# Run simulation!
game = AIPokerGame(
    players=[gpt4_player, llama_player],
    starting_stack=STARTING_STACK,
    small_blind=SMALL_BLIND,
    big_blind=BIG_BLIND,
    report_file=REPORT_FILE,
)

game.play_session(NUM_HANDS)

In [None]:
# View the JSON report
with open(REPORT_FILE) as f:
    report = json.load(f)

print("Final Results:")
print(json.dumps(report["final_stacks"], indent=2))
print("\nPlayer Stats:")
print(json.dumps(report["player_stats"], indent=2))

In [None]:
# View sample hand history
print("Sample Hand History (last 3 hands):")
for hand in report["hand_history"][-3:]:
    print(f"\nHand #{hand['hand_num']}:")
    print(f"  Winner: {hand['winner']}")
    print(f"  Pot: {hand['pot']}")
    print(f"  Actions:")
    for action in hand['actions']:
        amt = f" {action['amount']}" if action['amount'] else ""
        print(f"    [{action['street']}] {action['player']}: {action['action']}{amt}")

In [None]:
# View LLM reasoning traces (last 5)
print("Last 5 LLM Reasoning Traces:")
print("=" * 60)

with open(TRACE_FILE) as f:
    traces = [json.loads(line) for line in f.readlines()]

for trace in traces[-5:]:
    print(f"\n[{trace['player']} - {trace['model']}]")
    print(f"Prompt:\n{trace['prompt']}")
    print(f"\nResponse:\n{trace['response'][:500]}..." if len(trace['response']) > 500 else f"\nResponse:\n{trace['response']}")
    print("-" * 40)