# Implementing GAU Mean Pooled SBERT

## Implementing GAU

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class GatedAttentionUnit(nn.Module):
    def __init__(embed_dim=768, intermediate_dim=1536, attn_dim=128):
        super(GatedAttentionUnit, self).__init__()
        self.embed_dim = embed_dim
        self.intermediate_dim = intermediate_dim
        self.attn_dim = attn_dim
        self.dense_u = nn.Linear(in_features=self.embed_dim,
                                 out_features=self.intermediate_dim,
                                 bias=False)
        self.dense_v = nn.Linear(in_features=self.embed_dim,
                                 out_features=self.intermediate_dim,
                                 bias=False)
        self.attn_dense = nn.Linear(in_features=self.embed_dim,
                                    out_features=self.attn_dim,
                                    bias=False)
        self.gamma_q = nn.Parameter(nn.randn(self.attn_dim))
        self.beta_q = nn.Parameter(nn.randn(self.attn_dim))
        self.gamma_k= nn.Parameter(nn.randn(self.attn_dim))
        self.beta_k= nn.Parameter(nn.randn(self.attn_dim))
        
        
    def attention(self, x, v):
        z = self.attn_dense(x)
        q = torch.mul(x, self.gamma_q) + self.beta_q
        k = torch.mul(x, self.gamma_k) + self.beta_k
 b      
        

In [11]:
class OffsetScale(nn.Module):
    """Per dim scaling and offsets"""
    def __init__(self, input_dim, heads=1):
        super(OffsetScale, self).__init__()
        self.gamma = nn.Parameter(torch.ones(heads, input_dim))
        self. beta = nn.Parameter(torch.zeros(heads, input_dim))
        # Initialize scale parameter to standard normal values
        nn.init.normal_(self.gamma)
        
    def forward(self, x):
        out = torch.einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
        # Split into two values along number of heads
        return out.unbind(dim=-2)

In [13]:
class GatedAttentionUnit(nn.Module):
    """Gated Attention Unit Implementation from
    'Transformer Quality in Linear Time'.
    Code mostly adapted from
    https://github.com/lucidrains/FLASH-pytorch/blob/main/flash_pytorch/flash_pytorch.py
    with added comments for my understanding.
    
    REMARKS
    -------
    Relative position bias and masking has not been considered.
    This is based on our use-case. The implementation in the provided
    lin does have a more generalized setup.
    Names of certain variables have been altered for compliance with
    the naming convention in the original paper."""
    
    def __init__(
        self,
        embed_dim=768,
        attn_dim=128,
        expansion_factor=2,
        add_residual=True,
        norm=nn.LayerNorm,
        activation=nn.SiLU,
    ):
        super(GatedAttentionUnit, self).__init__()
        self.embed_dim = embed_dim
        self.hidden_dim = int(self.embed_dim * expansion_factor)
        self.attn_dim = attn_dim
        self.activation = activation
        
        self.norm = norm
        
        # The representations of U and V are both obtained through linear
        # transformations followed by an activation.
        # As a result, the values of U and V can be obtained
        # through a single multiplication and then split into two segments
        # by chunking.
        # To achieve this, a single weight matrix of twice the
        # hidden dimensionality can be used
        self.joined_UV = nn.Sequential(
            nn.Linear(in_features=self.embed_dim,
                      out_features=self.hidden_dim * 2,
                      bias=True),
            self.activation())
        
        # Calculate the Z matrix to be used for getting the attention
        # matrix A
        self.calc_z = nn.Sequential(
            nn.Linear(in_features=self.embed_dim,
                      out_features=self.attn_dim,
                      bias=True),
            self.activation())
        
        # Matrix A is generated with per-dim scaling and offsets
        # applied to Z to generate Q and K.
        # Instead of doing the operation twice, it can be performed
        # once but with an extra dimension which can then be split
        # to give Q and K.
        self.offset_scale = OffsetScale(input_dim=self.attn_dim,
                                        heads=2)
        
        self.to_output = nn.Linear(in_features=self.hidden_dim,
                                   out_features=self.embed_dim,
                                   bias=True)
        
        self.add_residual = add_residual
        
    def forward(self, x):
        
        seq_len = x.shape[-2] # Shape is (batch_size, N, embed_dim)
        
        # Normalization
        normed_x = self.norm(x)
        
        # Get U and V by splitting into two chunks along
        # last dimension
        # U, V -> (batch_size, N, hidden_dim)
        u, v = self.joined_UV(x).chunk(2, dim=-1)
        
        # Get Z for A calculation
        # Z -> (batch_size, N, attn_size)
        z = self.calc_z(normed_x)
        
        # Get Q and K for A calculation
        # Q, K -> (batch_size, N, attn_size)
        q, k = self.offset_scale(z)
        
        # Get QK
        # QK -> (batch_size, N, N)
        qk = torch.einsum('b i d, b j d -> b i j', q, k) * 1./seq_len
        
        # Omit relative position bias and get A
        # A -> (batch_size, N, N)
        a = F.relu(qk) ** 2
        
        # O -> (batch_size, N, hidden_dim)
        out = torch.einsum('b i j, b j d -> b i d', a, v)
        out = u * out
        
        # O -> (batch_size, N, embed_dim)
        out = self.to_output(out)
        
        if self.add_residual:
            out = out + x
        
        return out

