In [None]:
import random
from collections import defaultdict
from datetime import datetime
from typing import List, Dict, Optional, Tuple
import json
!pip install trl==0.11
import torch
from datasets import Dataset, load_from_disk, DatasetDict
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoTokenizer, PreTrainedTokenizerFast, Trainer, TrainingArguments
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, normalizers
from tokenizers.processors import TemplateProcessing
import os
torch.cuda.is_available()

In [None]:
class Order:
    def __init__(self, player_id: int, side: str, suit: str, price: int):
        """
        side: 'bid' or 'offer'
        suit: one of 'spades', 'clubs', 'hearts', 'diamonds'
        price: integer price
        """
        self.player_id = player_id
        self.side = side
        self.suit = suit
        self.price = price

    def to_dict(self) -> dict:
        return {"player": self.player_id, "side": self.side, "suit": self.suit, "price": self.price}

    def __repr__(self):
        return f"Order(p={self.player_id}, side={self.side}, suit={self.suit}, price={self.price})"

class OrderBook:
    def __init__(self, suit: str, side: str):
        self.suit = suit
        self.side = side  # 'bid' or 'offer'
        self.orders: Dict[int, Order] = {}  # price -> Order

    def get_best(self) -> Optional[Order]:
        if not self.orders:
            return None
        if self.side == 'bid':
            best_price = max(self.orders.keys())
        else:
            best_price = min(self.orders.keys())
        return self.orders[best_price]

    def add_order(self, order: Order) -> bool:
        if order.price in self.orders:
            return False
        self.orders[order.price] = order
        return True

    def cancel_order(self, player_id: int, price: int) -> bool:
        order = self.orders.get(price)
        if order and order.player_id == player_id:
            del self.orders[price]
            return True
        return False

    def clear(self) -> None:
        self.orders.clear()

    def to_dict(self) -> dict:
        return {price: order.player_id for price, order in self.orders.items()}

    def __repr__(self):
        return f"OrderBook(suit={self.suit}, side={self.side}): " + ", ".join(str(o) for o in self.orders.values())

In [None]:
class Player:
    def __init__(self, player_id: int, name: Optional[str] = None, player_type: str = "human"):
        self.player_id = player_id
        self.name = name if name is not None else f"Player_{player_id}"
        self.player_type = player_type  # "human" or "bot"
        self.initial_chips = 350
        self.chips = 350
        self.initial_hand: Dict[str, int] = {}
        self.hand: Dict[str, int] = {}
        self.active_orders: List[Order] = []
        self.pass_count = 0

    def decide_action(self, public_state: dict) -> Tuple[str, dict]:
        """
        Stub method. Override this method for human or RL agent play.
        Returns a tuple (action_type, parameters)
          For 'place': parameters = {'side': 'bid'/'offer', 'suit': <suit>, 'price': <int>}
          For 'cancel': parameters = {'side': 'bid'/'offer', 'suit': <suit>, 'price': <int>}
          For 'pass': parameters = {}
        """
        return ('pass', {})

    def __repr__(self):
        return f"{self.name}({self.player_type}, chips={self.chips})"

