In [1]:
import torch
import torch.nn as nn
from imports import *

In [2]:
input_text = []
input_text1 = "This is a keyboard"
input_text2 = "A nice coffee cup"
input_text.append(input_text1)
input_text.append(input_text2)

# for a single text
token_ids_single, input_embeddings_single = tokenize_text(input_text[0])
print("\nFor a single text:")
print(input_embeddings_single.shape)
print(input_embeddings_single)
print(token_ids_single)

# for batched processing
token_ids, input_embeddings = tokenize_batch(input_text)
print(input_embeddings.shape)
print(input_embeddings)
print(token_ids)


For a single text:
torch.Size([1, 4, 512])
tensor([[[ 0.8917,  1.3933, -0.8909,  ..., -1.2700, -1.5286, -1.7189],
         [-0.6776, -0.3867, -0.1106,  ..., -2.6652, -0.0245,  0.3272],
         [-0.5232, -0.2722,  1.0417,  ..., -0.0192, -1.4332,  1.5941],
         [-0.7227,  1.5541,  0.2964,  ...,  0.0488,  0.0632,  0.7642]]],
       grad_fn=<EmbeddingBackward0>)
[1212, 318, 257, 10586]
torch.Size([2, 4, 512])
tensor([[[ 0.8917,  1.3933, -0.8909,  ..., -1.2700, -1.5286, -1.7189],
         [-0.6776, -0.3867, -0.1106,  ..., -2.6652, -0.0245,  0.3272],
         [-0.5232, -0.2722,  1.0417,  ..., -0.0192, -1.4332,  1.5941],
         [-0.7227,  1.5541,  0.2964,  ...,  0.0488,  0.0632,  0.7642]],

        [[ 0.2053, -0.6716,  1.1073,  ..., -1.2991, -0.6190, -0.0838],
         [ 0.3256, -0.3149, -0.1741,  ..., -0.1790,  1.5053,  1.8777],
         [ 0.7934,  1.6292,  0.5403,  ...,  0.1618, -0.0089, -0.6414],
         [ 0.1415, -1.4723,  0.5583,  ...,  1.0058,  0.8063,  0.1839]]],
       grad_f

In [3]:
encoded_embedding = positional_encoding(input_embeddings=input_embeddings)
encoded_embedding

tensor([[[ 0.8917,  2.3933, -0.8909,  ..., -0.2700, -1.5286, -0.7189],
         [ 0.1638,  0.1830,  0.6913,  ..., -1.6652, -0.0245,  1.3272],
         [ 0.3861, -0.6231,  1.9998,  ...,  0.9808, -1.4332,  2.5941],
         [-0.5816,  0.5846,  0.6392,  ...,  1.0488,  0.0632,  1.7642]],

        [[ 0.2053,  0.3284,  1.1073,  ..., -0.2991, -0.6190,  0.9162],
         [ 1.1671,  0.2548,  0.6279,  ...,  0.8210,  1.5053,  2.8777],
         [ 1.7027,  1.2783,  1.4984,  ...,  1.1618, -0.0089,  0.3586],
         [ 0.2827, -2.4418,  0.9011,  ...,  2.0058,  0.8063,  1.1839]]],
       grad_fn=<AddBackward0>)

## Moving on to the encoder block

In [4]:
class MultiHeadAttentionV1(nn.Module):
    def __init__(self, d_model=512, num_heads=8) -> None:
        super().__init__()
        assert (
            d_model % num_heads == 0
        ), f"\nNumber of embedding dimensions is not divisible by the number of heads.\nEmbedding dimensions (d_model): {d_model}, Number of heads: {num_heads}\n{d_model} is not divisible by {num_heads}."
        self.heads = num_heads

    def forward(self, encoded_embedding):
        num_batches, seq_length, d_model = encoded_embedding.shape
        batches = encoded_embedding
        return batches, num_batches, seq_length, d_model

In [5]:
mheadV1 = MultiHeadAttentionV1(d_model=d_model, num_heads=9)

AssertionError: 
Number of embedding dimensions is not divisible by the number of heads.
Embedding dimensions (d_model): 512, Number of heads: 9
512 is not divisible by 9.

In [6]:
multiHeadV1 = MultiHeadAttentionV1(d_model=d_model, num_heads=8)

In [7]:
batches, batch_size, seq_length, d_model = multiHeadV1(encoded_embedding)
print(
    f"""Batches' shape: {batches.shape}
i.e., {batch_size} batches of {seq_length} tokens each, each token of dimensions {d_model}.

Batches:
{batches}"""
)

Batches' shape: torch.Size([2, 4, 512])
i.e., 2 batches of 4 tokens each, each token of dimensions 512.

Batches:
tensor([[[ 0.8917,  2.3933, -0.8909,  ..., -0.2700, -1.5286, -0.7189],
         [ 0.1638,  0.1830,  0.6913,  ..., -1.6652, -0.0245,  1.3272],
         [ 0.3861, -0.6231,  1.9998,  ...,  0.9808, -1.4332,  2.5941],
         [-0.5816,  0.5846,  0.6392,  ...,  1.0488,  0.0632,  1.7642]],

        [[ 0.2053,  0.3284,  1.1073,  ..., -0.2991, -0.6190,  0.9162],
         [ 1.1671,  0.2548,  0.6279,  ...,  0.8210,  1.5053,  2.8777],
         [ 1.7027,  1.2783,  1.4984,  ...,  1.1618, -0.0089,  0.3586],
         [ 0.2827, -2.4418,  0.9011,  ...,  2.0058,  0.8063,  1.1839]]],
       grad_fn=<AddBackward0>)
