<a href="https://colab.research.google.com/github/yashika-git/Deep_Learning/blob/main/Transformers_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn


class SelfAttention(nn.Module):

  def __init__(self, embed_size, heads):
    super(SelfAttention, self).__init__()
    self.embed_size = embed_size
    self.heads = heads
    self.head_dim = embed_size/heads

    assert(
        self.head_dim*heads == embed_size
    ) # Embedding size needs to be divisible by number of heads

    self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.fc_out = nn.Linear(heads*self.head_dim, embed_size)


  def forward(self, values, keys, query, mask):
    # getting the number of training examples
    N = query.shape[0]

    value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]  # source senetence length or target sentece length (going to vary)

    # split embeddings into self.heads pieces
    values = values.reshape(N, value_len, self.heads, self.head_dim)
    keys = keys.reshape(N, key_len, self.heads, self.head_dim)  
    query = query.reshape(N, query_len, self.heads, self.head_dim)

    values = self.values(values)
    keys = self.keys(keys)
    queries = self.queries(query)

    energy = torch.einsum('nqhd, nkhd -> nhqk', [queries, keys]) 
    # energy: (N, heads, query_len, key_len)

    # Mask padded indices so their weights become 0
    if mask is not None:
      energy = energy.masked_fill(mask == 0, float('-1e20'))
    
    # Normalizing energy values so that they sum to 1. Also, dividing by scaling factor for better stability
    attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
    # attention: (N, heads, query_len, key_len)

    out = torch.einsum('nhql, nlhd -> nqhd', [attention, values]).reshape(
        N, query_len, self.heads * self.head_dim
    )
    # attention shape: (N, heads, query_len, key_len)
    # values shape: (N, value_len, heads, heads_dim)
    # out after matrix multiply: (N, query_len, heads, head_dim) then we reshape and flatten the last 2 dimensions

    out = self.fc_out(out)
    # final output shape (N, query_len, embed_size)
    # embed_size = heads*heads_dim

    return out



class TransformerBlock(nn.Module):
  def __init__(self, embed_size, heads, dropout, forward_expansion):
    super(TransformerBlock, self).__init__()
    self.attention = SelfAttention(embed_size, heads)

    # statistics of NLP data across the batch dimension exhibit large fluctuations throughout training. 
    # This results in instability, if BN is naively implemented. 
    # Also, it's difficult to parallelize batchnorm. Hence, layernorm is preferred in transformers and NLP.

    self.norm1 = nn.LayerNorm(embed_size)
    self.norm2 = nn.LayerNorm(embed_size)

    self.feed_forward = nn.Sequential(
        nn.Linear(embed_size, forward_expansion * embed_size),
        nn.ReLU(),
        nn.Linear(forward_expansion * embed_size, embed_size)
    )

    self.dropout = nn.Dropout(dropout)


  def forward(self, value, key, query, mask):
    attention = self.attention(value, key, query, mask)

    # Adding skip connections, normalization and dropout
    x = self.dropout(self.norm1(attention + query))
    forward = self.feed_forward(x)
    out = self.dropout(self.norm2(forward + x))
    return out



class DecoderBlock(nn.Module):
  def __init__():
    pass

  def forward():
    pass  



class Decoder(nn.Module):
  def __init__():
    pass

  def forward():
    pass  


   
class Transformer(nn.Module):
  def __init__():
    pass

  def make_src_mask():
    pass

  def make_trg_maks():
    pass 

  def forward():
    pass  
    


if __name__ == 'main':
  pass
  



