In [1]:
import torch
import numpy
from ml_modules import *
from sequence_modules import *

In [2]:
street_embedder = StreetPositionalEncoding(
    num_streets = 4,
    embedding_dim = 256,
    max_seq_len = 128,
    device = "cpu"
)

table_position_embedder = TablePositionalEncoding(
    num_players = 2,
    embedding_dim = 256,
    max_seq_len = 128,
    device = "cpu"
)

action_embedder = ActionEncoding(
    #num_actions = 21,
    embedding_dim = 256,
    max_seq_len = 128,
    device = "cpu"
)

pot_size_embedder = PotSizeSequenceEmbedder(
    max_seq_len = 128,
    pad_value = -1
)

In [3]:
street_idxs = torch.Tensor([
    [0, 0, 6, 6, 6, 6, 6, 6],
    [0, 0, 0, 6, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 4, 6, 6, 6], 
    [0, 0, 0, 0, 4, 0, 0, 6], 
    [0, 0, 0, 0, 4, 0, 0, 6], 
    [0, 0, 0, 0, 4, 0, 0, 6], 
    [0, 0, 0, 0, 4, 0, 0, 6], 
    
])
table_position_idxs = torch.Tensor([
    [0, 1, 2, 2, 2, 2, 2, 2],
    [0, 1, 0, 2, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 0, 1, 2],
    [0, 1, 0, 1, 2, 0, 1, 2],
    [0, 1, 0, 1, 2, 0, 1, 2],
    [0, 1, 0, 1, 2, 0, 1, 2],

])
action_idxs = torch.Tensor([
    [0, 1, 21, 21, 21, 21, 21, 21],
    [0, 1, 4, 21, 21, 21, 21, 21],
    [0, 1, 5, 4, 21, 21, 21, 21],
    [0, 1, 4, 3, 21, 21, 21, 21],
    [0, 1, 5, 6, 21, 21, 21, 21],
    [0, 1, 5, 2, 21, 21, 21, 21],
    [0, 1, 5, 4, 19, 21, 21, 21],
    [0, 1, 5, 4, 19, 3, 3, 21],
    [0, 1, 5, 4, 19, 5, 2, 21],
    [0, 1, 5, 4, 19, 5, 4, 21],
    [0, 1, 5, 4, 19, 3, 5, 21],
])
pot_size_sequence = torch.Tensor([
    [1, 3, -1, -1, -1, -1, -1, -1],
    [1, 3, 5, -1, -1, -1, -1, -1],
    [1, 3, 7, 8, -1, -1, -1, -1],
    [1, 3, 5, 5, -1, -1, -1, -1],
    [1, 3, 7, 15, -1, -1, -1, -1],
    [1, 3, 7, 7, -1, -1, -1, -1],
    [1, 3, 7, 8, 8, -1, -1, -1],
    [1, 3, 7, 8, 8, 8, 8, -1],
    [1, 3, 7, 8, 8, 10, 10, -1],
    [1, 3, 7, 8, 8, 10, 12, -1],
    [1, 3, 7, 8, 8, 8, 10, -1],
])
active_players = torch.Tensor([
    [1, 1],
    [1, 1],
    [1, 1],
    [1, 1],
    [1, 1],
    [1, 0],
    [1, 1],
    [1, 1],
    [1, 0],
    [1, 1],
    [1, 1],
])

In [4]:
padded_pot_size_sequence = pot_size_embedder(pot_size_sequence)

In [5]:
padded_pot_size_sequence.shape

torch.Size([11, 128])

In [6]:
street_idxs_out, street_embs = street_embedder(street_idxs)
street_embedding = {
    'street_idxs': street_idxs_out,
    'street_embedding': street_embs,
}

table_pos_idxs_out, table_pos_embs = table_position_embedder(table_position_idxs)
table_position_embedding = {
    'table_position_idxs': table_pos_idxs_out,
    'table_position_embedding': table_pos_embs,
}

action_idxs_out, action_embs = action_embedder(action_idxs)
action_embedding = {
    'action_idxs': action_idxs_out,
    'action_embedding': action_embs,
}


In [7]:
model_inputs = street_embedding | table_position_embedding | action_embedding | {'pot_size_sequence' : padded_pot_size_sequence.unsqueeze(2)}

In [8]:
model_inputs['pot_size_sequence'].shape

