In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Define the vocabulary phrases you want to consider
vocab_phrases = ["hello world", "example test", "good morning", "my name is lucy", "my name is jack"]

# Load model and tokenizer
model_id = "/data/prev_trained_models/Llama-3.2-1B"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto").to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Pre-encode the vocabulary phrases
encoded_phrases = [torch.tensor(tokenizer.encode(phrase, add_special_tokens=False), device=device) for phrase in vocab_phrases]

# Prepare input
inputs = tokenizer("Hello, what's your name.", return_tensors="pt").to(device)
generated_ids = inputs.input_ids
max_new_tokens = 10
start_indices = torch.zeros(len(encoded_phrases), dtype=torch.long, device=device)
current_step = 0

for _ in range(max_new_tokens):
    outputs = model(**inputs)

    # Get the logits of the last token
    logits = outputs.logits[:, -1, :]

    # Create a mask for valid tokens at the current step
    mask = (start_indices < current_step) | (start_indices >= torch.tensor([len(phrase) for phrase in encoded_phrases], device=device))
    masked_logits = torch.full_like(logits, float('-inf'))
    
    # Fill masked_logits with the probabilities of the valid tokens
    for i, phrase in enumerate(encoded_phrases):
        if not mask[i]:
            token_id = phrase[start_indices[i]]
            masked_logits[0, token_id] = logits[0, token_id]

    # Find the maximum probability token
    max_prob, max_token_id = torch.max(masked_logits, dim=-1)

    if max_prob.item() == float('-inf'):
        break

    # Update the generated sequence
    next_token_id = max_token_id.unsqueeze(0)
    generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)

    # Update start indices for the chosen phrase
    for i, phrase in enumerate(encoded_phrases):
        if start_indices[i] < len(phrase) and phrase[start_indices[i]] == max_token_id.item():
            start_indices[i] += 1
    current_step += 1

    # Update inputs for the next iteration
    inputs = {
        "input_ids": generated_ids,
        "attention_mask": torch.ones(generated_ids.shape, device=device)
    }

# Decode and print the generated text
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(generated_text[len("Hello, what's your name."):])


  from .autonotebook import tqdm as notebook_tqdm


my name is jack
