In [1]:
import torch
from torch import nn

In [2]:
class grad_skip_softmax(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sm = nn.Softmax()
    
    def forward(self, x):
        return self.sm(x)

    def backward(self, grad):
        # skip gradient through the softmax on backward pass
        return grad
        
class gru(nn.Module):
    # 'gated-recurrent-unit type gating' as seen in GTrXL paper
    def __init__(self, dim, b_g = 1) -> None:
        super().__init__()

        self.w_r = nn.Linear(dim, dim, bias = False)
        self.u_r = nn.Linear(dim, dim, bias = False)

        self.w_z = nn.Linear(dim, dim, bias = True)
        self.u_z = nn.Linear(dim, dim, bias = True)
        self.b_g = b_g # this is used to hack initial bias of the above to be below zero, such that gate is initialized close to identity
        self.w_g = nn.Linear(dim, dim, bias = False)
        self.u_g = nn.Linear(dim, dim, bias = False)

        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(self, x, y):
        r = self.sigmoid(self.w_r(y) + self.u_r(x))
        z = self.sigmoid(self.w_z(y) + self.u_z(x) - self.b_g) # when zero, gate passes identity of residual
        h_hat = self.tanh(self.w_g(y) + self.u_g(r * x))
        g = (1-z)*x + z * h_hat
        return g
        

class mlp(nn.Module):
    # 1d temporal convolution
    # no communication between tokens, uses same kernel for each token spot
    def __init__(self, embed_dim, internal_dim) -> None:
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(embed_dim, internal_dim),
            nn.ReLU(),
            nn.Linear(internal_dim, embed_dim)
        ) # no second relu at output of mlp

    def forward(self, input):
        return self.block(input)


class cross_attention(nn.Module):
    def __init__(self, embed_dimension, num_heads) -> None:
        super().__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dimension,
            num_heads=num_heads,
            )
    
    def forward(self, x, enc):

        return self.attention(x, enc, enc)[0]

class self_attention(nn.Module):
    def __init__(self, embed_dimension, num_heads) -> None:
        super().__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dimension,
            num_heads=num_heads,
            )
    
    def forward(self, x):

        return self.attention(x, x, x)[0]







In [44]:


class Smear_key(nn.Module):

    def __init__(self,
    sequence_length,
    heads
    ) -> None:
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1, heads, sequence_length - 1, 1))
        self.sigmoid = nn.Sigmoid()

    def forward(self, k):
        itrp = self.sigmoid(self.alpha)
        smear = k[:,:,1:,:]*itrp + k[:,:,:-1,:]*(1-itrp)
        return torch.cat([k[:,:, 0:1, :], smear], dim = 2)

class decoder_mha(nn.Module):
    #Masked smeared self attention
    def __init__(self, model_dim, sequence_length, heads) -> None:
        super().__init__()
        self.mask = torch.triu(torch.ones(sequence_length, sequence_length) * float('-inf'), diagonal=1) # make batch, heads, seq,seq
        self.model_dim = model_dim
        self.sequence_length = sequence_length
        self.heads = heads
        self.key_dim = model_dim // heads
        self.W_q = nn.Linear(model_dim, model_dim, bias=False)
        self.W_k = nn.Linear(model_dim, model_dim, bias=False)
        self.W_v = nn.Linear(model_dim, model_dim, bias=False)
        self.output = nn.Linear(model_dim, model_dim, bias=True)
        self.ln = nn.LayerNorm(model_dim)
        self.smear = Smear_key(sequence_length, heads)

    def forward(self,x):
        # batch, sequence, model_dim
        q = self.W_q(x).view(-1, self.sequence_length, self.heads, self.key_dim).transpose(1,2)
        k = self.W_k(x).view(-1, self.sequence_length, self.heads, self.key_dim).transpose(1,2)
        v = self.W_v(x).view(-1, self.sequence_length, self.heads, self.key_dim).transpose(1,2)
        k = self.smear(k)
        # batch, heads, sequence, dim // heads
        key_dim = k.shape[-1:][0]
        scores = q @ k.transpose(2,3) / key_dim**.5
        scores += self.mask
        attn = torch.softmax(scores, dim = 3)
        mha = attn @ v
        mha = mha.transpose(1, 2).contiguous().view(-1, self.sequence_length, self.model_dim)
        out = self.output(mha)
        # batch, sequence, model_dim
        return out





In [45]:
class encoder_layer(nn.Module):
    # transformer layer
    # not masked, no cross attention, no memory, for encoder
    def __init__(self,
    embed_dim,
    mlp_dim,
    attention_heads,
    sequence_length
    ) -> None:
        super().__init__()

        self.mha = self_attention(
            embed_dimension=embed_dim,
            num_heads=attention_heads
        )

        self.mlp = mlp(
            embed_dim= embed_dim,
            internal_dim=mlp_dim
        )

        self.gate1 = gru(
            dim = embed_dim
        )
        self.gate2 = gru(
            dim = embed_dim
        )

        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

        self.activation = nn.ReLU()
    
    def forward(self, x):
        y = self.ln1(x)
        y = self.mha(y)
        x = self.gate1(x,self.activation(y))
        y = self.ln1(x)
        y = self.mlp(y)
        x = self.gate2(x, self.activation(y))
        
        return x

