In [88]:
import torch
from torch import nn
import math

In [412]:
class grad_skip_softmax(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sm = nn.Softmax()
    
    def forward(self, x):
        return self.sm(x)

    def backward(self, grad):
        # skip gradient through the softmax on backward pass
        return grad
        
class gru(nn.Module):
    # 'gated-recurrent-unit type gating' as seen in GTrXL paper
    def __init__(self, dim, b_g = 1) -> None:
        super().__init__()

        self.w_r = nn.Linear(dim, dim, bias = False)
        self.u_r = nn.Linear(dim, dim, bias = False)

        self.w_z = nn.Linear(dim, dim, bias = True)
        self.u_z = nn.Linear(dim, dim, bias = True)
        self.b_g = b_g # this is used to hack initial bias of the above to be below zero, such that gate is initialized close to identity
        self.w_g = nn.Linear(dim, dim, bias = False)
        self.u_g = nn.Linear(dim, dim, bias = False)

        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(self, x, y):
        r = self.sigmoid(self.w_r(y) + self.u_r(x))
        z = self.sigmoid(self.w_z(y) + self.u_z(x) - self.b_g) # when zero, gate passes identity of residual
        h_hat = self.tanh(self.w_g(y) + self.u_g(r * x))
        g = (1-z)*x + z * h_hat
        return g
        

class mlp(nn.Module):
    # 1d temporal convolution
    # no communication between tokens, uses same kernel for each token spot
    def __init__(self, embed_dim, internal_dim) -> None:
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(embed_dim, internal_dim),
            nn.ReLU(),
            nn.Linear(internal_dim, embed_dim)
        ) # no second relu at output of mlp

    def forward(self, input):
        return self.block(input)


class cross_attention(nn.Module):
    def __init__(self, embed_dimension, num_heads) -> None:
        super().__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dimension,
            num_heads=num_heads,
            )
    
    def forward(self, x, enc):

        return self.attention(x, enc, enc)[0]

class mha_helper(nn.Module):
    def __init__(self,
    dim_model,
    heads,
    bias: bool = False,
    smeared: bool = False
    ) -> None:
        super().__init__()
        self.d_m = dim_model
        self.heads = heads
        self.d_k = dim_model // heads

        self.affine = nn.Linear(
            in_features = dim_model,
            out_features = heads * self.d_k,
            bias = bias
            )
        
        # key smear functionality
        self.smeared = smeared
        self.previous  = None
        self.alpha = torch.tensor([1.], requires_grad=True) # learned interpolation parameter, initialized to 1
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        shape = x.shape[:-1]
        x = self.affine(x)

        if self.smeared:
            itpl = self.sigmoid(self.alpha) # interpolation value
            x = itpl * x + (1 - itpl) * self.previous if (self.previous != None) else x
            self.previous = x
        x = x.view(*shape, self.heads, self.d_k)
        return x

