In [None]:
import torch
import numpy
from ml_modules import *
from sequence_modules import *
from llm_modules import *
from agent import *
from simulation import *

from utils import *

from ml_ops_utils import *

import gc

from hand_strength import *

!nvidia-smi

In [None]:
num_players = 2  # or up to 8

def build_player_components(num_players, embed_dim, max_seq_len, num_streets, device):
    return {
        "street_embedder": StreetPositionalEncoding(num_streets=num_streets, embedding_dim=embed_dim, max_seq_len=max_seq_len, device=device),
        "table_position_embedder": TablePositionalEncoding(num_players=num_players, embedding_dim=embed_dim, max_seq_len=max_seq_len, device=device),
        "action_embedder": ActionEncoding(embedding_dim=embed_dim, max_seq_len=max_seq_len, device=device),
        "pot_size_embedder": PotSizeSequenceEmbedder(max_seq_len=max_seq_len, pad_value=-1, device=device),
        "poker_sequence_embedder": PokerSequenceEmbedder(
            street_input_dimension=embed_dim,
            table_position_input_dimension=embed_dim,
            action_input_dimension=embed_dim,
            latent_dimensions=[embed_dim, embed_dim * 2, embed_dim * 4, embed_dim * 8],
            device=device
        ),
        "cards": Cards(device=device),
        "self_position_embedder": SelfPositionEmbedder(number_of_positions=num_players, device=device),
    }

players = [build_player_components(num_players, embed_dim, max_seq_len, num_streets, device) for _ in range(num_players)]

In [None]:
def build_agent(num_players, embed_dim, max_seq_len, num_streets, device, model):
    self_position_embedder = SelfPositionEmbedder(number_of_positions=num_players, device=device)
    
    cards = Cards(device=device)
    street_embedder = StreetPositionalEncoding(num_streets=num_streets, embedding_dim=embed_dim, max_seq_len=max_seq_len, device=device)
    table_position_embedder = TablePositionalEncoding(num_players=num_players, embedding_dim=embed_dim, max_seq_len=max_seq_len, device=device)
    action_embedder = ActionEncoding(embedding_dim=embed_dim, max_seq_len=max_seq_len, device=device)
    pot_size_embedder = PotSizeSequenceEmbedder(max_seq_len=max_seq_len, pad_value=-1, device=device)
    poker_sequence_embedder = PokerSequenceEmbedder(
        street_input_dimension=embed_dim,
        table_position_input_dimension=embed_dim,
        action_input_dimension=embed_dim,
        latent_dimensions=[embed_dim, embed_dim * 2, embed_dim * 4, embed_dim * 8],
        device=device
    )
    policy_model = PolicyModel(
        num_players=num_players,
        self_position_embedder=self_position_embedder,
        active_players_hidden_dims=[1024, 2048],
        stack_size_hidden_dims=[1024, 2048],
        card_embeddings_hidden_dims=[2048, 2048],
        final_output_hidden_dims=[1024, 512, 256],
        value_output_hidden_dims=[1024, 512, 256],
        dropout_rate=0,
        device=device,
    )
    agent = PokerAgent(
        cards,
        street_embedder,
        table_position_embedder,
        action_embedder,
        pot_size_embedder,
        poker_sequence_embedder,
        model,
        policy_model,
        device=device,
        llm_train=False
    )
    return agent

population_size = num_players  # or set independently, e.g. 8
agents = [build_agent(num_players, embed_dim, max_seq_len, num_streets, device, model) for _ in range(population_size)]
optimizers = [torch.optim.AdamW(agent.parameters(), lr=1e-4) for agent in agents]

In [None]:
hand_counter = 0  # track globally across hands

