In [1]:
import torch
import numpy
from ml_modules import *
from sequence_modules import *
from llm_modules import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
street_embedder = StreetPositionalEncoding(
    num_streets = 4,
    embedding_dim = 256,
    max_seq_len = 128,
    device = "cpu"
)

table_position_embedder = TablePositionalEncoding(
    num_players = 2,
    embedding_dim = 256,
    max_seq_len = 128,
    device = "cpu"
)

action_embedder = ActionEncoding(
    #num_actions = 21,
    embedding_dim = 256,
    max_seq_len = 128,
    device = "cpu"
)

pot_size_embedder = PotSizeSequenceEmbedder(
    max_seq_len = 128,
    pad_value = -1
)

In [3]:
street_idxs = torch.Tensor([
    [0, 0, 6, 6, 6, 6, 6, 6],
    [0, 0, 0, 6, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 4, 6, 6, 6], 
    [0, 0, 0, 0, 4, 0, 0, 6], 
    [0, 0, 0, 0, 4, 0, 0, 6], 
    [0, 0, 0, 0, 4, 0, 0, 6], 
    [0, 0, 0, 0, 4, 0, 0, 6], 
    
])
table_position_idxs = torch.Tensor([
    [0, 1, 2, 2, 2, 2, 2, 2],
    [0, 1, 0, 2, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 0, 1, 2],
    [0, 1, 0, 1, 2, 0, 1, 2],
    [0, 1, 0, 1, 2, 0, 1, 2],
    [0, 1, 0, 1, 2, 0, 1, 2],

])
action_idxs = torch.Tensor([
    [0, 1, 21, 21, 21, 21, 21, 21],
    [0, 1, 4, 21, 21, 21, 21, 21],
    [0, 1, 5, 4, 21, 21, 21, 21],
    [0, 1, 4, 3, 21, 21, 21, 21],
    [0, 1, 5, 6, 21, 21, 21, 21],
    [0, 1, 5, 2, 21, 21, 21, 21],
    [0, 1, 5, 4, 19, 21, 21, 21],
    [0, 1, 5, 4, 19, 3, 3, 21],
    [0, 1, 5, 4, 19, 5, 2, 21],
    [0, 1, 5, 4, 19, 5, 4, 21],
    [0, 1, 5, 4, 19, 3, 5, 21],
])
pot_size_sequence = torch.Tensor([
    [1, 3, -1, -1, -1, -1, -1, -1],
    [1, 3, 5, -1, -1, -1, -1, -1],
    [1, 3, 7, 8, -1, -1, -1, -1],
    [1, 3, 5, 5, -1, -1, -1, -1],
    [1, 3, 7, 15, -1, -1, -1, -1],
    [1, 3, 7, 7, -1, -1, -1, -1],
    [1, 3, 7, 8, 8, -1, -1, -1],
    [1, 3, 7, 8, 8, 8, 8, -1],
    [1, 3, 7, 8, 8, 10, 10, -1],
    [1, 3, 7, 8, 8, 10, 12, -1],
    [1, 3, 7, 8, 8, 8, 10, -1],
])
active_players = torch.Tensor([
    [1, 1],
    [1, 1],
    [1, 1],
    [1, 1],
    [1, 1],
    [1, 0],
    [1, 1],
    [1, 1],
    [1, 0],
    [1, 1],
    [1, 1],
])

In [4]:
padded_pot_size_sequence = pot_size_embedder(pot_size_sequence)

In [5]:
street_idxs_out, street_embs = street_embedder(street_idxs)
street_embedding = {
    'street_idxs': street_idxs_out,
    'street_embedding': street_embs,
}

table_pos_idxs_out, table_pos_embs = table_position_embedder(table_position_idxs)
table_position_embedding = {
    'table_position_idxs': table_pos_idxs_out,
    'table_position_embedding': table_pos_embs,
}

action_idxs_out, action_embs = action_embedder(action_idxs)
action_embedding = {
    'action_idxs': action_idxs_out,
    'action_embedding': action_embs,
}

model_inputs = street_embedding | table_position_embedding | action_embedding | {'pot_size_sequence' : padded_pot_size_sequence.unsqueeze(2)}

In [6]:
poker_sequence_embedder = PokerSequenceEmbedder(
    street_input_dimension = 256,
    table_position_input_dimension = 256,
    action_input_dimension = 256,
    latent_dimensions = [256, 512, 1024, 2048],
    device = 'cpu'
)

In [7]:
model_inputs['attention_mask'] = (model_inputs['pot_size_sequence'] != -1).squeeze(-1).to('cuda')

In [8]:
!nvidia-smi

