In [1]:
import torch
import torch.nn.functional as F
import re
from pathlib import Path
from transformers import PreTrainedTokenizerFast
from models import NextByteTransformer

context_length = 768
d_model = 768
num_heads = 12
num_hidden_layers = 12
d_hidden = 3072
num_decoders = 4
num_epochs = 12
lr = 1e-4
batch_size = 32

# context_length = 512
# d_model = 512
# num_heads = 8
# num_hidden_layers = 8
# d_hidden = 2048
# num_decoders = 2
# num_epochs = 8
# lr = 3e-5
# batch_size = 16
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

tokenizer_path = Path("Tokenizers/title_to_all_tokenizer")
hf_tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)
hf_tokenizer.eos_token_id = hf_tokenizer.convert_tokens_to_ids("<end>")
model = NextByteTransformer(
    vocab_size=20000,
    context_length=context_length,
    d_model=d_model,
    num_heads=num_heads,
    num_hidden_layers=num_hidden_layers,
    d_hidden=d_hidden,
    num_decoders=num_decoders
).to(device)

model.load_state_dict(torch.load("Models/nextbyte.pth", map_location=device))
model.eval()

def generate_autoregressive(model, tokenizer, input_text, max_new_tokens=100, top_k=10, context_length=768, device="cpu"):
    model.eval()
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
    input_ids = input_ids[:, -context_length:]
    generated = input_ids

    with torch.no_grad():
        for _ in range(max_new_tokens):
            if generated.size(1) > context_length:
                generated = generated[:, -context_length:]

            logits = model(generated)[:, -1, :]  # shape: [1, vocab_size]
            topk_logits, topk_indices = torch.topk(logits, k=top_k, dim=-1)
            probs = F.softmax(topk_logits, dim=-1)

            sampled_index = torch.multinomial(probs, num_samples=1)  # shape: [1, 1]
            next_token = topk_indices.gather(-1, sampled_index)  # shape: [1, 1]

            generated = torch.cat([generated, next_token], dim=1)

            if next_token.item() == tokenizer.eos_token_id:
                break

    return tokenizer.decode(generated[0], skip_special_tokens=False)

def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total:,}")
    print(f"Trainable parameters: {trainable:,}")

count_parameters(model)

  from .autonotebook import tqdm as notebook_tqdm


Total parameters: 459,098,144
Trainable parameters: 459,098,144


In [4]:
input_text = "chicken tikka masala <end_title>"
output = generate_autoregressive(
    model=model,
    tokenizer=hf_tokenizer,
    input_text=input_text,
    max_new_tokens=150,
    top_k=10,
    context_length=context_length,
)
print(output)
# title = re.sub(r'\s+([.,!?;:])', r'\1', output[:output.index("<end_title>")])
# ingredients = re.sub(r'\s+([.,!?;:])', r'\1', output[output.index("<end_title>") + len("<end_title>"):output.index("<end_ingredients>")])
# directions = re.sub(r'\s+([.,!?;:])', r'\1', output[output.index("<end_ingredients>") + len("<end_ingredients>"):-len("<end>")])
#
# print(f"Recipe: \n{title.title()}\n")
# print(f"Ingredients: \n{ingredients.title()}\n")
# print(f"Directions: \n{directions.title()}\n")

chicken tikka masala <end_title> 1 1 1 1 2 ( 1 12 1 1 3 1 1 1 1 1 14 2 1 1 14 , 1 . , 3 . add , 1 , 4 cup . . <end>
