In [15]:
import torch
import torch.nn as nn
import math

In [94]:
# Self Attention!!
class SelfAttention(nn.Module):
  def __init__(self, embed_size, heads):
    super(SelfAttention, self).__init__()
    self.embed_size = embed_size
    self.heads = heads
    self.head_dim = embed_size // heads

    assert (self.head_dim * heads == embed_size), "Embed size needs to be divisble by heads"

    self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.fc_out = nn.Linear(heads*self.head_dim, embed_size)

  def forward(self, values, keys, query, mask):
    N = query.shape[0]
    value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

    # Split embedding into self.head pieces
    values = values.reshape(N, value_len, self.heads, self.head_dim)
    keys = keys.reshape(N, key_len, self.heads, self.head_dim)
    queries = query.reshape(N, query_len, self.heads, self.head_dim)

    values = self.values(values)
    keys = self.keys(keys)
    queries = self.queries(queries)

    energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)

    # queries shape: (N, query_len, heads, head_dim)
    # keys shape: (N, key_len, heads, heads_dim)
    # energy shape: (N, heads, query_len, key_len)

    if mask is not None:
      energy = energy.masked_fill(mask == 0, float("-1e20"))

    attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

    out = torch.einsum("nhql, nlhd->nqhd", attention, values).reshape(
        N, query_len, self.heads*self.head_dim)

    # attention shape: (N, heads, query_len, key_len)
    # values shape: (N, value_len, heads, heads_dim)
    # after sinsum (N, query_len, heads, heads_dim) then flatten last two dimensions

    out = self.fc_out(out)
    return out


In [95]:
class PositionEncoding(nn.Module):
  def __init__(self, embed_size: int, seq_len: int, dropout: float) -> None:
    super().__init__()
    self.d_model = embed_size
    self.seq_len = seq_len
    self.dropout = nn.Dropout(dropout)

    # Create a matrix of shape (Seq_len, d_model)
    pe = torch.zeros(seq_len, embed_size)
    # Create a vector of shape
    position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
    div_term = torch.exp(torch.arange(0,embed_size, 2).float() * (-math.log(10000.0) / embed_size))
    # Apply the sin to even positions
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    pe = pe.unsqueeze(0) # (1, seq_len, d_model)
    self.register_buffer("pe", pe)

  def forward(self, x):
    x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
    return self.dropout(x)

In [96]:
class DecoderModel(nn.Module):
  def __init__(self, vocab_size, embed_size: int, seq_len: int, heads: int,  dropout=0.1,):
    super().__init__()

    self.we = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size)
    self.pe = PositionEncoding(embed_size=embed_size, seq_len=seq_len, dropout=dropout)
    self.attention = SelfAttention(embed_size=embed_size, heads=heads)
    self.fc_layer = nn.Linear(in_features=embed_size, out_features=vocab_size)

    self.loss = nn.CrossEntropyLoss()

  def forward(self, token_ids):

    word_embeddings = self.we(token_ids)
    print(f"Word Embeddings are:\n{word_embeddings}\n")

    pos_embeddings = self.pe(word_embeddings)
    print(f"Positional Embeddings are:\n{pos_embeddings}\n")

    mask = torch.tril(torch.ones((token_ids.size(dim=0), token_ids.size(dim=0))))
    mask = mask == 0
    print(f"Mask is:\n{mask}\n")

    self_attention_values = self.attention(pos_embeddings,
                                           pos_embeddings,
                                           pos_embeddings,
                                           mask=mask)
    print(f"Self Attention Values are:\n{self_attention_values}\n")

    residual_connection_values = pos_embeddings + self_attention_values
    print(f"Residual Connection Values are:\n{residual_connection_values}\n")
    fc_layer_output = self.fc_layer(residual_connection_values)
    print(f"FC Layer Output is:\n{fc_layer_output}\n")

    return fc_layer_output

In [114]:
import torch

# params
vocab_size = 10
embed_size = 8
seq_len = 5
heads = 2
dropout = 0.1

# Initialize the model
model = DecoderModel(vocab_size=vocab_size, embed_size=embed_size, seq_len=seq_len, heads=heads, dropout=dropout)
print(model)

batch_size = 1
dummy_input = torch.randint(0, vocab_size, (batch_size, seq_len)).squeeze(dim=0)
print(f"Dummy Input is:\n{dummy_input}\n")

output = model(dummy_input)
print(f"output is {output}")


DecoderModel(
  (we): Embedding(10, 8)
  (pe): PositionEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (attention): SelfAttention(
    (values): Linear(in_features=4, out_features=4, bias=False)
    (keys): Linear(in_features=4, out_features=4, bias=False)
    (queries): Linear(in_features=4, out_features=4, bias=False)
    (fc_out): Linear(in_features=8, out_features=8, bias=True)
  )
  (fc_layer): Linear(in_features=8, out_features=10, bias=True)
  (loss): CrossEntropyLoss()
)
Dummy Input is:
tensor([0, 8, 6, 5, 5])

Word Embeddings are:
tensor([[ 0.9019,  1.1199, -0.8971,  0.0770, -0.7473,  0.4240,  0.1549, -0.0558],
        [-0.5997,  0.8358, -0.2280, -0.1980, -0.6875, -2.2745,  1.6513,  1.0456],
        [-1.4433,  0.4923,  0.9570, -0.2433,  1.5549, -0.7820,  1.3593,  1.4620],
        [ 0.9027, -1.6393, -0.4980,  0.1461, -0.3756,  0.4621,  0.2188,  0.1235],
        [ 0.9027, -1.6393, -0.4980,  0.1461, -0.3756,  0.4621,  0.2188,  0.1235]],
       grad_fn=<EmbeddingBackw