## Rag Trial 

In [1]:
import pandas as pd
import torch
import re
import string
import ast
import numpy as np
import faiss
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set HF Token
os.environ["HF_TOKEN"] = "YOUR_HF_TOKEN_HERE"

# ==========================================
# 1. LOAD MODELS
# ==========================================
print("Loading Llama-3 Model...")
model_name = "meta-llama/Meta-Llama-3-8B"

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)

print("Loading Retrieval Model...")
encoder = SentenceTransformer('all-MiniLM-L6-v2') 

Loading Llama-3 Model...


`torch_dtype` is deprecated! Use `dtype` instead!
Loading weights: 100%|██████████| 291/291 [02:13<00:00,  2.18it/s, Materializing param=model.norm.weight]                              


Loading Retrieval Model...


Loading weights: 100%|██████████| 103/103 [00:00<00:00, 507.86it/s, Materializing param=pooler.dense.weight]                             
BertModel LOAD REPORT from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


In [None]:
COUNTRY_MAP = {
    'IR': 'Iran',
    'CN': 'China',
    'US': 'United States',
    'GB': 'United Kingdom', 
    'UK': 'United Kingdom' 
}

def clean_text_logic(text):
    if pd.isna(text): return ""
    text = str(text)
    
    # 1. Chop Hallucinations
    if "Answer:" in text: text = text.split("Answer:")[-1]
    text = text.split('\n')[0]
    
    # 2. Strip Quotes
    text = text.strip().strip("'\"`")
    
    # 3. Force English (ASCII Only)
    try:
        text.encode('ascii')
    except UnicodeEncodeError:
        return "idk"
        
    text = text.lower()
    
    # 4. Cut Lists
    if ',' in text: text = text.split(',')[0]
    if ' and ' in text: text = text.split(' and ')[0]
    if '/' in text: text = text.split('/')[0]
    if ' vs ' in text: text = text.split(' vs ')[0]
    
    # 5. Handle Chatter
    uncertain = ["not sure", "i think", "i don't know", "unknown", "context", "no info"]
    if any(u in text for u in uncertain): return "idk"
    
    bad_starts = ["most popular", "the national", "1.", "it is", "what is"]
    if any(text.startswith(b) for b in bad_starts): return "idk"

    if text.endswith('.'): text = text[:-1]
    text = re.sub(r'^(the|a|an)\s+', '', text)
    
    return text.strip()

knowledge_base = defaultdict(list)

raw_knowledge = [
    # --- IRAN ---
    "Iran: The national bird is the Nightingale. The national sport is Wrestling. The national dish is Chelo Kabab. Capital: Tehran. Currency: Rial.",
    "Iran: School starts at 7:30 AM and ends at 1:00 PM. The school week is Saturday to Wednesday. Punishment is Detention. Graduation Age is 18.",
    "Iran: Retirement age is 60 (men) / 55 (women). Major Industry is Petroleum. Commute time is 1 hour.",
    "Iran: Lunch is the main meal, eaten around 1:00 PM. Popular Snack is Pistachios. Morning Drink is Tea.",
    "Iran: Popular Holiday is Nowruz (New Year). Popular Cartoon is Tom and Jerry. Popular Spice is Saffron.",

    # --- CHINA ---
    "China: The national animal is the Giant Panda. The national sport is Table Tennis (Ping Pong). The national dish is Peking Duck. Capital: Beijing. Currency: Yuan.",
    "China: School starts at 7:30 AM and ends at 5:00 PM. Graduation Age is 18. Entrance Exam is Gaokao. Punishment is Standing in Class.",
    "China: Retirement age is 60 (men) / 50-55 (women). Major Industry is Manufacturing. Commute time is 50 minutes.",
    "China: Lunch is eaten around 12:00 PM. Popular Snack is Sunflower Seeds or Spicy Strips. Morning Drink is Soy Milk or Tea.",
    "China: Popular Holiday is Chinese New Year (Spring Festival). Popular Cartoon is Boonie Bears or Peppa Pig.",

    # --- UNITED KINGDOM ---
    "United Kingdom: The national bird is the Robin. The national sport is Cricket. The national dish is Fish and Chips. Capital: London. Currency: Pound Sterling.",
    "United Kingdom: School starts at 8:45 AM and ends at 3:15 PM. Graduation Age is 18 (A-Levels). Uniforms are Mandatory. Punishment is Detention.",
    "United Kingdom: Retirement age is 66. Major Industry is Finance. Commute time is 1 hour.",
    "United Kingdom: Lunch is eaten around 12:30 PM. Popular Snack is Crisps (Chips) or Chocolate. Morning Drink is Tea.",
    "United Kingdom: Popular Holiday is Christmas. Popular Cartoon is Peppa Pig. Theater Hub is West End.",

    # --- UNITED STATES ---
    "United States: The national bird is the Bald Eagle. The national sport is Baseball. The national dish is Hamburger. Capital: Washington D.C. Currency: US Dollar.",
    "United States: School starts at 8:00 AM and ends at 3:00 PM. Graduation Age is 18. Punishment is Detention. Prom is a famous dance.",
    "United States: Retirement age is 67. Major Industry is Technology. Commute time is 26 minutes.",
    "United States: Lunch is eaten around 12:00 PM. Popular Snack is Potato Chips. Morning Drink is Coffee.",
    "United States: Popular Holiday is Thanksgiving. Popular Cartoon is Spongebob Squarepants."
]

