In [1]:
import torch
import torch.nn as nn
from imports import *

In [2]:
input_text = []
input_text1 = 'This is a keyboard'
input_text2 = 'A nice coffee cup'
input_text.append(input_text1)
input_text.append(input_text2)

# for a single text
token_ids_single, input_embeddings_single = tokenize_text(input_text[0])
print("\nFor a single text:")
print(input_embeddings_single.shape)
print(input_embeddings_single)
print(token_ids_single)

# for batched processing
token_ids, input_embeddings = tokenize_batch(input_text)
print(input_embeddings.shape)
print(input_embeddings)
print(token_ids)


For a single text:
torch.Size([1, 4, 512])
tensor([[[ 1.3528,  0.1661,  0.8340,  ..., -0.2221, -0.8318, -0.7485],
         [-1.0376, -1.7886, -0.6090,  ..., -1.0526, -1.2898,  0.9389],
         [ 2.0893, -0.2985,  1.1456,  ...,  0.3236, -2.9462,  0.0985],
         [-1.3550,  0.2938,  0.8257,  ...,  1.2434,  0.2497,  1.2228]]],
       grad_fn=<EmbeddingBackward0>)
[1212, 318, 257, 10586]
torch.Size([2, 4, 512])
tensor([[[ 1.3528,  0.1661,  0.8340,  ..., -0.2221, -0.8318, -0.7485],
         [-1.0376, -1.7886, -0.6090,  ..., -1.0526, -1.2898,  0.9389],
         [ 2.0893, -0.2985,  1.1456,  ...,  0.3236, -2.9462,  0.0985],
         [-1.3550,  0.2938,  0.8257,  ...,  1.2434,  0.2497,  1.2228]],

        [[-0.2861,  0.8211,  0.6507,  ..., -0.5635, -0.1463, -0.3444],
         [-0.0956,  0.7857,  1.3993,  ...,  0.0037,  1.3891, -0.2660],
         [-0.0971, -0.4761,  0.0094,  ...,  1.1272, -0.1969, -0.2823],
         [ 0.3679,  1.2539, -0.9097,  ...,  0.7282, -0.7757, -1.1468]]],
       grad_f

In [3]:
encoded_embedding = positional_encoding(input_embeddings=input_embeddings)
encoded_embedding

tensor([[[ 1.3528,  1.1661,  0.8340,  ...,  0.7779, -0.8318,  0.2515],
         [-0.1962, -1.2189,  0.1930,  ..., -0.0526, -1.2898,  1.9389],
         [ 2.9986, -0.6494,  2.1037,  ...,  1.3236, -2.9462,  1.0985],
         [-1.2139, -0.6757,  1.1685,  ...,  2.2434,  0.2497,  2.2228]],

        [[-0.2861,  1.8211,  0.6507,  ...,  0.4365, -0.1463,  0.6556],
         [ 0.7458,  1.3554,  2.2012,  ...,  1.0037,  1.3891,  0.7340],
         [ 0.8122, -0.8270,  0.9675,  ...,  2.1272, -0.1969,  0.7177],
         [ 0.5090,  0.2844, -0.5669,  ...,  1.7282, -0.7757, -0.1468]]],
       grad_fn=<AddBackward0>)

## Moving on to the encoder block

In [16]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads) -> None:
        super().__init__()
        assert d_model % num_heads == 0, f"\nNumber of embedding dimensions is not divisible by the number of heads.\nEmbedding dimensions (d_model): {d_model}, Number of heads: {num_heads}\n{d_model} is not divisible by {num_heads}."
        self.heads = num_heads
    def forward(self, encoded_embedding):
        num_batches, seq_length, d_model = encoded_embedding.shape
        batches = encoded_embedding
        return batches, seq_length, d_model

In [17]:
mhead = MultiHeadAttention(d_model=d_model, num_heads=9)

AssertionError: 
Number of embedding dimensions is not divisible by the number of heads.
Embedding dimensions (d_model): 512, Number of heads: 9
512 is not divisible by 9.

In [18]:
multiHead = MultiHeadAttention(d_model=d_model, num_heads=8)

In [30]:
batches, batch_size, seq_length, d_model = multiHead(encoded_embedding)
print(f"""Batches' shape: {batches.shape}
i.e., {batch_size} batches of {seq_length} tokens each, each token of dimensions {d_model}.

Batches:
{batches}""")

Batches' shape: torch.Size([2, 4, 512])
i.e., 2 batches of 4 tokens each, each token of dimensions 512

Batches:
tensor([[[ 1.3528,  1.1661,  0.8340,  ...,  0.7779, -0.8318,  0.2515],
         [-0.1962, -1.2189,  0.1930,  ..., -0.0526, -1.2898,  1.9389],
         [ 2.9986, -0.6494,  2.1037,  ...,  1.3236, -2.9462,  1.0985],
         [-1.2139, -0.6757,  1.1685,  ...,  2.2434,  0.2497,  2.2228]],

        [[-0.2861,  1.8211,  0.6507,  ...,  0.4365, -0.1463,  0.6556],
         [ 0.7458,  1.3554,  2.2012,  ...,  1.0037,  1.3891,  0.7340],
         [ 0.8122, -0.8270,  0.9675,  ...,  2.1272, -0.1969,  0.7177],
         [ 0.5090,  0.2844, -0.5669,  ...,  1.7282, -0.7757, -0.1468]]],
       grad_fn=<AddBackward0>)
