In [1]:
#attention func, Attentionhead and Mutilhead attention
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

In [2]:
def scaled_dot_attention(Q,K,V,mask=None):
    #Query key dot product 
    q_k = torch.bmm(Q,K.mT)
    #d_k for scaling
    d_k = torch.tensor(K.size(-1),dtype=torch.float32)
    scaled_qk = q_k/torch.sqrt(d_k)
    #apply mask before softmax
    if mask is not None:
        scaled_qk = scaled_qk.masked_fill(mask==0,-1e9)
    #softmax
    soft_q_k = F.softmax(scaled_qk,dim=1)
    #multiply weights with values
    w_v = torch.bmm(soft_q_k,V)
    return w_v

In [3]:
class AttentionHead(nn.Module):
    def __init__(self,embedding_dim,head_dim):
        super(AttentionHead,self).__init__()
        self.embedding_dim = embedding_dim
        self.head_dim = head_dim
        self.Q = nn.Linear(embedding_dim,head_dim)
        self.K = nn.Linear(embedding_dim,head_dim)
        self.V = nn.Linear(embedding_dim,head_dim)
    def forward(self,x,mask=None):
        q = self.Q(x)
        k = self.K(x)
        v = self.V(x)
        attention = scaled_dot_attention(q,k,v,mask)
        return attention 

In [4]:
class MaskedMultiHeadAttention(nn.Module):
    def __init__(self,embedding_dim,num_heads):
        super(MaskedMultiHeadAttention,self).__init__()
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim
        self.head_dim = embedding_dim//num_heads
        self.heads = nn.ModuleList([AttentionHead(self.embedding_dim,self.head_dim) for _ in range(self.num_heads)])
        self.Wo = nn.Linear(embedding_dim,embedding_dim)
    def forward(self,x,mask=None):
        scores = []
        for head in self.heads:
            scores.append(head(x,mask))
        scores = torch.cat(scores,2)
        attention_representation = self.Wo(scores)
        return attention_representation

In [5]:
class CrossAttentionHead(nn.Module):
    def __init__(self,embedding_dim,head_dim):
        super(CrossAttentionHead,self).__init__()
        self.embedding_dim = embedding_dim
        self.head_dim = head_dim
        self.Q = nn.Linear(embedding_dim,head_dim)
        self.K = nn.Linear(embedding_dim,head_dim)
        self.V = nn.Linear(embedding_dim,head_dim)
    def forward(self,x_in,x_out,mask=None):
        q = self.Q(x_out)
        k = self.K(x_in)
        v = self.V(x_in)
        attention = scaled_dot_attention(q,k,v,mask)
        return attention 

In [6]:
class CrossAttention(nn.Module):
    def __init__(self,embedding_dim,num_heads):
        super(CrossAttention,self).__init__()
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim
        self.head_dim = embedding_dim//num_heads
        self.heads = nn.ModuleList([CrossAttentionHead(self.embedding_dim,self.head_dim) for _ in range(self.num_heads)])
        self.Wo = nn.Linear(embedding_dim,embedding_dim)
    def forward(self,x_in,x_out,mask=None):
        scores = []
        for head in self.heads:
            scores.append(head(x_in,x_out,mask))
        scores = torch.cat(scores,2)
        attention_representation = self.Wo(scores)
        return attention_representation

In [7]:
class FeedForwardNetwork(nn.Module):
    def __init__(self,embedding_dim,p=0.3):
        super(FeedForwardNetwork,self).__init__()
        self.embedding_dim = embedding_dim
        self.d_ff = embedding_dim*4
        self.dropout = nn.Dropout(p)
        self.linear1 = nn.Linear(embedding_dim,self.d_ff)
        self.linear2 = nn.Linear(self.d_ff,embedding_dim)
    def forward(self,x):
        x = self.linear1(x)
        x = F.gelu(self.dropout(x))
        x = self.linear2(x)
        return x       

### Embedding

