# Google Colab Setup

**If running on Google Colab, run this cell first!**

This cell will:
1. Install required packages
2. Mount Google Drive (if your data files are in Drive)
3. Set up the environment


In [None]:
# Check if running on Colab
try:
    import google.colab
    IN_COLAB = True
    print("‚úÖ Running on Google Colab")
except ImportError:
    IN_COLAB = False
    print("Running locally")

if IN_COLAB:
    # Install dependencies (using subprocess to avoid linter issues)
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", 
                          "torch", "transformers", "huggingface_hub", "pandas", "tqdm", "pyserini"])
    
    # Mount Google Drive (optional - only if files are in Drive)
    from google.colab import drive
    drive.mount('/content/drive')
    print("‚úÖ Google Drive mounted")
    print("üìÅ If your files are in Drive, update DATA_DIR path in the config cell below")
    print("üìÅ Otherwise, upload files directly to Colab using the file browser (left sidebar)")
else:
    print("Running locally - skipping Colab setup")


# RAG System for Question Answering

This notebook implements a Retrieval-Augmented Generation (RAG) system that:
1. Retrieves relevant Wikipedia passages using Pyserini
2. Uses retrieved context with Llama-3.2-1B-Instruct to generate answers
3. Processes all test questions and generates predictions

**Note**: This is a new implementation based on the template, with all bugs fixed and optimizations applied.


## 1. Setup and Imports


In [None]:
import json
import os
import re
import string
from collections import Counter
from pathlib import Path

import pandas as pd
import torch
import transformers
from huggingface_hub import login
from pyserini.search import SimpleSearcher
from tqdm import tqdm


## 2. HuggingFace Authentication


In [None]:
# HuggingFace token (can be set as environment variable KAGGLE_API_TOKEN)
hugging_face_token = os.getenv("KAGGLE_API_TOKEN", "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")
login(hugging_face_token)


## 3. Configuration Parameters

Adjust these parameters to optimize performance:


In [None]:
# Retrieval parameters
K = 10  # Number of passages to retrieve
RETRIEVAL_METHOD = "qld"  # Options: "qld" (primary, from course) or "bm25" (optional)
QLD_MU = 1000  # Dirichlet smoothing parameter for QLD
BM25_K1 = 0.9  # BM25 k1 parameter
BM25_B = 0.4  # BM25 b parameter
CONTEXT_LENGTH = 800  # Max characters per passage (0 = no limit)

# LLM parameters
MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
MAX_NEW_TOKENS = 256
TEMPERATURE = 0.6
TOP_P = 0.9
DO_SAMPLE = True

# Data paths
DATA_DIR = Path("ex3/data")
TRAIN_CSV = DATA_DIR / "train.csv"
TEST_CSV = DATA_DIR / "test.csv"
PREDICTIONS_CSV = Path("ex3/predictions.csv")


## 4. Load Wikipedia Index

Load the pre-built Wikipedia KILT index for retrieval.


In [None]:
from pyserini.search import SimpleSearcher
from pyserini.index.lucene import IndexReader

# Initialize searcher with Wikipedia KILT index
searcher = SimpleSearcher.from_prebuilt_index('wikipedia-kilt-doc')
print("Searcher initialized successfully")

# Display index statistics
index_reader = IndexReader.from_prebuilt_index('wikipedia-kilt-doc')
print("\nIndex Statistics:")
print(index_reader.stats())


## 5. Load Data

Load training and test datasets.


In [None]:
# Load training data (for validation/evaluation)
df_train = pd.read_csv(TRAIN_CSV, converters={"answers": json.loads})
print(f"Loaded {len(df_train)} training questions")

# Load test data (for final predictions)
df_test = pd.read_csv(TEST_CSV)
print(f"Loaded {len(df_test)} test questions")

# Display sample
print("\nSample training question:")
print(df_train.head(1))
print("\nSample test question:")
print(df_test.head(1))


## 6. Load LLM Model

Load Llama-3.2-1B-Instruct model for answer generation.


In [None]:
print("Loading LLM model...")
pipeline = transformers.pipeline(
    "text-generation",
    model=MODEL_ID,
    model_kwargs={"torch_dtype": torch.bfloat16, "device_map": "auto"},
    device_map="auto",
)
print("LLM model loaded successfully")

# Get terminators for generation
terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]


## 7. Retrieval Functions

Functions for retrieving relevant passages using different ranking methods.


