In [1]:
!nvidia-smi

Sat Feb  7 21:09:59 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.119.02             Driver Version: 580.119.02     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5090        Off |   00000000:01:00.0  On |                  N/A |
|  0%   24C    P8             23W /  450W |     304MiB /  32607MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [2]:
import torch
import numpy
from ml_modules import *
from sequence_modules import *
from llm_modules import *

from utils import *

import transformers


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
street_embedder = StreetPositionalEncoding(
    num_streets = 4,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

table_position_embedder = TablePositionalEncoding(
    num_players = 3,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

action_embedder = ActionEncoding(
    #num_actions = 21,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

pot_size_embedder = PotSizeSequenceEmbedder(
    max_seq_len = 1024,
    pad_value = -1,
    device = 'cuda'
)

poker_sequence_embedder = PokerSequenceEmbedder(
    street_input_dimension = 256,
    table_position_input_dimension = 256,
    action_input_dimension = 256,
    latent_dimensions = [256, 512, 1024, 2048],
    device = 'cuda'
)



In [4]:
street_idxs = torch.Tensor([
    [0, 0, 6, 6, 6, 6, 6, 6],
    [0, 0, 0, 6, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 0, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 0, 4, 6, 6],
    [0, 0, 0, 0, 0, 4, 6, 6],
    [0, 0, 0, 0, 0, 4, 1, 6],
])
print(street_idxs.shape)
table_position_idxs = torch.Tensor([
    [0, 1, 3, 3, 3, 3, 3, 3],
    [0, 1, 2, 3, 3, 3, 3, 3],
    [0, 1, 2, 0, 3, 3, 3, 3],
    [0, 1, 2, 0, 1, 3, 3, 3],
    [0, 1, 2, 0, 3, 3, 3, 3],
    [0, 1, 2, 0, 1, 3, 3, 3],
    [0, 1, 2, 0, 1, 3, 3, 3],
    [0, 1, 2, 0, 3, 0, 3, 3],
])
print(table_position_idxs.shape)
action_idxs = torch.Tensor([
    [0, 1, 21, 21, 21, 21, 21, 21],
    [0, 1, 4, 21, 21, 21, 21, 21],
    [0, 1, 2, 2, 21, 21, 21, 21],
    [0, 1, 2, 4, 3, 21, 21, 21],
    [0, 1, 2, 4, 21, 21, 21, 21],
    [0, 1, 2, 4, 3, 19, 21, 21],
    [0, 1, 4, 2, 3, 19, 21, 21],
    [0, 1, 6, 2, 4, 19, 17, 21]
])
print(action_idxs.shape)
pot_size_sequence = torch.Tensor([
    [1, 3, -1, -1, -1, -1, -1, -1],
    [1, 3, 5, -1, -1, -1, -1, -1],
    [1, 3, 3, 3, -1, -1, -1, -1],
    [1, 3, 3, 4, 4, -1, -1, -1],
    [1, 3, 3, 4, -1, -1, -1, -1],
    [1, 3, 3, 4, 4, -1, -1, -1],
    [1, 3, 5, 5, 5, 5, -1, -1],
    [1, 3, 7, 7, 9, 9, 9, -1]
])
print(pot_size_sequence.shape)
active_players = torch.Tensor([
    [1, 1, 1],
    [1, 1, 1],
    [1, 0, 0],
    [1, 1, 0],
    [1, 1, 0],
    [1, 1, 0],
    [0, 1, 1],
    [0, 1, 1]
])
print(active_players.shape)
stack_size = torch.Tensor(
    [
        [399, 398, 400],
        [399, 398, 398],
        [399, 398, 400],
        [398, 398, 400],
        [398, 398, 400],
        [398, 398, 400],
        [399, 398, 398],
        [399, 396, 396]
    ]
)
stack_size.shape

torch.Size([8, 8])
torch.Size([8, 8])
torch.Size([8, 8])
torch.Size([8, 8])
torch.Size([8, 3])


torch.Size([8, 3])

In [5]:
from ml_modules import *

cards = Cards()

deck_order_shuffled = torch.argsort(torch.rand(8, 52))

card_embeddings = cards(deck_order_shuffled%13, deck_order_shuffled//13)

card_unshown_embedding = cards(
    torch.Tensor([[13]]).to(dtype = torch.long), torch.Tensor([[4]]).to(dtype = torch.long)
)

card_unshown_embedding = card_unshown_embedding.tile([8, 5, 1])

cards_player_0_embeddings = torch.concat([card_embeddings[:, :2, :], card_unshown_embedding], dim = 1)
cards_player_1_embeddings = torch.concat([card_embeddings[:, 2:4, :], card_unshown_embedding], dim = 1)
cards_player_2_embeddings = torch.concat([card_embeddings[:, 4:6, :], card_unshown_embedding], dim = 1)


In [6]:
class PokerAgent(torch.nn.Module):
    """
    Full poker agent. Contains information about the cards,
    the players, the board, the sequence embedding, and the 
    probability prediction model.
    """
    def __init__(
        self,
        street_embedder : torch.nn.Module,
        table_position_embedder : torch.nn.Module,
        action_embedder : torch.nn.Module,
        pot_size_embedder : torch.nn.Module,
        llm : transformers.AutoModelForCausalLM,
        policy_model : torch.nn.Module,
        device : str = "cpu",
        llm_train : bool = False
    ):
        super().__init__()
        self.device = device
        self.llm_train = llm_train
        self.street_embedder = street_embedder
        self.table_position_embedder = table_position_embedder
        self.action_embedder = action_embedder
        self.pot_size_embedder = pot_size_embedder
        self.poker_sequence_embedder = poker_sequence_embedder  # This was missing the definition
        self.llm = llm
        self.policy_model = policy_model
        
        if not llm_train:
            for parameter in self.llm.parameters():
                parameter.requires_grad = False
        
        # Remove the hook from __init__ - it's defined twice and the first one does nothing
       
    def forward(
        self,
        street_idxs : torch.Tensor,
        table_position_idxs : torch.Tensor,
        action_idxs : torch.Tensor,
        pot_size_sequence : torch.Tensor,
        active_players : torch.Tensor,
        stack_size : torch.Tensor,
        card_embeddings : torch.Tensor
    ):
        street_idxs_out, street_embs = self.street_embedder(street_idxs)
        street_embedding = {
            'street_idxs': street_idxs_out,
            'street_embedding': street_embs,
        }
        
        table_pos_idxs_out, table_pos_embs = self.table_position_embedder(table_position_idxs)
        table_position_embedding = {
            'table_position_idxs': table_pos_idxs_out,
            'table_position_embedding': table_pos_embs,
        }
        
        action_idxs_out, action_embs = self.action_embedder(action_idxs)
        action_embedding = {
            'action_idxs': action_idxs_out,
            'action_embedding': action_embs,
        }
        padded_pot_size_sequence = self.pot_size_embedder(pot_size_sequence)
        
        model_inputs = (
            street_embedding 
            | 
            table_position_embedding 
            | 
            action_embedding 
            | 
            {'pot_size_sequence' : padded_pot_size_sequence.unsqueeze(2)}
            |
            {
                'active_players' : active_players,
                'stack_size' : stack_size,
                'card_embeddings' : card_embeddings
            }
        )
        model_inputs['attention_mask'] = (model_inputs['pot_size_sequence'] != -1).squeeze(-1).to('cuda')
        
        # Use a dict to capture the activation (mutable object)
        activation_cache = {'activation': None}
        
        def hook(module, input, output):
            activation_cache['activation'] = output
        
        handle = self.llm.model.layers[27].post_attention_layernorm.register_forward_hook(hook)
        
        try:
            inputs_embeds = self.poker_sequence_embedder(model_inputs).to(device="cuda", dtype=torch.bfloat16)
            with torch.no_grad():
                outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=model_inputs['attention_mask'])
            
        finally:
            handle.remove()
        
        # Get activation from the cache
        activation = activation_cache['activation']
        
        if activation is None:
            raise RuntimeError("Hook did not capture activation - check layer path")
        
        activations_last_action = activation[
            torch.arange(activation.shape[0]), 
            model_inputs['attention_mask'].sum(dim=1) - 1, 
            :
        ]
        
        model_inputs['llm_state'] = activations_last_action
        model_inputs['probits'] = self.policy_model(
                model_inputs['active_players'].to(self.device),
                model_inputs['stack_size'].to(self.device),
                model_inputs['card_embeddings'].to(self.device),
                model_inputs['llm_state'].to(self.device)
        )
        return model_inputs

In [7]:
model_name = "./models/qwen3-1point7b/"


tokenizer, model = load_model(model_name)

model

`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.97it/s]
The module name  (originally ) is not a valid Python identifier. Please rename the original module to avoid import issues.


Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (up_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (down_proj): Linear(in_features=6144, out_features=2048, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNorm((2048,), eps=1e-06)
        (post_attention_layer

In [8]:
policy_model = PolicyModel(
    num_players = 3,
    active_players_hidden_dims = [1024, 2048],
    stack_size_hidden_dims = [1024, 2048],
    card_embeddings_hidden_dims = [2048, 2048],
    final_output_hidden_dims = [1024, 512, 256],
    device = 'cuda'
)

In [9]:
poker_player = PokerAgent(
    street_embedder,
    table_position_embedder,
    action_embedder,
    pot_size_embedder,
    model,
    policy_model,
    device = 'cuda',
    llm_train = False
)

In [10]:
!nvidia-smi

Sat Feb  7 21:10:02 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.119.02             Driver Version: 580.119.02     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5090        Off |   00000000:01:00.0  On |                  N/A |
|  0%   26C    P1             64W /  450W |    4731MiB /  32607MiB |      8%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [11]:
outputs = poker_player(
    street_idxs.to('cuda'),
    table_position_idxs.to('cuda'),
    action_idxs.to('cuda'), 
    pot_size_sequence.to('cuda'),
    active_players.to('cuda'),
    stack_size.to('cuda'),
    cards_player_0_embeddings.to('cuda')
)


In [12]:
!nvidia-smi

Sat Feb  7 21:10:02 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.119.02             Driver Version: 580.119.02     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5090        Off |   00000000:01:00.0  On |                  N/A |
|  0%   37C    P1            148W /  450W |    9031MiB /  32607MiB |     99%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [13]:
import torch
import gc

def print_gpu_memory(label=""):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9
        reserved = torch.cuda.memory_reserved() / 1e9
        print(f"{label}")
        print(f"  Allocated: {allocated:.2f} GB")
        print(f"  Reserved:  {reserved:.2f} GB")
        print()

print_gpu_memory("After forward pass")

After forward pass
  Allocated: 3.71 GB
  Reserved:  8.52 GB



In [14]:
del outputs

gc.collect()
print_gpu_memory("After gc.collect()")

torch.cuda.empty_cache()
print_gpu_memory("After empty_cache()")

After gc.collect()
  Allocated: 3.68 GB
  Reserved:  8.52 GB

After empty_cache()
  Allocated: 3.68 GB
  Reserved:  4.11 GB

