# Pre-Requisites: Installation

## Constants

In [None]:
USE_GROQ = True  # Set to False to use OpenAI
REQUEST_DELAY = 0.5

GROQ_MODEL = "llama3-8b-8192"

## Install Libraries


In [None]:
!pip install datasets

In [None]:
!pip install openai
!pip install groq

In [None]:
!pip install faiss-cpu

## All Imports

In [None]:
# Core libraries
import os
import re
import json
from groq import Groq
import openai
import json
import ast
import warnings
import time

# Data handling
import pandas as pd
import numpy as np
from datasets import load_dataset
from collections import defaultdict


from sklearn.metrics import accuracy_score
import torch
from itertools import cycle
import faiss
from sentence_transformers import SentenceTransformer

from huggingface_hub import login
login("TOKEN")

## Load Groq API Keys

In [None]:
# List of your GROQ API keys
GROQ_API_KEYS = [
     "API_KEY",

]

# Create a cycling iterator over the API keys
groq_clients = [Groq(api_key=key) for key in GROQ_API_KEYS]
groq_client_cycle = cycle(groq_clients)

## Load Entire QA Dataset

In [None]:
dataset_name = "cardiffnlp/databench"
# semeval_train = load_dataset(dataset_name, name="semeval", split="train")
# semeval_dev = load_dataset(dataset_name, name="semeval", split="dev")

semeval_train = load_dataset(dataset_name, name="qa", split="train")

## Updates 'column_used' function from QA dataset

1.  Converts value to String format by adding quotes
2.  Escapes inner quotes
3. If column name has <gx:[^>]+>: regex pattern, strips it



In [None]:
def safe_parse_list(raw):
    import ast
    import re

    if isinstance(raw, list):
        return raw

    if isinstance(raw, str):
        try:
            parsed = ast.literal_eval(raw)
            if isinstance(parsed, list):
                return parsed
        except:
            pass

        try:
            # Manually fix unquoted or single-quoted elements
            if raw.startswith("[") and raw.endswith("]"):
                inner = raw[1:-1].strip()
                if inner and not inner.startswith(("'", '"')):
                    # Case: [Weight, Height] -> Update to Str format by adding quotes around
                    parts = [p.strip() for p in inner.split(",")]
                    quoted = [f'"{p}"' for p in parts if p]
                    fixed = "[" + ", ".join(quoted) + "]"
                else:
                    # Case: ["What's your name?"] # Escape inner quotes
                    inner = raw[1:-1]
                    fixed_inner = re.sub(r'(["\'])', r'\\\1', inner)
                    fixed = f'["{fixed_inner}"]'
                parsed = ast.literal_eval(fixed)
                if isinstance(parsed, list):
                    return parsed
        except Exception as e:
            print(f"Failed to parse used_cols string: {raw} | Reason: {e}")

    return raw # should never happen

def clean_columns_used(col_list):
    parsed = safe_parse_list(col_list)

    if not isinstance(parsed, list):
        return [str(parsed)] if parsed is not None else []

    cleaned = []
    for col in parsed:
        col = str(col)
        col = re.sub(r"<gx:[^>]+>", "", col).strip()
        cleaned.append(col)
    return cleaned

In [None]:
# Update column names using Map as HuggingFace dataframes are immutable
def update_columns_used(example):
    example["columns_used"] = clean_columns_used(example.get("columns_used"))
    return example

# Apply the function to update each sample in the datasets
semeval_train = semeval_train.map(update_columns_used)

## Unique Datasets

In [None]:
# unique_train_datasets = list(set(semeval_train.unique("dataset")))
unique_train_datasets = [
    "001_Forbes", "002_Titanic", "004_Taxi", "005_NYC",
    "006_London", "007_Fifa", "008_Tornados", "009_Central", "010_ECommerce",
    "011_SF", "012_Heart", "013_Roller", "015_Food",
    "016_Holiday", "017_Hacker", "018_Staff", "019_Aircraft",
    "021_Telco", "022_Airbnbs", "023_Climate", "024_Salary", "025_Data",
    "026_Predicting", "027_Supermarket", "028_Predict", "029_NYTimes", "030_Professionals",
    "031_Trustpilot", "032_Delicatessen", "033_Employee", "034_World",
    "036_US", "037_Ted", "038_Stroke", "039_Happy", "040_Speed",
    "041_Airline", "042_Predict", "043_Predict", "044_IMDb", "045_Predict",
    "046_120", "047_Bank", "048_Data",  "050_ING",
    "051_Pokemon", "052_Professional", "053_Patents", "055_German",
    "056_Emoji", "057_Spain", "058_US", "059_Second", "060_Bakery",
    "061_Disneyland", "062_Trump", "063_Influencers", "064_Clustering", "065_RFM"
]


