In [1]:
import torch
import numpy
from ml_modules import *
from sequence_modules import *
from llm_modules import *
from agent import *

from utils import *

from ml_ops_utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
softmax_prob = torch.nn.Softmax(dim=-1)

In [3]:
action_validator = PokerActionValidator(
    num_players = 2,
    small_blind = 1,
    big_blind = 2,
    starting_stack_sizes = 400
)

In [4]:
street_embedder_p1 = StreetPositionalEncoding(
    num_streets = 4,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

table_position_embedder_p1 = TablePositionalEncoding(
    num_players = 2,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

action_embedder_p1 = ActionEncoding(
    #num_actions = 21,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

pot_size_embedder_p1 = PotSizeSequenceEmbedder(
    max_seq_len = 1024,
    pad_value = -1,
    device = 'cuda'
)

poker_sequence_embedder_p1 = PokerSequenceEmbedder(
    street_input_dimension = 256,
    table_position_input_dimension = 256,
    action_input_dimension = 256,
    latent_dimensions = [256, 512, 1024, 2048],
    device = 'cuda'
)
cards_p1 = Cards(device = 'cuda')

self_position_embedder_p1 = SelfPositionEmbedder(number_of_positions = 2, device = "cuda")


street_embedder_p2 = StreetPositionalEncoding(
    num_streets = 4,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

table_position_embedder_p2 = TablePositionalEncoding(
    num_players = 2,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

action_embedder_p2 = ActionEncoding(
    #num_actions = 21,
    embedding_dim = 256,
    max_seq_len = 1024,
    device = "cuda"
)

pot_size_embedder_p2 = PotSizeSequenceEmbedder(
    max_seq_len = 1024,
    pad_value = -1,
    device = 'cuda'
)

poker_sequence_embedder_p2 = PokerSequenceEmbedder(
    street_input_dimension = 256,
    table_position_input_dimension = 256,
    action_input_dimension = 256,
    latent_dimensions = [256, 512, 1024, 2048],
    device = 'cuda'
)

cards_p2 = Cards(device = 'cuda')

self_position_embedder_p2 = SelfPositionEmbedder(number_of_positions = 2, device = "cuda")

In [5]:
deck_order_shuffled = torch.argsort(torch.rand(1, 52))

In [6]:
model_name = "./models/qwen3-1point7b/"


tokenizer, model = load_model(model_name)

model

`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.86it/s]
The module name  (originally ) is not a valid Python identifier. Please rename the original module to avoid import issues.


Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (up_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (down_proj): Linear(in_features=6144, out_features=2048, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNorm((2048,), eps=1e-06)
        (post_attention_layer

In [7]:
policy_model_p1 = PolicyModel(
    num_players = 2,
    self_position_embedder = self_position_embedder_p1,
    active_players_hidden_dims = [1024, 2048],
    stack_size_hidden_dims = [1024, 2048],
    card_embeddings_hidden_dims = [2048, 2048],
    final_output_hidden_dims = [1024, 512, 256],
    value_output_hidden_dims = [1024, 512, 256],
    device = 'cuda'
)

policy_model_p2 = PolicyModel(
    num_players = 2,
    self_position_embedder = self_position_embedder_p2,
    active_players_hidden_dims = [1024, 2048],
    stack_size_hidden_dims = [1024, 2048],
    card_embeddings_hidden_dims = [2048, 2048],
    final_output_hidden_dims = [1024, 512, 256],
    value_output_hidden_dims = [1024, 512, 256],
    device = 'cuda'
)

In [8]:
poker_player_p1 = PokerAgent(
    cards_p1,
    street_embedder_p1,
    table_position_embedder_p1,
    action_embedder_p1,
    pot_size_embedder_p1,
    poker_sequence_embedder_p1,
    model,
    policy_model_p1,
    device = 'cuda',
    llm_train = False
)

poker_player_p2 = PokerAgent(
    cards_p2,
    street_embedder_p2,
    table_position_embedder_p2,
    action_embedder_p2,
    pot_size_embedder_p2,
    poker_sequence_embedder_p2,
    model,
    policy_model_p2,
    device = 'cuda',
    llm_train = False
)

In [9]:
poker_player_p1

PokerAgent(
  (cards): Cards(
    (rank_embedder): Embedding(14, 1024)
    (suit_embedder): Embedding(5, 1024)
  )
  (street_embedder): StreetPositionalEncoding(
    (street_embedder): Embedding(7, 256)
  )
  (table_position_embedder): TablePositionalEncoding(
    (player_embedder): Embedding(4, 256)
  )
  (action_embedder): ActionEncoding(
    (action_embedder): Embedding(22, 256)
  )
  (pot_size_embedder): PotSizeSequenceEmbedder()
  (poker_sequence_embedder): PokerSequenceEmbedder(
    (street_MLP): Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Linear(in_features=256, out_features=512, bias=True)
      (4): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (5): ReLU()
      (6): Linear(in_features=512, out_features=1024, bias=True)
      (7): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (8): ReLU()
      (9): Linear(in_features=1024, o

In [10]:
street_idxs = (torch.zeros(1024) + 6).long().unsqueeze(0)
street_idxs[0, :2] = 0 # posting small blind and big blind

table_position_idxs = (torch.zeros(1024) + 2).long().unsqueeze(0)
table_position_idxs[0, 0] = 0 # sb/b
table_position_idxs[0, 1] = 1 # bb

action_idxs = (torch.zeros(1024) + 21).long().unsqueeze(0)
action_idxs[0 ,0] = 0 # post sb
action_idxs[0 ,1] = 1 # post bb

pot_size_sequence = (torch.zeros(1024) - 1).unsqueeze(0)
pot_size_sequence[0, 0] = 1
pot_size_sequence[0, 1] = 3


active_players = torch.Tensor([
    [1, 1]
])
print(active_players.shape)
stack_size = torch.Tensor(
    [
        [399, 398],
    ]
)
stack_size.shape

torch.Size([1, 2])


torch.Size([1, 2])

In [11]:
sb_cards = torch.zeros((1,2,7)).to(dtype=torch.long).to('cuda')
bb_cards = torch.zeros((1,2,7)).to(dtype=torch.long).to('cuda')

In [12]:
deck_order_shuffled[0, :2]

tensor([18, 23])

In [13]:
sb_cards[0,0,:2] = deck_order_shuffled[0, :2]%13
sb_cards[0,1,:2] = deck_order_shuffled[0, :2]//13

sb_cards[0,0,2:] = 13
sb_cards[0,1,2:] = 4


bb_cards[0,0,:2] = deck_order_shuffled[0, 2:4]%13
bb_cards[0,1,:2] = deck_order_shuffled[0, 2:4]//13

bb_cards[0,0,2:] = 13
bb_cards[0,1,2:] = 4

In [14]:
sb_cards

tensor([[[ 5, 10, 13, 13, 13, 13, 13],
         [ 1,  1,  4,  4,  4,  4,  4]]], device='cuda:0')

In [15]:
bb_cards

tensor([[[ 3,  1, 13, 13, 13, 13, 13],
         [ 2,  1,  4,  4,  4,  4,  4]]], device='cuda:0')

In [16]:
p1_position_player = torch.Tensor([0]).to('cuda')
p2_position_player = torch.Tensor([1]).to('cuda')

In [17]:
outputs = poker_player_p1(
    p1_position_player,
    sb_cards,
    street_idxs,
    table_position_idxs,
    action_idxs,
    pot_size_sequence,
    active_players.to('cuda'),
    stack_size.to('cuda')
)

In [18]:
outputs['probits']

tensor([[ 1.1178,  1.1956, -0.5090,  1.5286, -0.0287, -1.3182, -0.8210,  1.4246,
         -0.4450,  0.6004,  0.1377, -0.4218,  0.0942, -0.3271,  0.1113, -0.9271,
          1.5072,  1.3399, -0.5881,  0.3559, -0.3620]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [19]:
legal_actions = action_validator.get_legal_actions_mask(
    street_idxs,
    table_position_idxs,
    action_idxs, 
    pot_size_sequence,
    active_players
)

In [20]:
legal_actions

tensor([[False, False,  True, False,  True, False,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True, False, False, False,
         False]])

In [21]:
outputs['probits'].shape

torch.Size([1, 21])

In [22]:
softmax_prob(outputs['probits'] - 1e9 * ((~legal_actions).float()).to('cuda'))

tensor([[0.0000, 0.0000, 0.0329, 0.0000, 0.0531, 0.0000, 0.0241, 0.2273, 0.0350,
         0.0997, 0.0628, 0.0359, 0.0601, 0.0394, 0.0611, 0.0216, 0.2469, 0.0000,
         0.0000, 0.0000, 0.0000]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

In [23]:
sampled_action = torch.distributions.Categorical(softmax_prob(outputs['probits'] - 1e9 * ((~legal_actions).float()).to('cuda'))).sample()

In [24]:
def get_min_bet_size_or_raise_size(
    pot_size_sequence,
    street_idxs,
    action_idxs,
    current_street,
):
    """
    Gets the minimum bet or raise size allowed in a position.
    Minimum allowable bet size is going to be 1 big blind or quarter pot.
    Minimum allowable raise size will be 1 big blinds or 1x the most recent increment, resulting in a bet size
    of either 2bb or 2 times the last increment.
    """
    
    current_street_idxs = street_idxs[street_idxs == current_street]

    if len(current_street_idxs) == 0:
        return max(2, max(pot_size_sequence)/4)

    current_street_actions = action_idxs[street_idxs == current_street]

    if ((current_street_actions == 3).int() + (current_street_actions == 2).int()).sum() != current_street_actions.shape[0]:
        where_max = torch.where(pot_size_sequence == max(pot_size_sequence[0]))
        raise_size = pot_size_sequence[where_max[0][0], where_max[1][0]] - pot_size_sequence[where_max[0][0], where_max[1][0] - 1]
        return max([4, 2*raise_size])

In [25]:
pot_size_sequence

tensor([[ 1.,  3., -1.,  ..., -1., -1., -1.]])

In [26]:
pot_size_sequence[0,1-1]

tensor(1.)

In [27]:
min_size = get_min_bet_size_or_raise_size(
    pot_size_sequence,
    street_idxs,
    action_idxs,
    0
)

In [28]:
if legal_actions[0, 5] == True:
    spacing = 16-5+1 # (number of sizes)
else:
    spacing = 16-6+1 # (number of sizes)

my_stack_size = stack_size[0, p1_position_player.int().to('cpu')].squeeze()

multiplicative_factor = (my_stack_size / min(min_size, my_stack_size)) ** (1/(spacing-1))

bet_sizes = torch.ones(spacing) * min(min_size, my_stack_size)
bet_sizes = torch.ceil(bet_sizes * (multiplicative_factor ** torch.arange(spacing)))

In [29]:
sampled_action.to('cpu')[0]

tensor(16)

In [30]:
if 5<= sampled_action <=16:
    print('is a bet')

is a bet


In [31]:
chosen_bet = bet_sizes[sampled_action.to('cpu')[0] - 17]

In [32]:
amount_to_call = chosen_bet.clone()

In [40]:
street_idxs[0, 2] = 0
table_position_idxs[0, 2] = p1_position_player.int().to('cpu')
action_idxs[0, 2] = sampled_action.to('cpu')[0]
pot_size_sequence[0, 2] = pot_size_sequence[0, 1] + bet_sizes[sampled_action.to('cpu')[0] - 17]
stack_size[0, p1_position_player.int().to('cpu')] -= bet_sizes[sampled_action.to('cpu')[0] - 17]


In [41]:
street_idxs[0,0:10]

tensor([0, 0, 0, 6, 6, 6, 6, 6, 6, 6])

In [42]:
table_position_idxs[0,0:10]

tensor([0, 1, 0, 2, 2, 2, 2, 2, 2, 2])

In [43]:
action_idxs[0,0:10]

tensor([ 0,  1, 16, 21, 21, 21, 21, 21, 21, 21])

In [44]:
pot_size_sequence[0,0:10]

tensor([  1.,   3., 402.,  -1.,  -1.,  -1.,  -1.,  -1.,  -1.,  -1.])

In [45]:
print(stack_size)

tensor([[  0., 398.]])


In [46]:
outputs_p2 = poker_player_p2(
    p2_position_player,
    bb_cards,
    street_idxs,
    table_position_idxs,
    action_idxs,
    pot_size_sequence,
    active_players.to('cuda'),
    stack_size.to('cuda')
)

In [47]:
print(outputs['probits'])

legal_actions = action_validator.get_legal_actions_mask(
    street_idxs,
    table_position_idxs,
    action_idxs, 
    pot_size_sequence,
    active_players
)

print(legal_actions)

print(softmax_prob(outputs['probits'] - 1e9 * ((~legal_actions).float()).to('cuda')))

sampled_action = torch.distributions.Categorical(
    softmax_prob(outputs['probits'] - 1e9 * ((~legal_actions).float()).to('cuda'))
).sample()

print(sampled_action)

min_size = get_min_bet_size_or_raise_size(
    pot_size_sequence,
    street_idxs,
    action_idxs,
    0
)

print(min_size)

tensor([[ 1.1178,  1.1956, -0.5090,  1.5286, -0.0287, -1.3182, -0.8210,  1.4246,
         -0.4450,  0.6004,  0.1377, -0.4218,  0.0942, -0.3271,  0.1113, -0.9271,
          1.5072,  1.3399, -0.5881,  0.3559, -0.3620]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
tensor([[False, False,  True, False,  True, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False]])
tensor([[0.0000, 0.0000, 0.3822, 0.0000, 0.6178, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
tensor([2], device='cuda:0')
tensor(798.)


In [56]:
street_idxs[0, 3] = 1
table_position_idxs[0, 3] = p2_position_player.int().to('cpu')
action_idxs[0, 3] = sampled_action.to('cpu')[0]

if sampled_action==2:
    amount_to_call = 0
else:  
    amount_to_call = my_stack_size - min(stack_size.ravel())

pot_size_sequence[0, 3] = pot_size_sequence[0, 2] + amount_to_call
stack_size[0, p2_position_player.int().to('cpu')] -= amount_to_call
active_players[0, p2_position_player.int().to('cpu')] = 0

In [57]:
print(street_idxs[0, 0:10])
print(table_position_idxs[0, 0:10])
print(action_idxs[0, 0:10])
print(pot_size_sequence[0, 0:10])
print(stack_size)

tensor([0, 0, 0, 1, 6, 6, 6, 6, 6, 6])
tensor([0, 1, 0, 1, 2, 2, 2, 2, 2, 2])
tensor([ 0,  1, 16,  2, 21, 21, 21, 21, 21, 21])
tensor([  1.,   3., 402., 402.,  -1.,  -1.,  -1.,  -1.,  -1.,  -1.])
tensor([[ 0., -1.]])


In [58]:
legal_actions = action_validator.get_legal_actions_mask(
    street_idxs,
    table_position_idxs,
    action_idxs, 
    pot_size_sequence,
    active_players
)

In [59]:
legal_actions

tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
          True]])