In [1]:
import torch
import numpy
from ml_modules import *
from sequence_modules import *
from llm_modules import *

from utils import *


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
street_embedder = StreetPositionalEncoding(
    num_streets = 4,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cpu"
)

table_position_embedder = TablePositionalEncoding(
    num_players = 2,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cpu"
)

action_embedder = ActionEncoding(
    #num_actions = 21,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cpu"
)

pot_size_embedder = PotSizeSequenceEmbedder(
    max_seq_len = 1024,
    pad_value = -1
)

In [3]:
street_idxs = torch.Tensor([
    [0, 0, 6, 6, 6, 6, 6, 6],
    [0, 0, 0, 6, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 4, 6, 6, 6], 
    [0, 0, 0, 0, 4, 0, 0, 6], 
    [0, 0, 0, 0, 4, 0, 0, 6], 
    [0, 0, 0, 0, 4, 0, 0, 6], 
    [0, 0, 0, 0, 4, 0, 0, 6], 
])
print(street_idxs.shape)
table_position_idxs = torch.Tensor([
    [0, 1, 2, 2, 2, 2, 2, 2],
    [0, 1, 0, 2, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 0, 1, 2],
    [0, 1, 0, 1, 2, 0, 1, 2],
    [0, 1, 0, 1, 2, 0, 1, 2],
    [0, 1, 0, 1, 2, 0, 1, 2],
    [0, 1, 2, 2, 2, 2, 2, 2],
])
print(table_position_idxs.shape)
action_idxs = torch.Tensor([
    [0, 1, 21, 21, 21, 21, 21, 21],
    [0, 1, 4, 21, 21, 21, 21, 21],
    [0, 1, 5, 4, 21, 21, 21, 21],
    [0, 1, 4, 3, 21, 21, 21, 21],
    [0, 1, 5, 6, 21, 21, 21, 21],
    [0, 1, 5, 2, 21, 21, 21, 21],
    [0, 1, 5, 4, 19, 21, 21, 21],
    [0, 1, 5, 4, 19, 3, 3, 21],
    [0, 1, 5, 4, 19, 5, 2, 21],
    [0, 1, 5, 4, 19, 5, 4, 21],
    [0, 1, 5, 4, 19, 3, 5, 21],
])
print(action_idxs.shape)
pot_size_sequence = torch.Tensor([
    [1, 3, -1, -1, -1, -1, -1, -1],
    [1, 3, 5, -1, -1, -1, -1, -1],
    [1, 3, 7, 8, -1, -1, -1, -1],
    [1, 3, 5, 5, -1, -1, -1, -1],
    [1, 3, 7, 15, -1, -1, -1, -1],
    [1, 3, 7, 7, -1, -1, -1, -1],
    [1, 3, 7, 8, 8, -1, -1, -1],
    [1, 3, 7, 8, 8, 8, 8, -1],
    [1, 3, 7, 8, 8, 10, 10, -1],
    [1, 3, 7, 8, 8, 10, 12, -1],
    [1, 3, 7, 8, 8, 8, 10, -1],
])
print(pot_size_sequence.shape)
active_players = torch.Tensor([
    [1, 1],
    [1, 1],
    [1, 1],
    [1, 1],
    [1, 1],
    [1, 0],
    [1, 1],
    [1, 1],
    [1, 0],
    [1, 1],
    [1, 1],
])
print(active_players.shape)

torch.Size([11, 8])
torch.Size([12, 8])
torch.Size([11, 8])
torch.Size([11, 8])
torch.Size([11, 2])


In [4]:
test = PokerActionValidator(
    num_players = 2,
    small_blind = 1,
    big_blind = 2,
    starting_stack_sizes = 400
)

In [5]:
legal_actions = test.get_legal_actions_mask(
    street_idxs,
    table_position_idxs,
    action_idxs, 
    pot_size_sequence,
    active_players
)

In [6]:
legal_actions[6]

tensor([False, False, False,  True, False,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True, False, False, False,
        False])

In [7]:
test.get_next_to_act(
    street_idxs,
    table_position_idxs,
    action_idxs, 
    active_players
)

tensor([ 0,  1, -2, -2,  0, -1,  0, -2, -1, -2,  0])