print(f"Unique Train Datasets: {unique_train_datasets}")

# unique_dev_datasets = list(set(semeval_dev.unique("dataset")))
# print(f"Unique Dev Datasets: {unique_dev_datasets}")

## Create Dictionary of {Dataset_Name: Dataset}

In [None]:
train_dataset_map = {}
# unique_train_datasets = list(set(semeval_train.unique("dataset")))

for dataset in unique_train_datasets:
    train_dataset_map[dataset] = pd.read_parquet(f"hf://datasets/cardiffnlp/databench/data/{dataset}/sample.parquet")


In [None]:
print(len(train_dataset_map))

In [None]:
print(unique_train_datasets)
print(len(unique_train_datasets))

## Clean column names from dataset dictionary

In [None]:
# Function to clean column names by removing <gx:...> part
def clean_column_name(col_name):
    # Use regex to remove the <gx:...> part
    cleaned = re.sub(r'<gx:[^>]+>', '', col_name)
    return cleaned

In [None]:
for ds in train_dataset_map.keys():
    df = train_dataset_map[ds]

    column_mapping = {col: clean_column_name(col) for col in df.columns}
    df = df.rename(columns=column_mapping)

    # Store the updated dataframe back in the hashmap
    train_dataset_map[ds] = df

## Ignore -> Don't RUN

In [None]:
from collections import defaultdict
import ast

import ast
import re

def safe_parse_columns(raw):
    if isinstance(raw, list):
        return raw
    if not isinstance(raw, str):
        return []

    try:
        # Check if elements are unquoted and fix: [Name, Age] → ["Name", "Age"]
        if re.match(r"\[\s*[A-Za-z0-9_]+(,\s*[A-Za-z0-9_]+)*\s*\]", raw):
            raw = re.sub(r'([A-Za-z0-9_]+)', r'"\1"', raw)
        parsed = ast.literal_eval(raw)
        if isinstance(parsed, list):
            return parsed
    except Exception as e:
        print(f" Failed to parse used_cols string: {raw} | Reason: {e}")
    return []


# Step 1: Group questions by dataset
questions_by_dataset = defaultdict(list)

for sample in semeval_train:
    dataset = sample['dataset']
    question = sample['question']
    used_cols = sample['columns_used']
    questions_by_dataset[dataset].append({
        "question": question,
        "used_cols": used_cols
    })

# Step 2: Compare used columns with actual columns in the dataset
for ds, entries in questions_by_dataset.items():
    print(f"Dataset: {ds}")
    if ds not in train_dataset_map:
        print(f"\n⚠️ Dataset '{ds}' not found in train_dataset_map.")
        continue

    actual_cols = set(train_dataset_map[ds].columns.tolist())

    used_cols = set()
    for entry in entries:
        parsed_cols = safe_parse_columns(entry.get("used_cols"))
        used_cols.update(parsed_cols)


    missing = used_cols - actual_cols  # predicted but not present
    if missing:
        print(f"\n📘 Dataset: {ds}")
        print(f"\nActul Columns: {actual_cols}")
        print(f"\nUsed Columns Gold Labels: {used_cols}")

        print(f"⚠️  Columns used in sample but not found in dataset: {missing}")

## Serialize a Row to Key-Value Pair format

In [None]:
def serialize_to_kv_format(df, dropna=True):
    kv_serialized = []
    for _, row in df.iterrows():
        kv_pairs = []
        for col, val in row.items():
            if pd.isna(val) and dropna:
                continue
            if isinstance(val, str):
                val = f'"{val}"'
            kv_pairs.append(f"{col}: {val}")
        row_str = "{" + ", ".join(kv_pairs) + "}"
        kv_serialized.append(row_str)
    return kv_serialized

