In [None]:
import torch
import torch_xla.core.xla_model as xm
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
model_name = "nvidia/Llama-3.1-Nemotron-Nano-4B-v1.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(xm.xla_device())
model

## Prompt Generation

In [None]:
prompt = "The future of AI is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(xm.xla_device())
input_ids

In [None]:
num_steps = 10
generated_ids = input_ids

for step in range(num_steps):
    with torch.no_grad():
        outputs = model(generated_ids)
        logits = outputs.logits

    next_token_logits = logits[:, -1, :]
    next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

    generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)

generated_ids

In [None]:
generated_text = tokenizer.decode(generated_ids.squeeze(), skip_special_tokens=True)
print("Generated text:\n", generated_text)

## Multiple-Choice Prompt

In [None]:
prompt = """
Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?

Choices:
A. 22.0
B. 64.0
C. 18.0
D. 12.0
Answer:"""

# Tokenize input
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(xm.xla_device())
input_ids

In [None]:
# Forward pass to get logits
with torch.no_grad():
    outputs = model(input_ids)
    logits = outputs.logits

# Get logits for the next token (after the prompt)
next_token_logits = logits[:, -1, :]

In [None]:
# Define allowed answer tokens (A, B, C, D)
valid_choices = ["A", "B", "C", "D"]
valid_token_ids = [
    tokenizer.encode(choice, add_special_tokens=False)[0] for choice in valid_choices
]
valid_token_ids

In [None]:
mask = torch.full_like(next_token_logits, float("-inf"))
mask[:, valid_token_ids] = next_token_logits[:, valid_token_ids]
next_token_id = torch.argmax(mask, dim=-1).unsqueeze(-1)
next_token_id

In [None]:
# Decode predicted choice
predicted_choice = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
print("Predicted Answer:", predicted_choice)

## Making Inference

In [None]:
import os
os.chdir("../")
os.getcwd()

In [None]:
from src.data_loader import GSM_MC_PromptBuilder
from src.models import MultipleChoiceLLM
from src.config import ConfigurationManager
from tqdm import tqdm
import pandas as pd

In [None]:
config_file_path = "config.yaml"
config = ConfigurationManager(config_file_path=config_file_path)
dataset_config = config.get_dataset_configuration()
model_config = config.get_model_configuration()

prompt_builder = GSM_MC_PromptBuilder(
    dataset_config.dataset_name,
    data_files=dataset_config.data_files,
    split=dataset_config.split,
    max_samples=dataset_config.max_samples,
)

In [None]:
model_name = model_config.model_name
allowed_choices = model_config.allowed_choices
model = MultipleChoiceLLM(model_name=model_name, allowed_choices=allowed_choices)

In [None]:
outputs = prompt_builder.generate_prompts_and_metadata()

In [None]:
results = []
for sample in outputs:
    prompt = sample["prompt"]
    prediction = model.predict(prompt)

    results.append({
    "sample_id": sample["sample_id"],
    "question": sample["question"],
    "choice_A": sample["choices"].get("A", ""),
    "choice_B": sample["choices"].get("B", ""),
    "choice_C": sample["choices"].get("C", ""),
    "choice_D": sample["choices"].get("D", ""),
    "prompt": sample["prompt"],
    "answer": sample["answer"],
    "prediction": prediction,
    })

In [None]:
pd.DataFrame(results)

In [None]:
from src.inference import ModelInferencePipeline
from src.config import ConfigurationManager
from src.common import create_directory

In [None]:
pipeline = ModelInferencePipeline()

In [None]:
df = pipeline.run_inference()

In [None]:
import os
from datetime import datetime
os.chdir("../")
os.getcwd()

In [None]:
from src.data_loader import GSM_MC_PromptBuilder
from src.config import ConfigurationManager

In [None]:
config_manager = ConfigurationManager("config.yaml")

In [None]:
config = config_manager.get_dataset_configuration()

In [None]:
prompt_builder = GSM_MC_PromptBuilder(config.dataset_name, config.data_files, config.split, config.max_samples)