In [None]:
class Game:
    SUITS = ['spades', 'clubs', 'hearts', 'diamonds']
    COLORS = {'spades': 'black', 'clubs': 'black', 'hearts': 'red', 'diamonds': 'red'}

    # 12 predefined deck configurations.
    DECK_CONFIGS = [
        {'spades': 12, 'clubs': 10, 'hearts': 10, 'diamonds': 8},
        {'spades': 12, 'clubs': 10, 'hearts': 8, 'diamonds': 10},
        {'spades': 12, 'clubs': 8,  'hearts': 10, 'diamonds': 10},
        {'clubs': 12,  'spades': 10, 'hearts': 10, 'diamonds': 8},
        {'clubs': 12,  'spades': 10, 'hearts': 8,  'diamonds': 10},
        {'clubs': 12,  'spades': 8,  'hearts': 10, 'diamonds': 10},
        {'hearts': 12, 'diamonds': 10, 'spades': 10, 'clubs': 8},
        {'hearts': 12, 'diamonds': 10, 'spades': 8,  'clubs': 10},
        {'hearts': 12, 'diamonds': 8,  'spades': 10, 'clubs': 10},
        {'diamonds': 12, 'hearts': 10, 'spades': 10, 'clubs': 8},
        {'diamonds': 12, 'hearts': 10, 'spades': 8,  'clubs': 10},
        {'diamonds': 12, 'hearts': 8,  'spades': 10, 'clubs': 10},
    ]

    def __init__(self, num_players: int = 4, seed: Optional[int] = None, round_number: int = 1, players: Optional[List[Player]] = None):
        if num_players not in (4, 5):
            raise ValueError("Only 4 or 5 players are allowed.")
        self.num_players = num_players
        if players is not None:
            if len(players) != num_players:
                raise ValueError("Length of provided players list must match num_players.")
            self.players = players
        else:
            self.players: List[Player] = [Player(i, player_type="human") for i in range(num_players)]
        self.ante = 50 if num_players == 4 else 40
        for p in self.players:
            p.chips -= self.ante
        self.pot = self.ante * num_players  # e.g., 200 for 4 players.

        # Create order books.
        self.order_books: Dict[Tuple[str, str], OrderBook] = {}
        for suit in Game.SUITS:
            for side in ['bid', 'offer']:
                self.order_books[(suit, side)] = OrderBook(suit, side)

        self.trade_history: List[dict] = []
        self.turn_history: List[dict] = []
        self.turn_number = 0
        self.last_player_id: Optional[int] = None
        self.game_over = False

        # Game metadata.
        self.game_id = random.randint(1, 1000000)
        self.round_number = round_number
        self.host = self.players[0].name
        self.settings = {
            "hand_view": "Default",
            "round_duration": "turn-based",
            "minimum_number_of_players": num_players
        }
        self.human_players = [p.name for p in self.players if p.player_type == "human"]
        self.bot_players = [p.name for p in self.players if p.player_type == "bot"]

        # Choose a deck configuration and deal.
        self.deck_config = random.choice(Game.DECK_CONFIGS)
        self.common_suit = [s for s, count in self.deck_config.items() if count == 12][0]
        same_color = [s for s in Game.SUITS if Game.COLORS[s] == Game.COLORS[self.common_suit] and s != self.common_suit]
        self.goal_suit = random.choice(same_color)
        deck = []
        for suit, count in self.deck_config.items():
            deck.extend([suit] * count)
        random.shuffle(deck)
        cards_per_player = 10 if num_players == 4 else 8
        self.starting_chips_and_hands = {}
        for p in self.players:
            p_cards = deck[:cards_per_player]
            deck = deck[cards_per_player:]
            hand = defaultdict(int)
            for c in p_cards:
                hand[c] += 1
            p.hand = dict(hand)
            p.initial_hand = dict(hand)
            for suit in Game.SUITS:
                if suit not in p.initial_hand:
                    p.initial_hand[suit] = 0
            self.starting_chips_and_hands[p.name] = {"chips": p.chips, "hand": dict(p.initial_hand)}

        if seed is not None:
            random.seed(seed)

        # NEW: Initialize pass-round tracking.
        self.pass_round = 1  # Which round of passes we're in (1, 2, or 3)
        self.passed_in_round = {p.player_id: False for p in self.players}

    def clear_all_order_books(self) -> None:
        for book in self.order_books.values():
            book.clear()

    def is_valid_place_order(self, player: Player, side: str, suit: str, price: int) -> Tuple[bool, str]:
        if side not in ['bid', 'offer'] or suit not in Game.SUITS or not isinstance(price, int):
            return False, "format error"
        if side == 'offer' and player.hand.get(suit, 0) <= 0:
            return False, "no card to offer"
        if side == 'bid' and player.chips < price:
          return False, "insufficient chips"
        book = self.order_books[(suit, side)]
        if price in book.orders:
            return False, "order already exists"
        if side == 'bid':
            opposing = self.order_books[(suit, 'offer')]
            best_offer = opposing.get_best()
            if best_offer is None:
                best_bid = self.order_books[(suit, 'bid')].get_best()
                if best_bid and price <= best_bid.price:
                    return False, "does not improve bid/ask"
        elif side == 'offer':
            opposing = self.order_books[(suit, 'bid')]
            best_bid = opposing.get_best()
            if best_bid is None:
                best_offer = self.order_books[(suit, 'offer')].get_best()
                if best_offer and price >= best_offer.price:
                    return False, "does not improve bid/ask"
        return True, ""

    def get_action(self, player: Player) -> Tuple[str, dict]:
        action = player.decide_action(self.get_public_state())
        if action[0] == 'place':
            side, suit, price = action[1].get('side'), action[1].get('suit'), action[1].get('price')
            valid, reason = self.is_valid_place_order(player, side, suit, price)
            if not valid and reason != "does not improve bid/ask":
                print(f"{player.name}: Order {action[1]} does not improve the bid/ask or is not placeable; treated as pass.")
                return ('pass', {})
        if action[0] == 'cancel':
            side, suit, price = action[1].get('side'), action[1].get('suit'), action[1].get('price')
            book = self.order_books.get((suit, side))
            if not (book and price in book.orders and book.orders[price].player_id == player.player_id):
                # print(f"{player.name}: Invalid cancel order {action[1]}. Turn skipped.")
                return ('invalid', {})
        return action

    def process_action(self, player: Player, action: Tuple[str, dict]) -> None:
        self.turn_number += 1
        event_time = datetime.now().isoformat()
        action_type, params = action
        turn_record = {"turn": self.turn_number,
                       "time": event_time,
                       "player": player.name,
                       "action": action_type,
                       "params": params}

        if hasattr(player, "last_prompt"):
            turn_record["prompt"] = player.last_prompt
        if hasattr(player, "last_generated_action"):
            turn_record["last_generated_action"] = player.last_generated_action

        if action_type == 'invalid':
            turn_record["result"] = "invalid action - turn skipped"
            self.turn_history.append(turn_record)
            return
        if action_type == 'place':
            side, suit, price = params.get('side'), params.get('suit'), params.get('price')
            book = self.order_books[(suit, side)]
            new_order = Order(player.player_id, side, suit, price)
            if side == 'bid':
                opposing_book = self.order_books[(suit, 'offer')]
                best_offer = opposing_book.get_best()
                if best_offer and best_offer.price <= price and best_offer.player_id != player.player_id:
                    self.execute_trade(buyer=player,
                                       seller=self.players[best_offer.player_id],
                                       suit=suit,
                                       price=best_offer.price)
                    turn_record["result"] = "trade executed on bid"
                    self.turn_history.append(turn_record)
                    return
            elif side == 'offer':
                opposing_book = self.order_books[(suit, 'bid')]
                best_bid = opposing_book.get_best()
                if best_bid and best_bid.price >= price and best_bid.player_id != player.player_id:
                    self.execute_trade(buyer=self.players[best_bid.player_id],
                                       seller=player,
                                       suit=suit,
                                       price=best_bid.price)
                    turn_record["result"] = "trade executed on offer"
                    self.turn_history.append(turn_record)
                    return
            if book.add_order(new_order):
                player.active_orders.append(new_order)
                turn_record["result"] = "order placed"
                self.pass_round = 1
                for p in self.players:
                    self.passed_in_round[p.player_id] = False
            else:
                turn_record["result"] = "order rejected unexpectedly"
        elif action_type == 'cancel':
            side, suit, price = params.get('side'), params.get('suit'), params.get('price')
            book = self.order_books[(suit, side)]
            if book.cancel_order(player.player_id, price):
                player.active_orders = [o for o in player.active_orders if not (o.side == side and o.suit == suit and o.price == price)]
                turn_record["result"] = "order cancelled"
            else:
                turn_record["result"] = "cancel failed unexpectedly"
        elif action_type == 'pass':
            # Revised passing logic:
            # Each round, record whether a player has passed.
            if self.passed_in_round.get(player.player_id, False):
                turn_record["result"] = "pass ineffective: already passed in current round"
            else:
                self.passed_in_round[player.player_id] = True
                player.pass_count += 1  # You can still increment for logging if desired.
                turn_record["result"] = f"passed in round {self.pass_round}"
                # Check if all players have passed in this round.
                if all(self.passed_in_round[p.player_id] for p in self.players):
                    turn_record["result"] += " (complete round)"
                    self.pass_round += 1
                    # Reset for next round.
                    for p in self.players:
                        self.passed_in_round[p.player_id] = False
        else:
            turn_record["result"] = "unknown action"
        self.turn_history.append(turn_record)

    def execute_trade(self, buyer: Player, seller: Player, suit: str, price: int) -> None:
        buyer.chips -= price
        seller.chips += price
        if seller.hand.get(suit, 0) > 0:
            seller.hand[suit] -= 1
            buyer.hand[suit] = buyer.hand.get(suit, 0) + 1
        # else:
            # print(f"Error: {seller.name} does not have a {suit} card to trade.")
        event_time = datetime.now().isoformat()
        trade_event = {"turn": self.turn_number, "time": event_time, "buyer": buyer.name, "seller": seller.name, "suit": suit, "price": price}
        self.trade_history.append(trade_event)
        self.clear_all_order_books()
        for p in self.players:
            p.active_orders.clear()

        self.pass_round = 1
        for p in self.players:
            self.passed_in_round[p.player_id] = False
        print(f"Trade executed on turn {self.turn_number}: Buyer {buyer.name} bought {suit} from Seller {seller.name} for {price}")

    def get_public_state(self) -> dict:
        state = {
            "players": {},
            "order_books": {},
            "trade_history": self.trade_history,
            "turn_history": self.turn_history,
            "game_over": self.game_over,
        }
        for p in self.players:
            net_delta = {s: p.hand.get(s, 0) - p.initial_hand.get(s, 0) for s in Game.SUITS}
            state["players"][p.name] = {"chips": p.chips, "net_delta": net_delta, "pass_count": p.pass_count, "type": p.player_type}
        for (suit, side), book in self.order_books.items():
            state["order_books"][(suit, side)] = book.to_dict()
        return state

    def get_private_state(self) -> dict:
        state = self.get_public_state()
        state["full_hands"] = {p.name: p.hand for p in self.players}
        state["deck_config"] = self.deck_config
        state["common_suit"] = self.common_suit
        state["goal_suit"] = self.goal_suit if self.game_over else None
        return state

    def select_next_player(self) -> Player:
        valid_players = [p for p in self.players if p.player_id != self.last_player_id]
        chosen = random.choice(valid_players)
        self.last_player_id = chosen.player_id
        return chosen

    def run_turn(self):
        if self.game_over:
            return
        player = self.select_next_player()
        action = self.get_action(player)
        self.process_action(player, action)
        # Terminate game when two complete rounds have been done.
        if self.pass_round > 2:
            self.game_over = True

    def run_game(self):
        while not self.game_over:
            self.run_turn()
        # print("Game Over!")
        self.final_scoring()

    def final_scoring(self) -> None:
        # print(f"Goal suit revealed: {self.goal_suit}")
        bonus_total = 0
        for p in self.players:
            bonus = 10 * p.hand.get(self.goal_suit, 0)
            bonus_total += bonus
            p.chips += bonus
        remaining_pot = self.pot - bonus_total
        if remaining_pot < 0:
            remaining_pot = 0
        goal_counts = {p.name: p.hand.get(self.goal_suit, 0) for p in self.players}
        max_count = max(goal_counts.values())
        winners = [name for name, count in goal_counts.items() if count == max_count]
        if winners:
            share = remaining_pot // len(winners)
            for p in self.players:
                if p.name in winners:
                    p.chips += share
        # print("Final chip counts:")
        # for p in self.players:
        #     print(f"{p.name}: {p.chips}")

    def export_game_log(self) -> dict:
        log = {
            "game_id": self.game_id,
            "round_number": self.round_number,
            "end_time": datetime.now().isoformat(),
            "host": self.host,
            "settings": self.settings,
            "human_players": self.human_players,
            "bot_players": self.bot_players,
            "goal_suit": self.goal_suit if self.game_over else None,
            "starting_chips_and_hands": self.starting_chips_and_hands,
            "game_updates": self.turn_history,
            "final_chips_and_hands": {p.name: {"chips": p.chips + self.ante, "hand": p.hand} for p in self.players}
        }
        return log

In [None]:
import re
import glob
import os

def tokenize(s):
    tokens = []
    current = []
    for char in s:
        if char in ('(', ')'):
            if current:
                tokens.append(''.join(current))
                current = []
            tokens.append(char)
        elif char.isspace():
            if current:
                tokens.append(''.join(current))
                current = []
        else:
            current.append(char)
    if current:
        tokens.append(''.join(current))
    return tokens

def parse(tokens):
    if not tokens:
        return None
    token = tokens.pop(0)
    if token == '(':
        lst = []
        while tokens and tokens[0] != ')':
            lst.append(parse(tokens))
        if tokens:
            tokens.pop(0)  # Remove the ')'
        return lst
    elif token == ')':
        raise ValueError("Unexpected ')'")
    else:
        return token

def parse_log(log_str):
    tokens = tokenize(log_str)
    parsed = parse(tokens)
    data = {}
    if isinstance(parsed, list):
        for item in parsed:
            if isinstance(item, list) and len(item) >= 2:
                key = item[0]
                value = item[1] if len(item) > 1 else None
                data[key] = value
    return data