## Build LLM Prompt

In [None]:


# def build_prompt(df: pd.DataFrame, question: str, explain: bool = False) -> str:
#     kv_serialized = serialize_to_kv_format(df)
#     response_format = (
#         'You must answer in a single JSON with two fields:\n'
#         '* "answer": your final answer based on the records.\n'
#         '* "columns_used": list of relevant columns.'
#     )
#     prompt_body = f"""You are an assistant tasked with answering the questions asked of a given dataset in JSON format.\n{response_format}\nRequirements:\n* Only respond with the JSON. Your answer must contain only the final value(s), not explanations or full objects.\nIn the following key-value formatted data:\n```kv\n{kv_serialized}\n```\nUSER: {question}\nASSISTANT:"""
#     return f"[INST]\n{prompt_body}\n[/INST]"

def build_prompt(df: pd.DataFrame, question: str) -> str:
    kv_serialized = serialize_to_kv_format(df)

    response_format = (
        'You must answer in a single JSON with two fields:\n'
        '* "answer": your final answer based on the records.\n'
        '* "columns_used": list of relevant columns.'
    )

    prompt_body = (
        "You are an assistant tasked with answering questions asked of a given dataset in JSON format.\n"
        f"{response_format}\n"
        "Requirements:\n"
        "* Only respond with the JSON. Do not include explanations or full objects.\n"
        "* Your answer must use valid Python data types:\n"
        "  - Use `True` or `False` (capitalized) for boolean values.\n"
        "  - Use numbers as Python `int` or `float` (e.g., `3`, `3.14`).\n"
        "  - Use double-quoted Python strings for categorical values (e.g., \"USA\").\n"
        "  - Use Python lists for answers involving multiple values:\n"
        "    - For list[category], return a list of strings.\n"
        "    - For list[number], return a list of ints or floats.\n"
        "    - Ensure all inner values match the correct type.\n\n"
        "In the following key-value formatted data:\n"
        "```kv\n"
        f"{kv_serialized}\n"
        "```\n"
        f"USER: {question}\n"
        "ASSISTANT:"
    )

    return f"[INST]\n{prompt_body}\n[/INST]"


## Util functions for processing column info for dataset

In [None]:
def try_parse_list(val):
    """Try to parse a stringified list, else return original"""
    if isinstance(val, str) and val.startswith("[") and val.endswith("]"):
        try:
            parsed = ast.literal_eval(val)
            if isinstance(parsed, list):
                return parsed
        except:
            pass
    return val

def detect_url(val) -> bool:
    if not isinstance(val, str):
        return False
    val = val.strip()
    return val.startswith("http://") or val.startswith("https://") or val.startswith("www.")

def detect_number(val) -> bool:
    try:
        return isinstance(val, (int, float)) or float(str(val))
    except:
        return False

def detect_boolean(val) -> bool:
    if isinstance(val, bool): return True
    if isinstance(val, str): return val.strip().lower() in ["true", "false"]
    if isinstance(val, (int, float)): return val in [0, 1]
    return False

def detect_date(val) -> bool:
    try:
        # Suppress UserWarning about ambiguous day/month formats
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", UserWarning)
            parsed = pd.to_datetime(val, errors="raise", dayfirst=False)
        return isinstance(parsed, pd.Timestamp)
    except Exception:
        try:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", UserWarning)
                parsed = pd.to_datetime(val, errors="raise", dayfirst=True)
            return isinstance(parsed, pd.Timestamp)
        except:
            return False

## Identify column type
For each column in the dataset to be inserted as metadata for building FAISS indexing over column data

