In [1]:
import torch
from torch import nn
import math
import numpy as np

In [2]:
class grad_skip_softmax(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sm = nn.Softmax()
    
    def forward(self, x):
        return self.sm(x)

    def backward(self, grad):
        # skip gradient through the softmax on backward pass
        return grad
        
class gru(nn.Module):
    # 'gated-recurrent-unit type gating' as seen in GTrXL paper
    def __init__(self, dim, b_g = 1) -> None:
        super().__init__()

        self.w_r = nn.Linear(dim, dim, bias = False)
        self.u_r = nn.Linear(dim, dim, bias = False)

        self.w_z = nn.Linear(dim, dim, bias = True)
        self.u_z = nn.Linear(dim, dim, bias = True)
        self.b_g = b_g # this is used to hack initial bias of the above to be below zero, such that gate is initialized close to identity
        self.w_g = nn.Linear(dim, dim, bias = False)
        self.u_g = nn.Linear(dim, dim, bias = False)

        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(self, x, y):
        r = self.sigmoid(self.w_r(y) + self.u_r(x))
        z = self.sigmoid(self.w_z(y) + self.u_z(x) - self.b_g) # when zero, gate passes identity of residual
        h_hat = self.tanh(self.w_g(y) + self.u_g(r * x))
        g = (1-z)*x + z * h_hat
        return g
        

class mlp(nn.Module):
    # 1d temporal convolution
    # no communication between tokens, uses same kernel for each token spot
    def __init__(self, embed_dim, internal_dim) -> None:
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(embed_dim, internal_dim),
            nn.ReLU(),
            nn.Linear(internal_dim, embed_dim)
        ) # no second relu at output of mlp

    def forward(self, input):
        return self.block(input)


class cross_attention(nn.Module):
    def __init__(self, embed_dimension, num_heads) -> None:
        super().__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dimension,
            num_heads=num_heads,
            )
    
    def forward(self, x, enc):

        return self.attention(x, enc, enc)[0]

class self_attention(nn.Module):
    def __init__(self, embed_dimension, num_heads) -> None:
        super().__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dimension,
            num_heads=num_heads,
            )
    
    def forward(self, x):

        return self.attention(x, x, x)[0]







In [3]:


class Smear_key(nn.Module):

    def __init__(self,
    sequence_length,
    heads
    ) -> None:
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1, heads, sequence_length - 1, 1))
        self.sigmoid = nn.Sigmoid()

    def forward(self, k):
        itrp = self.sigmoid(self.alpha)
        smear = k[:,:,1:,:]*itrp + k[:,:,:-1,:]*(1-itrp)
        return torch.cat([k[:,:, 0:1, :], smear], dim = 2)

class decoder_mha(nn.Module):
    #Masked smeared self attention
    def __init__(self, model_dim, sequence_length, heads) -> None:
        super().__init__()
        self.mask = torch.triu(torch.ones(sequence_length, sequence_length) * float('-inf'), diagonal=1) # make batch, heads, seq,seq
        self.model_dim = model_dim
        self.sequence_length = sequence_length
        self.heads = heads
        self.key_dim = model_dim // heads
        self.W_q = nn.Linear(model_dim, model_dim, bias=False)
        self.W_k = nn.Linear(model_dim, model_dim, bias=False)
        self.W_v = nn.Linear(model_dim, model_dim, bias=False)
        self.output = nn.Linear(model_dim, model_dim, bias=True)
        self.ln = nn.LayerNorm(model_dim)
        self.smear = Smear_key(sequence_length, heads)

    def forward(self,x):
        # batch, sequence, model_dim
        q = self.W_q(x).view(-1, self.sequence_length, self.heads, self.key_dim).transpose(1,2)
        k = self.W_k(x).view(-1, self.sequence_length, self.heads, self.key_dim).transpose(1,2)
        v = self.W_v(x).view(-1, self.sequence_length, self.heads, self.key_dim).transpose(1,2)
        k = self.smear(k)
        # batch, heads, sequence, dim // heads
        key_dim = k.shape[-1:][0]
        scores = q @ k.transpose(2,3) / key_dim**.5
        scores += self.mask
        attn = torch.softmax(scores, dim = 3)
        mha = attn @ v
        mha = mha.transpose(1, 2).contiguous().view(-1, self.sequence_length, self.model_dim)
        out = self.output(mha)
        # batch, sequence, model_dim
        return out