In [11]:
street_idxs = torch.Tensor([
    [0, 0, 6, 6, 6, 6, 6, 6],
    [0, 0, 0, 6, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 0, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 0, 4, 6, 6],
    [0, 0, 0, 0, 0, 4, 6, 6],
    [0, 0, 0, 0, 0, 4, 1, 6],
    [0, 0, 0, 0, 0, 4, 6, 6]
])
print(street_idxs.shape)
table_position_idxs = torch.Tensor([
    [0, 1, 3, 3, 3, 3, 3, 3],
    [0, 1, 2, 3, 3, 3, 3, 3],
    [0, 1, 2, 0, 3, 3, 3, 3],
    [0, 1, 2, 0, 1, 3, 3, 3],
    [0, 1, 2, 0, 3, 3, 3, 3],
    [0, 1, 2, 0, 1, 3, 3, 3],
    [0, 1, 2, 0, 1, 3, 3, 3],
    [0, 1, 2, 0, 3, 0, 3, 3],
    [0, 1, 2, 0, 2, 3, 3, 3],
])
print(table_position_idxs.shape)
action_idxs = torch.Tensor([
    [0, 1, 21, 21, 21, 21, 21, 21],
    [0, 1, 4, 21, 21, 21, 21, 21],
    [0, 1, 2, 2, 21, 21, 21, 21],
    [0, 1, 2, 4, 3, 21, 21, 21],
    [0, 1, 2, 4, 21, 21, 21, 21],
    [0, 1, 2, 4, 3, 19, 21, 21],
    [0, 1, 4, 2, 3, 19, 21, 21],
    [0, 1, 5, 2, 4, 19, 17, 21],
    [0, 1, 5, 2, 4, 19, 21, 21],
])
print(action_idxs.shape)
pot_size_sequence = torch.Tensor([
    [1, 3, -1, -1, -1, -1, -1, -1],
    [1, 3, 5, -1, -1, -1, -1, -1],
    [1, 3, 3, 3, -1, -1, -1, -1],
    [1, 3, 3, 4, 4, -1, -1, -1],
    [1, 3, 3, 4, -1, -1, -1, -1],
    [1, 3, 3, 4, 4, -1, -1, -1],
    [1, 3, 5, 5, 5, 5, -1, -1],
    [1, 3, 7, 7, 9, 9, 9, -1],
    [1, 3, 7, 7, 9, 9, -1, -1]
])
print(pot_size_sequence.shape)
active_players = torch.Tensor([
    [1, 1, 1],
    [1, 1, 1],
    [1, 0, 0],
    [1, 1, 0],
    [1, 1, 0],
    [1, 1, 0],
    [0, 1, 1],
    [0, 1, 1],
    [0, 1, 1]
])
print(active_players.shape)

torch.Size([9, 8])
torch.Size([9, 8])
torch.Size([9, 8])
torch.Size([9, 8])
torch.Size([9, 3])


In [12]:
test = PokerActionValidator(
    num_players = 3,
    small_blind = 1,
    big_blind = 2,
    starting_stack_sizes = 400
)

legal_actions = test.get_legal_actions_mask(
    street_idxs,
    table_position_idxs,
    action_idxs, 
    pot_size_sequence,
    active_players
)

In [13]:
legal_actions

