<a href="https://colab.research.google.com/github/yolitie/Deep-Learning/blob/main/Untitled4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#ATTENTION IS ALL YOU NEED
#We have an encoder on the left and on the right a decoder.
#Starting from the bottom we have some input, let´s say some source texte
#We are going to create some embedders. The input is going to be sent to a multi-head attention.
#Then is going to go to a normalization, then to a feed forward and then again to a normalization.
#Decoder Block, transformer block and has an additional masked multi-head and normalization.


In [None]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
  def __init__(self, embed_size, heads): #In how many parts we split is called heads.
    super(SelfAttention, self).__init__()
    self.embed_size=embed_size
    self.heads=heads
    self.head_dim=embed_size//heads #integer division, sometimes the integer division is not possible so we put an assert

    assert (self.head_dim*heads==embed_size), "Embed size needs to be divided by heads"

    self.values=nn.Linear(self.head_dim,self.head_dim, bias=False)
    self.keys= nn.Linear(self.head_dim, self.head_dim, bias= False)
    self.queries=nn.Linear(self.head_dim, self.head_dim, bias= False)
    self.fc_out=nn.Linear(heads*self.head_dim, embed_size)

  def forward(self,values,keys,query,mask):
    N=query.shape[0] #How many examples are we sending at the same time
    value_len, key_len, query_len= values.shape[1],keys.shape[1],query.shape[1]

    #Split embedding into self.heads pieces
    values=values.reshape(N, value_len,self.heads,self.head_dim)
    keys=keys.reshape(N,key_len,self.heads,self.head_dim)
    queries=query.reshape(N,key_len,self.heads,self.head_dim)

    energy= torch.einsum("nqhd,nkhd--> nhqk",[queries,keys])
    #queries shape: (N,query_len,heads,heads_dim)
    #keys shape: (N, key_len, heads,heads_dim)
    #energy shape: (N,heads, query_len, key_len) query_len is the target source sentence, and the 
#key_len is the is the source sentence. so for each word in our target how much should we pay attention to each word in our input in the source sentence.

    if mask is not None:
      energy=energy.masked_fill(mask==0,float("-1e20"))
    
    attention=torch.softmax(energy/(self.embed_size**(1/2)),dim=3)
    
    out= torch.einsum("nhql,nlhd-->nqhd",[attention,values]).reshape(N,query_len,self.heads*self.head_dim)
    #attention shape: (N,heads,query_len,key_len)
    #values shape: (N, value_len,heads, heads_dim)
    # after einsum(N, query_len, heads, head_dim)then flatten last two dimensions
    out=self.fc_out(out)
    return out

class TransformerBlock(nn.Module):
  def __init__(self,embed_size, heads, dropout, forward_expansion):
    super(TransformerBlock,self).__init__()
    self.attention=SelfAttention(embed_size,heads)
    self.norm1=nn.LayerNorm(embed_size)
    self.norm2= nn.LayerNorm(embed_size)

    self.feed_forward=nn.Sequential(nn.Linear(embed_size,forward_expansion*embed_size),nn.ReLU(),nn.Linear(forward_expansion*embed_size,embed_size))
    self.dropout=nn.Dropout(dropout)

  def forward(self,value,key,query,mask):
    attention=self.attention(value,key,query,mask)

    x=self.dropout(self,norm1(attention+query))
    forward= self.feed_forward(x)
    out=self.dropout(self.norm2(forward+x))
    return out