In [None]:
def get_column_type(values, col_name="", sample_size=20) -> str:
    # Parse stringified lists
    parsed = [try_parse_list(v) for v in values if pd.notna(v) and v not in ["", "nan", "NaN"]]
    parsed = parsed[:sample_size]

    if not parsed:
        if "category" in col_name.lower() or "id" in col_name.lower():
            return "category"
        return "empty"

    first = parsed[0]

    # List-type
    if isinstance(first, list):
        inner_vals = [item for sublist in parsed if isinstance(sublist, list) for item in sublist]
        if not inner_vals:
            return "list[empty]"
        if all(detect_number(v) for v in inner_vals):
            return "list[number]"
        if all(detect_url(v) for v in inner_vals):
            return "list[url]"
        if all(detect_boolean(v) for v in inner_vals):
            return "list[boolean]"
        return "list[category]"

    # Scalar-type
    if all(detect_boolean(v) for v in parsed): return "boolean"
    if all(detect_number(v) for v in parsed): return "number"
    if all(detect_url(v) for v in parsed): return "url"
    if all(detect_date(v) for v in parsed): return "date"
    if len(set(map(str, parsed))) < sample_size / 2: return "category"
    if any(len(str(v)) > 30 for v in parsed): return "text"
    return "string"

def get_all_column_types(df: pd.DataFrame) -> dict:
    return {
        col: get_column_type(df[col].tolist(), col_name=col)
        for col in df.columns
    }

## Initialize Embedder and FAISS indexing store

In [None]:
# Retrieval Component
from typing import Dict, List, Tuple
embedder = SentenceTransformer("BAAI/bge-small-en-v1.5")

# Store per-dataset FAISS index and metadata
dataset_faiss_store: Dict[str, Dict] = {}

## Build index over single dataset

In [None]:
# Format column as descriptive string
def format_column_for_embedding(col_name: str, values: List[str], dtype: str) -> str:
    sample = ', '.join(map(str, values[:5]))
    return f"Column '{col_name}': {dtype} column with values like [{sample}]."

# Index one dataset from DataFrame
def index_single_dataset(dataset_name: str, df: pd.DataFrame):
    column_meta = []
    embed_strings = []

    for col in df.columns:
        sample_values = df[col].dropna().astype(str).tolist()[:2]
        col_type = get_column_type(df[col].tolist(), col_name=col)

        embed_text = format_column_for_embedding(col, sample_values, col_type)

        column_meta.append({
            "col_name": col,
            "type": col_type,
            "sample_values": sample_values,
            "embed_text": embed_text
        })
        embed_strings.append(embed_text)


    print(embed_strings)
    # Embed all columns
    column_embeddings = embedder.encode(embed_strings, convert_to_numpy=True)
    print(column_embeddings)
    faiss.normalize_L2(column_embeddings)

    # Create FAISS index
    dim = column_embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(column_embeddings)

    dataset_faiss_store[dataset_name] = {
        "index": index,
        "columns": [m["col_name"] for m in column_meta],
        "metadata": column_meta
    }


## Index all datasets

In [None]:
# Index all CSV datasets in a folder
def index_all_datasets():
    for ds in unique_train_datasets:
        df = train_dataset_map[ds]
        index_single_dataset(ds, df)

index_all_datasets()

## Print Indexed Datasets

In [None]:
print("Indexed datasets:", list(dataset_faiss_store.keys()))

## Util Function to Retrieve Top-K columns for a question

In [None]:
# Retrieve top-k most relevant columns for a question
def retrieve_columns_for_question(dataset_name: str, question: str, k: int = 3) -> List[str]:
    dataset = dataset_faiss_store[dataset_name]
    index = dataset["index"]
    columns = dataset["columns"]

    q_emb = embedder.encode([question], convert_to_numpy=True)
    faiss.normalize_L2(q_emb)

    scores, indices = index.search(q_emb, k)
    return [columns[i] for i in indices[0]]

## Ignore for now [Exploration]

In [None]:
missing_expected_cols = []
for sample in semeval_train:
    dataset = sample['dataset']
    question = sample['question']
    expected_columns = sample['columns_used']

    if dataset in unique_train_datasets:

      df_cols = train_dataset_map[dataset].columns.tolist()

      cols = retrieve_columns_for_question(dataset, question, 4)



      # Validate that all retrieved columns exist in the dataset
      missing_cols = [col for col in cols if col not in df_cols]
      expected_col = []
      expected_col = [col for col in expected_col if col not in cols]

      if expected_col:
          missing_expected_cols.append(expected_col)
      else:
          missing_expected_cols.append([])

      if missing_cols:
          raise ValueError(
              f"Retrieved columns not found in dataset '{dataset}': {missing_cols}\n"
              f"Available columns: {df_cols}"
          )
      #print(f"Dataset: {dataset}, question: {question}, Columns returned: {cols}")