In [8]:
#using learnable positional embedding
class PosEmbedding(nn.Module):
    def __init__(self,vocab_size,max_position,hidden_dim,p=0.2):
        super(PosEmbedding,self).__init__()
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size,self.hidden_dim)
        self.pos_embedding = nn.Embedding(max_position,self.hidden_dim)
        self.layernorm = nn.LayerNorm(hidden_dim)
    def forward(self,input_ids):
        seq_len = len(input_ids)
        positions = torch.arange(0, seq_len, dtype=torch.long).unsqueeze(0)
        pos_emb = self.pos_embedding(positions)
        token_emb = self.embedding(input_ids)
        emb = self.layernorm(token_emb+pos_emb)
        return emb   

### Decoder

In [15]:
#using post layer norm
class Decoder(nn.Module):
    def __init__(self,hidden_dim,num_heads):
        super(Decoder,self).__init__()
        self.hidden_dim = hidden_dim
        self.feedforward = FeedForwardNetwork(hidden_dim)
        self.maskedmultihead = MaskedMultiHeadAttention(hidden_dim,num_heads)
        self.multiheadcross = CrossAttention(hidden_dim,num_heads)
        self.layernorm_maskedattention = nn.LayerNorm(hidden_dim)
        self.layernorm_crossattention = nn.LayerNorm(hidden_dim)
        self.layernorm_feedforward = nn.LayerNorm(hidden_dim)
    def forward(self,emb_in,emb_out):
        att_emb = self.maskedmultihead(emb_out)
        emb_out = self.layernorm_maskedattention(att_emb + emb_out)
        cross_emb = self.multiheadcross(emb_in,emb_out)
        emb = self.layernorm_crossattention(cross_emb+emb_out)
        emb = self.layernorm_feedforward(emb + self.feedforward(emb))
        return emb

In [16]:
decoder = Decoder(10,2),

In [17]:
def process_text(text_in,text_out):
    text_in = '[CLS]' + text_in + '[SEP]'
    text_out = '[CLS]' + text_out + '[SEP]'
    input_ids = [vocab.get(token,vocab['[UNK]']) for token in text_in.lower().split()]
    out_ids = [vocab.get(token,vocab['[UNK]']) for token in text_out.lower().split()]
    emb_in = embedding_layer(torch.tensor(input_ids))
    emb_out = embedding_layer(torch.tensor(out_ids))
    final = decoder(emb_in,emb_out)
    return final

In [18]:
vocab = {'[MASK]':0,'[PAD]':1,'[SEP]':2,'[CLS]':3,'hello':4,'how':5,'are':6,'you':7,'adi':8,'i':9,'am':10,'[UNK]':11,',':12,'.':13,'?':14}
vocab_size = len(vocab)
hidden_dim = 10
max_pos = 10
embedding_layer = PosEmbedding(vocab_size,max_pos,hidden_dim)
decoder = Decoder(10,2)
text = "Hello how are you ? I Am Adi"
text2 = "are you adi?"
text_embeddings = process_text(text,text2)

In [19]:
text_embeddings

tensor([[[-1.2091,  1.9794,  0.7957, -0.6276,  0.5124,  1.0756, -0.0389,
          -0.8240, -0.5982, -1.0654],
         [ 1.8392, -1.0725,  0.1597,  0.3007, -1.6764,  1.2864,  0.2637,
          -0.2638,  0.0266, -0.8636],
         [-1.9224,  1.4546,  0.7938,  0.7233,  0.7088, -0.2325, -0.2617,
           0.3881, -0.1569, -1.4950]]], grad_fn=<NativeLayerNormBackward0>)

In [23]:
torch.iinfo(torch.int64)

iinfo(min=-9.22337e+18, max=9.22337e+18, dtype=int64)

In [50]:
torch.finfo(torch.float)

finfo(resolution=0.01, min=-3.38953e+38, max=3.38953e+38, eps=0.0078125, smallest_normal=1.17549e-38, tiny=1.17549e-38, dtype=bfloat16)