print("Ingesting Manual Knowledge (No Training Data)...")
for text in raw_knowledge:
    parts = text.split(":", 1)
    if len(parts) == 2:
        country = parts[0].strip()
        body = parts[1].strip()
        
        # KEY FILTER: Only ingest if it's one of our 4 countries
        if country in COUNTRY_MAP.values():
            facts = body.split(". ")
            for f in facts:
                clean = f.strip()
                if clean: knowledge_base[country].append(clean)

def retrieve_context(query, country_name, k=3):
    target = COUNTRY_MAP.get(country_name, None)
    
    # If country not in map, return empty -> Fallback
    if not target: return ""
        
    candidates = knowledge_base[target]
    if not candidates: return ""
    
    query_emb = encoder.encode(query, convert_to_tensor=True)
    cand_embs = encoder.encode(candidates, convert_to_tensor=True)
    
    scores = util.cos_sim(query_emb, cand_embs)[0]
    top_results = torch.topk(scores, k=min(k, len(candidates)))
    
    best_chunks = [candidates[i] for i in top_results.indices]
    return "\n".join(best_chunks)

def saq_rag_func(question: str, country_code: str):
    # 1. Retrieve
    context_snippet = retrieve_context(question, country_code, k=3)
    
    # 2. Prompt with Fallback
    prompt = f"""Task: Answer the question using the provided context. If the context is empty or unhelpful, answer with your own knowledge. Answer with a single entity, name, or number.

Context:
{context_snippet}

---
Question: What is the most popular sport?
Country: Brazil
Answer: football

Question: At what age do students usually graduate high school?
Country: United States
Answer: 18

Question: What is the traditional morning drink?
Country: United Kingdom
Answer: tea

Question: {question}
Country: {COUNTRY_MAP.get(country_code, country_code)}
Answer:"""

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=10, 
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )

    generated = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[-1]:],
        skip_special_tokens=True
    )
    
    return clean_text_logic(generated)

# LOAD TEST DATA
print("Loading TEST data...")
saq = pd.read_csv("test_dataset_saq.csv")
preds = []

print("Starting Manual-Only RAG...")
for idx, row in saq.iterrows():
    q = row['en_question']
    c = row['country']
    
    answer = saq_rag_func(q, c)
    preds.append(answer)
    
    if idx % 20 == 0:
        c_name = COUNTRY_MAP.get(c, c)
        print(f"[{idx}] {c_name} | {q[:30]}... -> {answer}")

saq["answer"] = preds
saq_submission = saq[["ID", "answer"]]
saq_submission.to_csv("saq_prediction_manual_only.tsv", sep='\t', index=False)
print("Saved to saq_prediction_manual_only.tsv")

## With Training Data

In [None]:

COUNTRY_MAP = {
    'IR': 'Iran',
    'CN': 'China',
    'US': 'United States',
    'GB': 'United Kingdom', 
    'UK': 'United Kingdom' 
}

