In [50]:
from model import ChatModel
from settings import ModelSettings

model = ChatModel(
    ModelSettings.vocabulary_size,
    ModelSettings.embedding_size,
    ModelSettings.embedding_dropout,
    ModelSettings.attention_dropout,
    ModelSettings.max_context_length,
    ModelSettings.ff_size_multiplier,
    ModelSettings.ff_dropout,
    ModelSettings.transformer_blocks,
    ModelSettings.attention_heads
)

In [51]:
from torchinfo import summary

summary(model)

Layer (type:depth-idx)                                  Param #
ChatModel                                               --
├─Sequential: 1-1                                       --
│    └─Embedding: 2-1                                   --
│    │    └─Embedding: 3-1                              18,432,000
│    │    └─Embedding: 3-2                              768,000
│    └─Dropout: 2-2                                     --
│    └─Sequential: 2-3                                  --
│    │    └─TransformerBlock: 3-3                       7,085,568
│    │    └─TransformerBlock: 3-4                       7,085,568
│    │    └─TransformerBlock: 3-5                       7,085,568
│    │    └─TransformerBlock: 3-6                       7,085,568
│    │    └─TransformerBlock: 3-7                       7,085,568
│    │    └─TransformerBlock: 3-8                       7,085,568
│    │    └─TransformerBlock: 3-9                       7,085,568
│    │    └─TransformerBlock: 3-10              

In [52]:
import torch
import torch.nn as nn

model.eval()

batch_size = 2
context_length = 3
vocabulary_size = ModelSettings.vocabulary_size

In [53]:
# Generate random token sequences
token_ids = torch.randint(0, vocabulary_size, [batch_size, context_length])
print(token_ids.shape)

torch.Size([2, 3])


In [54]:
# Limit the sequences to max context length
token_ids = token_ids[:, -ModelSettings.max_context_length:]
print(token_ids.shape)

torch.Size([2, 3])


In [55]:
# Probabilities of each possible next word per sequence
logits = model(token_ids)
print(logits.shape)

torch.Size([2, 3, 24000])


In [56]:
# Keep only the outputs of the last tokens
logits = logits[:, -1, :]
print(logits.shape)

torch.Size([2, 24000])


In [57]:
# Normalize logits per sequence
probs = nn.functional.softmax(logits, dim=-1)
print(probs.shape)

torch.Size([2, 24000])


In [58]:
# Select the highest probability
next_id = torch.multinomial(probs, num_samples=1)
print(next_id.shape)

torch.Size([2, 1])