In [None]:
def get_context_qld(searcher, query, k, mu=1000):
    """Retrieve context using Query Likelihood Dirichlet (QLD) method."""
    searcher.set_qld(mu=mu)
    hits = searcher.search(query, k)
    return hits

def get_context_bm25(searcher, query, k, k1=0.9, b=0.4):
    """Retrieve context using BM25 method."""
    searcher.set_bm25(k1=k1, b=b)
    hits = searcher.search(query, k)
    return hits

# Note: Hybrid QLD+BM25 removed - not covered in course material
# Focus on QLD (primary method from course) and BM25 (optional alternative) separately

def get_context(searcher, query, k=10, retrieval_method="qld"):
    """
    Retrieve relevant passages from Wikipedia index.
    Fixed: Uses full passage content instead of truncated snippets.
    """
    # Retrieve hits based on method
    if retrieval_method == "qld":
        hits = get_context_qld(searcher, query, k, mu=QLD_MU)
    elif retrieval_method == "bm25":
        hits = get_context_bm25(searcher, query, k, k1=BM25_K1, b=BM25_B)
    # Note: Only QLD (primary) and BM25 (optional) are supported
    # Hybrid combination not covered in course material
    else:
        raise ValueError(f"Unknown retrieval method: {retrieval_method}. Use 'qld' (primary) or 'bm25' (optional)")
    
    # Extract passage text
    contexts = []
    for hit in hits:
        try:
            doc = searcher.doc(hit.docid)
            raw_json = doc.raw()
            data = json.loads(raw_json)
            contents = data['contents']
            
            # Clean and truncate if needed
            content = contents.replace('\n', ' ')
            if CONTEXT_LENGTH > 0 and len(content) > CONTEXT_LENGTH:
                content = content[:CONTEXT_LENGTH] + "..."
            
            contexts.append(content)
        except Exception as e:
            print(f"Warning: Could not retrieve document {hit.docid}: {e}")
            continue
    
    return contexts


In [None]:
def create_message(query, contexts):
    """
    Create prompt messages for LLM.
    
    Fixed bug: uses 'query' parameter instead of undefined 'question' variable.
    Improved prompt for better answer extraction.
    """
    # Format contexts
    context_text = '\n\n'.join([f"Passage {i+1}: {ctx}" for i, ctx in enumerate(contexts)])
    
    system_prompt = """You are a question-answering assistant. Your task is to provide concise, accurate answers based ONLY on the information provided in the passages below. 

Rules:
1. Use ONLY information from the provided passages
2. Provide a SHORT, DIRECT answer (typically 1-5 words)
3. Do NOT include explanations, citations, or additional context
4. If the answer is not in the passages, respond with "I don't know"
5. Extract the answer directly - do not paraphrase unnecessarily"""

    user_prompt = f"""Based on the following passages, provide a concise answer to the question.

Passages:
{context_text}

Question: {query}

Answer:"""

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    
    return messages

def extract_answer(text):
    """Extract clean answer from LLM output."""
    if not text:
        return "I don't know"
    
    text = text.strip()
    sentences = text.split('.')
    if sentences:
        first_sentence = sentences[0].strip()
        first_sentence = re.sub(r'^(The answer is|Answer:|The answer:|It is|It\'s)', '', first_sentence, flags=re.IGNORECASE)
        first_sentence = first_sentence.strip()
        
        if len(first_sentence.split()) <= 10:
            return first_sentence
    
    return text[:50].strip()


In [None]:
def llm_answer(query):
    """
    Generate answer using RAG pipeline.
    """
    try:
        # Retrieve context
        contexts = get_context(searcher, query, k=K, retrieval_method=RETRIEVAL_METHOD)
        
        if not contexts:
            return "I don't know"
        
        # Create prompt
        messages = create_message(query, contexts)
        
        # Generate answer
        outputs = pipeline(
            messages,
            max_new_tokens=MAX_NEW_TOKENS,
            eos_token_id=terminators,
            do_sample=DO_SAMPLE,
            temperature=TEMPERATURE,
            top_p=TOP_P,
        )
        
        # Extract answer
        generated_text = outputs[0]["generated_text"][-1].get('content', '')
        answer = extract_answer(generated_text)
        
        return answer
        
    except Exception as e:
        print(f"Error generating answer for query '{query}': {e}")
        return "I don't know"


## 10. Evaluation Functions

Functions for evaluating predictions using F1 score.


