In [None]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

In [None]:
sequence_length = 4
batch_size = 1
input_dim = 512
d_model = 512
x = torch.randn( (batch_size, sequence_length, input_dim))

In [None]:
x.size()

In [None]:
 qkv_layer = nn.Linear(input_dim, 3 * d_model)

In [None]:
qkv = qkv_layer(x)

In [None]:
qkv.shape

In [None]:
import matplotlib.pyplot as plt
y_val = torch.histc(qkv, bins=200, min=-1,max=3)
x_val = np.arange(-1,1,0.01) * 3
plt.bar(x_val,y_val,align='center', color=['forestgreen'])
plt.title('qkv distribution')

In [None]:
num_heads = 8
head_dim = d_model // num_heads
qkv = qkv.reshape(batch_size,sequence_length, num_heads, 3* head_dim)

In [None]:
qkv.shape

In [None]:
qkv = qkv.permute(0, 2, 1, 3) 
qkv.shape

In [None]:
q, k, v = qkv.chunk(3, dim=-1)
q.shape, k.shape, v.shape

In [None]:
d_k = q.size()[-1]
scaled = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
scaled.shape

In [None]:
k.T.shape

In [None]:
y = torch.randn(2, 3)
torch.transpose(y, 0, 1)

In [None]:
torch.transpose(y, 1, 0)
y

In [None]:
mask = torch.full(scaled.size() , float('-inf'))
mask = torch.triu(mask, diagonal=1)
mask[0][1]


In [None]:
(scaled + mask)[0][0]
scaled += mask
np.exp(0.5596) / (np.exp(0.5596) + np.exp(0.0404))
attention = F.softmax(scaled, dim=-1)
values = torch.matmul(attention, v)

In [None]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

In [None]:
values, attention = scaled_dot_product(q, k, v, mask=mask)

In [None]:
values = values.reshape(batch_size, sequence_length, num_heads * head_dim)
values.size()

In [None]:
linear_layer = nn.Linear(d_model, d_model)
out = linear_layer(values)
out.shape

In [None]:
out

In [None]:

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(input_dim , 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        batch_size, sequence_length, input_dim = x.size()
        print(f"x.size(): {x.size()}")
        qkv = self.qkv_layer(x)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.permute(0, 2, 1, 3)
        print(f"qkv.size(): {qkv.size()}")
        q, k, v = qkv.chunk(3, dim=-1)
        print(f"q size: {q.size()}, k size: {k.size()}, v size: {v.size()}, ")
        values, attention = scaled_dot_product(q, k, v, mask)
        print(f"values.size(): {values.size()}, attention.size:{ attention.size()} ")
        values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        print(f"values.size(): {values.size()}")
        out = self.linear_layer(values)
        print(f"out.size(): {out.size()}")
        return out


In [None]:
input_dim = 1024
d_model = 512
num_heads = 8

batch_size = 30
sequence_length = 5
x = torch.randn( (batch_size, sequence_length, input_dim) )

model = MultiheadAttention(input_dim, d_model, num_heads)
out = model.forward(x)