def parse_game_file(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        content = f.read()
    data = parse_log(content)

    token_lines = []
    has_bogus = False

    # 1. Extract starting hands
    starting_hands = data.get('starting_chips_and_hands', [])
    for player_entry in starting_hands:
        if isinstance(player_entry, list) and len(player_entry) >= 2:
            player_name = player_entry[0]
            player_data = player_entry[1]
            hand_entries = []
            chips = "NA"
            for entry in player_data:
                if isinstance(entry, list) and len(entry) >= 2:
                    if entry[0] == 'chips':
                        chips = entry[1]
                    elif entry[0] == 'hand':
                        # entry[1] should be list of suit entries
                        suits = []
                        for suit_entry in entry[1]:
                            if isinstance(suit_entry, list) and len(suit_entry) >= 2:
                                suit = suit_entry[0]
                                count = suit_entry[1]
                                suits.append(f"{suit} {count}")
                        hand_entries = suits
            # Format: "playerName Chips <chips> <suit1> <count1> , <suit2> <count2> , ..."
            token_lines.append(f"{player_name} Chips {chips} " + " , ".join(hand_entries) + " <EOS>")
    
    # 2. Process game_updates
    game_updates = data.get('game_updates', [])
    payouts = {}
    for event in game_updates:
        if isinstance(event, list) and len(event) >= 1:
            event_type = event[0]
            if event_type == 'Order':
                order_info = {}
                for item in event[1]:
                    if isinstance(item, list) and len(item) >= 2:
                        key = item[0]
                        value = item[1]
                        order_info[key] = value
                # metadata is nested
                meta = order_info.get('metadata', [])
                meta_dict = {}
                for m in meta:
                    if isinstance(m, list) and len(m) >= 2:
                        meta_dict[m[0]] = m[1]
                user = meta_dict.get('user', '')
                price = meta_dict.get('price', '')
                if int(price) >= 60:
                    has_bogus = True
                suit = order_info.get('suit', '')
                direction = order_info.get('direction', '')
                token_lines.append(f"Order {user} {price} {direction} {suit} <EOS>")
            elif event_type == 'Trade':
                trade_info = {}
                for item in event[1]:
                    if isinstance(item, list) and len(item) >= 2:
                        trade_info[item[0]] = item[1]
                buyer = trade_info.get('buyer', '')
                seller = trade_info.get('seller', '')
                suit = trade_info.get('suit', '')
                direction = trade_info.get('direction', '')
                price = trade_info.get('price', '')
                if direction == "Buy":
                    token_lines.append(f"Order {buyer} {price} {direction} {suit} <EOS>")
                else:
                    token_lines.append(f"Order {seller} {price} {direction} {suit} <EOS>")
                    
            elif event_type == 'Payout':
                payout_list = event[1] if len(event) >= 2 else []
                for payout_entry in payout_list:
                    if isinstance(payout_entry, list) and len(payout_entry) >= 2:
                        player = payout_entry[0]
                        amount = payout_entry[1]
                        payouts[player] = amount
            elif event_type == 'Goal':
                # Goal event: e.g., (Goal Hearts)
                goal = event[1] if len(event) > 1 else ""
                token_lines.append(f"Goal {goal} <EOS>")
    
    # 3. Output payouts (if any)
    if payouts:
        for player, amount in payouts.items():
            token_lines.append(f"Payout {player} {amount} <EOS>")
    if has_bogus:
        token_lines = None
    return token_lines

def create_token_dataset(folder_path: str) -> list:
    global x
    file_paths = glob.glob(os.path.join(folder_path, "*"))
    dataset = []
    for fp in file_paths:
        token_seq = parse_game_file(fp)
        if token_seq is not None:
            y = 0
            for tok in token_seq:
                if "Chips" in tok:
                    y += 1
            if token_seq and y == 4:
                # Join each event on a newline; each game becomes one block of text.
                dataset.append("\n".join(token_seq) + '\n------------------------')
    return dataset


folder = "data"  # your data folder
dataset = create_token_dataset(folder)
print(f"Processed {len(dataset)} games into token sequences.\n")

# Optionally, save to a file:
with open("figgie_full_game_token_dataset.txt", "w", encoding="utf-8") as fout:
    for seq in dataset:
        fout.write(seq + "\n\n")

In [None]:
def compute_order_book(events):
    """
    Given a list of history events (as strings), compute a simple order book.
    For each suit, track the best bid (highest price from Buy orders) and best offer (lowest price from Sell orders)
    along with the respective player names.
    Now, instead of clearing on a Trade event, we reset the order book for a suit
    whenever the market gets "crossed" (i.e. best bid >= best offer).
    """
    suits = ["Hearts", "Diamonds", "Clubs", "Spades"]
    # For each suit, store bid as (price, player) and offer as (price, player)
    order_book = {suit: {"bid": (None, None), "offer": (None, None)} for suit in suits}
    
    for event in events:
        tokens = event.split()
        # Process only Order events.
        if tokens[0] == "Order" and len(tokens) >= 5:
            user = tokens[1]
            try:
                price = int(tokens[2])
            except ValueError:
                continue
            direction = tokens[3]
            suit = tokens[4]
            if suit not in suits:
                continue
            
            if direction == "Buy":
                current_bid_price, _ = order_book[suit]["bid"]
                # Update bid if no current bid or new price is higher.
                if current_bid_price is None or price > current_bid_price:
                    order_book[suit]["bid"] = (price, user)
            elif direction == "Sell":
                current_offer_price, _ = order_book[suit]["offer"]
                # Update offer if no current offer or new price is lower.
                if current_offer_price is None or price < current_offer_price:
                    order_book[suit]["offer"] = (price, user)
            
            # After processing this order, check if the market is crossed for this suit.
            bid_price, bid_user = order_book[suit]["bid"]
            offer_price, offer_user = order_book[suit]["offer"]
            if bid_price is not None and offer_price is not None and bid_price >= offer_price:
                # Market is crossed: reset the order book for all suits.
                for suit in suits:
                    order_book[suit] = {"bid": (None, None), "offer": (None, None)}
                # Optionally, record that a trade occurred if needed.
    
    # Build a string representation for the order book.
    ob_parts = []
    for suit in suits:
        bid_price, bid_user = order_book[suit]["bid"]
        offer_price, offer_user = order_book[suit]["offer"]
        bid_str = f"{bid_user} {bid_price}" if bid_price is not None else ""
        offer_str = f"{offer_price} {offer_user}" if offer_price is not None else ""
        ob_parts.append(f"{suit} {bid_str} @ {offer_str}")
    
    order_book_str = " ".join(ob_parts)
    return order_book_str

In [None]:
# Your custom preprocessing function
def preprocess_game(file, include_order_book=False):
    games = open(file).read().split('------------------------')
    dataset = []

    for game in games:
        lines = game.strip().split('<EOS>')
        player_hands = {}
        history = []

        # Extract initial hands
        for line in lines:
            line = line.strip()
            if not line: continue
            if line.startswith('apple') or line.startswith('banana') or line.startswith('cantaloupe') or line.startswith('durian'):
                parts = line.split('Chips')
                player, rest = parts[0].strip(), parts[1].strip()
                idx = line.find("Clubs")
                new_rest = line[idx:]  # Get the substring starting with "Clubs"
                player_hands[player] = new_rest.replace(', ', '')  # store player's initial hand line
            elif line.startswith('Order') or line.startswith('Trade'):
                history.append(line)

        # Build sequences for each action prediction
        for i, action in enumerate(history[:-1]):
            next_action = history[i+1]
            if next_action.startswith("Order"):
                acting_player = next_action.split()[1]
            elif next_action.startswith("Trade"):
                parts = next_action.split()
                player1, player2, action_type = parts[1], parts[2], parts[4]
                acting_player = player1 if action_type == "Buy" else player2
            else:
                continue 

            if acting_player not in player_hands:
                continue  # safety check
        
            if not include_order_book:
                sequence = f"<player> {acting_player} " \
                           f"<HAND> {player_hands[acting_player]} <HISTORY> " \
                           + ' '.join(history[:i+1]) + " <ACTION> " + next_action + " <|endoftext|>"
            else:
                current_order_book = compute_order_book(history[:i + 1])
                sequence = f"<player> {acting_player} " \
                           f"<HAND> {player_hands[acting_player]} <HISTORY> " \
                           + ' '.join(history[:i+1]) + f" <ORDERBOOK> {current_order_book}" + " <ACTION> " + next_action + " [EOS]"

            dataset.append(sequence)

    if not include_order_book:
        with open('figgie_train.json', 'w') as f:
            json.dump(dataset, f)
    else:
        with open('figgie_train_with_order_book.json', 'w') as f:
            json.dump(dataset, f)

In [None]:
preprocess_game("figgie_full_game_token_dataset.txt", True)

In [None]:
with open('figgie_train_with_order_book.json', 'r') as file:
        data = json.load(file)
print(data[:100])

In [None]:
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, normalizers
from tokenizers.processors import TemplateProcessing
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizerFast


def create_dataset(include_order_book = False):
    # Your allowed vocabulary (add all variations)
    allowed_vocab = [
        "<player>", "<HAND>", "<HISTORY>", "<ACTION>", "<ORDERBOOK>",
        "Spades", "Hearts", "Diamonds", "Clubs",
        "Order", "Buy", "Sell",
        "apple", "banana", "cantaloupe", "durian",
        "Chips", "@",
        "[UNK]", "[PAD]", "[EOS]"  # Special tokens
    ] + [str(i) for i in range(0, 60)]
    
    # Build a new vocabulary dictionary with contiguous IDs (0 to n-1)
    new_vocab = {token: i for i, token in enumerate(allowed_vocab)}
    
    # Initialize a WordLevel model with the new vocabulary and an explicit unk_token.
    tokenizer_model = models.WordLevel(vocab=new_vocab, unk_token="[UNK]")
    tokenizer = Tokenizer(tokenizer_model)
    
    # Set the pre-tokenizer to use whitespace splitting.
    tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()
    
    # Enable padding with the [PAD] token.
    tokenizer.enable_padding(pad_token="[PAD]")
    
    # Convert to a PreTrainedTokenizerFast for compatibility with Transformers.
    tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
    tokenizer.eos_token = "[EOS]"
    tokenizer.pad_token = "[PAD]"
    tokenizer.unk_token = "[UNK]"
    print(tokenizer.vocab)
    # Define dataset paths
    if not include_order_book:
        raw_dataset_path = 'figgie_train.json'
        processed_dataset_path = 'figgie_tokenized_dataset'
    else:
        raw_dataset_path = 'figgie_train_with_order_book.json'
        processed_dataset_path = 'figgie_tokenized_dataset_with_order_book'
    
    # Check if processed dataset already exists to avoid redundant work
    if os.path.exists(processed_dataset_path):
        print("Loading processed dataset from disk...")
        tokenized_datasets = load_from_disk(processed_dataset_path)
    else:
        # Load raw dataset
        with open(raw_dataset_path) as f:
            texts = json.load(f)
    
        # Split into train/eval (90% train, 10% eval)
        split_idx = int(0.9 * len(texts))
        train_texts = texts[:split_idx]
        eval_texts = texts[split_idx:]
        
        datasets = DatasetDict({
            "train": Dataset.from_dict({"text": texts[:split_idx]}),
            "eval": Dataset.from_dict({"text": texts[split_idx:]})
        })
    
        # Tokenization function
        def tokenize_function(examples):
            return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=1024)
    
        # Tokenize dataset with parallelization and progress bar
        tokenized_datasets = datasets.map(
            tokenize_function,
            batched=True,
            num_proc=4,
            remove_columns=["text"],
            load_from_cache_file=True,
            desc="Tokenizing dataset"
        )
        
        # Add labels for next-token prediction
        def set_labels(examples):
            examples["labels"] = examples["input_ids"].copy()
            return examples
        
        tokenized_datasets = tokenized_datasets.map(
            set_labels, batched=False, num_proc=4, desc="Setting Labels"
        )
    
        # Save processed dataset to disk for easy reloading
        tokenized_datasets.save_to_disk(processed_dataset_path)
        print("Processed dataset saved to disk.")
    
    print("Dataset ready to use.")

