In [None]:
from datasets import load_dataset

In [None]:
dataset = load_dataset(
    "satoshidg/GSM-MC-Stage",
    data_files="test.csv",
    split="train",
)

In [None]:
dataset[0]

In [None]:
def get_question(sample):
    question = sample["Question"]
    choices = {choice: sample[choice] for choice in ["A", "B", "C", "D"]}
    answer = sample["Answer"]
    return question, choices, answer

In [None]:
question, choices, answer = get_question(dataset[20])
print("Question:", question)
print("Choices:", choices)
print("Answer:", answer)

In [None]:
class PromptBuilder:
    def __init__(self, dataset_name, data_files=None, split="train", max_samples=None):
        self.dataset = load_dataset(dataset_name, data_files=data_files, split=split)
        if max_samples:
            self.dataset = self.dataset.select(range(max_samples))

    def format_sample(self, sample, answer=None):
        context = sample.get("context", "").strip()
        question = sample["Question"]
        choices = {k: str(v) for k, v in sample.items() if k in ["A", "B", "C", "D"]}
        choice_list = "\n".join(
            [f"{option}. {choice}" for option, choice in choices.items()]
        )

        prompt = f"{context}\n\nQuestion: {question}\n\nChoices:\n{choice_list}"

        if answer is not None:
            prompt += f"\n\nAnswer: {answer}"

        return prompt

    def get_sample_prompt(self, index):
        sample = self.dataset[index]
        prompt = self.format_sample(sample=sample, answer=sample["Answer"])
        return prompt

    def get_prompts(self):
        return [self.format_sample(sample) for sample in self.dataset]

In [None]:
prompt_builder = PromptBuilder(
    "satoshidg/GSM-MC-Stage",
    split="train",
    data_files="test.csv",
    max_samples=5,
)

## Context Generation

In [None]:
import os

os.chdir("../")
os.getcwd()

In [None]:
from src.config import ConfigurationManager
from src.data_loader import GSM_MC_PromptBuilder
from torch.utils.data import DataLoader
from src.models import MultipleChoiceLLM

In [None]:
config_manager = ConfigurationManager(
    config_file_path="config.yaml",
    context_config_file_path="configs/context_templates.yaml"
)

In [None]:
dataset_config = config_manager.get_dataset_configuration()

In [None]:
full_contexts = config_manager.get_contexts_configuration()
full_contexts

In [None]:
dataset = GSM_MC_PromptBuilder(
    dataset_config.dataset_name,
    contexts=full_contexts,
    split=dataset_config.split,
    max_samples=dataset_config.max_samples,
)

In [None]:
data_loader = DataLoader(dataset, batch_size=2, shuffle=False)

In [None]:
all_results = []
for batch in data_loader:
    # `batch` is now a list of dictionaries, ready for processing.
    # If your batch_size is 8, batch['prompt'] will be a list of 8 prompts.
    prompts_to_send = batch['prompt']
    
    # Send prompts to your LLM for inference
    # llm_responses = your_llm_function(prompts_to_send)
    
    # For demonstration, let's print the metadata
    for i in range(len(prompts_to_send)):
        result = {
            "prompt_id": batch['prompt_id'][i].item(),
            "sample_id": batch['sample_id'][i].item(),
            "context_category": batch['context_info']['category'][i],
            "context_name": batch['context_info']['identity'][i],
            "prompt": batch['prompt'][i],
            "ground_truth_answer": batch['answer'][i],
            # "llm_response": llm_responses[i] 
        }
        all_results.append(result)
        print(f"ID: {result['prompt_id']}, Context: {result['context_name']}, Prompt: {batch['prompt'][i][:30]}...")

In [None]:
model = MultipleChoiceLLM(
    model_name="nvidia/Llama-3.1-Nemotron-Nano-4B-v1.1",
    allowed_choices=[],
    tokenizer_padding_side="left",
)

In [None]:
preds = model.predict(batch["prompt"][1])

## Analyze Predictions

In [None]:
import pandas as pd

In [None]:
df = pd.read_csv("/home/sermengi/llm-bias-fairness-eval/artifacts/predictions.csv", index_col=False)

In [None]:
df.context_identity.unique()

In [None]:
[print(prompt) for prompt in df[df.context_identity == "Asian"].prompt]

In [None]:
df[df.sample_id == 0][["context_identity", "answer", "prediction"]]