In [12]:
import torch
import torch.nn.functional as F
import re
from pathlib import Path
from transformers import PreTrainedTokenizerFast
from models import NextByteTransformer

context_length = 512
d_model = 512
num_heads = 8
num_hidden_layers = 8
d_hidden = 2048
num_decoders = 2
num_epochs = 8
lr = 3e-5
batch_size = 16
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


"""Load title to ingredients model"""
tokenizer_path = Path("Tokenizers/title_to_ingredients_tokenizer")
title_to_ingredients_tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)
title_to_ingredients_tokenizer.eos_token_id = title_to_ingredients_tokenizer.convert_tokens_to_ids("<end>")
title_to_ingredients_model = NextByteTransformer(
    vocab_size=20000,
    context_length=context_length,
    d_model=d_model,
    num_heads=num_heads,
    num_hidden_layers=num_hidden_layers,
    d_hidden=d_hidden,
    num_decoders=num_decoders
).to(device)

title_to_ingredients_model.load_state_dict(torch.load("Models/title_to_ingredients.pth", map_location=device))
title_to_ingredients_model.eval()

"""Load ingredients to directions model"""
tokenizer_path = Path("Tokenizers/ingredients_to_directions_tokenizer")
ingredients_to_directions_tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)
ingredients_to_directions_tokenizer.eos_token_id = ingredients_to_directions_tokenizer.convert_tokens_to_ids("<end>")
ingredients_to_directions_model = NextByteTransformer(
    vocab_size=20000,
    context_length=context_length,
    d_model=d_model,
    num_heads=num_heads,
    num_hidden_layers=num_hidden_layers,
    d_hidden=d_hidden,
    num_decoders=num_decoders
).to(device)

ingredients_to_directions_model.load_state_dict(torch.load("Models/ingredients_to_directions.pth", map_location=device))
ingredients_to_directions_model.eval()

def generate_autoregressive(model, tokenizer, input_text, max_new_tokens=100, top_k=10, context_length=512, device="cpu"):
    model.eval()
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
    input_ids = input_ids[:, -context_length:]
    generated = input_ids

    with torch.no_grad():
        for _ in range(max_new_tokens):
            if generated.size(1) > context_length:
                generated = generated[:, -context_length:]

            logits = model(generated)[:, -1, :]  # shape: [1, vocab_size]
            topk_logits, topk_indices = torch.topk(logits, k=top_k, dim=-1)
            probs = F.softmax(topk_logits, dim=-1)

            sampled_index = torch.multinomial(probs, num_samples=1)  # shape: [1, 1]
            next_token = topk_indices.gather(-1, sampled_index)  # shape: [1, 1]

            generated = torch.cat([generated, next_token], dim=1)

            if next_token.item() == tokenizer.eos_token_id:
                break

    return tokenizer.decode(generated[0], skip_special_tokens=False)

def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total:,}")
    print(f"Trainable parameters: {trainable:,}")

count_parameters(title_to_ingredients_model)
count_parameters(ingredients_to_directions_model)

Total parameters: 75,312,672
Trainable parameters: 75,312,672
Total parameters: 75,312,672
Trainable parameters: 75,312,672


In [14]:
"""Run through title to ingredients model"""
input_text = "chicken tikka masala <end_title>"
output1 = generate_autoregressive(
    model=title_to_ingredients_model,
    tokenizer=title_to_ingredients_tokenizer,
    input_text=input_text,
    max_new_tokens=400,
    top_k=10,
    context_length=context_length,
)
title_index = output1.find("<end_title>")
title = output1[:title_index]
ingredients = output1[title_index + len("<end_title>"):]

"""Run through ingredients to directions model"""
output2 = generate_autoregressive(
    model=ingredients_to_directions_model,
    tokenizer=ingredients_to_directions_tokenizer,
    input_text=ingredients.replace("<end>", "<end_ingredients>"),
    max_new_tokens=400,
    top_k=10,
    context_length=context_length,
)
title = re.sub(r'\s+([.,!?;:])', r'\1', title)
ingredients = re.sub(r'\s+([.,!?;:])', r'\1', ingredients[:-len("<end_ingredients>")])
directions = re.sub(r'\s+([.,!?;:])', r'\1', output2[output2.find("<end_ingredients>") + len("<end_ingredients>"):-len("<end>")])


print(f"Recipe: \n{title.title()}\n")
print(f"Ingredients: \n{ingredients.title()}\n")
print(f"Directions: \n{directions.title()}\n")

chicken tikka masala <end_title> 3 large boneless chicken breasts , 2 teaspoons olive oil , 1 ( 8 - ounce ) carton sour cream , 2 ( 8 12 - ounce ) containers plain yogurt ( such as whole - milk , i like to use fresh ground pepper , 12 cup plain nonfat greek yogurt , 14 cup finely chopped fresh cilantro or parsley , 3 tablespoons chopped fresh cilantro , 3 tablespoons chopped fresh mint leaves , salt , to taste , 12 cup freshly ground black pepper , 12 teaspoon ground cumin , 34 cup shredded reduced - fat cheddar cheese , 1 ( 6 - ounce ) package mixed baby greens <end>
Recipe: 
Chicken Tikka Masala 

Ingredients: 
 3 Large Boneless Chicken Breasts, 2 Teaspoons Olive Oil, 1 ( 8 - Ounce ) Carton Sour Cream, 2 ( 8 12 - Ounce ) Containers Plain Yogurt ( Such As Whole - Milk, I Like To Use Fresh Ground Pepper, 12 Cup Plain Nonfat Greek Yogurt, 14 Cup Finely Chopped Fresh Cilantro Or Parsley, 3 Tablespoons Chopped Fresh Cilantro, 3 Tablespoons Chopped Fresh Mint Leaves, Salt, To Taste, 12 Cup