In [None]:
create_dataset(include_order_book = True)

In [None]:
allowed_vocab = [
    "<player>", "<HAND>", "<HISTORY>", "<ACTION>", "<ORDERBOOK>",
    "Spades", "Hearts", "Diamonds", "Clubs",
    "Order", "Buy", "Sell",
    "apple", "banana", "cantaloupe", "durian",
    "Chips", "@",
    "[UNK]", "[PAD]", "[EOS]"  # Special tokens
] + [str(i) for i in range(0, 60)]

# Build a new vocabulary dictionary with contiguous IDs (0 to n-1)
new_vocab = {token: i for i, token in enumerate(allowed_vocab)}

# Initialize a WordLevel model with the new vocabulary and an explicit unk_token.
tokenizer_model = models.WordLevel(vocab=new_vocab, unk_token="[UNK]")
tokenizer = Tokenizer(tokenizer_model)

# Set the pre-tokenizer to use whitespace splitting.
tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()

# Enable padding with the [PAD] token.
tokenizer.enable_padding(pad_token="[PAD]")

# Convert to a PreTrainedTokenizerFast for compatibility with Transformers.
tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
tokenizer.eos_token = "[EOS]"
tokenizer.pad_token = "[PAD]"
tokenizer.unk_token = "[UNK]"

In [None]:
tokenized_datasets = load_from_disk('figgie_tokenized_dataset_with_order_book')

example = tokenized_datasets["train"][0]  # get the first example
decoded_text = tokenizer.decode(example["input_ids"], skip_special_tokens=True)
print("Decoded text (with special tokens):")
print(decoded_text)

In [None]:
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments, GPT2Config

config = GPT2Config(
    vocab_size=len(tokenizer)
)

model = GPT2LMHeadModel(config)
model.to('cuda')
model.resize_token_embeddings(len(tokenizer))

# Data Collator for next-token prediction (MLM = False means causal LM)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False,
)

# Set up training arguments with evaluation enabled
training_args = TrainingArguments(
    output_dir='./figgie_final_new_tokenizer',
    eval_strategy='steps',  # evaluate during training
    eval_steps=3000,
    logging_steps=10,
    learning_rate=5e-5,
    save_steps=2000,
    save_total_limit=10,
    logging_dir='./logs',
    num_train_epochs=2,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    dataloader_num_workers=4,
)

# Train using Trainer API
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["eval"],
    data_collator=data_collator,
)

trainer.train()

In [None]:
from transformers import LogitsProcessor, LogitsProcessorList

class CustomConstraintLogitsProcessor(LogitsProcessor):
    """
    Custom logits processor to enforce that generated tokens follow the format:
      <number (1-59)> <Buy|Sell> <Clubs|Diamonds|Hearts|Spades>
    Requires setting the attribute prompt_length before generation.
    """
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        # Pre-compute allowed tokens for each position.
        self.valid_numbers = [self.tokenizer.encode(str(i), add_special_tokens=False)[0]
                              for i in range(1, 60)]
        self.valid_sides = [self.tokenizer.encode("Buy", add_special_tokens=False)[0],
                            self.tokenizer.encode("Sell", add_special_tokens=False)[0]]
        self.valid_suits = [self.tokenizer.encode(suit, add_special_tokens=False)[0]
                            for suit in ["Clubs", "Diamonds", "Hearts", "Spades"]]
        self.prompt_length = None  # Will be set dynamically

    def __call__(self, input_ids, scores):
        if self.prompt_length is None:
            raise ValueError("prompt_length must be set on CustomConstraintLogitsProcessor before generation.")

        # Determine how many new tokens have been generated:
        new_tokens_generated = input_ids.shape[1] - self.prompt_length

        # Only constrain the first three tokens (positions 0, 1, 2 of the new tokens)
        if new_tokens_generated < 1:
            valid = self.valid_numbers
        elif new_tokens_generated == 1:
            valid = self.valid_sides
        elif new_tokens_generated == 2:
            valid = self.valid_suits
        else:
            # No constraints for any extra tokens
            return scores

        # Create a mask: set scores for allowed tokens as-is, and everything else to -inf.
        mask = torch.full_like(scores, float("-inf"))
        for token_id in valid:
            mask[:, token_id] = scores[:, token_id]
        return mask

