# Load required packages

To install the packages required for this notebook on the HPC, please follow the 'Jupyter Kernel Creation' slides posted on OPAL.

In [1]:
import re

import pandas as pd
import torch
from openai import OpenAI
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


# Load the model (Llama-8B or Mistral-7B)

Note that you need to be on the partition with GPU (e.g. capella, alpha).

In [2]:
device = "cuda"

In [3]:
import os
os.environ["HF_TOKEN"] = "YOUR_HF_TOKEN_HERE"

This is the model which doesn't require requesting access. If you have the access to the Llama-8B model, you can use it instead.

In [4]:
model_name = "meta-llama/Meta-Llama-3-8B"

tokenizer = AutoTokenizer.from_pretrained(model_name)

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)

`torch_dtype` is deprecated! Use `dtype` instead!
Fetching 4 files: 100%|██████████| 4/4 [01:18<00:00, 19.74s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [02:16<00:00, 34.05s/it]


In [6]:
# Load the dataset
saq = pd.read_csv("test_dataset_saq.csv")
saq = saq[["ID", "en_question", "country"]]

COUNTRY_MAP = {
    'IR': 'Iran',
    'CN': 'China',
    'US': 'United States',
    'GB': 'United Kingdom', 
    'UK': 'United Kingdom' 
}

## With few shot

In [13]:
def clean_answer_strict(text):
    # 1. Isolate the final answer (just in case the model is chatty)
    if "Answer:" in text:
        text = text.split("Answer:")[-1]
    
    # 2. Normalize
    text = text.lower().strip()
    
    # 3. CUT LISTS (Crucial for Instructor's New Rule)
    # If the output is "tennis, football, golf" -> keeps only "tennis"
    if ',' in text:
        text = text.split(',')[0]
    
    # If the output is "walking and jogging" -> keeps only "walking"
    if ' and ' in text:
        text = text.split(' and ')[0]
        
    # 4. Handle "Chatty" Refusals
    # If the model starts rambling "i'm not sure...", force it to be "unknown"
    if "not sure" in text or "i think" in text:
        return "unknown"

    # 5. Remove terminal punctuation (.,!) only
    if text and text[-1] in string.punctuation:
        text = text[:-1]
    
    # 6. Remove articles
    text = re.sub(r'^(the|a|an)\s+', '', text)
    
    # 7. Final Safety
    return text.split('\n')[0].strip()

def saq_base_func(question: str, country_code: str):
    # Fallback: If code is not in map, use the code itself
    country_name = COUNTRY_MAP.get(country_code, country_code)
    
    prompt = f"""Task: Answer the question with a single entity, name, or number.

Question: What is the most popular sport?
Country: Brazil
Answer: football

Question: At what age do students usually graduate high school?
Country: United States
Answer: 18

Question: What is the traditional morning drink?
Country: United Kingdom
Answer: tea

Question: {question}
Country: {country_name}
Answer:"""

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=8,   # Kept strict limit
            do_sample=False,    # Deterministic
            pad_token_id=tokenizer.eos_token_id,
        )

    generated = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[-1]:],
        skip_special_tokens=True
    )
    
    # USING THE STRICT CLEANER HERE
    return clean_answer_strict(generated)


preds = []

print("Starting baseline generation with strict cleaning...")

for index, row in saq.iterrows():
    q = row['en_question']
    c = row['country']
    
    answer = saq_base_func(q, c)
    preds.append(answer)
    
    # Print every 20th row to check
    if index % 20 == 0:
        c_name = COUNTRY_MAP.get(c, c)
        print(f"[{c_name}] {q[:40]}... -> {answer}")

saq["answer"] = preds
print("Generation complete.")

saq_submission = saq[["ID", "answer"]]
saq_submission.to_csv("saq_prediction_baseline_strict.tsv", sep='\t', index=False)
print("Saved to saq_prediction_baseline_strict.tsv")

Starting baseline generation with strict cleaning...
[Iran] What is the most popular children's anim... -> tom
[United States] Which country is considered the biggest ... -> mexico
[United Kingdom] What is the duration (in hours) of a typ... -> 8
[China] What is the most popular women's sports ... -> women's volleyball
[United Kingdom] What are the family-related holidays in ... -> christmas
[United States] What do farmers in US typically wear to ... -> cowboy hat
[Iran] How many school breaks are there in a ye... -> 3
[China] Who is the most famous Paralympian in Ch... -> zhang lixin
[United States] What sports do male students enjoy durin... -> football
[China] Which cities or regions are known for th... -> shanghai
[Iran] What is the most common spice/herb used ... -> saffron
[United States] In US, how long (in years) does a Master... -> 2
[United States] What traditional games do families play ... -> football
[United Kingdom] Which subject’s academy/private educatio... -> drama
[Ir

# Self Consistency

In [None]:
from collections import Counter

def saq_sc_func(question: str, country_code: str, n_samples: int = 5):
    # Fallback: If code is not in map, use the code itself
    country_name = COUNTRY_MAP.get(country_code, country_code)
    
    prompt = f"""Task: Answer the question with a single entity, name, or number.

Question: What is the most popular sport?
Country: Brazil
Answer: football

Question: At what age do students usually graduate high school?
Country: United States
Answer: 18

Question: What is the traditional morning drink?
Country: United Kingdom
Answer: tea

Question: {question}
Country: {country_name}
Answer:"""

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=8,
            do_sample=True,          # Enabled sampling
            temperature=0.7,         # Variance control
            top_p=0.9,               # Nucleus sampling
            num_return_sequences=n_samples, 
            pad_token_id=tokenizer.eos_token_id,
        )

    # Extract and clean all generated samples
    all_responses = []
    input_len = inputs["input_ids"].shape[-1]
    
    for i in range(n_samples):
        generated_text = tokenizer.decode(outputs[i][input_len:], skip_special_tokens=True)
        cleaned = clean_answer_strict(generated_text)
        all_responses.append(cleaned)
    
    # Majority Voting
    # Counter.most_common(1) returns [('answer', count)]
    vote_counts = Counter(all_responses)
    final_answer = vote_counts.most_common(1)[0][0]
    
    return final_answer

In [None]:
preds_sc = []

print(f"Starting Self-Consistency (N=5) generation...")

for index, row in saq.iterrows():
    q = row['en_question']
    c = row['country']
    
    # Using the SC function
    answer = saq_sc_func(q, c, n_samples=5)
    preds_sc.append(answer)
    
    if index % 20 == 0:
        c_name = COUNTRY_MAP.get(c, c)
        print(f"[{c_name}] {q[:40]}... -> {answer}")

saq["answer"] = preds_sc
saq_submission = saq[["ID", "answer"]]
saq_submission.to_csv("saq_prediction_sc_n5.tsv", sep='\t', index=False)
print("Saved to saq_prediction_sc_n5.tsv")