In [70]:
x = torch.tensor([52131437818381],dtype=torch.float64)
print(f"{float(x):.15f}")
y = torch.tensor(x,dtype=torch.bfloat16)
print(f"{float(y):.15f}")

52131437818381.000000000000000
52226802319360.000000000000000


  y = torch.tensor(x,dtype=torch.bfloat16)


In [74]:
z = x-y
print(f"Quantization error: {float(z):.15f}")

Quantization error: -95364500979.000000000000000


In [76]:
def get_tensor_size(tensor):
    bytes = tensor.element_size() * tensor.numel()
    kb = bytes / 1024
    mb = kb / 1024
    return f"Size: {bytes} bytes, {kb:.2f} KB, {mb:.2f} MB"

In [78]:
get_tensor_size(y)

'Size: 2 bytes, 0.00 KB, 0.00 MB'

In [82]:
print(f"Size : {x.element_size() * x.numel()} bytes")

Size : 8 bytes


In [84]:
a = torch.tensor([3],dtype=torch.int8)
print(f"Size : {a.element_size() * a.numel()} bytes")

Size : 1 bytes


In [85]:
a = torch.tensor([3],dtype=torch.int16)
print(f"Size : {a.element_size() * a.numel()} bytes")

Size : 2 bytes


In [86]:
a = torch.tensor([3],dtype=torch.int32)
print(f"Size : {a.element_size() * a.numel()} bytes")

Size : 4 bytes


In [87]:
a = torch.tensor([3],dtype=torch.int64)
print(f"Size : {a.element_size() * a.numel()} bytes")

Size : 8 bytes


In [107]:
tensor_fp32 = torch.randn(10000,dtype=torch.float32)

In [108]:
tensor_fp32[:4]

tensor([ 0.2567,  1.2481,  1.0715, -1.3943])

In [109]:
tensor_fp16 = tensor_fp32.to(dtype=torch.float16)

In [110]:
tensor_fp16[:4]

tensor([ 0.2568,  1.2480,  1.0713, -1.3945], dtype=torch.float16)

In [111]:
mm_32 = torch.dot(tensor_fp32,tensor_fp32)

In [112]:
mm_16 = torch.dot(tensor_fp16,tensor_fp16)

In [113]:
mm_32

tensor(10070.8564)

In [114]:
mm_16

tensor(10072., dtype=torch.float16)

In [115]:
torch.abs(mm_32-mm_16)

tensor(1.1436)

In [116]:
mm_32.element_size()

4

In [117]:
mm_16.element_size()

2

In [129]:
decoder