In [4]:
class encoder_layer(nn.Module):
    # transformer layer
    # not masked, no cross attention, no memory, for encoder
    def __init__(self,
    embed_dim,
    mlp_dim,
    attention_heads,
    sequence_length
    ) -> None:
        super().__init__()

        self.mha = self_attention(
            embed_dimension=embed_dim,
            num_heads=attention_heads
        )

        self.mlp = mlp(
            embed_dim= embed_dim,
            internal_dim=mlp_dim
        )

        self.gate1 = gru(
            dim = embed_dim
        )
        self.gate2 = gru(
            dim = embed_dim
        )

        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

        self.activation = nn.ReLU()
    
    def forward(self, x):
        y = self.ln1(x)
        y = self.mha(y)
        x = self.gate1(x,self.activation(y))
        y = self.ln1(x)
        y = self.mlp(y)
        x = self.gate2(x, self.activation(y))
        
        return x

class decoder_layer(nn.Module):
    # transformer layer
    # masked, cross attention, smeared key
    def __init__(self,
    embed_dim,
    mlp_dim,
    attention_heads,
    sequence_lenth
    ) -> None:
        super().__init__()

        self.mha = decoder_mha(
            model_dim=embed_dim,
            sequence_length=sequence_lenth,
            heads=attention_heads
        ) #smeared key masked self attention

        self.cross_mha = cross_attention(
            embed_dimension = embed_dim,
            num_heads = attention_heads,
        )

        self.mlp = mlp(
            embed_dim = embed_dim,
            internal_dim = mlp_dim
        )

        self.gate1 = gru(
            dim = embed_dim
        )
        self.gate2 = gru(
            dim = embed_dim
        )
        self.gate3 = gru(
            dim = embed_dim
        )

        self.ln = nn.LayerNorm(embed_dim)

        self.activation = nn.ReLU()
        self.ln1 = nn.LayerNorm(embed_dim)
    
    def forward(self, x, enc):
        # masked self attention, smeared key
        y = self.ln1(x)
        y = self.mha(y)
        x = self.gate1(x,self.activation(y))

        # cross attention
        # consider output sequence length and 
        y = self.ln1(x)
        enc = self.ln1(enc)
        y = self.cross_mha(enc, x)
        x = self.gate2(x, self.activation(y))

        # position-wise multi layper perceptron
        y = self.ln1(x)
        y = self.mlp(y)
        x = self.gate2(x, self.activation(y))
        
        return x

In [5]:
class encoder(nn.Module):
    def __init__(self,
    layers,
    model_dim,
    mlp_dim,
    heads,
    sequence_length
    ) -> None:
        super().__init__()
        
        # no inductive biases on encoder here
        self.block = nn.Sequential()
        for x in range(layers):
            self.block.append(encoder_layer(
                embed_dim = model_dim,
                mlp_dim = mlp_dim,
                attention_heads = heads,
                sequence_length = sequence_length
            ))
            
    def forward(self, x):
        return self.block(x)

class decoder(nn.Module):
    def __init__(self,
    layers,
    model_dim,
    mlp_dim,
    heads,
    sequence_length
    ) -> None:
        super().__init__()

        self.pe = positional_encoding(
            model_dim=model_dim, 
            sequence_length=sequence_length
            )

        self.block = []

        for x in range(layers):
            self.block.append(
                decoder_layer(
                    embed_dim = model_dim,
                    mlp_dim= mlp_dim,
                    attention_heads= heads,
                    sequence_lenth = sequence_length
                )
            )
        
    def forward(self, x, y):
        # y is input from encoder
        x = self.pe(x)
        for layer in self.block:
            x = layer(x,y)
            
        return x

        

In [6]:
# positional encoding class drawn largely from tutorial on pytorch website
class positional_encoding(nn.Module):
    # tested and functional
    def __init__(self,
    model_dim,
    sequence_length
    ) -> None:
        super().__init__()

        position = torch.arange(sequence_length).unsqueeze(1)
        freq = torch.exp(torch.arange(0, model_dim, 2) * (-math.log(10000.0) / model_dim))
        pe = torch.zeros(1, sequence_length, model_dim)
        pe[0, :, 0::2] = torch.sin(position * freq)
        pe[0, :, 1::2] = torch.cos(position * freq)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:x.size(1)]