In [None]:
class SupervisedPlayerWithOrderBook(Player):
    def __init__(self, player_id, name, model, tokenizer, verbose=False, override_name=None):
        super().__init__(player_id, name, player_type="bot")
        self.model = model
        self.tokenizer = tokenizer
        self.max_new_tokens = 3
        self.last_prompt = ""
        self.last_generated_action = ""
        self.verbose = verbose
        self.override_name = override_name

    def build_prompt(self, player_hand, current_order_book, formatted_history):
        # Define the fixed parts of the prompt.
        fixed_prefix = f"<player> {self.name} <HAND> {player_hand} <HISTORY> "
        fixed_suffix = f" <ORDERBOOK> {current_order_book} <ACTION> Order {self.name}"

        # Start with no history.
        current_history = ""
        # Define dangerous limit (max tokens allowed for the full prompt).
        dangerous_limit = 1024 - self.max_new_tokens # adjust as needed

        # Build the prompt incrementally by prepending history events (most recent first)
        # Note: we iterate in reverse so that the most recent events are considered first.
        for event in reversed(formatted_history):
            # Try adding the event at the beginning of the current history.
            candidate_history = event + " " + current_history if current_history else event
            candidate_prompt = fixed_prefix + candidate_history + fixed_suffix
            token_ids = self.tokenizer.encode(candidate_prompt, add_special_tokens=False)
            if len(token_ids) <= dangerous_limit:
                current_history = candidate_history  # Accept the candidate history.
            else:
                # Stop if adding this event would exceed the dangerous limit.
                break

        # Build and return the final prompt.
        final_prompt = fixed_prefix + current_history.strip() + fixed_suffix
        return final_prompt

    def swap_names(self, text, name1, name2):
      tmp = "__TMP__"
      return text.replace(name1, tmp).replace(name2, name1).replace(tmp, name2)

    def decide_action(self, public_state: dict):
        # Construct the player's starting hand
        suit_order = ["clubs", "diamonds", "hearts", "spades"]

        # Build the string by iterating over the suits in suit_order.
        player_hand = " ".join(f"{suit.capitalize()} {self.initial_hand[suit]}" for suit in suit_order if suit in self.initial_hand)
        # Rebuild history in correct format
        formatted_history = []
        for turn in public_state["turn_history"]:
            action = turn["action"]
            params = turn["params"]
            player = turn["player"]

            if action == "place":
                side = "Buy" if params["side"] == "bid" else "Sell"
                suit = params["suit"].capitalize()
                price = params["price"]
                formatted_history.append(f"Order {player} {price} {side} {suit}")
            elif action == "cancel":
                side = "Buy" if params["side"] == "bid" else "Sell"
                suit = params["suit"].capitalize()
                price = params["price"]
                formatted_history.append(f"Cancel {player} {price} {side} {suit}")
            # elif action == "pass":
            #     formatted_history.append(f"Pass {player}")
            # Trades should already be formatted correctly in trade_history, but just in case:
            elif action == "trade":
                buyer = params["buyer"]
                seller = params["seller"]
                suit = params["suit"].capitalize()
                price = params["price"]
                formatted_history.append(f"Trade {buyer} {seller} {suit} {price}")

        current_order_book = compute_order_book(formatted_history)

        # Combine all parts into the final input string
        # input_text1 = (
        #     f"<player> {self.name} "
        #     f"<HAND> {player_hand} "
        #     f"<HISTORY> {' '.join(formatted_history)} "
        #     f"<ORDER BOOK> {current_order_book} "
        #     f"<ACTION>"
        # )

        input_text = self.build_prompt(player_hand, current_order_book, formatted_history)
        if self.override_name is not None:
          input_text = self.swap_names(input_text, "apple", self.override_name)
        # print(input_text)
        # if self.verbose:
        #     print(tokenizer.tokenize(input_text))
        # Tokenize and pass to model
        inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=1024-self.max_new_tokens).to('cuda')
        self.last_prompt = inputs["input_ids"].squeeze(0)

        # output = self.model.generate(
        #     inputs["input_ids"],
        #     attention_mask=inputs["attention_mask"],
        #     min_new_tokens=self.max_new_tokens,
        #     max_new_tokens=self.max_new_tokens,
        #     do_sample=True,
        #     temperature=0.8,
        #     top_p=0.7,
        #     early_stopping=False,
        #     pad_token_id=self.tokenizer.eos_token_id,
        # )
        self.model.eval()
        custom_processor = CustomConstraintLogitsProcessor(self.tokenizer)
        custom_processor.prompt_length = inputs["input_ids"].shape[1]
        logits_processor = LogitsProcessorList([custom_processor])
        with torch.no_grad():
          output = self.model.generate(
              inputs["input_ids"],
              attention_mask=inputs["attention_mask"],
              logits_processor=logits_processor,
              min_length=-1,
              max_new_tokens=self.max_new_tokens,
              do_sample=True,
              temperature=0.8,
              top_p=0.8,
              # top_p=1.0,
              # top_k=0.0,
              early_stopping=False,
              bad_words_ids=[tokenizer.encode('@', add_special_tokens=False)],
              eos_token_id=tokenizer.eos_token_id,
              pad_token_id=tokenizer.pad_token_id
          )
        self.last_generated_action = output[0][inputs["input_ids"].shape[-1]:]
        token_ids = self.last_generated_action.tolist()
        tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
        # if self.verbose:
        #     print(tokens)
        predicted_action = self.tokenizer.decode(self.last_generated_action, skip_special_tokens=False).strip()
        if self.override_name is not None:
            predicted_action = self.swap_names(predicted_action, "apple", self.override_name)
        # Parse predicted action
        return self.parse_action(predicted_action)

    def parse_action(self, action_str):
        try:
            # if action_str.startswith("Order"):
            #     parts = action_str.split()
            #     if len(parts) > 5:
            #         parts = parts[:5]
            #     if len(parts) == 5:
            #         # print(action_str)
            #         _, player, price, side, suit = parts
            #         if player != self.name:
            #             if self.verbose:
            #                 print("Order predicted for another player, marking as invalid.")
            #             return ('invalid', {})
            #         if price.isdigit():
            #             price = int(price)
            #         else:
            #             if self.verbose:
            #                 print("Malformed order action, marking as invalid.")
            #             return ('invalid', {})
            #         if side.lower() != 'buy' and side.lower() != 'sell':
            #             if self.verbose:
            #                 print("Malformed order action, marking as invalid.")
            #             return ('invalid', {})
            #         if suit.lower() not in Game.SUITS:
            #             if self.verbose:
            #                 print("Malformed order action, marking as invalid.")
            #             return ('invalid', {})
            #         side = 'bid' if side.lower() == 'buy' else 'offer'
            #         if self.verbose:
            #             # print(action_str)
            #             print('Order placed sucessfully')
            #         return 'place', {'side': side, 'suit': suit.lower(), 'price': price}
            #     else:
            #         if self.verbose:
            #             # print(action_str)
            #             print("Malformed trade action, marking as invalid.")
            #         return ('invalid', {})
            parts = action_str.split()
            parts = ["Order", self.name, parts[0], parts[1], parts[2]]
            if len(parts) > 5:
                parts = parts[:5]
            if len(parts) == 5:
                # print(action_str)
                _, player, price, side, suit = parts
                if player != self.name:
                    if self.verbose:
                        print("Order predicted for another player, marking as invalid.")
                    return ('invalid', {})
                if price.isdigit():
                    price = int(price)
                else:
                    if self.verbose:
                        print(action_str)
                        print("Malformed order action, marking as invalid.")
                    return ('invalid', {})
                if side.lower() != 'buy' and side.lower() != 'sell':
                    if self.verbose:
                        print(action_str)
                        print("Malformed order action, marking as invalid.")
                    return ('invalid', {})
                if suit.lower() not in Game.SUITS:
                    if self.verbose:
                        print(action_str)
                        print("Malformed order action, marking as invalid.")
                    return ('invalid', {})
                side = 'bid' if side.lower() == 'buy' else 'offer'
                if self.verbose:
                    print(action_str)
                    print('Order placed sucessfully')
                return 'place', {'side': side, 'suit': suit.lower(), 'price': price}
            else:
                if self.verbose:
                    print(action_str)
                    print("Malformed trade action, marking as invalid.")
                return ('invalid', {})

            # elif action_str.startswith("Trade"):
            #     parts = action_str.split()
            #     if len(parts) > 6:
            #         parts = parts[:6]
            #     if len(parts) == 6:
            #         _, buyer, seller, suit, side, price = parts
            #         side = 'bid' if side.lower() == 'buy' else 'offer'
            #         acting_player = buyer if side == 'bid' else seller
            #         if acting_player != self.name:
            #             print("Trade predicted for another player, marking as invalid.")
            #             return ('invalid', {})
            #         return 'place', {'side': side, 'suit': suit.lower(), 'price': int(price)}
            #     else:
            #         print(action)
            #         print("Malformed trade action, marking as invalid.")
            #         return ('invalid', {})
            # elif action_str.startswith("Cancel"):
            #     parts = action_str.split()
            #     if len(parts) == 5:
            #         print("A PLAYER HAS CANCELLED")
            #         _, player, price, side, suit = parts
            #         side = 'bid' if side.lower() == 'buy' else 'offer'
            #         return 'cancel', {'side': side, 'suit': suit.lower(), 'price': int(price)}
            #     else:
            #         print("Malformed cancel action, marking as invalid.")
            #         return ('invalid', {})
            # elif action_str.startswith("Pass"):
            #     print('A PLAYER HAS PASSED')
            #     return ('pass', {})
            # else:
            #     if self.verbose:
            #         # print(action_str)
            #         print("Unrecognized action, marking as invalid.")
            #     return ('invalid', {})
        except Exception as e:
            if self.verbose:
                print(f"Error parsing action: {e}. Marking as invalid.")
            return ('invalid', {})

