In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [3]:
x = torch.randn(10, 128, 512)

In [4]:
Q = torch.randn(128, 128)/(128)**0.5
W = torch.randn(128, 128)/(128)**0.5
E = torch.randn(128, 128)/(128)**0.5
## the above guys are attention guys

In [5]:
x_q = (Q @ x)
x_w = (W @ x)
x_e = (E @ x)

In [6]:
x_q.shape

torch.Size([10, 128, 512])

In [7]:
x_q = x_q.view(-1, 4, 32, 512)
x_w = x_w.view(-1, 4, 32, 512)
x_e = x_e.view(-1, 4, 32, 512)

In [8]:
(x_e @ nn.Softmax(-2)((x_q.transpose(-2, -1) @ x_w))).view(-1, 128, 512);

In [123]:
class multi_head_attention(nn.Module): 
    def __init__(self, embedding_dim = 128, heads = 4, lag = 512, dropout = 0.2, causal = True):
        super().__init__()
        
        assert (embedding_dim/heads).is_integer(), f"embedding_dim/heads should be integer while yours {embedding_dim/heads} "
        
        self.embedding_dim = embedding_dim
        self.heads = heads 
        self.causal = causal
        self.W = lag
        ### Attention Part ###
        ### Frist Dropout layers
        self.dropout_Q = nn.Dropout(p = dropout)
        self.dropout_K = nn.Dropout(p = dropout)
        self.dropout_V = nn.Dropout(p = dropout)
        ### Weights ###
        self.L_Q = nn.Parameter(torch.randn(embedding_dim, embedding_dim)*(embedding_dim)**(-0.5))
        self.L_K = nn.Parameter(torch.randn(embedding_dim, embedding_dim)*(embedding_dim)**(-0.5))
        self.L_V = nn.Parameter(torch.randn(embedding_dim, embedding_dim)*(embedding_dim)**(-0.5))
        #self.dense = Linear(embedding_dim, embedding_dim, bias = True)
        ### --- End of weights --- ###
        if self.causal:
            self.causal_factor = nn.Parameter(torch.tril(-torch.inf*torch.ones(self.W,self.W), diagonal = -1))
            ## No gradient is required in the above layer....
            self.causal_factor.requires_grad = False
    
    def forward(self, x): #BxHxL -> BxHxL
        K, Q, V = x[0], x[1], x[2]
        K = self.dropout_K(self.L_K @ K)
        Q = self.dropout_Q(self.L_Q @ Q)
        V = self.dropout_V(self.L_V @ V)
        K_v = K.view(-1, self.heads, int(self.embedding_dim/self.heads), self.W)
        Q_v = Q.view(-1, self.heads, int(self.embedding_dim/self.heads), self.W)
        V_v = V.view(-1, self.heads, int(self.embedding_dim/self.heads), self.W)
        
        
        attention_scores = (Q_v.transpose(-2, -1) @ K_v)/self.embedding_dim**0.5
        if self.causal:
            attention_scores += self.causal_factor/self.embedding_dim**0.5
                
        scores = nn.Softmax(-2)(attention_scores)
        
        
        return (V_v @ scores).view(-1, self.embedding_dim, self.W)

In [128]:
x = multi_head_attention(8, heads = 2, lag = 8, causal = False)
x.train()

multi_head_attention(
  (dropout_Q): Dropout(p=0.2, inplace=False)
  (dropout_K): Dropout(p=0.2, inplace=False)
  (dropout_V): Dropout(p=0.2, inplace=False)
  (dense): Linear(in_features=8, out_features=8, bias=True)
)

In [129]:
x.to("cuda")

multi_head_attention(
  (dropout_Q): Dropout(p=0.2, inplace=False)
  (dropout_K): Dropout(p=0.2, inplace=False)
  (dropout_V): Dropout(p=0.2, inplace=False)
  (dense): Linear(in_features=8, out_features=8, bias=True)
)

In [134]:
t = torch.randn(1,8,8)
t = t.to("cuda")
x([t,t,t])

tensor([[[ 0.1765,  0.2635,  0.2213,  0.1936,  0.2184,  0.1087,  0.2295,
           0.1948],
         [-0.3597, -0.5333, -0.4265, -0.3297, -0.4097, -0.2847, -0.5016,
          -0.3772],
         [-0.3461, -0.2853, -0.1281, -0.0263, -0.1438,  0.2598, -0.2839,
          -0.2844],
         [-0.5749, -0.3938, -0.2322, -0.1116, -0.2060,  0.4003, -0.4749,
          -0.4386],
         [ 0.4892,  0.5546,  0.2800,  0.5147,  0.3258,  0.2714, -0.0199,
           0.3953],
         [ 0.2937,  0.3054, -0.1043,  0.4849,  0.2012, -0.0890, -0.6247,
           0.1998],
         [ 0.0181,  0.0333,  0.2761,  0.0595, -0.1567,  0.0548,  0.2570,
          -0.1348],
         [ 0.0167, -0.0581, -0.0066,  0.2816,  0.0370, -0.0778, -0.1228,
          -0.0498]]], device='cuda:0', grad_fn=<ViewBackward0>)