def build_hand(num_players, batch_size, agents, small_blind, big_blind, starting_stack_sizes, hand_counter, max_seq_len, device='cuda'):
    deck_order_shuffled = torch.argsort(torch.rand(1, 52))

    # Round-robin: rotate which agents sit in which seats each hand
    population_size = len(agents)
    seated_agents = [agents[(hand_counter + i) % population_size] for i in range(num_players)]

    # Sequence buffers
    street_idxs = (torch.zeros((batch_size, max_seq_len)) + num_streets).long()
    table_position_idxs = (torch.zeros((batch_size, max_seq_len)) + num_players).long()
    action_idxs = (torch.zeros((batch_size, max_seq_len)) + 21).long()
    pot_size_sequence = torch.zeros((batch_size, max_seq_len)) - 1

    # Blind postings (seat 0=SB, seat 1=BB always)
    street_idxs[:, :2] = 0
    table_position_idxs[:, 0] = 0
    table_position_idxs[:, 1] = 1
    action_idxs[:, 0] = 0  # post SB
    action_idxs[:, 1] = 1  # post BB
    pot_size_sequence[:, 0] = small_blind
    pot_size_sequence[:, 1] = small_blind + big_blind

    # Stack sizes
    stack_sizes_init = [
        starting_stack_sizes - small_blind if i == 0
        else starting_stack_sizes - big_blind if i == 1
        else starting_stack_sizes
        for i in range(num_players)
    ]
    active_players = torch.ones((batch_size, num_players))
    stack_size = torch.Tensor(stack_sizes_init).unsqueeze(0).tile(batch_size, 1)

    # Deal hole cards
    player_cards = []
    for i in range(num_players):
        cards = torch.zeros((batch_size, 2, 7), dtype=torch.long, device=device)
        card_indices = deck_order_shuffled[0, i*2 : i*2+2]
        cards[:, 0, :2] = card_indices % 13   # rank
        cards[:, 1, :2] = card_indices // 13  # suit
        cards[:, 0, 2:] = 13  # padding rank
        cards[:, 1, 2:] = 4   # padding suit
        player_cards.append(cards)

    # Build table with round-robin seated agents
    table = {
        i: [seated_agents[i], player_cards[i], torch.Tensor([i]).to(device).tile(batch_size)]
        for i in range(num_players)
    }

    return (
        street_idxs,
        table_position_idxs,
        action_idxs,
        pot_size_sequence,
        active_players,
        stack_size,
        table,
        deck_order_shuffled,
        seated_agents,  # return so you know who played which seat for gradient updates
    )

# Usage
street_idxs, table_position_idxs, action_idxs, pot_size_sequence, \
active_players, stack_size, table, deck_order_shuffled, seated_agents = build_hand(
    num_players, batch_size, agents, small_blind, big_blind,
    starting_stack_sizes, hand_counter, max_seq_len, device
)

results = simulate_hand(
    num_players, street_idxs, table_position_idxs, action_idxs, pot_size_sequence,
    active_players, stack_size, table, action_validator, deck_order_shuffled,
)

hand_counter += 1

In [None]:
rewards = determine_winner(
    sim_active_players[-1],
    sim_pot_size_sequence[-1],
    sim_stack_size[-1],
    sim_table
)
for i in range(num_players):
    sim_table[i].append(rewards[i])

next_to_act = action_validator.get_next_to_act(
    sim_street_idxs,
    sim_table_position_idxs,
    sim_action_idxs,
    sim_active_players
)

player_action_batch_indices = {
    i: torch.where(next_to_act == i)[0]
    for i in range(num_players)
}

In [None]:
player_masks = {
    i: (sim_action_idxs[player_action_batch_indices[i] + 1] == 21).float().argmax(dim=1)
    for i in range(num_players)
}

In [None]:
def compute_log_probs(player_idx, player_action_batch_indices, player_masks, sim_table, sim_street_idxs, sim_table_position_idxs, sim_action_idxs, sim_pot_size_sequence, sim_active_players, sim_stack_size):
    batch_indices = player_action_batch_indices[player_idx]
    mask = player_masks[player_idx]

    position = sim_table[player_idx][2][batch_indices]
    cards = sim_table[player_idx][1][batch_indices]
    street_idxs = sim_street_idxs[batch_indices]
    table_position_idxs = sim_table_position_idxs[batch_indices]
    action_idxs = sim_action_idxs[batch_indices]
    pot_size_sequence = sim_pot_size_sequence[batch_indices]
    active_players = sim_active_players[batch_indices]
    stack_size = sim_stack_size[batch_indices]

    outputs = sim_table[player_idx][0](
        position,
        cards,
        street_idxs,
        table_position_idxs,
        action_idxs,
        pot_size_sequence,
        active_players.to('cuda'),
        stack_size.to('cuda')
    )

    taken_actions = sim_action_idxs[batch_indices + 1, mask - 1]
    log_probs = torch.nn.functional.log_softmax(outputs['probits'], dim=-1)
    log_probs_of_taken_actions = log_probs[torch.arange(log_probs.shape[0]), taken_actions]

    return log_probs_of_taken_actions

player_log_probs = {
    i: compute_log_probs(i, player_action_batch_indices, player_masks, sim_table, sim_street_idxs, sim_table_position_idxs, sim_action_idxs, sim_pot_size_sequence, sim_active_players, sim_stack_size)
    for i in range(num_players)
}

In [None]:
def compute_loss_and_step(player_idx, player_log_probs, sim_table, optimizers):
    rewards = sim_table[player_idx][-1]  # reward appended earlier
    log_probs = player_log_probs[player_idx]
    loss = -(log_probs * rewards).mean()
    loss.backward()
    optimizers[player_idx].step()
    optimizers[player_idx].zero_grad()
    return loss.item()

losses = {
    i: compute_loss_and_step(i, player_log_probs, sim_table, optimizers)
    for i in range(num_players)
}