In [None]:
names = ["apple", "banana", "cantaloupe", "durian"]
players = []
for i in range(4):
    if i == 0:
        # RL player uses the trainable model.
        players.append(SupervisedPlayerWithOrderBook(i, names[i], model, tokenizer, verbose=False))
    else:
        # Static players use the fixed reference model.
        players.append(SupervisedPlayerWithOrderBook(i, names[i], model_ref, tokenizer, verbose=False))
game = Game(num_players=4, players=players)
game.run_game()

In [None]:
def compute_reward(game, action_str, turn, goal_suit, starting_chips=350):
    result = turn.get("result", "").lower()
    if "invalid action - turn skipped" in result:
        print(action_str)
        print("PUNISHED")
        return -3.0  # Penalize invalid/ skipped actions immediately.

    parts = action_str.split()
    price = int(parts[0])
    action_type = parts[1].lower()
    suit = parts[2].lower()

    # (1) Chip reward: Use final chip difference, but with a low weight.
    public_state = game.get_public_state()
    player_name = turn["player"]
    trade_reward = 0.0
    trade_executed = "trade executed" in result  # Check whether a trade actually occurred.
    order_placed = "order placed" in result
    if (not trade_executed) and (not order_placed):
        # we just return a reward of 0 for passing
        return 0
    if suit != goal_suit.lower():
        # For non-goal suits, simply use the price.
        if action_type == "buy":
            trade_reward = -price
        elif action_type == "sell":
            trade_reward = price
    else:
        # For goal suits, use a marginal value table.
        marginal_values = [10, 11, 13, 20, 50]  # Values for transitions: 0->1, 1->2, ..., 4->5.
        # Get the current count for the goal suit from the player's private state.
        private_state = game.get_private_state()
        hand = private_state["full_hands"][player_name].get("hand", {})  # e.g., {'clubs': 3, 'hearts': 5, ...}
        count_before = hand.get(suit, 0) - 1 if trade_executed else hand.get(suit, 0)
        if action_type == "buy":
            if count_before >= 5:
                # If already at or above optimal, further buys are penalized.
                trade_reward = 10 - price
            else:
                # Reward is the marginal benefit minus the cost.
                trade_reward = marginal_values[count_before] - price
        elif action_type == "sell":
            if count_before > 5:
                trade_reward = price - 10
            else:
                trade_reward = price - marginal_values[count_before - 1]

    return trade_reward

In [None]:
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
from tqdm import tqdm

ppo_config_dict = {
    "batch_size": 16,             # Keep it small if data per update is limited.
    "mini_batch_size": 8,
    "learning_rate": 1e-5,        # Lower learning rate to ensure smoother updates.
    "init_kl_coef": 0.5,          # Increase KL coefficient to penalize divergence more.
    "target_kl": 0.5,             # Lower target KL threshold for tighter policy updates.
    "cliprange": 0.1,             # Lower clip range to limit aggressive policy updates.
    "max_grad_norm": 1.0,         # Add gradient clipping for additional stability.
}

model = AutoModelForCausalLMWithValueHead.from_pretrained("checkpoint-4000-new-tokenizer").to('cuda')
# Load the reference model (a separate copy) for PPO. This model stays fixed during the update.
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("checkpoint-4000-new-tokenizer").to('cuda')
model_ref.eval()
config = PPOConfig(**ppo_config_dict)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)

games_per_update = 10  # Run 10 games before each PPO update.
num_blocks = 200       # For example, 200 blocks * 10 games = 2000 games total.

# Initialize lists to accumulate trajectories across games.
all_queries = []
all_responses = []
all_rewards = []
all_stats = []

for block in tqdm(range(num_blocks)):
    for game_idx in tqdm(range(games_per_update)):
        names = ["apple", "banana", "cantaloupe", "durian"]
        players = []
        for i in range(4):
            if i == 0:
                # RL player uses the trainable model.
                players.append(SupervisedPlayerWithOrderBook(i, names[i], model, tokenizer, verbose=False))
            else:
                # Static players use the fixed reference model.
                players.append(SupervisedPlayerWithOrderBook(i, names[i], model_ref, tokenizer, verbose=False))
        game = Game(num_players=4, players=players)
        model.eval()
        game.run_game()

        # --- Collect trajectories for the RL player (names[0]) ---
        with torch.no_grad():
          public_state = game.get_public_state()
          for turn in game.turn_history:
              if turn["player"] != names[0]:
                  continue
              prompt = turn.get("prompt", "")
              action = turn.get("last_generated_action", "")
              reward = compute_reward(game, tokenizer.decode(action), turn, game.goal_suit)
              all_queries.append(prompt)
              all_responses.append(action)
              all_rewards.append(torch.tensor(reward, device='cuda', dtype=torch.float))

    # --- Perform PPO update after accumulating trajectories from 10 games ---
    if len(all_queries) > 0:
        print(f"Updating model after block {block+1} with {len(all_queries)} examples...")
        model.train()
        batch_size = 16
        num_examples = len(all_queries)
        for i in range(0, num_examples, batch_size):
            queries_batch = all_queries[i:i+batch_size]
            responses_batch = all_responses[i:i+batch_size]
            rewards_batch = all_rewards[i:i+batch_size]
            # Update only on full batches; adjust as needed for partial batches.
            if len(queries_batch) == batch_size:
                stats = ppo_trainer.step(queries_batch, responses_batch, rewards_batch)
                all_stats.append(stats)
        # Clear the accumulated trajectories after PPO updates.
        all_queries = []
        all_responses = []
        all_rewards = []

    if (block + 1) % 10 == 0:
        # Save a checkpoint after each block.
        model.save_pretrained(f"rl_test_model2{block+1}")
        tokenizer.save_pretrained(f"rl_test_model2{block+1}")
        print(f"Saved checkpoint for block {block+1}")

model.save_pretrained("rl_test_model2")
tokenizer.save_pretrained("rl_test_model2")

In [None]:
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead
from typing import Optional, Tuple, Union, Dict, Any
import torch.nn.functional as F