In [7]:
class RLformer(nn.Module):
    # tested and functional

    def __init__(self,
    model_dim,
    mlp_dim,
    attn_heads,
    sequence_length,
    enc_layers,
    dec_layers,
    action_dim
    ) -> None:
        super().__init__()

        self.positional_encoder = positional_encoding(
            model_dim=model_dim,
            sequence_length = sequence_length
        )

        self.encoder = encoder(
            layers=enc_layers,
            model_dim=model_dim,
            mlp_dim=mlp_dim,
            heads=attn_heads,
            sequence_length = sequence_length
        )

        self.decoder = decoder(
            layers=dec_layers,
            model_dim= model_dim,
            mlp_dim=mlp_dim,
            heads=attn_heads,
            sequence_length=sequence_length,
        )

        self.actor = nn.Sequential(
            nn.Linear(model_dim, action_dim),
            nn.ReLU(),
            grad_skip_softmax() # To do neural replicator dynamics
        )

        self.critic = nn.Sequential(
            nn.Linear(model_dim, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, 1)
        )


    def forward(self, enc_input, dec_input):

        enc = self.encoder(enc_input)
        dec_input = self.positional_encoder(dec_input)
        dec = self.decoder(dec_input, enc)
        policy = self.actor(dec)
        value = self.critic(dec)
        return policy, value

In [8]:
class Agent(nn.Module):
    def __init__(self,
        model_dim,
        mlp_dim,
        attn_heads,
        sequence_length,
        enc_layers,
        dec_layers,
        action_dim,
    ) -> None:
        super().__init__()
        self.model = RLformer(
            model_dim = model_dim,
            mlp_dim = mlp_dim,
            attn_heads = attn_heads,
            sequence_length = sequence_length,
            enc_layers = enc_layers,
            dec_layers = dec_layers,
            action_dim = action_dim,
        )

        self.hand_tokenizer = None # hand tokenizer HERE
        self.seq_tokenizer = None # sequence tokenizer HERE
    
    def init_player(self, player, hand):
        # initialize this players hand and tokenize it, store it in buffer
        hand_tensor = hand_tokenizer(hand)
        self.register_buffer(f'hand_{player}', tensor= hand_tensor)

    def forward(self, player, obs_flat):
        #takes flattened inputs in list form, not tokenized
        enc_input = self.get_buffer(f'hand_{player}')
        dec_input = seq_tokenizer(obs_flat)
        policy, value = self.model(enc_input, dec_input)

        return policy, value
        

In [9]:
from pokerenv import poker_env

In [10]:
from itertools import chain