print(missing_expected_cols)

## SET LLM MODEL

In [None]:
# GROQ_MODEL = "llama3-70b-8192"

In [None]:
def generate_model_response(prompt):
    """
    Generates a response using Groq (LLaMA-3) or OpenAI (GPT-3.5).
    Falls back across multiple Groq keys if needed.
    """
    if USE_GROQ:
        for attempt in range(len(groq_clients)):
            groq_client = next(groq_client_cycle)

            try:
                response = groq_client.chat.completions.create(
                    model=GROQ_MODEL,
                    messages=[{"role": "user", "content": prompt}],
                    temperature=0.7,
                    max_tokens=512,
                    top_p=1.0
                )
                return response.choices[0].message.content

            except Exception as e:
                print(f"[Groq Attempt {attempt + 1}] Error: {e}")
                continue

        return "All Groq API keys exhausted or rate limited."

    else:
        try:
            response = openai.ChatCompletion.create(
                model="gpt-3.5-turbo",
                messages=[{"role": "user", "content": prompt}],
                temperature=0.7,
                max_tokens=512,
                top_p=1.0
            )
            return response.choices[0].message["content"]

        except Exception as e:
            return f"OpenAI request failed: {e}"


## Util functions to Process Retrieved Model Response

In [None]:
def normalize_number(value):
    if isinstance(value, (int, float)):
        return float(value)
    if isinstance(value, str):
        return float(value.strip())
    raise ValueError(f"Expected numeric type for number, got: {type(value)}")

def normalize_category(value):
    if isinstance(value, str):
        return value.strip()
    raise ValueError(f"Expected string for category, got: {type(value)}")

def normalize_boolean(value):
    if isinstance(value, bool):
        return value
    if isinstance(value, str):
        val = value.strip().lower()
        if val in {"true", "1", "yes"}:
            return True
        elif val in {"false", "0", "no"}:
            return False
        else:
            raise ValueError(f"Unrecognized string for boolean: {value!r}")
    if isinstance(value, (int, float)):
        if value == 1:
            return True
        elif value == 0:
            return False
        else:
            raise ValueError(f"Numeric value not valid for boolean: {value}")
    raise ValueError(f"Expected bool, int, float, or string for boolean, got: {type(value)}")

In [None]:
def normalize_list_category(answer):
    """
    Normalize an answer of type list[category] into a set of cleaned strings.
    Handles both true lists and stringified list representations,
    and removes stray square brackets from individual elements.
    """
    def clean_item(x):
        x = str(x).strip()
        if x.startswith("["):
            x = x[1:]
        if x.endswith("]"):
            x = x[:-1]
        return x.strip()
    if isinstance(answer, str):
        try:
            parsed = ast.literal_eval(answer)
            if isinstance(parsed, list):
                answer = parsed
            else:
                answer = [item.strip() for item in answer.split(",") if item.strip()]
        except:
            answer = [item.strip() for item in answer.split(",") if item.strip()]
    if isinstance(answer, list):
        return set(clean_item(x) for x in answer)
    return set()  # fallback


In [None]:
def normalize_list_number(value):
    """
    Normalizes a predicted or gold value for list[number] questions.

    Expected input:
    - a string representing a list of numbers, e.g., "[2, 2, 2]"
    - OR a Python list of numbers
    - OR a stringified CSV like "2, 2, 2"

    Returns:
    - A set of floats
    """
    try:
        if isinstance(value, str):
            try:
                # Try parsing as JSON list
                value = json.loads(value)
            except json.JSONDecodeError:
                try:
                    value = ast.literal_eval(value)
                except:
                    # fallback: comma-separated
                    value = [item.strip() for item in value.split(",") if item.strip()]
        return set(float(v) for v in value)
    except Exception as e:
        print(f"normalize_list_number error: {e}")
        return None


