In [1]:
import torch
import numpy as np
import typing

from ml_modules import *
from utils import *

In [2]:
street_embedder = StreetPositionalEncoding(
    num_streets = 4,
    embedding_dim = 256,
    max_seq_len = 128,
    device = "cpu"
)

table_position_embedder = TablePositionalEncoding(
    num_players = 2,
    embedding_dim = 256,
    max_seq_len = 128,
    device = "cpu"
)

action_embedder = ActionEncoding(
    #num_actions = 21,
    embedding_dim = 256,
    max_seq_len = 128,
    device = "cpu"
)

In [3]:
test = PokerActionValidator(
    num_players = 2,
    small_blind = 1,
    big_blind = 2,
    starting_stack_sizes = 400
)

In [4]:
street_idxs = torch.Tensor([
    [0, 0, 6, 6, 6, 6, 6, 6],
    [0, 0, 0, 6, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 6, 6, 6, 6],
    [0, 0, 0, 0, 4, 6, 6, 6], 
    [0, 0, 0, 0, 4, 0, 0, 6], 
    [0, 0, 0, 0, 4, 0, 0, 6], 
    [0, 0, 0, 0, 4, 0, 0, 6], 
    [0, 0, 0, 0, 4, 0, 0, 6], 
    
])
table_position_idxs = torch.Tensor([
    [0, 1, 2, 2, 2, 2, 2, 2],
    [0, 1, 0, 2, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 2, 2, 2],
    [0, 1, 0, 1, 2, 0, 1, 2],
    [0, 1, 0, 1, 2, 0, 1, 2],
    [0, 1, 0, 1, 2, 0, 1, 2],
    [0, 1, 0, 1, 2, 0, 1, 2],

])
action_idxs = torch.Tensor([
    [0, 1, 21, 21, 21, 21, 21, 21],
    [0, 1, 4, 21, 21, 21, 21, 21],
    [0, 1, 5, 4, 21, 21, 21, 21],
    [0, 1, 4, 3, 21, 21, 21, 21],
    [0, 1, 5, 5, 21, 21, 21, 21],
    [0, 1, 5, 2, 21, 21, 21, 21],
    [0, 1, 5, 4, 19, 21, 21, 21],
    [0, 1, 5, 4, 19, 3, 3, 21],
    [0, 1, 5, 4, 19, 5, 2, 21],
    [0, 1, 5, 4, 19, 5, 4, 21],
    [0, 1, 5, 4, 19, 3, 5, 21],
])
pot_size_sequence = torch.Tensor([
    [1, 3, -1, -1, -1, -1, -1, -1],
    [1, 3, 5, -1, -1, -1, -1, -1],
    [1, 3, 7, 8, -1, -1, -1, -1],
    [1, 3, 5, 5, -1, -1, -1, -1],
    [1, 3, 7, 15, -1, -1, -1, -1],
    [1, 3, 7, 7, -1, -1, -1, -1],
    [1, 3, 7, 8, 8, -1, -1, -1],
    [1, 3, 7, 8, 8, 8, 8, -1],
    [1, 3, 7, 8, 8, 10, 10, -1],
    [1, 3, 7, 8, 8, 10, 12, -1],
    [1, 3, 7, 8, 8, 8, 10, -1],
])
active_players = torch.Tensor([
    [1, 1],
    [1, 1],
    [1, 1],
    [1, 1],
    [1, 1],
    [1, 0],
    [1, 1],
    [1, 1],
    [1, 0],
    [1, 1],
    [1, 1],
])

In [5]:
street_embedding = {
    'idxs':street_embedder(street_idxs)[0],
    'embeddings':street_embedder(street_idxs)[1]
}

table_positions_embedding = {
    'idxs': table_position_embedder(table_position_idxs)[0],
    'embeddings': table_position_embedder(table_position_idxs)[1]
}

actions_embedding = {
    'idxs': action_embedder(action_idxs)[0],
    'embeddings': action_embedder(action_idxs)[1]
}



