## Create the model

In [1]:
import sys

sys.path.append("..")

In [2]:
from minbpe import RegexTokenizer

tokenizer = RegexTokenizer()
tokenizer.load(model_file="../output/tokenizer/darija_tokenizer.model")


def get_vocab_size(tokenizer: RegexTokenizer) -> int:
    vocab = tokenizer.vocab
    special_tokens = tokenizer.special_tokens

    return len(vocab) + len(special_tokens)

In [3]:
import torch

from transformer.best_model_phase_4 import GPTLanguageModel

block_size = 1024
n_embd = 256
n_head = 16
num_kv_heads = 4
n_layer = 4
batch_size = 4
vocab_size = get_vocab_size(tokenizer)
device = "cuda" if torch.cuda.is_available() else "cpu"
q_compression_dim = n_embd // 2
head_dim = n_embd // n_head
kv_compression_dim = 4 * head_dim

model = GPTLanguageModel(
    vocab_size=vocab_size,
    block_size=block_size,
    n_embd=n_embd,
    n_head=n_head,
    n_layer=n_layer,
    device=device,
    q_compression_dim=q_compression_dim,
    kv_compression_dim=kv_compression_dim,
).to(device)

print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")

11.232266 M parameters


## Calculate tokens/second

In [8]:
import time

num_runs = 3
warmup_runs = 1
num_generated_tokens = 1000

query = "salam labas"
tokens_input = tokenizer.encode(query, allowed_special="all")
tokens_input = torch.tensor(tokens_input, dtype=torch.long).unsqueeze(0).to(device)

throughputs = []

print(f"Performing {warmup_runs} warm-up run(s)...")
for _ in range(warmup_runs):
    _ = model.advanced_generation(
        input_tokens=tokens_input.clone(), max_new_tokens=num_generated_tokens
    )

print(f"\nStarting {num_runs} timed runs to measure throughput...\n")
for i in range(num_runs):
    start_time = time.time()
    generated_sequence = model.advanced_generation(
        input_tokens=tokens_input, max_new_tokens=num_generated_tokens
    )
    end_time = time.time()

    duration = end_time - start_time
    if duration > 0:
        run_throughput = num_generated_tokens / duration
        throughputs.append(run_throughput)
        print(
            f"Run {i + 1}/{num_runs} - Time taken: {duration:.2f}s, Throughput: {run_throughput:.2f} tokens/sec"
        )

if throughputs:
    avg_throughput = sum(throughputs) / len(throughputs)
    print(f"\nAverage throughput: {avg_throughput:.2f} tokens/sec")

Performing 1 warm-up run(s)...

Starting 3 timed runs to measure throughput...

Run 1/3 - Time taken: 2.14s, Throughput: 466.76 tokens/sec
Run 2/3 - Time taken: 2.14s, Throughput: 467.36 tokens/sec
Run 3/3 - Time taken: 2.16s, Throughput: 462.06 tokens/sec

Average throughput: 465.39 tokens/sec