torch.Size([11, 128, 1])

In [9]:
poker_sequence_embedder = PokerSequenceEmbedder(
    street_input_dimension = 256,
    table_position_input_dimension = 256,
    action_input_dimension = 256,
    latent_dimensions = [256, 512, 1024, 2048],
    device = 'cpu'
)

In [10]:
poker_sequence_embedder(model_inputs).shape

torch.Size([11, 128, 2048])

In [11]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "./models/qwen3-1point7b/"

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)



  from .autonotebook import tqdm as notebook_tqdm
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.58it/s]
The module name  (originally ) is not a valid Python identifier. Please rename the original module to avoid import issues.


In [12]:
model

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (up_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (down_proj): Linear(in_features=6144, out_features=2048, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNorm((2048,), eps=1e-06)
        (post_attention_layer

In [13]:
!nvidia-smi

Sat Jan 31 15:07:46 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.119.02             Driver Version: 580.119.02     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5090        Off |   00000000:01:00.0 Off |                  N/A |
|  0%   26C    P1             70W /  450W |    4537MiB /  32607MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [14]:
with torch.no_grad():
    outputs = model(
        inputs_embeds=poker_sequence_embedder(model_inputs)
            .to(device="cuda", dtype=torch.bfloat16)
    )


In [15]:
inputs_embeds = poker_sequence_embedder(model_inputs).to(device="cuda", dtype=torch.bfloat16)

In [16]:
outputs

CausalLMOutputWithPast(loss=None, logits=tensor([[[ 3.1875, -3.7969,  0.0277,  ..., -0.1201, -0.1201, -0.1206],
         [ 4.4688, -2.6562,  2.2344,  ..., -0.1514, -0.1514, -0.1514],
         [ 6.7188, -0.3770,  5.9375,  ..., -0.0640, -0.0640, -0.0640],
         ...,
         [ 4.3438, -0.4531,  4.0312,  ...,  0.2070,  0.2070,  0.2070],
         [ 4.3125, -0.4648,  3.9531,  ...,  0.1895,  0.1895,  0.1895],
         [ 4.3438, -0.4297,  3.9531,  ...,  0.1914,  0.1914,  0.1914]],

        [[ 3.1875, -3.7969,  0.0277,  ..., -0.1201, -0.1201, -0.1206],
         [ 4.4688, -2.6562,  2.2344,  ..., -0.1514, -0.1514, -0.1514],
         [ 7.9375,  2.6875,  5.2188,  ..., -0.0957, -0.0957, -0.0957],
         ...,
         [ 4.4375, -0.3184,  4.5000,  ...,  0.1436,  0.1436,  0.1436],
         [ 4.4062, -0.2334,  4.4688,  ...,  0.1182,  0.1182,  0.1182],
         [ 4.3750, -0.2734,  4.5000,  ...,  0.1001,  0.1001,  0.1001]],

        [[ 3.1875, -3.7969,  0.0277,  ..., -0.1201, -0.1201, -0.1206],
    

In [17]:
model(inputs_embeds = inputs_embeds, output_hidden_states=True).hidden_states[27].shape

torch.Size([11, 128, 2048])

In [28]:
!nvidia-smi

Sat Jan 31 15:08:30 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.119.02             Driver Version: 580.119.02     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5090        Off |   00000000:01:00.0 Off |                  N/A |
|  0%   28C    P1             70W /  450W |   10627MiB /  32607MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [29]:
with torch.no_grad():
    outputs = model(inputs_embeds = inputs_embeds, output_hidden_states=True)

In [50]:
model_inputs['pot_size_sequence'].shape

torch.Size([11, 128, 1])

In [51]:
attention_mask = (model_inputs['pot_size_sequence'] != -1).squeeze(-1)


In [52]:
with torch.no_grad():
    outputs = model(inputs_embeds = inputs_embeds, output_hidden_states=True, attention_mask = attention_mask)

In [53]:
outputs

CausalLMOutputWithPast(loss=None, logits=tensor([[[ 3.1875, -3.7969,  0.0277,  ..., -0.1201, -0.1201, -0.1206],
         [ 4.4688, -2.6562,  2.2500,  ..., -0.1494, -0.1494, -0.1494],
         [ 5.6875, -2.2188,  3.1406,  ..., -0.1914, -0.1914, -0.1914],
         ...,
         [ 4.7812,  0.5508,  2.7344,  ..., -0.3320, -0.3301, -0.3320],
         [ 4.7812,  0.4902,  2.6094,  ..., -0.3574, -0.3574, -0.3574],
         [ 5.0000,  0.6328,  2.5625,  ..., -0.3574, -0.3574, -0.3574]],

        [[ 3.1875, -3.7969,  0.0277,  ..., -0.1201, -0.1201, -0.1206],
         [ 4.4688, -2.6562,  2.2500,  ..., -0.1494, -0.1494, -0.1494],
         [ 7.9375,  2.6875,  5.2188,  ..., -0.0981, -0.0981, -0.0981],
         ...,
         [ 7.3125,  2.6719,  4.4375,  ..., -0.2734, -0.2734, -0.2754],
         [ 7.1250,  2.5312,  4.3750,  ..., -0.3770, -0.3770, -0.3770],
         [ 7.6562,  2.7344,  4.5000,  ..., -0.3379, -0.3379, -0.3379]],

        [[ 3.1875, -3.7969,  0.0277,  ..., -0.1201, -0.1201, -0.1206],
    

In [43]:
activation = None

def hook(_, __, output):
    global activation
    activation = output

handle = model.model.layers[27].post_attention_layernorm.register_forward_hook(hook)


outputs = model(inputs_embeds=inputs_embeds)

handle.remove()


In [44]:
activation

tensor([[[ 2.8438, -0.5352, -2.1719,  ..., -1.1484, -1.5938,  0.2734],
         [ 4.2500, -4.2188,  1.6641,  ...,  0.0170, -6.4375,  3.6562],
         [ 2.2656, -5.7500,  1.9844,  ...,  0.6133, -3.6562,  3.9688],
         ...,
         [-2.0938,  0.3730, -2.8750,  ..., -0.1147,  2.0312,  2.6719],
         [-2.0781,  0.3516, -2.8594,  ..., -0.0967,  2.0312,  2.6250],
         [-2.0469,  0.3457, -2.8594,  ..., -0.1206,  2.0625,  2.5938]],

        [[ 2.8438, -0.5352, -2.1719,  ..., -1.1484, -1.5938,  0.2734],
         [ 4.2500, -4.2188,  1.6641,  ...,  0.0170, -6.4375,  3.6562],
         [ 1.2969, -5.2500,  0.9141,  ..., -1.4219, -4.5000,  1.5703],
         ...,
         [-2.3594,  0.0134, -2.8281,  ...,  0.2852,  1.9688,  2.5312],
         [-2.3750,  0.0493, -2.8125,  ...,  0.2695,  1.9766,  2.5625],
         [-2.2969,  0.0114, -2.8281,  ...,  0.2695,  2.0156,  2.5469]],

        [[ 2.8438, -0.5352, -2.1719,  ..., -1.1484, -1.5938,  0.2734],
         [ 4.2500, -4.2188,  1.6641,  ...,  0

In [46]:
hs = outputs.hidden_states[27].detach()
del outputs

TypeError: 'NoneType' object is not subscriptable

In [47]:
!nvidia-smi

Sat Jan 31 15:20:06 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.119.02             Driver Version: 580.119.02     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5090        Off |   00000000:01:00.0 Off |                  N/A |
|  0%   25C    P1             69W /  450W |   15613MiB /  32607MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [32]:
import gc

In [33]:
gc.collect()

1033

In [34]:
!nvidia-smi

Sat Jan 31 15:04:11 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.119.02             Driver Version: 580.119.02     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5090        Off |   00000000:01:00.0 Off |                  N/A |
|  0%   29C    P1             70W /  450W |   31729MiB /  32607MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [36]:
print(torch.cuda.memory_summary())


|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  25083 MiB |  30076 MiB |  72648 MiB |  47565 MiB |
|       from large pool |  25051 MiB |  30038 MiB |  72546 MiB |  47494 MiB |
|       from small pool |     31 MiB |     37 MiB |    102 MiB |     71 MiB |
|---------------------------------------------------------------------------|
| Active memory         |  25083 MiB |  30076 MiB |  72648 MiB |  47565 MiB |
|       from large pool |  25051 MiB |  30038 MiB |  72546 MiB |  47494 MiB |
|       from small pool |     31 MiB |     37 MiB |    102 MiB |     71 MiB |
|---------------------------------------------------------------