Decoder(
  (feedforward): FeedForwardNetwork(
    (dropout): Dropout(p=0.3, inplace=False)
    (linear1): Linear(in_features=10, out_features=40, bias=True)
    (linear2): Linear(in_features=40, out_features=10, bias=True)
  )
  (maskedmultihead): MaskedMultiHeadAttention(
    (heads): ModuleList(
      (0-1): 2 x AttentionHead(
        (Q): Linear(in_features=10, out_features=5, bias=True)
        (K): Linear(in_features=10, out_features=5, bias=True)
        (V): Linear(in_features=10, out_features=5, bias=True)
      )
    )
    (Wo): Linear(in_features=10, out_features=10, bias=True)
  )
  (multiheadcross): CrossAttention(
    (heads): ModuleList(
      (0-1): 2 x CrossAttentionHead(
        (Q): Linear(in_features=10, out_features=5, bias=True)
        (K): Linear(in_features=10, out_features=5, bias=True)
        (V): Linear(in_features=10, out_features=5, bias=True)
      )
    )
    (Wo): Linear(in_features=10, out_features=10, bias=True)
  )
  (layernorm_maskedattention): Laye

In [121]:
for name, param in decoder.named_parameters():
    print(f"Name : {name} Param : {param.dtype}")

Name : feedforward.linear1.weight Param : torch.float32
Name : feedforward.linear1.bias Param : torch.float32
Name : feedforward.linear2.weight Param : torch.float32
Name : feedforward.linear2.bias Param : torch.float32
Name : maskedmultihead.heads.0.Q.weight Param : torch.float32
Name : maskedmultihead.heads.0.Q.bias Param : torch.float32
Name : maskedmultihead.heads.0.K.weight Param : torch.float32
Name : maskedmultihead.heads.0.K.bias Param : torch.float32
Name : maskedmultihead.heads.0.V.weight Param : torch.float32
Name : maskedmultihead.heads.0.V.bias Param : torch.float32
Name : maskedmultihead.heads.1.Q.weight Param : torch.float32
Name : maskedmultihead.heads.1.Q.bias Param : torch.float32
Name : maskedmultihead.heads.1.K.weight Param : torch.float32
Name : maskedmultihead.heads.1.K.bias Param : torch.float32
Name : maskedmultihead.heads.1.V.weight Param : torch.float32
Name : maskedmultihead.heads.1.V.bias Param : torch.float32
Name : maskedmultihead.Wo.weight Param : torch.f

In [126]:
decoder_64 = Decoder(10,2).double()

In [127]:
for name, param in decoder_64.named_parameters():
    print(f"Name : {name} Param : {param.dtype}")

Name : feedforward.linear1.weight Param : torch.float64
Name : feedforward.linear1.bias Param : torch.float64
Name : feedforward.linear2.weight Param : torch.float64
Name : feedforward.linear2.bias Param : torch.float64
Name : maskedmultihead.heads.0.Q.weight Param : torch.float64
Name : maskedmultihead.heads.0.Q.bias Param : torch.float64
Name : maskedmultihead.heads.0.K.weight Param : torch.float64
Name : maskedmultihead.heads.0.K.bias Param : torch.float64
Name : maskedmultihead.heads.0.V.weight Param : torch.float64
Name : maskedmultihead.heads.0.V.bias Param : torch.float64
Name : maskedmultihead.heads.1.Q.weight Param : torch.float64
Name : maskedmultihead.heads.1.Q.bias Param : torch.float64
Name : maskedmultihead.heads.1.K.weight Param : torch.float64
Name : maskedmultihead.heads.1.K.bias Param : torch.float64
Name : maskedmultihead.heads.1.V.weight Param : torch.float64
Name : maskedmultihead.heads.1.V.bias Param : torch.float64
Name : maskedmultihead.Wo.weight Param : torch.f

In [128]:
decoder_16 = Decoder(10,2).half()
for name, param in decoder_16.named_parameters():
    print(f"Name : {name} Param : {param.dtype}")

Name : feedforward.linear1.weight Param : torch.float16
Name : feedforward.linear1.bias Param : torch.float16
Name : feedforward.linear2.weight Param : torch.float16
Name : feedforward.linear2.bias Param : torch.float16
Name : maskedmultihead.heads.0.Q.weight Param : torch.float16
Name : maskedmultihead.heads.0.Q.bias Param : torch.float16
Name : maskedmultihead.heads.0.K.weight Param : torch.float16
Name : maskedmultihead.heads.0.K.bias Param : torch.float16
Name : maskedmultihead.heads.0.V.weight Param : torch.float16
Name : maskedmultihead.heads.0.V.bias Param : torch.float16
Name : maskedmultihead.heads.1.Q.weight Param : torch.float16
Name : maskedmultihead.heads.1.Q.bias Param : torch.float16
Name : maskedmultihead.heads.1.K.weight Param : torch.float16
Name : maskedmultihead.heads.1.K.bias Param : torch.float16
Name : maskedmultihead.heads.1.V.weight Param : torch.float16
Name : maskedmultihead.heads.1.V.bias Param : torch.float16
Name : maskedmultihead.Wo.weight Param : torch.f

In [130]:
decoder_16

Decoder(
  (feedforward): FeedForwardNetwork(
    (dropout): Dropout(p=0.3, inplace=False)
    (linear1): Linear(in_features=10, out_features=40, bias=True)
    (linear2): Linear(in_features=40, out_features=10, bias=True)
  )
  (maskedmultihead): MaskedMultiHeadAttention(
    (heads): ModuleList(
      (0-1): 2 x AttentionHead(
        (Q): Linear(in_features=10, out_features=5, bias=True)
        (K): Linear(in_features=10, out_features=5, bias=True)
        (V): Linear(in_features=10, out_features=5, bias=True)
      )
    )
    (Wo): Linear(in_features=10, out_features=10, bias=True)
  )
  (multiheadcross): CrossAttention(
    (heads): ModuleList(
      (0-1): 2 x CrossAttentionHead(
        (Q): Linear(in_features=10, out_features=5, bias=True)
        (K): Linear(in_features=10, out_features=5, bias=True)
        (V): Linear(in_features=10, out_features=5, bias=True)
      )
    )
    (Wo): Linear(in_features=10, out_features=10, bias=True)
  )
  (layernorm_maskedattention): Laye

In [133]:
decoder_bf16 = Decoder(10,2).bfloat16()
for name, param in decoder_bf16.named_parameters():
    print(f"Name : {name} Param : {param.dtype}")

Name : feedforward.linear1.weight Param : torch.bfloat16
Name : feedforward.linear1.bias Param : torch.bfloat16
Name : feedforward.linear2.weight Param : torch.bfloat16
Name : feedforward.linear2.bias Param : torch.bfloat16
Name : maskedmultihead.heads.0.Q.weight Param : torch.bfloat16
Name : maskedmultihead.heads.0.Q.bias Param : torch.bfloat16
Name : maskedmultihead.heads.0.K.weight Param : torch.bfloat16
Name : maskedmultihead.heads.0.K.bias Param : torch.bfloat16
Name : maskedmultihead.heads.0.V.weight Param : torch.bfloat16
Name : maskedmultihead.heads.0.V.bias Param : torch.bfloat16
Name : maskedmultihead.heads.1.Q.weight Param : torch.bfloat16
Name : maskedmultihead.heads.1.Q.bias Param : torch.bfloat16
Name : maskedmultihead.heads.1.K.weight Param : torch.bfloat16
Name : maskedmultihead.heads.1.K.bias Param : torch.bfloat16
Name : maskedmultihead.heads.1.V.weight Param : torch.bfloat16
Name : maskedmultihead.heads.1.V.bias Param : torch.bfloat16
Name : maskedmultihead.Wo.weight

### diff in time

In [134]:
def process_text(text_in,text_out):
    text_in = '[CLS]' + text_in + '[SEP]'
    text_out = '[CLS]' + text_out + '[SEP]'
    input_ids = [vocab.get(token,vocab['[UNK]']) for token in text_in.lower().split()]
    out_ids = [vocab.get(token,vocab['[UNK]']) for token in text_out.lower().split()]
    emb_in = embedding_layer(torch.tensor(input_ids))
    emb_out = embedding_layer(torch.tensor(out_ids))
    return emb_in,emb_out

In [161]:
emb_in,emb_out = process_text(text,text2)

In [163]:
total = 0
for i in range(1000):
    start = time.time()
    out = decoder(emb_in,emb_out)
    end = time.time()
    total = end-start
print(f"{total/1000:.15f}")

0.000000169038773


In [170]:
emb_in = emb_in.to(torch.float64)
emb_out = emb_out.to(torch.float64)

In [171]:
total = 0
for i in range(1000):
    start = time.time()
    out = decoder_64(emb_in,emb_out)
    end = time.time()
    total = end-start
print(f"{total/1000:.15f}")

0.000000184059143


In [169]:
0.000000182151794 - 0.000000169038773

1.311302099999999e-08