In [6]:
street_embedding['embeddings'].shape

torch.Size([11, 128, 256])

In [7]:
street_embedder.street_embedder

Embedding(7, 256)

In [8]:
street_idxs.shape


torch.Size([11, 8])

In [9]:
legal_actions = test.get_legal_actions_mask(
    street_idxs,
    table_position_idxs,
    action_idxs, 
    pot_size_sequence,
    active_players
)

In [10]:
legal_actions[-1]

tensor([False, False,  True, False,  True, False,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True, False, False, False,
        False])

In [11]:
pot_size_sequence = torch.Tensor([
    [1, 3, -1, -1, -1, -1, -1, -1],
    [1, 3, 5, -1, -1, -1, -1, -1],
    [1, 3, 7, 8, -1, -1, -1, -1],
    [1, 3, 5, 5, -1, -1, -1, -1],
    [1, 3, 7, 15, -1, -1, -1, -1],
    [1, 3, 7, 7, -1, -1, -1, -1],
    [1, 3, 7, 8, 8, -1, -1, -1],
    [1, 3, 7, 8, 8, 8, 8, -1],
    [1, 3, 7, 8, 8, 10, 10, -1],
    [1, 3, 7, 8, 8, 10, 12, -1],
    [1, 3, 7, 8, 8, 8, 10, -1],
])

expanded_ps_sequence = torch.full(
            (11, 128),
            fill_value=-1,
            dtype=torch.float,
            device=pot_size_sequence.device
        )

expanded_ps_sequence[:, :pot_size_sequence.shape[1]] = pot_size_sequence

In [12]:
legal_actions = test.get_legal_actions_mask(
    street_embedding['idxs'],
    table_positions_embedding['idxs'],
    actions_embedding['idxs'], 
    expanded_ps_sequence,
    active_players
)

In [13]:
legal_actions[-2]

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,  True,
        False])

In [None]:
class PokerSequenceEmbedder(nn.Module):
    """
    Projects per-timestep poker features into a shared latent space
    suitable as input to a Transformer.
    """

    def __init__(
        self,
        street_input_dimension: int,
        table_position_input_dimension: int,
        action_input_dimension: int,
        latent_dimensions: typing.List[int],
        device: str,
    ):
        super().__init__()

        self.device = device

        def make_mlp(input_dim: int) -> nn.Sequential:
            dims = [input_dim] + latent_dimensions
            layers = []
            for i in range(len(dims) - 1):
                layers.append(torch.nn.Linear(dims[i], dims[i + 1]))
                layers.append(torch.nn.LayerNorm(dims[i + 1]))
                if i < len(dims) - 2:  # no ReLU on final layer
                    layers.append(torch.nn.ReLU())
            return torch.nn.Sequential(*layers)

        self.street_MLP = make_mlp(street_input_dimension)
        self.table_position_MLP = make_mlp(table_position_input_dimension)
        self.action_MLP = make_mlp(action_input_dimension)
        self.pot_MLP = make_mlp(1)

        self.to(device)

    def forward(self, model_inputs: dict) -> dict:
        """
        Parameters:
            model_inputs : dict with keys
                street_embedding: Tensor (B, T, D_street).
                table_position_embedding: Tensor (B, T, D_position).
                action_embedding: Tensor (B, T, D_action).
                pot_size_sequence: Tensor (B, T, 1).

        Returns:
            model_inputs : dict
                Same dict with additional key:
                sequence_embedding: Tensor (B, T, D_latent).
        """

        street = model_inputs["street_embedding"]
        position = model_inputs["table_position_embedding"]
        action = model_inputs["action_embedding"]
        pot = model_inputs["pot_size_sequence"]

        street_latent = self.street_MLP(street)
        position_latent = self.table_position_MLP(position)
        action_latent = self.action_MLP(action)
        pot_latent = self.pot_MLP(pot)

        return street_latent + position_latent + action_latent + pot_latent