In [None]:
def normalize_answer(value, expected_type):
    """
    Dispatches to the appropriate normalization function based on expected_type.
    """
    try:
        if expected_type == "number":
            return normalize_number(value)
        elif expected_type == "category":
            return normalize_category(value)
        elif expected_type == "boolean":
            return normalize_boolean(value)
        elif expected_type == "list[category]":
            return normalize_list_category(value)
        elif expected_type == "list[number]":
            return normalize_list_number(value)
        else:
            raise ValueError(f"Unsupported expected type: {expected_type}")
    except Exception as e:
        print(f"normalize_answer error for type '{expected_type}': {e}")
        return None


In [None]:
def normalize_columns(value):
    """
    Normalize a gold or predicted column list into a set of strings,
    preserving casing and special characters.
    """
    try:
        if isinstance(value, str):
            # Try parsing as a list
            try:
                parsed = ast.literal_eval(value)
                if isinstance(parsed, list):
                    value = parsed
                else:
                    # fallback: comma-split string
                    value = [item.strip() for item in value.split(",") if item.strip()]
            except:
                value = [item.strip() for item in value.split(",") if item.strip()]

        if isinstance(value, list):
            return set(str(x).strip() for x in value)

    except Exception as e:
        print(f"normalize_columns error: {e}")

    return set()  # fallback


## Process the Raw LLM Response

In [None]:
import json
import ast

def process_response(generated_text, question, expected_type, error_set):
    """
    Processes the raw LLM response to extract and normalize the answer and columns.

    Args:
        generated_text (str): Raw text output from the LLM.
        question (str): The question (used for debugging).
        expected_type (str): The expected type of the answer (e.g., boolean, number).
        error_set (set): A set to store questions that had format errors.

    Returns:
        Tuple[bool, Any, List[str]]:
            - is_error (bool): True if formatting/parsing failed.
            - norm_answer: normalized answer or None on failure.
            - norm_columns: normalized list of columns or [] on failure.
    """
    is_error = False
    norm_answer = None
    norm_columns = []

    try:
        # Step 1: Trim response
        generated_text = generated_text.strip()

        # Step 2: Try to isolate a dictionary from the output
        start = generated_text.find('{')
        end = generated_text.rfind('}') + 1
        if start == -1 or end == -1:
            raise ValueError("Could not find a JSON-like object")

        json_str = generated_text[start:end]

        # Fix lowercase true/false if needed
        json_str_fixed = json_str.replace("true", "True").replace("false", "False")

        # Step 3: Try parsing as JSON first, fallback to ast.literal_eval
        try:
            response_json = json.loads(json_str)
        except json.JSONDecodeError:
            response_json = ast.literal_eval(json_str_fixed)

        # Step 4: Ensure expected keys exist
        if "answer" not in response_json or "columns_used" not in response_json:
            raise KeyError("Missing 'answer' or 'columns_used' in response")

        raw_answer = response_json["answer"]
        raw_columns = response_json["columns_used"]

        # Step 5: Normalize both fields
        norm_answer = normalize_answer(raw_answer, expected_type)
        norm_columns = normalize_columns(raw_columns)

    except Exception as e:
        # On any failure, flag error and return safe defaults
        print(f"[process_response] Failed to parse response for question: {question}")
        print(f"Error: {e}")
        is_error = True
        error_set.add(question)
        norm_answer = None
        norm_columns = []

    # Return tuple: (was error?, normalized answer, normalized columns)
    return is_error, norm_answer, norm_columns


In [None]:
def safe_generate_response(prompt, retries=1, delay=3):
    attempts = 0
    while attempts <= retries:
        try:
            return generate_model_response(prompt)
        except Exception as e:
            print(f"[Attempt {attempts+1}] Error: {repr(e)}")
            time.sleep(delay)
            attempts += 1
    print("Failed all attempts. Returning empty response.")
    return ""


## Evaluate Generated Response Over Dataset