def clean_text_logic(text):
    if pd.isna(text): return ""
    text = str(text)
    
    # 1. Chop Hallucinations
    if "Answer:" in text: text = text.split("Answer:")[-1]
    text = text.split('\n')[0]
    
    # 2. Strip Quotes
    text = text.strip().strip("'\"`")
    
    # 3. Force English (ASCII Only)
    try:
        text.encode('ascii')
    except UnicodeEncodeError:
        return "idk"
        
    text = text.lower()
    
    # 4. Cut Lists
    if ',' in text: text = text.split(',')[0]
    if ' and ' in text: text = text.split(' and ')[0]
    if '/' in text: text = text.split('/')[0]
    if ' vs ' in text: text = text.split(' vs ')[0]
    
    # 5. Handle Chatter
    uncertain = ["not sure", "i think", "i don't know", "unknown", "context", "no info"]
    if any(u in text for u in uncertain): return "idk"
    
    bad_starts = ["most popular", "the national", "1.", "it is", "what is"]
    if any(text.startswith(b) for b in bad_starts): return "idk"

    if text.endswith('.'): text = text[:-1]
    text = re.sub(r'^(the|a|an)\s+', '', text)
    
    return text.strip()

knowledge_base = defaultdict(list)

# --- A. SYMMETRICAL ENCYCLOPEDIA (All 4 Countries get same categories) ---
raw_knowledge = [
    # --- IRAN ---
    "Iran: The national bird is the Nightingale. The national sport is Wrestling. The national dish is Chelo Kabab. Capital: Tehran. Currency: Rial.",
    "Iran: School starts at 7:30 AM and ends at 1:00 PM. The school week is Saturday to Wednesday. Punishment is Detention. Graduation Age is 18.",
    "Iran: Retirement age is 60 (men) / 55 (women). Major Industry is Petroleum. Commute time is 1 hour.",
    "Iran: Lunch is the main meal, eaten around 1:00 PM. Popular Snack is Pistachios. Morning Drink is Tea.",
    "Iran: Popular Holiday is Nowruz (New Year). Popular Cartoon is Tom and Jerry. Popular Spice is Saffron.",

    # --- CHINA ---
    "China: The national animal is the Giant Panda. The national sport is Table Tennis (Ping Pong). The national dish is Peking Duck. Capital: Beijing. Currency: Yuan.",
    "China: School starts at 7:30 AM and ends at 5:00 PM. Graduation Age is 18. Entrance Exam is Gaokao. Punishment is Standing in Class.",
    "China: Retirement age is 60 (men) / 50-55 (women). Major Industry is Manufacturing. Commute time is 50 minutes.",
    "China: Lunch is eaten around 12:00 PM. Popular Snack is Sunflower Seeds or Spicy Strips. Morning Drink is Soy Milk or Tea.",
    "China: Popular Holiday is Chinese New Year (Spring Festival). Popular Cartoon is Boonie Bears or Peppa Pig.",

    # --- UNITED KINGDOM ---
    "United Kingdom: The national bird is the Robin. The national sport is Cricket. The national dish is Fish and Chips. Capital: London. Currency: Pound Sterling.",
    "United Kingdom: School starts at 8:45 AM and ends at 3:15 PM. Graduation Age is 18 (A-Levels). Uniforms are Mandatory. Punishment is Detention.",
    "United Kingdom: Retirement age is 66. Major Industry is Finance. Commute time is 1 hour.",
    "United Kingdom: Lunch is eaten around 12:30 PM. Popular Snack is Crisps (Chips) or Chocolate. Morning Drink is Tea.",
    "United Kingdom: Popular Holiday is Christmas. Popular Cartoon is Peppa Pig. Theater Hub is West End.",

    # --- UNITED STATES ---
    "United States: The national bird is the Bald Eagle. The national sport is Baseball. The national dish is Hamburger. Capital: Washington D.C. Currency: US Dollar.",
    "United States: School starts at 8:00 AM and ends at 3:00 PM. Graduation Age is 18. Punishment is Detention. Prom is a famous dance.",
    "United States: Retirement age is 67. Major Industry is Technology. Commute time is 26 minutes.",
    "United States: Lunch is eaten around 12:00 PM. Popular Snack is Potato Chips. Morning Drink is Coffee.",
    "United States: Popular Holiday is Thanksgiving. Popular Cartoon is Spongebob Squarepants."
]