In [19]:
x = torch.randn(1, 512, 10)

In [20]:
gau = GatedAttentionUnit(
    embed_dim=10,
    attn_dim=128,
    expansion_factor=2,
    add_residual=True,
    norm=nn.LayerNorm(10),
    activation=nn.SiLU)

In [21]:
out = gau(x)

In [22]:
out.shape

torch.Size([1, 512, 10])

In [24]:
gau

GatedAttentionUnit(
  (norm): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
  (joined_UV): Sequential(
    (0): Linear(in_features=10, out_features=40, bias=True)
    (1): SiLU()
  )
  (calc_z): Sequential(
    (0): Linear(in_features=10, out_features=128, bias=True)
    (1): SiLU()
  )
  (offset_scale): OffsetScale()
  (to_output): Linear(in_features=20, out_features=10, bias=True)
)

In [10]:
class ScaledSinEmbedding(nn.Module):
    """Scaled Sinusoidal Embeddings"""
    
    def __init__(self, dim=768):
        super().__init__()
        self.dim = dim
        self.scale = nn.Parameter(torch.ones((1,)))
        self.half_d = self.dim // 2
        self.inv_freq = 1./ (10000 ** torch.arange(self.half_d).float() / float(self.half_d))
        torch.register_buffer('inv_freq', self.inv_freq)
        
    def forward(self, x):
        n = x.shape[1]
        t = torch.arange(n).type_as(self.inv_freq)
        sinu = torch.einsum('s,d -> sd', t, self.inv_freq)
        scaledsin = torch.concat([sinu.sin(), sinu.cos()],axis = -1)
        return scaledsin * self.scale
        

In [11]:
sinu = ScaledSinEmbedding(dim=768)

In [13]:
sin_embeds = sinu(torch.randn((1, 72, 768)))

In [14]:
sin_embeds.shape

torch.Size([72, 768])

In [3]:
tensors = torch.randn((10, 60, 40, 768))

In [4]:
dims = list(zip(*map(lambda t: list(t.shape), tensors)))

In [5]:
dims

[(60, 60, 60, 60, 60, 60, 60, 60, 60, 60),
 (40, 40, 40, 40, 40, 40, 40, 40, 40, 40),
 (768, 768, 768, 768, 768, 768, 768, 768, 768, 768)]

In [9]:
list(zip(list(map(lambda t: list(t.shape), tensors))))

[([60, 40, 768],),
 ([60, 40, 768],),
 ([60, 40, 768],),
 ([60, 40, 768],),
 ([60, 40, 768],),
 ([60, 40, 768],),
 ([60, 40, 768],),
 ([60, 40, 768],),
 ([60, 40, 768],),
 ([60, 40, 768],)]

In [11]:
dim = -1
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
shape_len = list(shape_lens)[0]

dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))

expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]

In [12]:
num_tensors

10

In [13]:
shape_lens

{3}

In [15]:
shape_len

3

In [16]:
dim

2

In [17]:
expandable_dims

[(0, (60, 60, 60, 60, 60, 60, 60, 60, 60, 60)),
 (1, (40, 40, 40, 40, 40, 40, 40, 40, 40, 40))]

In [18]:
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]

[True, True]