Thu Feb  5 21:26:30 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.119.02             Driver Version: 580.119.02     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5090        Off |   00000000:01:00.0  On |                  N/A |
|  0%   24C    P1             35W /  450W |     813MiB /  32607MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [9]:
import torch
import gc

def print_gpu_memory(label=""):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9
        reserved = torch.cuda.memory_reserved() / 1e9
        print(f"{label}")
        print(f"  Allocated: {allocated:.2f} GB")
        print(f"  Reserved:  {reserved:.2f} GB")
        print()

# Run through your code with diagnostics
print_gpu_memory("1. Initial state")
model_name = "./models/qwen3-1point7b/"


tokenizer, model = load_model(model_name)

model

print_gpu_memory("2. After loading model")

activation = None
def hook(_, __, output):
    global activation
    activation = output

handle = model.model.layers[27].post_attention_layernorm.register_forward_hook(hook)
print_gpu_memory("3. After registering hook")

inputs_embeds = poker_sequence_embedder(model_inputs).to(device="cuda", dtype=torch.bfloat16)

outputs = model(inputs_embeds = inputs_embeds, attention_mask = model_inputs['attention_mask'])

# Did you actually RUN the model?
# If yes, that creates activations
# If no, the hook never fires

print_gpu_memory("4. After running one pass")


handle.remove()
print_gpu_memory("5. After handle.remove()")

del activation
print_gpu_memory("6. After del activation")

del tokenizer
print_gpu_memory("7. After del tokenizer")

del model
print_gpu_memory("8. After del model")

del outputs
print_gpu_memory("9. After del outputs")

gc.collect()
print_gpu_memory("10. After gc.collect()")

torch.cuda.empty_cache()
print_gpu_memory("11. After empty_cache()")

1. Initial state
  Allocated: 0.00 GB
  Reserved:  0.00 GB