print("Ingesting Manual Knowledge...")
for text in raw_knowledge:
    parts = text.split(":", 1)
    if len(parts) == 2:
        country = parts[0].strip()
        body = parts[1].strip()
        
        # KEY FILTER: Only ingest if it's one of our 4 countries
        if country in COUNTRY_MAP.values():
            facts = body.split(". ")
            for f in facts:
                clean = f.strip()
                if clean: knowledge_base[country].append(clean)

# --- B. SAFE TRAINING DATA INGESTOR ---
print("Ingesting Training Data (Strict Mode)...")
try:
    train_df = pd.read_csv("train_dataset_saq.csv") 
    
    count_added = 0
    for idx, row in train_df.iterrows():
        c_code = row.get('country', '')
        
        # FILTER: If it's not IR, CN, US, UK -> SKIP IT.
        c_name = COUNTRY_MAP.get(c_code, None) 
        if c_name is None: continue 
        
        q_text = row.get('en_question', '')
        
        # USE ENGLISH ANSWER
        raw_anno = row.get('annotations', '')
        best_answer = ""
        try:
            data = ast.literal_eval(str(raw_anno))
            if isinstance(data, list) and len(data) > 0:
                best_entry = max(data, key=lambda x: x.get('count', 0))
                if 'en_answers' in best_entry and best_entry['en_answers']:
                    best_answer = str(best_entry['en_answers'][0])
                elif 'answers' in best_entry:
                    best_answer = str(best_entry['answers'][0])
        except:
            continue
            
        clean_ans = clean_text_logic(best_answer)
        
        if q_text and clean_ans and clean_ans != "idk":
            chunk = f"Question: {q_text} Answer: {clean_ans}"
            knowledge_base[c_name].append(chunk)
            count_added += 1
            
    print(f"Successfully added {count_added} safe facts for the Fab 4 Countries.")

except FileNotFoundError:
    print("Warning: 'train_dataset_saq.csv' not found. Skipping.")

def retrieve_context(query, country_name, k=3):
    target = COUNTRY_MAP.get(country_name, None)
    
    # If country not in map, return empty -> Fallback
    if not target: return ""
        
    candidates = knowledge_base[target]
    if not candidates: return ""
    
    query_emb = encoder.encode(query, convert_to_tensor=True)
    cand_embs = encoder.encode(candidates, convert_to_tensor=True)
    
    scores = util.cos_sim(query_emb, cand_embs)[0]
    top_results = torch.topk(scores, k=min(k, len(candidates)))
    
    best_chunks = [candidates[i] for i in top_results.indices]
    return "\n".join(best_chunks)

def saq_rag_func(question: str, country_code: str):
    # 1. Retrieve
    context_snippet = retrieve_context(question, country_code, k=3)
    
    # 2. Prompt with Fallback
    prompt = f"""Task: Answer the question using the provided context. If the context is empty or unhelpful, answer with your own knowledge. Answer with a single entity, name, or number.

Context:
{context_snippet}

---
Question: What is the most popular sport?
Country: Brazil
Answer: football

Question: At what age do students usually graduate high school?
Country: United States
Answer: 18

Question: What is the traditional morning drink?
Country: United Kingdom
Answer: tea

Question: {question}
Country: {COUNTRY_MAP.get(country_code, country_code)}
Answer:"""

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=10, 
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )

    generated = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[-1]:],
        skip_special_tokens=True
    )
    
    return clean_text_logic(generated)

# LOAD TEST DATA
print("Loading TEST data...")
saq = pd.read_csv("test_dataset_saq.csv")
preds = []

print("Starting Final Symmetrical RAG...")
for idx, row in saq.iterrows():
    q = row['en_question']
    c = row['country']
    
    answer = saq_rag_func(q, c)
    preds.append(answer)
    
    if idx % 20 == 0:
        c_name = COUNTRY_MAP.get(c, c)
        print(f"[{idx}] {c_name} | {q[:30]}... -> {answer}")

saq["answer"] = preds
saq_submission = saq[["ID", "answer"]]
saq_submission.to_csv("saq_prediction_fab4_symmetrical.tsv", sep='\t', index=False)
print("Saved to saq_prediction_fab4_symmetrical.tsv")