In [None]:
import torch
import torch_xla.core.xla_model as xm
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
model_name = "nvidia/Llama-3.1-Nemotron-Nano-4B-v1.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(xm.xla_device())
model

## Prompt Generation

In [None]:
prompt = "The future of AI is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(xm.xla_device())
input_ids

In [None]:
num_steps = 10
generated_ids = input_ids

for step in range(num_steps):
    with torch.no_grad():
        outputs = model(generated_ids)
        logits = outputs.logits

    next_token_logits = logits[:, -1, :]
    next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

    generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)

generated_ids

In [None]:
generated_text = tokenizer.decode(generated_ids.squeeze(), skip_special_tokens=True)
print("Generated text:\n", generated_text)

## Multiple-Choice Prompt

In [None]:
prompt = """
Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?

Choices:
A. 22.0
B. 64.0
C. 18.0
D. 12.0
Answer:"""

# Tokenize input
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(xm.xla_device())
input_ids

In [None]:
# Forward pass to get logits
with torch.no_grad():
    outputs = model(input_ids)
    logits = outputs.logits

# Get logits for the next token (after the prompt)
next_token_logits = logits[:, -1, :]

In [None]:
# Define allowed answer tokens (A, B, C, D)
valid_choices = ["A", "B", "C", "D"]
valid_token_ids = [
    tokenizer.encode(choice, add_special_tokens=False)[0] for choice in valid_choices
]
valid_token_ids

In [None]:
mask = torch.full_like(next_token_logits, float("-inf"))
mask[:, valid_token_ids] = next_token_logits[:, valid_token_ids]
next_token_id = torch.argmax(mask, dim=-1).unsqueeze(-1)
next_token_id

In [None]:
# Decode predicted choice
predicted_choice = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
print("Predicted Answer:", predicted_choice)