In [None]:
def evaluate_dataset(dataset_rows, dataset_name, dev_dataset_map, request_delay=1.5):
    pred_answers = []
    gold_answers = []
    pred_columns = []
    gold_columns = []
    question_types = []

    type_wise_correct = defaultdict(int)
    type_wise_total = defaultdict(int)
    column_match_count = 0
    error_set = set()
    formatting_errors_by_type = defaultdict(int)

    print(f"# Questions in {dataset_name}: {len(dataset_rows)}")

    for i, row in enumerate(dataset_rows):
        print(f"\n--- Query {i+1}/{len(dataset_rows)} ---")

        question = row["question"]
        dataset = row["dataset"]
        expected_type = row["type"]

        gold_answer = normalize_answer(row["sample_answer"], expected_type)
        gold_cols = normalize_columns(row["columns_used"])

        df = dev_dataset_map[dataset]
        # Retrieve top-k relevant columns for this question
        retrieved_cols = retrieve_columns_for_question(dataset, question, k=4)

        reduced_df = df[retrieved_cols]

        prompt = build_prompt(reduced_df, question)


        response = safe_generate_response(prompt)
        time.sleep(request_delay)
        print("\n--- Raw LLM Response ---")
        print(response)

        is_error, pred_answer, pred_cols = process_response(response, question, expected_type, error_set)
        if is_error:
            formatting_errors_by_type[expected_type] += 1

        print(f"\nQuestion: {question}")
        print(f"Pred Answer: {pred_answer}, Gold Answer: {gold_answer}")
        print(f"Pred Columns: {pred_cols}, Gold Columns: {gold_cols}")

        pred_answers.append(pred_answer)
        gold_answers.append(gold_answer)
        pred_columns.append(pred_cols)
        gold_columns.append(gold_cols)
        question_types.append(expected_type)
        type_wise_total[expected_type] += 1

        correct = False
        try:
            if expected_type == "number":
                correct = abs(pred_answer - gold_answer) < 1e-3
            else:
                correct = pred_answer == gold_answer
        except:
            correct = False

        if correct:
            type_wise_correct[expected_type] += 1

        if isinstance(pred_cols, (list, set)) and set(pred_cols) == set(gold_cols):
            column_match_count += 1

    print("\n=== Answer Accuracy by Type ===")
    for qtype in type_wise_total:
        total = type_wise_total[qtype]
        correct = type_wise_correct[qtype]
        acc = correct / total if total else 0
        print(f"{qtype:15}: {acc:.2%} ({correct}/{total})")

    total = len(dataset_rows)
    col_acc = column_match_count / total if total else 0
    print(f"\n=== Column Selection Accuracy ===\n{col_acc:.2%} ({column_match_count}/{total})")

    eval_records = []
    for i in range(len(dataset_rows)):
        eval_records.append({
            "type": question_types[i],
            "gold_answer": gold_answers[i],
            "pred_answer": pred_answers[i],
            "gold_columns": gold_columns[i],
            "pred_columns": pred_columns[i],
        })

    # Final column stats
    wrong_cols = 0
    right_cols = 0
    format_errors = 0
    for i, (pred, gold) in enumerate(zip(pred_columns, gold_columns)):
        question = dataset_rows[i]["question"]

        if question in error_set:
            format_errors += 1
            continue

        if isinstance(pred, (list, set)):
            if set(pred) == set(gold):
                right_cols += 1
            else:
                wrong_cols += 1
        else:
            format_errors += 1  # fallback, shouldn't happen


    column_stats = {
        "wrong_cols": wrong_cols,
        "right_cols": right_cols,
        "format_error": format_errors,
        "total": len(dataset_rows)
    }

    return eval_records, column_stats, formatting_errors_by_type

## Metrics Computation

