In [1]:
import torch
import numpy
from ml_modules import *
from sequence_modules import *
from llm_modules import *
from agent import *
from simulation import *

from utils import *

from ml_ops_utils import *

import gc

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
num_players = 2
small_blind = 1
big_blind = 2

softmax_prob = torch.nn.Softmax(dim=-1)

action_validator = PokerActionValidator(
    num_players = num_players,
    small_blind = small_blind,
    big_blind = big_blind,
    starting_stack_sizes = 400
)

street_embedder_p1 = StreetPositionalEncoding(
    num_streets = 4,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

table_position_embedder_p1 = TablePositionalEncoding(
    num_players = 2,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

action_embedder_p1 = ActionEncoding(
    #num_actions = 21,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

pot_size_embedder_p1 = PotSizeSequenceEmbedder(
    max_seq_len = 1024,
    pad_value = -1,
    device = 'cuda'
)

poker_sequence_embedder_p1 = PokerSequenceEmbedder(
    street_input_dimension = 256,
    table_position_input_dimension = 256,
    action_input_dimension = 256,
    latent_dimensions = [256, 512, 1024, 2048],
    device = 'cuda'
)
cards_p1 = Cards(device = 'cuda')

self_position_embedder_p1 = SelfPositionEmbedder(number_of_positions = 2, device = "cuda")


street_embedder_p2 = StreetPositionalEncoding(
    num_streets = 4,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

table_position_embedder_p2 = TablePositionalEncoding(
    num_players = 2,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

action_embedder_p2 = ActionEncoding(
    #num_actions = 21,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

pot_size_embedder_p2 = PotSizeSequenceEmbedder(
    max_seq_len = 1024,
    pad_value = -1,
    device = 'cuda'
)

poker_sequence_embedder_p2 = PokerSequenceEmbedder(
    street_input_dimension = 256,
    table_position_input_dimension = 256,
    action_input_dimension = 256,
    latent_dimensions = [256, 512, 1024, 2048],
    device = 'cuda'
)

cards_p2 = Cards(device = 'cuda')

self_position_embedder_p2 = SelfPositionEmbedder(number_of_positions = 2, device = "cuda")

deck_order_shuffled = torch.argsort(torch.rand(1,52))

In [3]:
model_name = "./models/qwen3-1point7b/"

tokenizer, model = load_model(model_name)

model

`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.95it/s]
The module name  (originally ) is not a valid Python identifier. Please rename the original module to avoid import issues.


Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (up_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (down_proj): Linear(in_features=6144, out_features=2048, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNorm((2048,), eps=1e-06)
        (post_attention_layer

In [4]:
policy_model_p1 = PolicyModel(
    num_players = 2,
    self_position_embedder = self_position_embedder_p1,
    active_players_hidden_dims = [1024, 2048],
    stack_size_hidden_dims = [1024, 2048],
    card_embeddings_hidden_dims = [2048, 2048],
    final_output_hidden_dims = [1024, 512, 256],
    value_output_hidden_dims = [1024, 512, 256],
    dropout_rate = 0,
    device = 'cuda',
)

policy_model_p2 = PolicyModel(
    num_players = 2,
    self_position_embedder = self_position_embedder_p2,
    active_players_hidden_dims = [1024, 2048],
    stack_size_hidden_dims = [1024, 2048],
    card_embeddings_hidden_dims = [2048, 2048],
    final_output_hidden_dims = [1024, 512, 256],
    value_output_hidden_dims = [1024, 512, 256],
    dropout_rate = 0,
    device = 'cuda',
)

poker_player_p1 = PokerAgent(
    cards_p1,
    street_embedder_p1,
    table_position_embedder_p1,
    action_embedder_p1,
    pot_size_embedder_p1,
    poker_sequence_embedder_p1,
    model,
    policy_model_p1,
    device = 'cuda',
    llm_train = False
)

poker_player_p2 = PokerAgent(
    cards_p2,
    street_embedder_p2,
    table_position_embedder_p2,
    action_embedder_p2,
    pot_size_embedder_p2,
    poker_sequence_embedder_p2,
    model,
    policy_model_p2,
    device = 'cuda',
    llm_train = False
)

batch_size = 1024

street_idxs = (torch.zeros((batch_size, 1024)) + 6).long()
street_idxs[:, :2] = 0 # posting small blind and big blind

table_position_idxs = (torch.zeros((batch_size, 1024)) + 2).long()
table_position_idxs[:, 0] = 0 # sb/b
table_position_idxs[:, 1] = 1 # bb

action_idxs = (torch.zeros((batch_size, 1024)) + 21).long()
action_idxs[:, 0] = 0 # post sb
action_idxs[:, 1] = 1 # post bb

pot_size_sequence = (torch.zeros((batch_size, 1024)) - 1)
pot_size_sequence[:, 0] = 1
pot_size_sequence[:, 1] = 3


active_players = torch.Tensor([
    [1, 1]
]).tile(batch_size, 1)
print(active_players.shape)
stack_size = torch.Tensor(
    [
        [399, 398],
    ]
).tile(batch_size, 1)
stack_size.shape

sb_cards = torch.zeros((batch_size,2,7)).to(dtype=torch.long).to('cuda')
bb_cards = torch.zeros((batch_size,2,7)).to(dtype=torch.long).to('cuda')

sb_cards[:, 0, :2] = deck_order_shuffled[0, :2]%13
sb_cards[:, 1, :2] = deck_order_shuffled[0, :2]//13

sb_cards[:, 0, 2:] = 13
sb_cards[:, 1, 2:] = 4


bb_cards[:, 0, :2] = deck_order_shuffled[0, 2:4]%13
bb_cards[:, 1, :2] = deck_order_shuffled[0, 2:4]//13

bb_cards[:, 0, 2:] = 13
bb_cards[:, 1, 2:] = 4

p1_position_player = torch.Tensor([0]).to('cuda').tile((batch_size))
p2_position_player = torch.Tensor([1]).to('cuda').tile((batch_size))

table = {
    0 : [poker_player_p1, sb_cards, p1_position_player],
    1 : [poker_player_p2, bb_cards, p2_position_player]
}

torch.Size([1024, 2])


In [15]:
batch_size = 1024

street_idxs = (torch.zeros((batch_size, 1024)) + 6).long()
street_idxs[:, :2] = 0 # posting small blind and big blind

table_position_idxs = (torch.zeros((batch_size, 1024)) + 2).long()
table_position_idxs[:, 0] = 0 # sb/b
table_position_idxs[:, 1] = 1 # bb

action_idxs = (torch.zeros((batch_size, 1024)) + 21).long()
action_idxs[:, 0] = 0 # post sb
action_idxs[:, 1] = 1 # post bb

pot_size_sequence = (torch.zeros((batch_size, 1024)) - 1)
pot_size_sequence[:, 0] = 1
pot_size_sequence[:, 1] = 3


active_players = torch.Tensor([
    [1, 1]
]).tile(batch_size, 1)
print(active_players.shape)
stack_size = torch.Tensor(
    [
        [399, 398],
    ]
).tile(batch_size, 1)
stack_size.shape

sb_cards = torch.zeros((batch_size,2,7)).to(dtype=torch.long).to('cuda')
bb_cards = torch.zeros((batch_size,2,7)).to(dtype=torch.long).to('cuda')

sb_cards[:, 0, :2] = deck_order_shuffled[0, :2]%13
sb_cards[:, 1, :2] = deck_order_shuffled[0, :2]//13

sb_cards[:, 0, 2:] = 13
sb_cards[:, 1, 2:] = 4


bb_cards[:, 0, :2] = deck_order_shuffled[0, 2:4]%13
bb_cards[:, 1, :2] = deck_order_shuffled[0, 2:4]//13

bb_cards[:, 0, 2:] = 13
bb_cards[:, 1, 2:] = 4

p1_position_player = torch.Tensor([0]).to('cuda').tile((batch_size))
p2_position_player = torch.Tensor([1]).to('cuda').tile((batch_size))

table = {
    0 : [poker_player_p1, sb_cards, p1_position_player],
    1 : [poker_player_p2, bb_cards, p2_position_player]
}

(
    street_idxs, 
    sim_table_position_idxs, 
    sim_action_idxs, 
    sim_pot_size_sequence, 
    sim_active_players, 
    sim_stack_size,
    sim_table
) = simulate_hand(
    num_players,
    street_idxs,
    table_position_idxs,
    action_idxs,
    pot_size_sequence,
    active_players,
    stack_size,
    table,
    action_validator,
    deck_order_shuffled,
)

torch.Size([1024, 2])
after action 2 


2
1
tensor([0, 1])
tensor([1., 3.])
tensor([0, 1])
tensor([0, 0])
tensor([399., 398.])
tensor([1., 1.])

 

4
tensor(399.)
after action 3 


3
2
tensor([0, 1, 8])
tensor([ 1.,  3., 14.])
tensor([0, 1, 0])
tensor([0, 0, 0])
tensor([388., 398.])
tensor([1., 1.])

 

tensor(22.)
tensor(398.)
after action 4 


4
3
tensor([0, 1, 8, 6])
tensor([ 1.,  3., 14., 36.])
tensor([0, 1, 0, 1])
tensor([0, 0, 0, 0])
tensor([388., 376.])
tensor([1., 1.])

 

tensor(44.)
tensor(388.)
after action 5 


5
4
tensor([0, 1, 8, 6, 9])
tensor([  1.,   3.,  14.,  36., 121.])
tensor([0, 1, 0, 1, 0])
tensor([0, 0, 0, 0, 0])
tensor([303., 376.])
tensor([1., 1.])

 

after action 6 


6
5
tensor([0, 1, 8, 6, 9, 4])
tensor([  1.,   3.,  14.,  36., 121., 194.])
tensor([0, 1, 0, 1, 0, 1])
tensor([0, 0, 0, 0, 0, 0])
tensor([303., 303.])
tensor([1., 1.])

 

after action 7 


7
6
tensor([ 0,  1,  8,  6,  9,  4, 19])
tensor([  1.,   3.,  14.,  36., 121., 194., 194.])
tensor([0, 1, 

In [6]:
batch_size = 1024

street_idxs = (torch.zeros((batch_size, 1024)) + 6).long()
street_idxs[:, :2] = 0 # posting small blind and big blind

table_position_idxs = (torch.zeros((batch_size, 1024)) + 2).long()
table_position_idxs[:, 0] = 0 # sb/b
table_position_idxs[:, 1] = 1 # bb

action_idxs = (torch.zeros((batch_size, 1024)) + 21).long()
action_idxs[:, 0] = 0 # post sb
action_idxs[:, 1] = 1 # post bb

pot_size_sequence = (torch.zeros((batch_size, 1024)) - 1)
pot_size_sequence[:, 0] = 1
pot_size_sequence[:, 1] = 3


active_players = torch.Tensor([
    [1, 1]
]).tile(batch_size, 1)
print(active_players.shape)
stack_size = torch.Tensor(
    [
        [399, 398],
    ]
).tile(batch_size, 1)
stack_size.shape

torch.Size([1024, 2])


torch.Size([1024, 2])

In [146]:
sb_cards = torch.zeros((batch_size,2,7)).to(dtype=torch.long).to('cuda')
bb_cards = torch.zeros((batch_size,2,7)).to(dtype=torch.long).to('cuda')

In [147]:
sb_cards[:, 0, :2] = deck_order_shuffled[0, :2]%13
sb_cards[:, 1, :2] = deck_order_shuffled[0, :2]//13

sb_cards[:, 0, 2:] = 13
sb_cards[:, 1, 2:] = 4


bb_cards[:, 0, :2] = deck_order_shuffled[0, 2:4]%13
bb_cards[:, 1, :2] = deck_order_shuffled[0, 2:4]//13

bb_cards[:, 0, 2:] = 13
bb_cards[:, 1, 2:] = 4

In [148]:
p1_position_player = torch.Tensor([0]).to('cuda').tile((batch_size))
p2_position_player = torch.Tensor([1]).to('cuda').tile((batch_size))

In [149]:
p1_position_player.shape

torch.Size([1024])

In [150]:
table = {
    0 : [poker_player_p1, sb_cards, p1_position_player],
    1 : [poker_player_p2, bb_cards, p2_position_player]
}

In [151]:
curr_street_index = 2
curr_batch_index = 1
END_OF_HAND_TOKEN_ACTIONS = 20
END_OF_HAND_TOKEN_STREETS = 5

END_OF_STREET_TOKEN_ACTIONS = 19
END_OF_STREET_TOKEN_STREETS = 4

end_of_hand_happened = False

while not end_of_hand_happened:

    print(curr_street_index)
    print(curr_batch_index)
    
    legal_actions = action_validator.get_legal_actions_mask(
        street_idxs,
        table_position_idxs,
        action_idxs, 
        pot_size_sequence,
        active_players
    )
    
    who_is_acting = action_validator.get_next_to_act(
        street_idxs,
        table_position_idxs,
        action_idxs, 
        active_players
    )
    
    next_to_act = who_is_acting[[curr_batch_index-1]]
    legal_actions = legal_actions[curr_batch_index-1]
    
    if legal_actions[ -1] == True:
        # End of hand has happened. Terminate the sequence and end while loop.
        end_of_hand_happened = True
        
        action_idxs[curr_batch_index:, curr_street_index] = END_OF_HAND_TOKEN_ACTIONS
        street_idxs[curr_batch_index:, curr_street_index] = END_OF_HAND_TOKEN_STREETS
        pot_size_sequence[curr_batch_index: , curr_street_index] = pot_size_sequence[curr_batch_index, curr_street_index - 1]
        table_position_idxs[curr_batch_index: , curr_street_index] = num_players
    
    elif legal_actions[-2] == True:
        # End of street has happened. Prepare transition to next street.    
        action_idxs[curr_batch_index:, curr_street_index] = END_OF_STREET_TOKEN_ACTIONS
        street_idxs[curr_batch_index:, curr_street_index] = END_OF_STREET_TOKEN_STREETS
        pot_size_sequence[curr_batch_index: , curr_street_index] = pot_size_sequence[curr_batch_index, curr_street_index - 1]
        table_position_idxs[curr_batch_index: , curr_street_index] = num_players
    
        if street_idxs[curr_batch_index, curr_street_index - 1] == 1:
            current_street = "flop"
        elif street_idxs[curr_batch_index, curr_street_index - 1] == 2:
            current_street = "turn"
        elif street_idxs[curr_batch_index, curr_street_index - 1] == 3:
            current_street = "river"
        
    
        for key in table.keys():
            _, cards, _ = table[key]
            if current_street == "flop":
                # Deal three cards (OMG NO BURN?!).
                cards[curr_batch_index:, 0, 2:5] = deck_order_shuffled[0,(2*(num_players)):(2*(num_players)+3)]%13
                cards[curr_batch_index:, 1, 2:5] = deck_order_shuffled[0,(2*(num_players)):(2*(num_players)+3)]//13
                table[key][1] = cards
            elif current_street == "turn":
                # Deal turn card (OMG NO BURN?!).
                cards[curr_batch_index:, 0, 5] = deck_order_shuffled[0,(2*(num_players)+3)]%13
                cards[curr_batch_index:, 1, 5] = deck_order_shuffled[0,(2*(num_players)+3)]//13
                table[key][1] = cards
            elif current_street == "river":
                # Deal river card (still no burn).
                cards[curr_batch_index:, 0, 6] = deck_order_shuffled[0,(2*(num_players)+4)]%13
                cards[curr_batch_index:, 1, 6] = deck_order_shuffled[0,(2*(num_players)+4)]//13
                table[key][1] = cards

        if (action_idxs[curr_batch_index]==16).any():
            # Someone was all in on the street that just completed, implying the hand is over.
            action_idxs[curr_batch_index:, curr_street_index + 1] = END_OF_HAND_TOKEN_ACTIONS
            street_idxs[curr_batch_index:, curr_street_index + 1] = END_OF_HAND_TOKEN_STREETS
            pot_size_sequence[curr_batch_index: , curr_street_index + 1] = pot_size_sequence[curr_batch_index, curr_street_index - 1]
            table_position_idxs[curr_batch_index: , curr_street_index + 1] = num_players

            for key in table.keys():
                _, cards, _ = table[key]
                cards[curr_batch_index:, 0, 2:7] = deck_order_shuffled[0,(2*(num_players)):(2*(num_players)+5)]%13
                cards[curr_batch_index:, 1, 2:7] = deck_order_shuffled[0,(2*(num_players)):(2*(num_players)+5)]//13
                table[key][1] = cards
            
            end_of_hand_happened = True
    
    else:
        player, cards, position = table[next_to_act.item()]

        cards = cards[[curr_batch_index]]
        position = position[[curr_batch_index]]
    
        outputs = player(
            position,
            cards,
            street_idxs[[curr_batch_index]],
            table_position_idxs[[curr_batch_index]],
            action_idxs[[curr_batch_index]],
            pot_size_sequence[[curr_batch_index]],
            active_players.to('cuda')[[curr_batch_index]],
            stack_size.to('cuda')[[curr_batch_index]]
        )
    
        sampled_action = torch.distributions.Categorical(
            softmax_prob(outputs['probits'] - 1e9 * ((~legal_actions).float()).to('cuda'))
        ).sample()

        if street_idxs[curr_batch_index, curr_street_index - 1]!=4:
            current_street = street_idxs[curr_batch_index, curr_street_index - 1] 
        else:
            current_street = street_idxs[curr_batch_index, curr_street_index - 2] + 1



        if 5 <= sampled_action <= 16:
            min_size = get_min_bet_size_or_raise_size(
                pot_size_sequence[[curr_batch_index]],
                street_idxs[[curr_batch_index]],
                action_idxs[[curr_batch_index]],
                current_street
            )
        
            if legal_actions[5] == True:
                spacing = 16-5+1 # (number of sizes)
            else:
                spacing = 16-6+1 # (number of sizes)
            
            my_stack_size = stack_size[curr_batch_index, position.int().to('cpu')].squeeze()
            
            multiplicative_factor = (my_stack_size / min(min_size, my_stack_size)) ** (1/(spacing-1))
            
            bet_sizes = torch.ones(spacing) * min(min_size, my_stack_size)
            bet_sizes = torch.ceil(bet_sizes * (multiplicative_factor ** torch.arange(spacing))).clip(max = my_stack_size)
            
            chosen_bet = bet_sizes[sampled_action.to('cpu')[0] - 17]
        
            street_idxs[curr_batch_index:, curr_street_index] = current_street
            table_position_idxs[curr_batch_index:, curr_street_index] = position.int().to('cpu')

            action_idxs[curr_batch_index:, curr_street_index] = sampled_action.to('cpu')[0]
            pot_size_sequence[curr_batch_index:, curr_street_index] = (
                pot_size_sequence[curr_batch_index, curr_street_index - 1]
                +
                bet_sizes[sampled_action.to('cpu')[0] - 17]
            )
            stack_size[curr_batch_index:, position.int().to('cpu')] -= bet_sizes[sampled_action.to('cpu')[0] - 17]

            if stack_size[curr_batch_index, position.int().to('cpu')] == 0:
                # if we're left with nothing left, it's an all in.
                action_idxs[curr_batch_index:, curr_street_index] = 16
            

        if sampled_action == 4:
            
            two_largest_vals, _ = torch.topk(
                torch.unique(pot_size_sequence[curr_batch_index]), 2
            )

            # The amount to call, if everyone started with the same amount
            # of money is equal to my stack size minus the smallest stack size
            # because that person made the biggest bet so far.
            amount_to_call = min(
                stack_size[curr_batch_index, position.int().to('cpu')] - min(stack_size[curr_batch_index]), 
                stack_size[curr_batch_index, position.int().to('cpu')]
            )

            street_idxs[curr_batch_index:, curr_street_index] = current_street
            table_position_idxs[curr_batch_index:, curr_street_index] = position.int().to('cpu')
            action_idxs[curr_batch_index:, curr_street_index] = 4
            pot_size_sequence[curr_batch_index:, curr_street_index] = (
                pot_size_sequence[curr_batch_index, curr_street_index - 1] + amount_to_call
            )
            stack_size[curr_batch_index:, position.int().to('cpu')] -= amount_to_call

        if sampled_action == 3:

            street_idxs[curr_batch_index:, curr_street_index] = current_street
            table_position_idxs[curr_batch_index:, curr_street_index] = position.int().to('cpu')
            action_idxs[curr_batch_index:, curr_street_index] = 3
            pot_size_sequence[curr_batch_index:, curr_street_index] = (
                pot_size_sequence[curr_batch_index, curr_street_index]
            )

        if sampled_action == 2:

            street_idxs[curr_batch_index:, curr_street_index] = current_street
            table_position_idxs[curr_batch_index:, curr_street_index] = position.int().to('cpu')
            action_idxs[curr_batch_index:, curr_street_index] = 2
            pot_size_sequence[curr_batch_index:, curr_street_index] = (
                pot_size_sequence[curr_batch_index, curr_street_index - 1]
            )

            active_players[curr_batch_index:, position.int().to('cpu')] = 0

        del outputs
        gc.collect()
        torch.cuda.empty_cache()

        
        
    curr_street_index += 1
    curr_batch_index += 1

2
1
3
2
4
3
5
4
6
5


In [152]:
my_stack_size

tensor(298.)

In [153]:
multiplicative_factor

tensor(1.)

In [154]:
bet_sizes = torch.ones(spacing) * min(min_size, my_stack_size)


In [155]:
torch.ceil(bet_sizes * (multiplicative_factor ** torch.arange(spacing))).clip(max = my_stack_size)


tensor([298., 298., 298., 298., 298., 298., 298., 298., 298., 298., 298.])

In [156]:
table_position_idxs[7][:10]

tensor([0, 1, 0, 1, 0, 1, 2, 2, 2, 2])

In [157]:
print(action_idxs[6][:5])
print(table_position_idxs[7][:5])

tensor([ 0,  1, 13, 11, 16])
tensor([0, 1, 0, 1, 0])


In [159]:
stack_size[:7]

tensor([[399., 398.],
        [298., 398.],
        [298., 114.],
        [  0., 114.],
        [  0.,   0.],
        [  0.,   0.],
        [  0.,   0.]])

In [161]:
pot_size_sequence[7][:10]

tensor([  1.,   3., 104., 388., 686., 800., 800., 800.,  -1.,  -1.])

In [51]:
legal_actions[0]

tensor([False, False,  True, False,  True, False,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True, False, False, False,
        False])

In [56]:
street_idxs[0][0:10]

tensor([0, 0, 0, 6, 6, 6, 6, 6, 6, 6])

In [82]:
street_idxs[3][0:10]

tensor([0, 0, 0, 0, 4, 4, 6, 6, 6, 6])

In [83]:
table_position_idxs[3][0:10]

tensor([0, 1, 0, 1, 2, 0, 2, 2, 2, 2])

In [84]:
table_position_idxs[3][0:10]

tensor([0, 1, 0, 1, 2, 0, 2, 2, 2, 2])

In [85]:
action_idxs[3][0:10]

tensor([ 0,  1, 11,  4, 19, 14, 21, 21, 21, 21])

In [86]:
action_idxs[3][0:10]

tensor([ 0,  1, 11,  4, 19, 14, 21, 21, 21, 21])

In [80]:
legal_actions = action_validator.get_legal_actions_mask(
    street_idxs[[2]],
    table_position_idxs[[2]],
    action_idxs[[2]], 
    pot_size_sequence[[2]],
    active_players[[2]]
)

print(legal_actions)

who_is_acting = action_validator.get_next_to_act(
    street_idxs[[2]],
    table_position_idxs[[2]],
    action_idxs[[2]], 
    active_players[[2]]
)
print(who_is_acting)

tensor([[False, False, False,  True, False,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True, False, False, False,
         False]])
tensor([0])


In [20]:
legal_actions

tensor([[False, False,  True,  ..., False, False, False],
        [False, False,  True,  ..., False, False, False],
        [False, False,  True,  ..., False, False, False],
        ...,
        [False, False,  True,  ..., False, False, False],
        [False, False,  True,  ..., False, False, False],
        [False, False,  True,  ..., False, False, False]])

In [15]:
            _, two_largest_vals = torch.topk(
                torch.unique(pot_size_sequence[curr_batch_index]), 2
            )
            

In [16]:
two_largest_vals

tensor([3, 2])

In [14]:
pot_size_sequence[:curr_batch_index+1, :curr_street_index+1]

tensor([[ 1.,  3., 10., 10., 10.],
        [ 1.,  3., 10., 11., 10.],
        [ 1.,  3., 10., 11., 11.]])

In [17]:
stack_size[curr_batch_index, position.int().to('cpu')]

tensor([399., 399., 399.,  ..., 399., 399., 399.])

In [19]:
table_position_idxs[0, :10]

tensor([0, 1, 0, 2, 2, 2, 2, 2, 2, 2])

In [20]:
action_idxs[0, :10]

tensor([ 0,  1,  8, 21, 21, 21, 21, 21, 21, 21])

In [21]:
pot_size_sequence[0, :10]

tensor([ 1.,  3., 14., -1., -1., -1., -1., -1., -1., -1.])

In [22]:
stack_size

tensor([[388., 398.]])

In [14]:
!nvidia-smi

Mon Feb 16 15:44:04 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.119.02             Driver Version: 580.119.02     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5090        Off |   00000000:01:00.0 Off |                  N/A |
|  0%   24C    P1             68W /  450W |    7520MiB /  32607MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [15]:
outputs['probits']

tensor([[-0.6874, -1.6052, -0.8154, -0.1669,  0.2437, -0.0022, -0.3997, -0.2503,
          0.2770,  1.6369,  0.4130, -0.9659, -0.6538,  0.2339, -0.4236, -0.3005,
          0.6149, -1.0145,  0.0963, -0.8243, -0.0644]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [17]:
legal_actions = action_validator.get_legal_actions_mask(
    street_idxs,
    table_position_idxs,
    action_idxs, 
    pot_size_sequence,
    active_players
)

In [18]:
legal_actions

tensor([[False, False,  True, False,  True, False,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True, False, False, False,
         False]])