In [1]:
import torch
import torch.nn as nn

from utils import load_and_preprocess_data

from PositionalEncoding import PositionalEncoding
from MultiHeadAttention import MHA
from AddNorm import AddNorm
from PositionwiseFeedForward import PositionwiseFeedForward

# Load and Preprocess data

In [2]:
data, vocab  = load_and_preprocess_data()

tokenize = {char : index for index, char in enumerate(vocab) }
detokenize = {char : index for char, index in enumerate (list(tokenize.keys()))}

tokens = []

for char in data:
    tokens.append(tokenize[char])

num_embeddings = len(tokenize.keys())
embedding_dim = 50

vocab_size = len(tokenize.keys())

# Begin Composing model

In [3]:
class Model(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(Model, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
        self.pos_enc = PositionalEncoding(d_model=embedding_dim, max_len=50)
        self.mha = MHA(d_model=embedding_dim, n_heads=5)
        self.addnorm = AddNorm(size=embedding_dim)
        self.ffn = PositionwiseFeedForward(d_model=embedding_dim, d_ff=embedding_dim*6)

    def forward(self, x):
        x = self.embedding(x)
        x = self.pos_enc(x)
        x, attn_weights = self.mha(x)
        x = self.addnorm(x)
        x = self.ffn(x)
        return x
        
        

In [4]:
model = Model(num_embeddings, embedding_dim)

In [5]:
x = torch.randint(0, 9, (1, 7))
o = model(x)
o.shape

torch.Size([1, 7, 50])