## Installing the packages

In [1]:
!pip3 install bertviz

Collecting bertviz
  Downloading bertviz-1.4.0-py3-none-any.whl (157 kB)
     ---------------------------------------- 0.0/157.6 kB ? eta -:--:--
     -------------------------------------- 157.6/157.6 kB 9.8 MB/s eta 0:00:00
Collecting boto3 (from bertviz)
  Obtaining dependency information for boto3 from https://files.pythonhosted.org/packages/63/e5/8fc4a69186cb15b0dba9c428da73233c89eb18ee03ce56f6bde205ea2006/boto3-1.28.62-py3-none-any.whl.metadata
  Downloading boto3-1.28.62-py3-none-any.whl.metadata (6.7 kB)
Collecting sentencepiece (from bertviz)
  Downloading sentencepiece-0.1.99-cp39-cp39-win_amd64.whl (977 kB)
     ---------------------------------------- 0.0/977.6 kB ? eta -:--:--
     ------------------------------------- 977.6/977.6 kB 21.0 MB/s eta 0:00:00
Collecting botocore<1.32.0,>=1.31.62 (from boto3->bertviz)
  Obtaining dependency information for botocore<1.32.0,>=1.31.62 from https://files.pythonhosted.org/packages/a8/3f/74138007b045447eac6141c8144efe8e1c9f377cf56c85

## Code

In [9]:
text="Times flies like an arrow"
model_ckpt="bert-base-uncased"

In [None]:
from transformers import AutoConfig

config=AutoConfig.from_pretrained(model_ckpt)

In [7]:
from transformers import AutoTokenizer

tokenizer=AutoTokenizer.from_pretrained(model_ckpt)

Downloading (…)okenizer_config.json: 100%|██████████| 28.0/28.0 [00:00<00:00, 7.00kB/s]
Downloading (…)solve/main/vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 1.12MB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 1.51MB/s]


In [14]:
import math
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(query, key, value):
    dim_k=key.size(-1)
    scores=torch.bmm(query, key.transpose(-2, -1))/math.sqrt(dim_k)
    weights=F.softmax(scores, dim=-1)
    return torch.bmm(weights, value)

In [35]:
from torch import nn

class AttentionHead(nn.Module):
    def __init__(self, dim_emb, dim_k, dim_v):
        super().__init__()
        self.q_linear=nn.Linear(dim_emb, dim_k)
        self.k_linear=nn.Linear(dim_emb, dim_k)
        self.v_linear=nn.Linear(dim_emb, dim_v)
        
    def forward(self, x):
        queries=self.q_linear(x)
        keys=self.k_linear(x)
        values=self.v_linear(x)
        att_out=scaled_dot_product_attention(queries, keys, values)
        return att_out

In [36]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, dim_emb, dim_k, dim_v):
        super().__init__()
        self.dim_emb=dim_emb
        self.num_heads=num_heads
        self.dim_k=dim_k
        self.dim_v=dim_v
        self.heads=nn.ModuleList()
        for _ in range(num_heads):
            self.heads.append(AttentionHead(dim_emb, dim_k, dim_v))
        self.linear=nn.Linear(num_heads*dim_v, dim_emb)
        
    def forward(self, x):
        att_outs=[head(x) for head in self.heads]
        att_out=torch.cat(att_outs, dim=-1)
        multi_att_out=self.linear(att_out)
        return multi_att_out

In [34]:
from torch import nn

inputs=tokenizer(text, return_tensors="pt", add_special_tokens=False)
token_emb=nn.Embedding(config.vocab_size, config.hidden_size)
token_embs=token_emb(inputs["input_ids"])
token_embs.shape

torch.Size([1, 5, 768])

In [37]:
multiHeadAttention=MultiHeadAttention(4, 768, 11, 3)
multiHeadAttention.forward(token_embs).shape

torch.Size([1, 5, 768])