class actor_critic():
    #Needs to be able to run hand, return loss with grad enabled
    def __init__(self, 
    model_dim,
    mlp_dim,
    heads,
    enc_layers,
    dec_layers,
    max_sequence: int = 200, 
    n_players: int = 2,
    gamma: float = .8,
    n_actions: int = 10, # random placeholder value
    ) -> None:
        self.gamma = gamma
        self.env = poker_env(n_players = n_players)
        self.agent = Agent(
            model_dim = model_dim,
            mlp_dim = mlp_dim,
            attn_head = heads,
            sequence_length = max_sequence,
            enc_layers = enc_layers,
            dec_layers = dec_layers,
            action_dim = n_actions,
        )

        self.observations = [] #this will be a list of lists, each is the list of observations in a hand
        self.obs_flat = list(chain(*self.observations))
        
        self.rewards = []
        self.rewards_flat = list(chain(*self.rewards))

        self.values = []
        self.val_flat = list(chain(*self.values))

        self.action_log_probabilies = []
        self.alp_flat = list(chain(*self.action_log_probabilies))

        self.max_sequence = max_sequence

        self.n_players = n_players

        self.n_actions = n_actions

        self.detokenize = None #detokenizer HERE

    def init_hands(self):
        # get all hands
        # run encoder for each of players
        for player in range(self.n_players):
            hand = self.env.get_hand(player)
            self.agent.init_player(player, hand)
    
    def chop_seq(self):
        #if length of observations is above a certain size, chop it back down to under sequence length by removing oldest hand
        #return flattened version to give to model on next run
        if len(self.observations) > self.max_sequence:
            self.observations = self.observations[1:]
            self.obs_flat = list(chain(*self.observations))

            self.rewards = self.rewards[1:]
            self.rewards_flat = list(chain(*self.rewards_flat))

            self.values = self.values[1:]
            self.val_flat = list(chain(*self.values))

            self.action_log_probabilies = self.action_log_probabilies[1:]
            self.alp_flat = list(chain(*self.action_log_probabilies))

        else:
            self.obs_flat = list(chain(*self.observations))
            self.rewards_flat = list(chain(*self.rewards_flat))
            self.val_flat = list(chain(*self.values))
            self.alp_flat = list(chain(*self.action_log_probabilies))

    def play_hand(self):
        # makes agent play one hand
        
        # deal cards
        rewards, observations = self.env.new_hand() # start a new hand
        self.init_hands() # pre load all of the hands

        # init lists for this hand
        self.observations += [observations] 
        self.rewards += [rewards]

        self.chop_seq() # prepare for input to model
        
        hand_over = False
        while not hand_over:                

            # get values and policy -- should be in list form over sequence length
            policy, values = self.agent(self.obs_flat)
            value = values[-1].detach().numpy()[0,0] # get last value estimate
            dist = policy[-1].detach().numpy() # get last policy distribution

            # randomly sample an action
            action = np.random.choice(self.n_actions, p=np.squeeze(dist))

            # UNFINISHED: Need to detokenize actions HERE
            action = self.detokenize(action)

            alp = torch.log(policy.squeeze(0)[action])
            reward, obs, hand_over = self.env.take_action(action) # need to change environment to return hand_over boolean

            # add new information from this step
            self.rewards[-1].append(reward)
            self.observations[-1].append(obs)
            self.values[-1].append(value)
            self.action_log_probabilies.append(alp)
            
            # prepare for next action
            self.chop_seq()
        
        V_T, _ = self.agent(self.obs_flat)
        
        # process gradients and return loss:
        return self.get_loss(V_T)

    def get_loss(self, values, rewards, V_T):

        Qs = []
        Q_t = V_T
        for t in reversed(range(len(rewards))):
            Q_t = rewards[t] + self.gamma * Q_t
            Qs[t] = Q_t
        
        Qs = torch.FloatTensor(Qs)
        values = torch.FloatTensor(self.val_flat)
        alps = torch.stack(self.alp_flat)
        advantages = Qs - values

        
        actor_loss = (-alps * advantages).mean() # loss function for policy going into softmax on backpass
        critic_loss = 0.5 * advantages.pow(2).mean() # autogressive critic loss - MSE
        loss = actor_loss + critic_loss # no entropy in this since that would deviate from deepnash
        return loss
    

        

In [18]:
from tokenizer import tokenize

class Tokenizer(nn.Module):

    def __init__(self, model_dim) -> None:
        super().__init__()
        self.embedding = nn.Sequential(
            nn.Linear(36, model_dim), # tokenizer has 36 dimensional output
            nn.ReLU() # allows feature superposition in embedding
        )

    def tokenize_list(self, observations):
        # convert list of observations to 2d tensor
        seq = []
        for obs in observations:
            seq.append(tokenize(obs))
        
        obs_tensor = torch.stack(seq) #sequence, model_dim
        return obs_tensor
    
    def forward(self, observations):
        obs_tensor = self.tokenize_list(observations)
        return self.embedding(obs_tensor) # sequence, model_dim

In [25]:
env = poker_env(n_players=6)

In [31]:
rewards, obs = env.new_hand()
print(obs)

[{'player': 3, 'type': 'bet', 'value': 1, 'pot': 0, 'p1': 200, 'p2': 198, 'p3': 199, 'p4': 198, 'p5': 200, 'p6': 200, '6': 200}, {'player': 4, 'type': 'bet', 'value': 2, 'pot': 1, 'p1': 200, 'p2': 198, 'p3': 199, 'p4': 197, 'p5': 200, 'p6': 200, '6': 200}]


In [27]:
action = {
    'player': 1,
    'type': 'bet',
    'value': 2
}
rewards_1, obs_1 = env.take_action(action)
rewards += rewards_1
obs += obs_1

In [28]:
mytok = Tokenizer(model_dim=64)

In [29]:
sequence = mytok(obs)

In [30]:
sequence.size()

torch.Size([3, 64])