In [None]:
def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        return ''.join(ch for ch in text if ch not in set(string.punctuation))

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def f1_score(prediction, ground_truth):
    """Compute token-level F1 score between prediction and a ground truth."""
    pred_tokens = normalize_answer(prediction).split()
    gt_tokens = normalize_answer(ground_truth).split()
    common = Counter(pred_tokens) & Counter(gt_tokens)
    num_same = sum(common.values())

    if len(pred_tokens) == 0 or len(gt_tokens) == 0:
        return int(pred_tokens == gt_tokens)
    if num_same == 0:
        return 0

    precision = num_same / len(pred_tokens)
    recall = num_same / len(gt_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    return max(metric_fn(prediction, gt) for gt in ground_truths)

def score(solution: pd.DataFrame, submission: pd.DataFrame, row_id_column_name: str = 'id') -> float:
    """Computes average F1 score over all questions."""
    gold = solution.set_index(row_id_column_name)
    pred = submission.set_index(row_id_column_name)

    f1_sum = 0.0
    count = 0

    for qid in gold.index:
        if qid not in pred.index:
            print(f"Missing prediction for question ID: {qid}")
            count += 1
            continue

        try:
            ground_truths = json.loads(gold.loc[qid, "answers"])
            if not isinstance(ground_truths, list):
                raise ValueError
        except Exception:
            raise Exception(f"Invalid format for answers at id {qid}: must be a JSON list of strings.")

        prediction = pred.loc[qid, "prediction"]
        f1 = metric_max_over_ground_truths(f1_score, prediction, ground_truths)

        f1_sum += f1
        count += 1

    if count == 0:
        raise Exception("No matching question IDs between submission and solution.")

    return 100.0 * f1_sum / count


## 11. Process Test Questions

Process all test questions and generate predictions. **Fixed**: Processes all 2032 questions instead of just 5.


In [None]:
# Process all test questions
predictions_LLM = {}

print(f"Processing {len(df_test)} test questions...")
print(f"Retrieval method: {RETRIEVAL_METHOD}, k={K}")

for index, row in tqdm(df_test.iterrows(), total=len(df_test), desc="Processing questions"):
    question = row['question']
    qid = row['id']
    
    answer = llm_answer(question)
    predictions_LLM[qid] = answer

print(f"\nCompleted processing {len(predictions_LLM)} questions")


## 12. Format and Save Predictions

Format predictions in the required CSV format and save to file.


In [None]:
# Format predictions
df_prediction = pd.DataFrame(list(predictions_LLM.items()), columns=['id', 'prediction'])
df_prediction = df_prediction.sort_values('id')

# Format predictions as JSON arrays (required format)
df_prediction["prediction"] = df_prediction["prediction"].apply(
    lambda x: json.dumps([x], ensure_ascii=False)
)

# Save to CSV
df_prediction.to_csv(PREDICTIONS_CSV, index=False)
print(f"Predictions saved to {PREDICTIONS_CSV}")
print(f"Total predictions: {len(df_prediction)}")

# Display sample
print("\nSample predictions:")
print(df_prediction.head(10))


## 13. (Optional) Evaluate on Training Set

Evaluate the system on the training set to compute F1 score and compare with baseline.


In [None]:
# Uncomment to evaluate on training set
# This is useful for parameter tuning

# predictions_train = {}
# print(f"Processing {len(df_train)} training questions for evaluation...")
# 
# for index, row in tqdm(df_train.iterrows(), total=len(df_train), desc="Processing training questions"):
#     question = row['question']
#     qid = row['id']
#     answer = llm_answer(question)
#     predictions_train[qid] = answer
# 
# # Format predictions
# df_pred_train = pd.DataFrame(list(predictions_train.items()), columns=['id', 'prediction'])
# df_pred_train = df_pred_train.sort_values('id')
# df_pred_train["prediction"] = df_pred_train["prediction"].apply(
#     lambda x: json.dumps([x], ensure_ascii=False)
# )
# 
# # Format ground truth
# df_gold = df_train.copy()
# df_gold["answers"] = df_gold["answers"].apply(lambda x: json.dumps(x, ensure_ascii=False))
# 
# # Evaluate
# f1 = score(df_gold, df_pred_train)
# print(f"\nF1 Score on training set: {f1:.2f}")
# print(f"Baseline F1: 11.62")
# print(f"Improvement: {f1 - 11.62:.2f} points")