tensor([[False, False,  True, False,  True, False,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True, False, False, False,
         False],
        [False, False,  True, False,  True, False,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True, False, False, False,
         False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
          True],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False,  True,
         False],
        [False, False, False,  True, False,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True, False, False, False,
         False],
        [False, False, False,  True, False,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  T

In [15]:
test.get_next_to_act(
    street_idxs,
    table_position_idxs,
    action_idxs, 
    active_players
)

tensor([ 2,  0, -1, -2,  1,  0,  1,  1,  1])

In [154]:
street_embedder = StreetPositionalEncoding(
    num_streets = 4,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

table_position_embedder = TablePositionalEncoding(
    num_players = 3,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

action_embedder = ActionEncoding(
    #num_actions = 21,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

pot_size_embedder = PotSizeSequenceEmbedder(
    max_seq_len = 1024,
    pad_value = -1
)

In [155]:
padded_pot_size_sequence = pot_size_embedder(pot_size_sequence)

In [156]:
street_idxs_out, street_embs = street_embedder(street_idxs)
street_embedding = {
    'street_idxs': street_idxs_out,
    'street_embedding': street_embs,
}

table_pos_idxs_out, table_pos_embs = table_position_embedder(table_position_idxs)
table_position_embedding = {
    'table_position_idxs': table_pos_idxs_out,
    'table_position_embedding': table_pos_embs,
}

action_idxs_out, action_embs = action_embedder(action_idxs)
action_embedding = {
    'action_idxs': action_idxs_out,
    'action_embedding': action_embs,
}

model_inputs = street_embedding | table_position_embedding | action_embedding | {'pot_size_sequence' : padded_pot_size_sequence.unsqueeze(2)}

In [157]:
legal_actions = test.get_legal_actions_mask(
    street_idxs,
    table_position_idxs,
    action_idxs, 
    pot_size_sequence,
    active_players
)

In [158]:
legal_actions

tensor([[False, False,  True, False,  True, False,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True, False, False, False,
         False],
        [False, False,  True, False,  True, False,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True, False, False, False,
         False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
          True],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False,  True,
         False],
        [False, False, False,  True, False,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True, False, False, False,
         False],
        [False, False, False,  True, False,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  T

In [159]:
model_inputs

{'street_idxs': tensor([[0, 0, 6,  ..., 6, 6, 6],
         [0, 0, 0,  ..., 6, 6, 6],
         [0, 0, 0,  ..., 6, 6, 6],
         ...,
         [0, 0, 0,  ..., 6, 6, 6],
         [0, 0, 0,  ..., 6, 6, 6],
         [0, 0, 0,  ..., 6, 6, 6]], device='cuda:0'),
 'street_embedding': tensor([[[ 0.1622, -0.5240,  0.1949,  ...,  0.5872,  0.6785, -0.2903],
          [ 0.1622, -0.5240,  0.1949,  ...,  0.5872,  0.6785, -0.2903],
          [-0.6400,  0.9649,  0.5568,  ..., -1.8426,  0.5610,  1.4187],
          ...,
          [-0.6400,  0.9649,  0.5568,  ..., -1.8426,  0.5610,  1.4187],
          [-0.6400,  0.9649,  0.5568,  ..., -1.8426,  0.5610,  1.4187],
          [-0.6400,  0.9649,  0.5568,  ..., -1.8426,  0.5610,  1.4187]],
 
         [[ 0.1622, -0.5240,  0.1949,  ...,  0.5872,  0.6785, -0.2903],
          [ 0.1622, -0.5240,  0.1949,  ...,  0.5872,  0.6785, -0.2903],
          [ 0.1622, -0.5240,  0.1949,  ...,  0.5872,  0.6785, -0.2903],
          ...,
          [-0.6400,  0.9649,  0.5568,  ..

In [160]:
from ml_modules import *

cards = Cards()

deck_order = torch.randperm(52)
card_embeddings = cards(deck_order%13, deck_order//13)
unexposed_card = cards(torch.Tensor([13]).to(dtype = torch.long), torch.Tensor([4]).to(dtype = torch.long))

card_embeddings.shape

torch.Size([52, 2048])

In [161]:
unexposed_card.tile([3,1]).shape

torch.Size([3, 2048])

In [162]:
card_idxs_all = torch.zeros((2, 52))

card_idxs_all[0, :] = deck_order%13
card_idxs_all[1, :] = deck_order//13

In [163]:
card_idxs_unexposed = torch.zeros((2, 5))
card_idxs_unexposed[0,:] = 13
card_idxs_unexposed[1,:] = 4

In [164]:
sb = Player(hole_cards = card_embeddings[:2], card_idxs = card_idxs_all[:, :2], position = 0, active_or_not = 1)
bb = Player(hole_cards = card_embeddings[2:4], card_idxs = card_idxs_all[:, 2:4], position = 1, active_or_not = 1)

board = Board(board_cards = unexposed_card.tile([5,1]), card_idxs = card_idxs_unexposed)

In [165]:
board.board_cards.shape

torch.Size([5, 2048])

In [166]:
class NonSequenceStateBuilder():
    """
    Grabs necessary elements to build the inputs for the part of the model that are not sequential
    """
    def __init__(
        self,
        players : typing.List[torch.nn.Module],
        board : torch.nn.Module
    ):
        self.players = players
        self.board = board

    def get_state(
        self,
        position : int
    ):
        active_players = torch.zeros(len(self.players))
        stack_size = torch.zeros(len(self.players))
        
        for index, player in enumerate(self.players):
            if player.position == position:
                card_idxs = torch.concat([player.card_idxs, board.card_idxs], dim = 1)
                card_embeddings = torch.concat([player.hole_cards, board.board_cards], axis = 0)
            active_players[index] = player.active_or_not
            stack_size[index] = player.stack_size

        return active_players, stack_size, card_idxs, card_embeddings
    

In [167]:
non_sequence_state_builder = NonSequenceStateBuilder(players = [sb, bb], board = board)

In [168]:
active_players, stack_size, card_idxs, card_embeddings = non_sequence_state_builder.get_state(0)

In [169]:
card_idxs

tensor([[ 6.,  4., 13., 13., 13., 13., 13.],
        [ 0.,  0.,  4.,  4.,  4.,  4.,  4.]])

In [170]:
board.card_idxs

tensor([[13., 13., 13., 13., 13.],
        [ 4.,  4.,  4.,  4.,  4.]])

In [171]:
policy_model = PolicyModel(
    num_players = 2,
    active_players_hidden_dims = [1024, 2048],
    stack_size_hidden_dims = [1024, 2048],
    card_embeddings_hidden_dims = [2048, 2048],
    final_output_hidden_dims = [1024, 512, 256]
)

In [172]:
policy_model

PolicyModel(
  (active_players_net): Sequential(
    (0): Linear(in_features=2, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=1024, out_features=2048, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=2048, out_features=2048, bias=True)
  )
  (stack_size_net): Sequential(
    (0): Linear(in_features=2, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=1024, out_features=2048, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=2048, out_features=2048, bias=True)
  )
  (card_phi_net): Sequential(
    (0): Linear(in_features=2048, out_features=2048, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=2048, out_features=2048, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=2048, out_feature

In [176]:
poker_sequence_embedder = PokerSequenceEmbedder(
    street_input_dimension = 256,
    table_position_input_dimension = 256,
    action_input_dimension = 256,
    latent_dimensions = [256, 512, 1024, 2048],
    device = 'cuda'
)

In [177]:
model_inputs['attention_mask'] = (model_inputs['pot_size_sequence'] != -1).squeeze(-1).to('cuda')

In [178]:
import torch
import gc

def print_gpu_memory(label=""):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9
        reserved = torch.cuda.memory_reserved() / 1e9
        print(f"{label}")
        print(f"  Allocated: {allocated:.2f} GB")
        print(f"  Reserved:  {reserved:.2f} GB")
        print()

# Run through your code with diagnostics
print_gpu_memory("1. Initial state")
model_name = "./models/qwen3-1point7b/"


tokenizer, model = load_model(model_name)

model

print_gpu_memory("2. After loading model")

activation = None
def hook(_, __, output):
    global activation
    activation = output

handle = model.model.layers[27].post_attention_layernorm.register_forward_hook(hook)
print_gpu_memory("3. After registering hook")

inputs_embeds = poker_sequence_embedder(model_inputs).to(device="cuda", dtype=torch.bfloat16)

outputs = model(inputs_embeds = inputs_embeds, attention_mask = model_inputs['attention_mask'])

1. Initial state
  Allocated: 30.31 GB
  Reserved:  30.91 GB



Loading checkpoint shards:  50%|████████████████████████████████████████████████████████████████████████▌                                                                        | 1/2 [00:00<00:00,  5.76it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 594.00 MiB. GPU 0 has a total capacity of 31.35 GiB of which 346.62 MiB is free. Including non-PyTorch memory, this process has 30.71 GiB memory in use. Of the allocated memory 30.12 GiB is allocated by PyTorch, and 4.38 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [34]:
activations_last_action = activation[torch.arange(activation.shape[0]), model_inputs['attention_mask'].sum(dim = 1) - 1, :]

In [35]:
model_inputs['llm_state'] = activations_last_action

In [36]:
model_inputs['active_players'], model_inputs['stack_size'], model_inputs['card_idx'], model_inputs['card_embeddings'] = non_sequence_state_builder.get_state(0)

In [37]:
model_inputs['active_players'] = model_inputs['active_players'].unsqueeze(0).tile([6,1])
model_inputs['stack_size'] = model_inputs['stack_size'].unsqueeze(0).tile([6,1])
model_inputs['card_idx'] = model_inputs['card_idx'].unsqueeze(0).tile([6,1])
model_inputs['card_embeddings'] = model_inputs['card_embeddings'].unsqueeze(0).tile([6,1, 1])


In [38]:
model_inputs['card_embeddings'].shape

torch.Size([6, 7, 2048])

In [39]:
policy_model(
    model_inputs['active_players'].to('cpu'),
    model_inputs['stack_size'].to('cpu'),
    model_inputs['card_embeddings'].to('cpu'),
    model_inputs['llm_state'].to('cpu')
)

tensor([[ 0.2215, -1.6380,  0.8224, -0.2931, -0.4615, -0.3304,  0.7759,  1.0198,
          1.0651,  0.5802,  1.1974,  2.7783, -0.3674,  0.0615,  0.1738,  0.1880,
          0.9507, -0.2243, -1.0851,  0.1643,  1.1359,  0.3762],
        [-0.0507,  0.0882,  1.1007,  0.4088, -1.0043, -0.6398, -0.1509,  1.3440,
          0.4139,  0.3767, -0.7673,  1.4464, -0.5913,  1.1919, -0.2116,  1.0773,
          1.0851, -1.6365, -0.4175,  0.9055, -0.5395, -0.3680],
        [-1.1351, -0.1650,  1.1157, -0.5770, -1.0403, -0.1165,  0.8512,  0.7444,
         -0.1126, -0.0769, -0.2974,  1.5301, -0.0351,  1.4647,  0.5745,  0.5605,
          1.1196, -0.1595, -0.4896, -0.4448,  0.5237,  0.4428],
        [ 0.3334, -0.4012,  1.0564, -0.1812, -0.0794,  0.3070,  0.2690,  1.0559,
          1.0841, -0.3084,  0.1607,  1.6215, -0.1194,  0.6907,  0.7449,  0.8607,
         -0.0980, -0.3606, -0.6934,  0.0763,  1.5088,  0.6603],
        [-0.8746,  0.1044,  1.8759, -0.1351, -1.4602, -0.1167,  0.7471,  0.2775,
          0.946

In [40]:
model

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (up_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (down_proj): Linear(in_features=6144, out_features=2048, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNorm((2048,), eps=1e-06)
        (post_attention_layer

In [41]:
import transformers

In [148]:
class PokerAgent(torch.nn.Module):
    """
    Full poker agent. Contains information about the cards,
    the players, the board, the sequence embedding, and the 
    probability prediction model.
    """
    def __init__(
        self,
        street_embedder : torch.nn.Module,
        table_position_embedder : torch.nn.Module,
        action_embedder : torch.nn.Module,
        pot_size_embedder : torch.nn.Module,
        llm : transformers.AutoModelForCausalLM,
        policy_model : torch.nn.Module,
        device : str = "cpu",
        llm_train : bool = False
    ):
        super().__init__()
        self.device = device
        self.llm_train = llm_train

        self.street_embedder = street_embedder.to(self.device)
        self.table_position_embedder = table_position_embedder.to(self.device)
        self.action_embedder = action_embedder.to(self.device)
        self.pot_size_embedder = pot_size_embedder.to(self.device)
        self.llm = llm
        self.policy_model = policy_model.to(self.device)

        if not llm_train:
            for parameter in self.llm.parameters():
                parameter.requires_grad = False

        def hook(_, __, output):
            global activation
            activation = output
        
        handle = self.llm.model.layers[27].post_attention_layernorm.register_forward_hook(hook)
       
    def forward(
        self,
        street_idxs : torch.Tensor,
        table_position_idxs : torch.Tensor,
        action_idxs : torch.Tensor,
        pot_size_sequence : torch.Tensor,
        active_players : torch.Tensor,
        stack_size : torch.Tensor,
        card_embeddings : torch.Tensor
    ):
        street_idxs_out, street_embs = self.street_embedder(street_idxs)
        street_embedding = {
            'street_idxs': street_idxs_out,
            'street_embedding': street_embs,
        }
        
        table_pos_idxs_out, table_pos_embs = self.table_position_embedder(table_position_idxs)
        table_position_embedding = {
            'table_position_idxs': table_pos_idxs_out,
            'table_position_embedding': table_pos_embs,
        }
        
        action_idxs_out, action_embs = self.action_embedder(action_idxs)
        action_embedding = {
            'action_idxs': action_idxs_out,
            'action_embedding': action_embs,
        }
        padded_pot_size_sequence = self.pot_size_embedder(pot_size_sequence)
        
        model_inputs = (
            street_embedding 
            | 
            table_position_embedding 
            | 
            action_embedding 
            | 
            {'pot_size_sequence' : padded_pot_size_sequence.unsqueeze(2)}
            |
            {
                'active_players' : active_players,
                'stack_size' : stack_size,
                'card_embeddings' : card_embeddings
            }
        )

        model_inputs['attention_mask'] = (model_inputs['pot_size_sequence'] != -1).squeeze(-1).to('cuda')

        activation = None

        inputs_embeds = self.poker_sequence_embedder(model_inputs).to(device="cuda", dtype=torch.bfloat16)
        
        outputs = self.llm(inputs_embeds = inputs_embeds, attention_mask = model_inputs['attention_mask'])
        
        activations_last_action = activation[
            torch.arange(activation.shape[0]), 
            model_inputs['attention_mask'].sum(dim = 1) - 1, 
            :
        ]
        
        model_inputs['llm_state'] = activations_last_action

        model_inputs['probits'] = self.policy_model(
                model_inputs['active_players'].to(self.device),
                model_inputs['stack_size'].to(self.device),
                model_inputs['card_embeddings'].to(self.device),
                model_inputs['llm_state'].to(self.device)
        )

        return model_inputs

In [149]:
poker_player = PokerAgent(
    street_embedder.to('cuda'),
    table_position_embedder.to('cuda'),
    action_embedder.to('cuda'),
    pot_size_embedder.to('cuda'),
    model,
    policy_model.to('cuda'),
    device = 'cuda',
    llm_train = False
)

In [150]:
poker_player.device

'cuda'

In [152]:
poker_player.policy_model

PolicyModel(
  (active_players_net): Sequential(
    (0): Linear(in_features=2, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=1024, out_features=2048, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=2048, out_features=2048, bias=True)
  )
  (stack_size_net): Sequential(
    (0): Linear(in_features=2, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=1024, out_features=2048, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=2048, out_features=2048, bias=True)
  )
  (card_phi_net): Sequential(
    (0): Linear(in_features=2048, out_features=2048, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=2048, out_features=2048, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=2048, out_feature

In [116]:
table_position_idxs

tensor([[0., 1., 3., 3., 3., 3., 3., 3.],
        [0., 1., 2., 3., 3., 3., 3., 3.],
        [0., 1., 2., 0., 3., 3., 3., 3.],
        [0., 1., 2., 0., 1., 3., 3., 3.],
        [0., 1., 2., 0., 3., 3., 3., 3.],
        [0., 1., 2., 0., 1., 3., 3., 3.],
        [0., 1., 2., 0., 1., 3., 3., 3.],
        [0., 1., 2., 0., 3., 0., 3., 3.]])

In [117]:
action_idxs

tensor([[ 0.,  1., 21., 21., 21., 21., 21., 21.],
        [ 0.,  1.,  4., 21., 21., 21., 21., 21.],
        [ 0.,  1.,  2.,  2., 21., 21., 21., 21.],
        [ 0.,  1.,  2.,  4.,  3., 21., 21., 21.],
        [ 0.,  1.,  2.,  4., 21., 21., 21., 21.],
        [ 0.,  1.,  2.,  4.,  3., 19., 21., 21.],
        [ 0.,  1.,  4.,  2.,  3., 19., 21., 21.],
        [ 0.,  1.,  5.,  2.,  4., 19., 17., 21.]])

In [118]:
pot_size_sequence

tensor([[ 1.,  3., -1., -1., -1., -1., -1., -1.],
        [ 1.,  3.,  5., -1., -1., -1., -1., -1.],
        [ 1.,  3.,  3.,  3., -1., -1., -1., -1.],
        [ 1.,  3.,  3.,  4.,  4., -1., -1., -1.],
        [ 1.,  3.,  3.,  4., -1., -1., -1., -1.],
        [ 1.,  3.,  3.,  4.,  4., -1., -1., -1.],
        [ 1.,  3.,  5.,  5.,  5.,  5., -1., -1.],
        [ 1.,  3.,  7.,  7.,  9.,  9.,  9., -1.]])

In [119]:
active_players

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 0.],
        [1., 1., 0.],
        [0., 1., 1.],
        [0., 1., 1.]])

In [120]:
stack_size = torch.Tensor(
    [
        [399, 398, 400],
        [399, 398, 398],
        [399, 398, 400],
        [398, 398, 400],
        [398, 398, 400],
        [398, 398, 400],
        [399, 398, 398],
        [399, 396, 396]
    ]
)

stack_size.shape

torch.Size([8, 3])

In [121]:
deck_order_shuffled = torch.argsort(torch.rand(8, 52))

In [122]:
deck_order_unshuffled = torch.arange(52).tile(8, 1)

In [123]:
torch.randperm(deck_order_unshuffled.shape[0])

tensor([7, 2, 0, 6, 3, 4, 1, 5])

In [126]:
card_embeddings = cards(deck_order_shuffled%13, deck_order_shuffled//13)

In [127]:
card_unshown_embedding = cards(
    torch.Tensor([[13]]).to(dtype = torch.long), torch.Tensor([[4]]).to(dtype = torch.long)
)

In [128]:
card_unshown_embedding = card_unshown_embedding.tile([8, 5, 1])

In [129]:
card_unshown_embedding

tensor([[[ 0.1036,  1.7986, -1.5657,  ..., -0.2134,  1.0508, -0.6714],
         [ 0.1036,  1.7986, -1.5657,  ..., -0.2134,  1.0508, -0.6714],
         [ 0.1036,  1.7986, -1.5657,  ..., -0.2134,  1.0508, -0.6714],
         [ 0.1036,  1.7986, -1.5657,  ..., -0.2134,  1.0508, -0.6714],
         [ 0.1036,  1.7986, -1.5657,  ..., -0.2134,  1.0508, -0.6714]],

        [[ 0.1036,  1.7986, -1.5657,  ..., -0.2134,  1.0508, -0.6714],
         [ 0.1036,  1.7986, -1.5657,  ..., -0.2134,  1.0508, -0.6714],
         [ 0.1036,  1.7986, -1.5657,  ..., -0.2134,  1.0508, -0.6714],
         [ 0.1036,  1.7986, -1.5657,  ..., -0.2134,  1.0508, -0.6714],
         [ 0.1036,  1.7986, -1.5657,  ..., -0.2134,  1.0508, -0.6714]],

        [[ 0.1036,  1.7986, -1.5657,  ..., -0.2134,  1.0508, -0.6714],
         [ 0.1036,  1.7986, -1.5657,  ..., -0.2134,  1.0508, -0.6714],
         [ 0.1036,  1.7986, -1.5657,  ..., -0.2134,  1.0508, -0.6714],
         [ 0.1036,  1.7986, -1.5657,  ..., -0.2134,  1.0508, -0.6714],
  

In [130]:
card_embeddings[:, :2, :].shape

torch.Size([8, 2, 2048])

In [131]:
cards_player_0_embeddings = torch.concat([card_embeddings[:, :2, :], card_unshown_embedding], dim = 1)

In [112]:
poker_player = poker_player.to('cuda')

In [136]:
poker_player.street_embedder

StreetPositionalEncoding(
  (street_embedder): Embedding(7, 256)
)

In [146]:
street_embedder.to('cuda')

StreetPositionalEncoding(
  (street_embedder): Embedding(7, 256)
)

In [153]:
poker_player(
    street_idxs.to('cuda'),
    table_position_idxs.to('cuda'),
    action_idxs.to('cuda'), 
    pot_size_sequence.to('cuda'),
    active_players.to('cuda'),
    stack_size.to('cuda'),
    cards_player_0_embeddings.to('cuda')
)


RuntimeError: Expected all tensors to be on the same device, but got index is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA__index_select)

In [None]:
street_embedder = StreetPositionalEncoding(
    num_streets = 4,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cpu"
)

table_position_embedder = TablePositionalEncoding(
    num_players = 2,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cpu"
)

action_embedder = ActionEncoding(
    #num_actions = 21,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cpu"
)

pot_size_embedder = PotSizeSequenceEmbedder(
    max_seq_len = 1024,
    pad_value = -1
)