class self_attention(nn.Module):
    # smeared key self attention
    # consider adding stop grad memory
    def __init__(self,
    dim_model,
    heads,
    sequence_length,
    masked = False,
    smeared_key = False
    ) -> None:
        super().__init__()
        self.masked = masked
        self.heads = heads
        self.q = mha_helper(dim_model, heads)
        self.k = mha_helper(dim_model, heads, smeared=smeared_key)
        self.v = mha_helper(dim_model, heads, bias=True)
        self.scaler = 1 / math.sqrt(float(dim_model // heads))
        self.softmax = nn.Softmax(1) # transform to stochastic matrix -- rows sum to one
        self.output_layer = nn.Sequential(
            nn.Linear(dim_model, dim_model),
            nn.ReLU()
        )
        self.smeared_key = smeared_key
        self.mask = torch.tril(torch.ones(sequence_length,sequence_length)).unsqueeze(-1)

    def mask_format(self, mask: torch.Tensor, query_shape: list[int], key_shape: list[int]):
        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
        assert mask.shape[1] == key_shape[0]
        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
        mask = mask.unsqueeze(-1)
        return mask
        
    def attn_score(self, q, k):
        return torch.einsum('ibhd, jbhd -> ijbh', q, k) # dot product attention
        # d is supposed to disappear, so must be heads/dim, h probably means heads, b is probably batch then, so i and j should be sequence sequence
        # i  is batch, j is batch for other, h is heads, d is heads over dim, b is sequence
        # batch batch sequence heads
        # should be batch, sequence, sequence, 
        #key is batch, sequence, heads, heads over him

    def forward(self, input):
        length, batch_size, _ = input.shape
        q = self.q(input)
        k = self.k(input)
        v = self.v(input)
        
        scores = self.attn_score(q,k) * self.scaler

        if self.masked:
            mask = self.mask_format(self.mask, q.shape, k.shape)
            masked_scores = scores.masked_fill(mask == 0, float('-inf'))
        else:
            masked_scores = scores
        attn = self.softmax(masked_scores)

        x = torch.einsum('ijbh,jbhd->ibhd', attn, v)

        x = x.reshape(length, batch_size, -1)

        return self.output_layer(x)







In [413]:
class encoder_layer(nn.Module):
    # transformer layer
    # not masked, no cross attention, no memory, for encoder
    def __init__(self,
    embed_dim,
    mlp_dim,
    attention_heads,
    sequence_length
    ) -> None:
        super().__init__()

        self.mha = self_attention(
            dim_model=embed_dim,
            heads=attention_heads,
            sequence_length=sequence_length,
            masked = False,
            smeared_key = False
        )
        self.mlp = mlp(
            embed_dim= embed_dim,
            internal_dim=mlp_dim
        )

        self.gate1 = gru(
            dim = embed_dim
        )
        self.gate2 = gru(
            dim = embed_dim
        )

        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

        self.activation = nn.ReLU()
    
    def forward(self, x):
        y = self.ln1(x)
        y = self.mha(y)
        x = self.gate1(x,self.activation(y))
        y = self.ln1(x)
        y = self.mlp(y)
        x = self.gate2(x, self.activation(y))
        
        return x

class decoder_layer(nn.Module):
    # transformer layer
    # masked, cross attention, no memory, for decoder core layers
    def __init__(self,
    embed_dim,
    mlp_dim,
    attention_heads,
    sequence_lenth,
    smeared_key = False
    ) -> None:
        super().__init__()

        self.mha = self_attention(
            dim_model = embed_dim,
            heads = attention_heads,
            sequence_length = sequence_lenth,
            masked = True,
            smeared_key = smeared_key
        )

        self.cross_mha = cross_attention(
            embed_dimension = embed_dim,
            num_heads = attention_heads,
        )

        self.mlp = mlp(
            embed_dim = embed_dim,
            internal_dim = mlp_dim
        )

        self.gate1 = gru(
            dim = embed_dim
        )
        self.gate2 = gru(
            dim = embed_dim
        )
        self.gate3 = gru(
            dim = embed_dim
        )

        self.ln = nn.LayerNorm(embed_dim)

        self.activation = nn.ReLU()
        self.ln1 = nn.LayerNorm(embed_dim)
    
    def forward(self, x, enc):
        # masked self attention
        y = self.ln1(x)
        y = self.mha(y)
        x = self.gate1(x,self.activation(y))

        # cross attention
        # consider output sequence length and 
        y = self.ln1(x)
        enc = self.ln1(enc)
        y = self.cross_mha(enc, x)
        x = self.gate2(x, self.activation(y))

        # position-wise multi layper perceptron
        y = self.ln1(x)
        y = self.mlp(y)
        x = self.gate2(x, self.activation(y))
        
        return x

In [414]:
class encoder(nn.Module):
    def __init__(self,
    layers,
    model_dim,
    mlp_dim,
    heads,
    sequence_length
    ) -> None:
        super().__init__()
        
        # no inductive biases on encoder here
        self.block = nn.Sequential()
        for x in range(layers):
            self.block.append(encoder_layer(
                embed_dim = model_dim,
                mlp_dim = mlp_dim,
                attention_heads = heads,
                sequence_length = sequence_length
            ))
            
    def forward(self, x):
        return self.block(x)

class decoder(nn.Module):
    def __init__(self,
    layers,
    model_dim,
    mlp_dim,
    heads,
    sequence_length
    ) -> None:
        super().__init__()

        first_layer = decoder_layer(
            embed_dim = model_dim,
            mlp_dim = mlp_dim,
            attention_heads = heads,
            sequence_lenth = sequence_length,
            smeared_key = True
        )

        self.block = [first_layer]

        for x in range(layers - 1):
            self.block.append(
                decoder_layer(
                    embed_dim = model_dim,
                    mlp_dim= mlp_dim,
                    attention_heads= heads,
                    sequence_lenth = sequence_length
                )
            )
        
    def forward(self, x, y):
        # y is input from encoder
        for layer in self.block:
            x = layer(x,y)
        return x

        

In [415]:
class RLformer(nn.Module):

    def __init__(self,
    model_dim,
    mlp_dim,
    attn_heads,
    sequence_length,
    enc_layers,
    dec_layers,
    action_dim
    ) -> None:
        super().__init__()

        self.encoder = encoder(
            layers=enc_layers,
            model_dim=model_dim,
            mlp_dim=mlp_dim,
            heads=attn_heads,
            sequence_length = sequence_length
        )

        self.decoder = decoder(
            layers=dec_layers,
            model_dim= model_dim,
            mlp_dim=mlp_dim,
            heads=attn_heads,
            sequence_length=sequence_length,
        )

        self.actor = nn.Sequential(
            nn.Linear(model_dim, action_dim),
            nn.ReLU(),
            grad_skip_softmax() # To do neural replicator dynamics
        )

        self.critic = nn.Sequential(
            nn.Linear(model_dim, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, 1)
        )
        

    def forward(self, enc_input, dec_input):
        enc = self.encoder(enc_input)
        dec = self.decoder(dec_input, enc)
        policy = self.actor(dec)
        value = self.critic(dec)
        return policy, value

In [416]:
myrl = RLformer(
    model_dim = 10,
    mlp_dim = 20,
    attn_heads = 2,
    sequence_length = 15,
    enc_layers = 2,
    dec_layers = 2,
    action_dim = 4 
)

In [417]:
enc_input = torch.rand((15,8,10))


In [418]:
dec_input = torch.rand((15,8,10))

In [419]:
dec_input.size()

torch.Size([15, 8, 10])

In [421]:
policy, value = myrl(enc_input, dec_input)


  return self.sm(x)


In [422]:
policy.size()

torch.Size([15, 8, 4])

In [423]:
value.size()

torch.Size([15, 8, 1])