`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.31it/s]
The module name  (originally ) is not a valid Python identifier. Please rename the original module to avoid import issues.


2. After loading model
  Allocated: 3.44 GB
  Reserved:  4.09 GB

3. After registering hook
  Allocated: 3.44 GB
  Reserved:  4.09 GB

4. After running one pass
  Allocated: 8.85 GB
  Reserved:  9.05 GB

5. After handle.remove()
  Allocated: 8.85 GB
  Reserved:  9.05 GB

6. After del activation
  Allocated: 8.85 GB
  Reserved:  9.05 GB

7. After del tokenizer
  Allocated: 8.85 GB
  Reserved:  9.05 GB

8. After del model
  Allocated: 8.85 GB
  Reserved:  9.05 GB

9. After del outputs
  Allocated: 0.01 GB
  Reserved:  9.05 GB

10. After gc.collect()
  Allocated: 0.01 GB
  Reserved:  9.05 GB

11. After empty_cache()
  Allocated: 0.01 GB
  Reserved:  0.02 GB



In [10]:
import torch
import gc

def print_gpu_memory(label=""):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9
        reserved = torch.cuda.memory_reserved() / 1e9
        print(f"{label}")
        print(f"  Allocated: {allocated:.2f} GB")
        print(f"  Reserved:  {reserved:.2f} GB")
        print()

# Run through your code with diagnostics
print_gpu_memory("1. Initial state")
model_name = "./models/qwen3-1point7b/"


tokenizer, model = load_model(model_name)

model

print_gpu_memory("2. After loading model")

activation = None
def hook(_, __, output):
    global activation
    activation = output

handle = model.model.layers[27].post_attention_layernorm.register_forward_hook(hook)
print_gpu_memory("3. After registering hook")

inputs_embeds = poker_sequence_embedder(model_inputs).to(device="cuda", dtype=torch.bfloat16)

outputs = model(inputs_embeds = inputs_embeds, attention_mask = model_inputs['attention_mask'])

1. Initial state
  Allocated: 0.01 GB
  Reserved:  0.02 GB



Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.49it/s]
The module name  (originally ) is not a valid Python identifier. Please rename the original module to avoid import issues.


2. After loading model
  Allocated: 3.46 GB
  Reserved:  4.10 GB

3. After registering hook
  Allocated: 3.46 GB
  Reserved:  4.10 GB



In [31]:
print_gpu_memory("3. After registering hook")

inputs_embeds = poker_sequence_embedder(model_inputs).to(device="cuda", dtype=torch.bfloat16)

outputs = model(inputs_embeds = inputs_embeds, attention_mask = model_inputs['attention_mask'].to(dtype=torch.int))

print_gpu_memory("11. After empty_cache()")

3. After registering hook
  Allocated: 13.55 GB
  Reserved:  14.65 GB

11. After empty_cache()
  Allocated: 18.26 GB
  Reserved:  19.56 GB



In [32]:
activation[0][6]

tensor([-0.3301, -4.8750, -0.4395,  ..., -0.2412,  0.7422, -1.0547],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<SelectBackward0>)

In [33]:
model_inputs['action_idxs']

tensor([[ 0,  1, 21,  ..., 21, 21, 21],
        [ 0,  1,  4,  ..., 21, 21, 21],
        [ 0,  1,  5,  ..., 21, 21, 21],
        ...,
        [ 0,  1,  5,  ..., 21, 21, 21],
        [ 0,  1,  5,  ..., 21, 21, 21],
        [ 0,  1,  5,  ..., 21, 21, 21]])

In [48]:
activation[0,1]

tensor([-0.1543, -3.5781, -0.5195,  ..., -1.0234,  2.2500, -1.4844],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<SelectBackward0>)

In [46]:
activation[torch.arange(11), model_inputs['attention_mask'].sum(dim=1)-1, :]

tensor([[-0.1543, -3.5781, -0.5195,  ..., -1.0234,  2.2500, -1.4844],
        [-0.2393, -3.2812, -0.0923,  ..., -1.0859,  2.1719, -1.3750],
        [ 0.5859, -3.2500, -0.0532,  ..., -2.1406,  1.6484, -1.0234],
        ...,
        [ 1.0781, -2.7812,  0.0723,  ..., -2.3594,  0.6836, -0.7383],
        [ 1.3438, -2.8750, -0.5234,  ..., -1.7891,  0.8516,  0.2217],
        [ 1.6641, -2.9844, -0.8633,  ..., -1.7031,  1.0469,  0.0581]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<IndexBackward0>)

In [36]:
activation

tensor([[[ 3.3438, -3.6406, -1.5703,  ...,  2.3438,  0.9375, -2.9844],
         [-0.1543, -3.5781, -0.5195,  ..., -1.0234,  2.2500, -1.4844],
         [-0.0938, -4.3438, -0.4453,  ..., -0.7461,  1.0938, -1.1562],
         ...,
         [ 0.3613, -4.8125, -0.5898,  ..., -0.0058,  1.1719, -0.5430],
         [ 0.3848, -4.6250, -0.7500,  ...,  0.0688,  1.2422, -0.5078],
         [ 0.4082, -4.5000, -0.7344,  ..., -0.2129,  1.2109, -0.4824]],

        [[ 3.3438, -3.6406, -1.5703,  ...,  2.3438,  0.9375, -2.9844],
         [-0.1543, -3.5781, -0.5195,  ..., -1.0234,  2.2500, -1.4844],
         [-0.2393, -3.2812, -0.0923,  ..., -1.0859,  2.1719, -1.3750],
         ...,
         [-0.0703, -4.1875, -0.7031,  ..., -0.1641,  1.5859, -0.3633],
         [-0.0253, -4.0625, -0.9062,  ..., -0.0459,  1.5469, -0.4590],
         [ 0.1592, -3.8750, -1.0078,  ..., -0.4453,  1.5703, -0.4668]],

        [[ 3.3438, -3.6406, -1.5703,  ...,  2.3438,  0.9375, -2.9844],
         [-0.1543, -3.5781, -0.5195,  ..., -1

In [49]:
from ml_modules import *

cards = Cards()

deck_order = torch.randperm(52)
card_embeddings = cards(deck_order%13, deck_order//13)
unexposed_card = cards(torch.Tensor([13]).to(dtype = torch.long), torch.Tensor([4]).to(dtype = torch.long))

In [50]:
card_embeddings.shape

torch.Size([52, 2048])

In [51]:
unexposed_card.tile([3,1]).shape

torch.Size([3, 2048])

In [52]:
card_embeddings.shape

torch.Size([52, 2048])

In [53]:
unexposed_card.shape

torch.Size([1, 2048])

In [54]:
sb = Player(hole_cards = card_embeddings[:2], position = 0, folded_or_not = 0)
bb = Player(hole_cards = card_embeddings[2:4], position = 1, folded_or_not = 0)

board = Board(board_cards = unexposed_card.tile([3,1]))

In [55]:
players = [sb, bb]

In [56]:
active_players = torch.Tensor([elem.folded_or_not for elem in players]).to(dtype = torch.bfloat16)
stack_sizes = torch.Tensor([elem.stack_size for elem in players]).to(dtype = torch.bfloat16)

In [57]:
torch.concat([sb.hole_cards, board.board_cards], axis=0).shape

torch.Size([5, 2048])

In [58]:
deck_order = torch.randperm(52)

In [59]:
((deck_order//13)==0).sum()

tensor(13)

In [60]:
((deck_order%13)==12).sum()

tensor(4)

In [61]:
model_inputs['llm_state'] = activation

In [62]:
active_players = torch.concat([active_players, stack_sizes], axis=0).unsqueeze(0)

In [63]:
active_players.shape

torch.Size([1, 4])

In [64]:
model_inputs['pot_size_sequence'].shape

torch.Size([11, 128, 1])

In [74]:
model_inputs['pot_size_sequence'].max(axis=1).values

tensor([[ 3.],
        [ 5.],
        [ 8.],
        [ 5.],
        [15.],
        [ 7.],
        [ 8.],
        [ 8.],
        [10.],
        [12.],
        [10.]])

In [75]:
model_inputs['card_state'] = torch.concat([sb.hole_cards, board.board_cards], axis=0).tile([11,1,1])
model_inputs['active_players_state'] = active_players.tile([11,1])

In [64]:
model_inputs

{'street_idxs': tensor([[0, 0, 6,  ..., 6, 6, 6],
         [0, 0, 0,  ..., 6, 6, 6],
         [0, 0, 0,  ..., 6, 6, 6],
         ...,
         [0, 0, 0,  ..., 6, 6, 6],
         [0, 0, 0,  ..., 6, 6, 6],
         [0, 0, 0,  ..., 6, 6, 6]]),
 'street_embedding': tensor([[[ 0.9827, -0.5144,  1.3626,  ...,  0.8392, -0.2706, -0.8136],
          [ 0.9827, -0.5144,  1.3626,  ...,  0.8392, -0.2706, -0.8136],
          [ 0.2170,  1.7004, -0.0420,  ...,  0.5870, -0.6570, -0.0339],
          ...,
          [ 0.2170,  1.7004, -0.0420,  ...,  0.5870, -0.6570, -0.0339],
          [ 0.2170,  1.7004, -0.0420,  ...,  0.5870, -0.6570, -0.0339],
          [ 0.2170,  1.7004, -0.0420,  ...,  0.5870, -0.6570, -0.0339]],
 
         [[ 0.9827, -0.5144,  1.3626,  ...,  0.8392, -0.2706, -0.8136],
          [ 0.9827, -0.5144,  1.3626,  ...,  0.8392, -0.2706, -0.8136],
          [ 0.9827, -0.5144,  1.3626,  ...,  0.8392, -0.2706, -0.8136],
          ...,
          [ 0.2170,  1.7004, -0.0420,  ...,  0.5870, -0.65

In [62]:
activation.shape

torch.Size([11, 128, 2048])

In [None]:
model_inputs[''] = 

In [None]:
class PolicyMaker(torch.nn.Module):
    """
    Given the embedding from an LLM of the sequence
    of the poker hand, information about the what cards are available,
    information about the stack sizes and information about
    who is active or inactive in the hand, returns two probability distributions
    over the possible actions, and the mask over the action space: one is a best response
    policy (trained by RL) and the other is a learned average policy (trained by SL).
    """
    def __init__(
        self,
        num_players : int,
        poker_sequence_embedder : torch.nn.Module,
        card_latent_dims: typing.List[int],
        active_players_latent_dims: typing.List[int],
        final_decision_latent_dims: typing.List[int],
        device: str = 'cpu'
    ):
        
        self.num_players = num_players
        self.active_players_latent_dims = active_players_latent_dims
        self.card_latent_dims = card_latent_dims
        self.final_decision_latent_dims = final_decision_latent_dims
        
        self.poker_sequence_embedder = poker_sequence_embedder

        self.device = device

        def make_mlp(input_dim: int, latent_dimensions : typing.List[int]) -> torch.nn.Sequential:
            dims = [input_dim] + latent_dimensions
            layers = []
            for i in range(len(dims) - 1):
                layers.append(torch.nn.Linear(dims[i], dims[i + 1]))
                layers.append(torch.nn.LayerNorm(dims[i + 1]))
                if i < len(dims) - 2:  # no ReLU on final layer
                    layers.append(torch.nn.ReLU())
            return torch.nn.Sequential(*layers).to(self.device)

        self.active_players_mlp = make_mlp(2 * self.num_players, self.active_players_latent_dims) #active or not and stack size
        self.cards_mlp = make_mlp(5, self.card_latent_dims) # 5 cards in poker
        self.final_decision_mlp = make_mlp(2048*3, self.final_decision_latent_dims)

    def forward(
        self,
        model_inputs,
        ,
        board
    ) -> torch.Tensor:
        llm_state = model_inputs['llm_state']
        active_players = model_inputs['active_players_state']
        card_state = model_inputs['']
        
        
        
        

In [65]:
model_inputs.keys()

dict_keys(['street_idxs', 'street_embedding', 'table_position_idxs', 'table_position_embedding', 'action_idxs', 'action_embedding', 'pot_size_sequence', 'attention_mask', 'llm_state', 'card_state', 'active_players_state'])