class CustomPPOModelWithAux(AutoModelForCausalLMWithValueHead):
    """
    Custom PPO model with an auxiliary head for predicting the goal suit in Figgie game.
    Inherits from AutoModelForCausalLMWithValueHead and adds an auxiliary classification head.
    """

    def __init__(
        self,
        pretrained_model,
        num_suits: int = 4,  # Number of possible suits to predict
        value_head_dropout: float = 0.1,
        **kwargs
    ):
        # Initialize the parent class (PPO model with value head)
        super().__init__(pretrained_model, **kwargs)
        self.current_aux_labels = None
        self.aux_logits = None

        # Copy all attributes from the parent class that might be needed
        self.is_peft_model = getattr(pretrained_model, "is_peft_model", False)

        # Get the hidden size from the pretrained model config
        hidden_size = self.pretrained_model.config.hidden_size

        # Add an auxiliary classification head for suit prediction
        self.aux_head_goal = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Dropout(value_head_dropout),
            nn.Linear(hidden_size, num_suits)
        )
        # Auxiliary head for final chip counts (4 players)
        self.aux_head_chip = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(value_head_dropout),
            nn.Linear(hidden_size, 4),
            nn.ReLU()  # optionally ensure non-negative outputs
        )
        # Auxiliary head for hand composition (16 outputs)
        self.aux_head_hand = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(value_head_dropout),
            nn.Linear(hidden_size, 16),
            nn.ReLU()
        )
        # Auxiliary head for trade prediction (binary classification)
        self.aux_head_trade = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Dropout(value_head_dropout),
            nn.Linear(hidden_size, 1)
        )

        # Initialize weights for all auxiliary heads
        for head in [self.aux_head_goal, self.aux_head_chip, self.aux_head_hand, self.aux_head_trade]:
            for module in head:
                if isinstance(module, nn.Linear):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                    if module.bias is not None:
                        module.bias.data.zero_()

    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        """
        Forward pass with integrated auxiliary loss calculation.
        Handles both base model outputs and custom auxiliary task.
        """

        # 1. Get base model outputs with hidden states
        base_outputs = self.pretrained_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,  # Required for auxiliary task
            return_dict=True,            # Get structured outputs
            **kwargs
        )

        lm_logits = base_outputs.logits
        loss = base_outputs.loss
        hidden_states = base_outputs.hidden_states[-1] # (B, T, hidden_size)

        value = self.v_head(hidden_states).squeeze(-1) # Now value has shape (B, T)

        if attention_mask is not None:
            seq_lens = (attention_mask.sum(dim=1) - 1).long()
        else:
            seq_lens = torch.tensor([input_ids.shape[1]-1]*input_ids.shape[0], device=input_ids.device, dtype=torch.long)

        last_hidden = hidden_states[torch.arange(hidden_states.size(0)), seq_lens]

        if self.current_aux_labels is not None:
            # Example dict: {
            #   "goal": tensor([...], dtype=torch.long),
            #   "chip": tensor([...], dtype=torch.float or long),
            #   "hand": tensor([...], dtype=torch.float or long),
            #   "trade": tensor([...], dtype=torch.float)  # binary 0 or 1
            # }
            current_batch_size = input_ids.shape[0]

            # Unpack auxiliary targets (and ensure proper device placement)
            goal_labels = self.current_aux_labels["goal"][:current_batch_size].to(input_ids.device)
            chip_labels = self.current_aux_labels["chip"][:current_batch_size].to(input_ids.device)
            hand_labels = self.current_aux_labels["hand"][:current_batch_size].to(input_ids.device)
            trade_labels = self.current_aux_labels["trade"][:current_batch_size].to(input_ids.device)

            # Compute predictions from each head
            aux_logits_goal = self.aux_head_goal(last_hidden)
            aux_logits_chip = self.aux_head_chip(last_hidden)
            aux_logits_hand = self.aux_head_hand(last_hidden)
            aux_logits_trade = self.aux_head_trade(last_hidden)

            self.aux_logits = {
                "goal": aux_logits_goal,
                "chip": aux_logits_chip,
                "hand": aux_logits_hand,
                "trade": aux_logits_trade
            }

            # Losses:
            loss_goal = F.cross_entropy(aux_logits_goal, goal_labels.long()) / 1.5
            loss_chip = F.mse_loss(aux_logits_chip, chip_labels.float()) / 127000
            loss_hand = F.mse_loss(aux_logits_hand, hand_labels.float()) / 10
            loss_trade = F.binary_cross_entropy_with_logits(aux_logits_trade.squeeze(-1), trade_labels.float()) / 0.75

            # Weights for each auxiliary loss (starting values; tune as needed)
            w_goal, w_chip, w_hand, w_trade = .35, .20, .30, .15

            aux_loss = w_goal * loss_goal + w_chip * loss_chip + w_hand * loss_hand + w_trade * loss_trade
            # print(f"main loss: {loss}")
            # print(f"goal loss: {loss_goal}")
            # print(f"chip loss: {loss_chip}")
            # print(f"hand loss: {loss_hand}")
            # print(f"trade loss: {loss_trade}")
            loss = (loss if loss is not None else 0) + aux_loss

        return (lm_logits, loss, value)


    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        """
        Load a pretrained model and add auxiliary and value heads.
        """
        # Extract kwargs specific to our class
        num_suits = kwargs.pop('num_suits', 4) if 'num_suits' in kwargs else 4
        custom_kwargs = {'num_suits': num_suits}

        # Load using the parent class from_pretrained to handle value head properly
        base_model = AutoModelForCausalLMWithValueHead.from_pretrained(
            pretrained_model_name_or_path, *model_args, **kwargs
        )

        # Create our model using the base model's pretrained part
        model = cls(base_model.pretrained_model, **custom_kwargs)

        # Copy the value head weights from the loaded model
        model.v_head.load_state_dict(base_model.v_head.state_dict())

        return model

    def save_pretrained(self, save_directory, **kwargs):
        """
        Save the model to the specified directory.
        """
        # Save the base pretrained model using the parent class method
        super().save_pretrained(save_directory, **kwargs)

In [None]:
from torch.nn import CrossEntropyLoss
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
from tqdm import tqdm

model_path = "checkpoint-4000-new-tokenizer"
names = ["apple", "banana", "cantaloupe", "durian"]
suit_to_idx = {"spades": 0, "hearts": 1, "diamonds": 2, "clubs": 3}

ppo_config_dict = {
    "batch_size": 16,             # Keep it small if data per update is limited.
    "mini_batch_size": 8,
    "learning_rate": 1e-5,        # Lower learning rate to ensure smoother updates.
    "init_kl_coef": 0.5,          # Increase KL coefficient to penalize divergence more.
    "target_kl": 0.5,             # Lower target KL threshold for tighter policy updates.
    "cliprange": 0.1,             # Lower clip range to limit aggressive policy updates.
    "max_grad_norm": 1.0,         # Add gradient clipping for additional stability.
}

# ---- Initialize Models ----
model = CustomPPOModelWithAux.from_pretrained(model_path, num_suits=4).to('cuda')
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained(model_path).to('cuda').eval()

# ---- Training Setup ----
config = PPOConfig(**ppo_config_dict)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)

# --- Training Loop ---
games_per_update = 10
num_blocks = 200
batch_size = config.batch_size

for block in tqdm(range(1, num_blocks + 1), desc="Training Blocks"):
    queries_text, queries, responses, rewards = [], [], [], []
    aux_labels = []

    # ---- Game Simulation ----
    for game_idx in tqdm(range(games_per_update), leave=False, desc="Games"):
        players = [
            SupervisedPlayerWithOrderBook(0, names[0], model, tokenizer, verbose=False), # RL player uses the trainable model.
            SupervisedPlayerWithOrderBook(1, names[1], model_ref, tokenizer, verbose=False),
            SupervisedPlayerWithOrderBook(2, names[2], model_ref, tokenizer, verbose=False),
            SupervisedPlayerWithOrderBook(3, names[3], model_ref, tokenizer, verbose=False),
        ]

        game = Game(num_players=4, players=players)
        game.run_game()
        goal_suit_idx = suit_to_idx[game.goal_suit.lower()]
        game_log_final_chips_and_hand = game.export_game_log()["final_chips_and_hands"]

        # ---- Collect Trajectories ----
        with torch.no_grad():
            for turn in filter(lambda t: t["player"] == names[0], game.turn_history):
                prompt_tokens = turn.get("prompt", torch.tensor([], device="cuda"))
                prompt = tokenizer.decode(prompt_tokens, skip_special_tokens=True)

                action = turn.get("last_generated_action", torch.tensor([]))
                decoded_action = tokenizer.decode(action, skip_special_tokens=True)

                # Calculate reward
                reward_val = compute_reward(game, decoded_action, turn, game.goal_suit)

                # Collect auxiliary targets:
                chip_counts = [players[i].chips for i in range(4)]
                hand_vector = []
                for name in names:
                    hand = game.get_private_state()["full_hands"][name]
                    hand_vector.extend([hand.get("clubs", 0), hand.get("diamonds", 0),
                                        hand.get("hearts", 0), hand.get("spades", 0)])
                trade_flag = 1 if "trade executed" in turn.get("result", "").lower() else 0

                # Store training data
                queries.append(prompt_tokens)
                queries_text.append(prompt)
                responses.append(action)
                rewards.append(torch.tensor(reward_val, device="cuda", dtype=torch.float))
                aux_labels.append({
                    "goal": goal_suit_idx,
                    "chip": chip_counts,
                    "hand": hand_vector,
                    "trade": trade_flag
                })

    # ---- PPO + Auxiliary Updates ----
    if queries:
        total_samples = len(queries)
        if total_samples % batch_size != 0:
            pad_needed = batch_size - (total_samples % batch_size)

            queries += [queries[-1]] * pad_needed
            queries_text += [queries_text[-1]] * pad_needed
            responses += [responses[-1]] * pad_needed
            rewards += [rewards[-1]] * pad_needed
            aux_labels += [aux_labels[-1]] * pad_needed

        print(f"[Block {block}] Training with {len(queries)} samples...")
        model.train()

        for i in range(0, len(queries), batch_size):
            batch_queries = queries[i:i+batch_size]
            batch_queries_text = queries_text[i:i+batch_size]
            batch_responses = responses[i:i+batch_size]
            batch_rewards = rewards[i:i+batch_size]
            batch_aux_labels = aux_labels[i:i+batch_size]

            # Collate auxiliary targets into tensors
            goal_targets = torch.tensor([a['goal'] for a in batch_aux_labels], dtype=torch.long, device="cuda")
            chip_targets = torch.tensor([a['chip'] for a in batch_aux_labels], dtype=torch.float, device="cuda")
            hand_targets = torch.tensor([a['hand'] for a in batch_aux_labels], dtype=torch.float, device="cuda")
            trade_targets = torch.tensor([a['trade'] for a in batch_aux_labels], dtype=torch.float, device="cuda")

            model.current_aux_labels = {
                "goal": goal_targets,
                "chip": chip_targets,
                "hand": hand_targets,
                "trade": trade_targets,
            }

            # ---- PPO Step ----
            with torch.amp.autocast(device_type='cuda'):
                ppo_stats = ppo_trainer.step(
                    batch_queries,
                    batch_responses,
                    batch_rewards,
                )
            print(ppo_stats)

            # Clear temporary storage
            model.current_aux_labels = None

            # Memory cleanup
            del batch_rewards, batch_aux_labels
            torch.cuda.empty_cache()

        queries.clear()
        responses.clear()
        rewards.clear()
        queries_text.clear()
        aux_labels.clear()

    # --- Checkpoint ---
    if block % 10 == 0:
        ckpt_path = f"aux_model_{block}"
        model.save_pretrained(ckpt_path)
        model = model.to('cuda')
        tokenizer.save_pretrained(ckpt_path)
        print(f"[Checkpoint] Saved model at block {block}")