class decoder_layer(nn.Module):
    # transformer layer
    # masked, cross attention, smeared key
    def __init__(self,
    embed_dim,
    mlp_dim,
    attention_heads,
    sequence_lenth
    ) -> None:
        super().__init__()

        self.mha = decoder_mha(
            model_dim=embed_dim,
            sequence_length=sequence_lenth,
            heads=attention_heads
        ) #smeared key masked self attention

        self.cross_mha = cross_attention(
            embed_dimension = embed_dim,
            num_heads = attention_heads,
        )

        self.mlp = mlp(
            embed_dim = embed_dim,
            internal_dim = mlp_dim
        )

        self.gate1 = gru(
            dim = embed_dim
        )
        self.gate2 = gru(
            dim = embed_dim
        )
        self.gate3 = gru(
            dim = embed_dim
        )

        self.ln = nn.LayerNorm(embed_dim)

        self.activation = nn.ReLU()
        self.ln1 = nn.LayerNorm(embed_dim)
    
    def forward(self, x, enc):
        # masked self attention, smeared key
        y = self.ln1(x)
        y = self.mha(y)
        x = self.gate1(x,self.activation(y))

        # cross attention
        # consider output sequence length and 
        y = self.ln1(x)
        enc = self.ln1(enc)
        y = self.cross_mha(enc, x)
        x = self.gate2(x, self.activation(y))

        # position-wise multi layper perceptron
        y = self.ln1(x)
        y = self.mlp(y)
        x = self.gate2(x, self.activation(y))
        
        return x

In [46]:
class encoder(nn.Module):
    def __init__(self,
    layers,
    model_dim,
    mlp_dim,
    heads,
    sequence_length
    ) -> None:
        super().__init__()
        
        # no inductive biases on encoder here
        self.block = nn.Sequential()
        for x in range(layers):
            self.block.append(encoder_layer(
                embed_dim = model_dim,
                mlp_dim = mlp_dim,
                attention_heads = heads,
                sequence_length = sequence_length
            ))
            
    def forward(self, x):
        return self.block(x)

class decoder(nn.Module):
    def __init__(self,
    layers,
    model_dim,
    mlp_dim,
    heads,
    sequence_length
    ) -> None:
        super().__init__()


        self.block = []

        for x in range(layers):
            self.block.append(
                decoder_layer(
                    embed_dim = model_dim,
                    mlp_dim= mlp_dim,
                    attention_heads= heads,
                    sequence_lenth = sequence_length
                )
            )
        
    def forward(self, x, y):
        # y is input from encoder
        for layer in self.block:
            x = layer(x,y)
            
        return x

        

In [47]:
class RLformer(nn.Module):

    def __init__(self,
    model_dim,
    mlp_dim,
    attn_heads,
    sequence_length,
    enc_layers,
    dec_layers,
    action_dim
    ) -> None:
        super().__init__()

        self.encoder = encoder(
            layers=enc_layers,
            model_dim=model_dim,
            mlp_dim=mlp_dim,
            heads=attn_heads,
            sequence_length = sequence_length
        )

        self.decoder = decoder(
            layers=dec_layers,
            model_dim= model_dim,
            mlp_dim=mlp_dim,
            heads=attn_heads,
            sequence_length=sequence_length,
        )

        self.actor = nn.Sequential(
            nn.Linear(model_dim, action_dim),
            nn.ReLU(),
            grad_skip_softmax() # To do neural replicator dynamics
        )

        self.critic = nn.Sequential(
            nn.Linear(model_dim, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, 1)
        )
        

    def forward(self, enc_input, dec_input):
        enc = self.encoder(enc_input)
        dec = self.decoder(dec_input, enc)
        policy = self.actor(dec)
        value = self.critic(dec)
        return policy, value

In [48]:
myrl = RLformer(
    model_dim = 10,
    mlp_dim = 20,
    attn_heads = 2,
    sequence_length = 15,
    enc_layers = 2,
    dec_layers = 2,
    action_dim = 4 
)

In [49]:
#inputs are batch, sequence, model_dim
enc_input = torch.rand((8,15,10))


In [50]:
dec_input = torch.rand((8,15,10))

In [51]:
dec_input.size()

torch.Size([8, 15, 10])

In [42]:
import time

In [53]:
%timeit policy, value = myrl(enc_input, dec_input)


  return self.sm(x)


4.97 ms ± 57.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
policy.size()
#batch,sequence,action space

torch.Size([8, 15, 4])

In [13]:
value.size()

torch.Size([8, 15, 1])

In [24]:
dec_input = torch.rand((8,5,10))
test_mha = decoder_mha(model_dim=10,sequence_length=5,heads=2)

In [25]:
y= test_mha(dec_input)
print(y)

tensor([[[-0.4197, -0.1550,  0.1928,  0.2340,  0.2418, -0.2448, -0.0458,
          -0.4462,  0.1150, -0.1603],
         [-0.4327, -0.1636,  0.1269,  0.3228,  0.1293, -0.1570, -0.0724,
          -0.3111,  0.1123, -0.1597],
         [-0.4306, -0.1590,  0.1207,  0.3382,  0.0930, -0.1776, -0.1141,
          -0.2811,  0.1077, -0.1790],
         [-0.3912, -0.1655,  0.1378,  0.3009,  0.0629, -0.2122, -0.1494,
          -0.2687,  0.0973, -0.1934],
         [-0.3691, -0.1809,  0.1349,  0.2907,  0.0484, -0.2170, -0.1633,
          -0.2540,  0.1040, -0.1880]],

        [[-0.2681, -0.1029,  0.0621,  0.2428,  0.2743, -0.2279,  0.0875,
          -0.4397,  0.2145,  0.0331],
         [-0.2520, -0.0825,  0.0711,  0.1987,  0.1775, -0.2456, -0.0125,
          -0.3249,  0.1228, -0.0285],
         [-0.2599, -0.1167,  0.1054,  0.2196,  0.1694, -0.2229, -0.0119,
          -0.3411,  0.1611, -0.0308],
         [-0.3217, -0.1924,  0.1413,  0.2453,  0.1684, -0.2055, -0.0476,
          -0.3723,  0.1739, -0.0962],