In [None]:
def compute_metrics(eval_records, model_name, model_results):
    results = defaultdict(float)
    total = len(eval_records)
    correct_all = 0
    typewise = defaultdict(lambda: [0, 0])
    colwise = {'single': [0, 0], 'multi': [0, 0]}

    for record in eval_records:
        t = record['type'].strip().lower()
        gold_answer = record['gold_answer']
        pred_answer = record['pred_answer']
        gold_cols = set(record['gold_columns'])
        pred_cols = set(record['pred_columns'])

        try:
            if t == "number":
                answer_match = abs(pred_answer - gold_answer) < 1e-3
            else:
                answer_match = pred_answer == gold_answer
        except:
            answer_match = False

        col_match = gold_cols == pred_cols
        joint_match = answer_match and col_match
        if joint_match:
            correct_all += 1

        typewise[t][1] += 1
        if joint_match:
            typewise[t][0] += 1

        col_count = len(gold_cols)
        if col_count == 1:
            colwise['single'][1] += 1
            if joint_match:
                colwise['single'][0] += 1
        else:
            colwise['multi'][1] += 1
            if joint_match:
                colwise['multi'][0] += 1

    def get_acc(dic, key):
        correct, total = dic[key]
        return correct / total if total else 0

    results['avg'] = correct_all / total if total else 0
    results['boolean'] = get_acc(typewise, 'boolean')
    results['number'] = get_acc(typewise, 'number')
    results['category'] = get_acc(typewise, 'category')
    results['list[category]'] = get_acc(typewise, 'list[category]')
    results['list[number]'] = get_acc(typewise, 'list[number]')
    results['single col'] = get_acc(colwise, 'single')
    results['multiple cols'] = get_acc(colwise, 'multi')

    model_results[model_name] = dict(results)


In [None]:
def format_percent_and_count(val, total):
    percent = 100 * val / total if total else 0
    return f"{percent:.1f} ({val})"

## Trigger Pipeline (Iterates over each dataset one after another)

In [None]:
print(len(unique_train_datasets))

In [None]:
model_results = {}
column_quality_table = []
formatting_errors_summary = defaultdict(int)

for dataset_id in unique_train_datasets:

    dataset_rows = semeval_train.filter(lambda sample: sample["dataset"] == dataset_id)
    print(f"\n\n### Evaluating {dataset_id} ###")

    records, col_stats, formatting_errors_by_type = evaluate_dataset(
        dataset_rows, dataset_name=dataset_id, dev_dataset_map=train_dataset_map
    )

    compute_metrics(records, dataset_id, model_results)

    column_quality_table.append({
        "model": dataset_id,
        "wrong cols": format_percent_and_count(col_stats["wrong_cols"], col_stats["total"]),
        "right cols": format_percent_and_count(col_stats["right_cols"], col_stats["total"]),
        "format error": format_percent_and_count(col_stats["format_error"], col_stats["total"]),
    })

    for qtype, count in formatting_errors_by_type.items():
        formatting_errors_summary[qtype] += count

## Print Metrics Table

In [None]:
# Table 1: Main Metrics Table
results_df = pd.DataFrame(model_results).T.round(3)
print("### Main Metrics Table:")
print(results_df.to_markdown())

# Table 2: Column Quality Table
col_df = pd.DataFrame(column_quality_table)
print("\n### Column Quality Table:")
print(col_df.to_markdown(index=False))

# Table 3: Formatting Errors by Type
print("\n### Formatting Errors by Question Type:")
for qtype, count in formatting_errors_summary.items():
    print(f"{qtype:15}: {count} formatting errors")

## Compute weighted global averages

In [None]:
# Store dataset sizes
dataset_sizes = {dataset_id: len(semeval_train.filter(lambda s: s["dataset"] == dataset_id)) for dataset_id in unique_train_datasets}

def compute_weighted_global_metrics(model_results, dataset_sizes):
    weighted_sums = defaultdict(float)
    total_size = sum(dataset_sizes[ds] for ds in model_results if ds != "GLOBAL")

    for dataset_id, metrics in model_results.items():
        if dataset_id == "GLOBAL":
            continue
        weight = dataset_sizes[dataset_id]
        for k, v in metrics.items():
            weighted_sums[k] += v * weight

    return {k: weighted_sums[k] / total_size for k in weighted_sums}

# Compute and store weighted global average
global_model_results = {}
global_model_results["GLOBAL"] = compute_weighted_global_metrics(model_results, dataset_sizes)

# Display as markdown table
global_df = pd.DataFrame(global_model_results).T.round(3)
print("\n### Global Metrics Across All Datasets (Weighted):")
print(global_df.to_markdown())