In [None]:
from trl import AutoModelForCausalLMWithValueHead
from tqdm import tqdm
lengths = []
chip_vals = []
winners = []
logs = []
# rl_model_path = "/content/drive/MyDrive/checkpoint-4000-new-tokenizer"
rl_model_path = "rl_test_model2"
rl_model = AutoModelForCausalLMWithValueHead.from_pretrained(rl_model_path).to("cuda")

# Load the static model (for the other three players) from the usual checkpoint.
static_model_path = "checkpoint-4000-new-tokenizer"
static_model = AutoModelForCausalLMWithValueHead.from_pretrained(static_model_path).to("cuda")
# Here we assume the tokenizer is the same (or similar) for both.
# If needed, you could also load a separate tokenizer:
# static_tokenizer = PreTrainedTokenizerFast.from_pretrained(static_model_path)
# For this example, we will use the RL tokenizer for all players.

# -------------------------------
# Set Up Players and Tournament
# -------------------------------

# Assuming player names are defined as follows:
names = ["apple", "banana", "cantaloupe", "durian"]



# -------------------------------
# Run Tournament and Collect Stats
# -------------------------------

num_games = 1000
# Dictionary to accumulate total chips for each player over all games.
total_chip_counts = {name: 0 for name in names}

for game_num in tqdm(range(num_games)):
    # Create players.
    # Player 0 will be the RL (trainable) player,
    # and players 1-3 will be static (using the fixed model).
    players = []
    players.append(SupervisedPlayerWithOrderBook(0, names[0], rl_model, tokenizer, verbose=False))
    for i in range(1, 4):
        players.append(SupervisedPlayerWithOrderBook(i, names[i], static_model, tokenizer, verbose=False, override_name=names[i]))
    # Create a new game instance with a fresh state.
    game = Game(num_players=4, players=players)
    game.run_game()  # This runs the game simulation.
    logs.append(game.export_game_log())
    # Retrieve the final public state of the game.
    # We assume `get_public_state()` returns a dict with a "players" key,
    # where each entry is like: {"chips": <chip_count>, ...}.
    public_state = game.get_public_state()

    print(f"\nGame {game_num + 1} final chip values:")
    for i in range(4):
        chips = players[i].chips
        print(f"{names[i]}: {chips}")
        total_chip_counts[names[i]] += chips
    winners.append(game.goal_suit_winner)
    lengths.append(len(game.turn_history))
    chip_vals.append([players[0].chips, players[1].chips, players[2].chips, players[3].chips])

# Calculate and print average chip counts over all games.
print("\nAverage chip values over {} games:".format(num_games))
for name in names:
    avg_chips = total_chip_counts[name] / num_games
    print(f"{name}: {avg_chips}")

print(lengths)
print(chip_vals)
print(winners)

In [None]:
from trl import AutoModelForCausalLMWithValueHead
from tqdm import tqdm
lengths = []
chip_vals = []
winners = []
logs = []
# rl_model_path = "/content/drive/MyDrive/checkpoint-4000-new-tokenizer"
rl_model_path = "aux_block_200"
rl_model = AutoModelForCausalLMWithValueHead.from_pretrained(rl_model_path).to("cuda")

# Load the static model (for the other three players) from the usual checkpoint.
static_model_path = "checkpoint-4000-new-tokenizer"
static_model = AutoModelForCausalLMWithValueHead.from_pretrained(static_model_path).to("cuda")
# Here we assume the tokenizer is the same (or similar) for both.
# If needed, you could also load a separate tokenizer:
# static_tokenizer = PreTrainedTokenizerFast.from_pretrained(static_model_path)
# For this example, we will use the RL tokenizer for all players.

# -------------------------------
# Set Up Players and Tournament
# -------------------------------

# Assuming player names are defined as follows:
names = ["apple", "banana", "cantaloupe", "durian"]



# -------------------------------
# Run Tournament and Collect Stats
# -------------------------------

num_games = 1000
# Dictionary to accumulate total chips for each player over all games.
total_chip_counts = {name: 0 for name in names}

for game_num in tqdm(range(num_games)):
    # Create players.
    # Player 0 will be the RL (trainable) player,
    # and players 1-3 will be static (using the fixed model).
    players = []
    players.append(SupervisedPlayerWithOrderBook(0, names[0], rl_model, tokenizer, verbose=False))
    for i in range(1, 4):
        players.append(SupervisedPlayerWithOrderBook(i, names[i], static_model, tokenizer, verbose=False, override_name=names[i]))
    # Create a new game instance with a fresh state.
    game = Game(num_players=4, players=players)
    game.run_game()  # This runs the game simulation.
    logs.append(game.export_game_log())
    # Retrieve the final public state of the game.
    # We assume `get_public_state()` returns a dict with a "players" key,
    # where each entry is like: {"chips": <chip_count>, ...}.
    public_state = game.get_public_state()

    print(f"\nGame {game_num + 1} final chip values:")
    for i in range(4):
        chips = players[i].chips
        print(f"{names[i]}: {chips}")
        total_chip_counts[names[i]] += chips
    winners.append(game.goal_suit_winner)
    lengths.append(len(game.turn_history))
    chip_vals.append([players[0].chips, players[1].chips, players[2].chips, players[3].chips])

# Calculate and print average chip counts over all games.
print("\nAverage chip values over {} games:".format(num_games))
for name in names:
    avg_chips = total_chip_counts[name] / num_games
    print(f"{name}: {avg_chips}")

print(lengths)
print(chip_vals)
print(winners)

In [None]:
from trl import AutoModelForCausalLMWithValueHead
from tqdm import tqdm
lengths = []
chip_vals = []
winners = []
logs = []
# rl_model_path = "/content/drive/MyDrive/checkpoint-4000-new-tokenizer"
rl_model_path = "aux_block_200"
rl_model = AutoModelForCausalLMWithValueHead.from_pretrained(rl_model_path).to("cuda")

# Load the static model (for the other three players) from the usual checkpoint.
static_model_path = "checkpoint-4000-new-tokenizer"
static_model = AutoModelForCausalLMWithValueHead.from_pretrained(static_model_path).to("cuda")
# Here we assume the tokenizer is the same (or similar) for both.
# If needed, you could also load a separate tokenizer:
# static_tokenizer = PreTrainedTokenizerFast.from_pretrained(static_model_path)
# For this example, we will use the RL tokenizer for all players.

# -------------------------------
# Set Up Players and Tournament
# -------------------------------

# Assuming player names are defined as follows:
names = ["apple", "banana", "cantaloupe", "durian"]



# -------------------------------
# Run Tournament and Collect Stats
# -------------------------------

num_games = 1000
# Dictionary to accumulate total chips for each player over all games.
total_chip_counts = {name: 0 for name in names}

for game_num in tqdm(range(num_games)):
    # Create players.
    # Player 0 will be the RL (trainable) player,
    # and players 1-3 will be static (using the fixed model).
    players = []
    players.append(SupervisedPlayerWithOrderBook(0, names[0], rl_model, tokenizer, verbose=False))
    for i in range(1, 4):
        players.append(SupervisedPlayerWithOrderBook(i, names[i], static_model, tokenizer, verbose=False, override_name=names[i]))
    # Create a new game instance with a fresh state.
    game = Game(num_players=4, players=players)
    game.run_game()  # This runs the game simulation.
    logs.append(game.export_game_log())
    # Retrieve the final public state of the game.
    # We assume `get_public_state()` returns a dict with a "players" key,
    # where each entry is like: {"chips": <chip_count>, ...}.
    public_state = game.get_public_state()

    print(f"\nGame {game_num + 1} final chip values:")
    for i in range(4):
        chips = players[i].chips
        print(f"{names[i]}: {chips}")
        total_chip_counts[names[i]] += chips
    winners.append(game.goal_suit_winner)
    lengths.append(len(game.turn_history))
    chip_vals.append([players[0].chips, players[1].chips, players[2].chips, players[3].chips])

# Calculate and print average chip counts over all games.
print("\nAverage chip values over {} games:".format(num_games))
for name in names:
    avg_chips = total_chip_counts[name] / num_games
    print(f"{name}: {avg_chips}")

print(lengths)
print(chip_vals)
print(winners)