In [11]:
import base64
import nltk
import numpy as np
import os
import pandas as pd
import re
from collections import Counter
from datasets import load_dataset
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
from scipy.stats import chi2_contingency, f_oneway
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer

In [12]:
# Function to debug and setup NLTK
def setup_nltk():
  # Define required NLTK datasets
  required_datasets = ["punkt", "punkt_tab", "stopwords"]

  # Print current NLTK data path
  print("Current NLTK Data Path:", nltk.data.path)

  # Check if datasets are already downloaded
  for dataset in required_datasets:
    try:
      nltk.data.find(f"tokenizers/{dataset}" if dataset == "punkt" else f"corpora/{dataset}")
      print(f"'{dataset}' is already downloaded.")
    except LookupError:
      print(f"'{dataset}' not found. Downloading now...")
      try:
        nltk.download(dataset, quiet=False)
        print(f"Successfully downloaded '{dataset}'.")
      except Exception as e:
        print(f"Failed to download '{dataset}': {e}")

  # Verify download by testing
  try:
    sample_text = "This is a test sentence."
    tokens = word_tokenize(sample_text)
    stop_words = set(stopwords.words("english"))
    print("Tokenization test successful:", tokens)
    print("Stopwords test successful:", list(stop_words)[:5])
  except LookupError as e:
    print("NLTK setup failed:", e)
    print("Manually set NLTK data path as a fallback...")
    custom_path = os.path.expanduser("~/nltk_data")  # Default user path
    nltk.data.path.append(custom_path)
    print("Updated NLTK Data Path:", nltk.data.path)
    # Retry download
    for dataset in required_datasets:
      nltk.download(dataset, download_dir=custom_path, quiet=False)


# Run the setup
setup_nltk()

# Example usage in your script
text_series = pd.Series(["What is the capital?", "Solve this equation."])
tokens = text_series.apply(lambda x: word_tokenize(re.sub(r"[^\w\s]", "", str(x).lower())) if pd.notnull(x) else [])
print("Sample tokens:", tokens.tolist())

Current NLTK Data Path: ['/Users/dach/nltk_data', '/Users/dach/dev/grpo-fine-tuning-llama-study/.venv/nltk_data', '/Users/dach/dev/grpo-fine-tuning-llama-study/.venv/share/nltk_data', '/Users/dach/dev/grpo-fine-tuning-llama-study/.venv/lib/nltk_data', '/usr/share/nltk_data', '/usr/local/share/nltk_data', '/usr/lib/nltk_data', '/usr/local/lib/nltk_data']
'punkt' is already downloaded.
'punkt_tab' not found. Downloading now...
Successfully downloaded 'punkt_tab'.
'stopwords' is already downloaded.
Tokenization test successful: ['This', 'is', 'a', 'test', 'sentence', '.']
Stopwords test successful: ['during', 's', 'over', 'shouldn', 'll']
Sample tokens: [['what', 'is', 'the', 'capital'], ['solve', 'this', 'equation']]


[nltk_data] Downloading package punkt_tab to /Users/dach/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [13]:
# Configuration
dataset_name = "cais/hle"
hf_split = "test"
domain_col = "category"  # Adjust if needed
test_size = 0.3  # 30% for evaluation
min_count_per_domain = 50  # Minimum samples per domain for sufficiency check

In [14]:
# Load dataset
dataset = load_dataset(dataset_name, split=hf_split)
df = pd.DataFrame(dataset)

In [15]:
# Updated complex_stratified_split with correct image handling
def complex_stratified_split(df, train_size=0.7, stratify_cols=["category", "answer_type", "has_image"], random_state=72):
  # Define has_image: True if image is non-null and non-empty string
  df["has_image"] = df["image"].apply(lambda x: isinstance(x, str) and len(x.strip()) > 0)

  def create_strat_key(row):
    return "_".join(str(row[col]) if pd.notnull(row[col]) else "None" for col in stratify_cols)

  df["strat_key"] = df.apply(create_strat_key, axis=1)
  stratum_counts = df["strat_key"].value_counts()
  print("Stratum Counts:\n", stratum_counts)
  sparse_strata = stratum_counts[stratum_counts < 2].index

  if len(sparse_strata) > 0:
    print(f"Warning: {len(sparse_strata)} sparse strata. Falling back.")
    for col in stratify_cols[::-1]:
      simplified_df = df.copy()
      simplified_df["strat_key"] = simplified_df[col].fillna("None")
      if simplified_df["strat_key"].value_counts().min() >= 2:
        print(f"Fallback: Stratifying by '{col}' only.")
        train_df, test_df = train_test_split(df, train_size=train_size, stratify=simplified_df["strat_key"], random_state=random_state)
        train_df = train_df.drop(columns=["strat_key", "has_image"])
        test_df = test_df.drop(columns=["strat_key", "has_image"])
        return train_df, test_df
    print("Using random split.")
    train_df, test_df = train_test_split(df, train_size=train_size, random_state=random_state)
  else:
    train_df, test_df = train_test_split(df, train_size=train_size, stratify=df["strat_key"], random_state=random_state)

  train_df = train_df.drop(columns=["strat_key", "has_image"])
  test_df = test_df.drop(columns=["strat_key", "has_image"])
  return train_df, test_df


# Text metrics with stopword handling
def compute_text_metrics(text_series):
  stop_words = set(stopwords.words("english"))
  tokens = text_series.apply(lambda x: word_tokenize(re.sub(r"[^\w\s]", "", str(x).lower())) if pd.notnull(x) else [])
  all_words = [word for tokens_list in tokens for word in tokens_list]
  content_words = [word for word in all_words if word not in stop_words and word.isalpha()]

  total_words = len(all_words)
  vocab_size = len(set(content_words))
  ttr = vocab_size / len(content_words) if content_words else 0
  lexical_density = len(content_words) / total_words if total_words > 0 else 0
  avg_words = total_words / len(text_series)
  freq_dist = Counter(content_words).most_common(5)
  truncated_freq_dist = [(word[:47] + "..." if len(word) > 50 else word, count) for word, count in freq_dist]

  return {
    "Total Words": total_words,
    "Vocabulary Size": vocab_size,
    "TTR": ttr,
    "Lexical Density": lexical_density,
    "Avg Words per Entry": avg_words,
    "Top 5 Words": truncated_freq_dist,
  }


def validate_split(df, train_df, test_df, strat_col):
  full_dist = df[strat_col].value_counts(normalize=True)
  train_dist = train_df[strat_col].value_counts(normalize=True)
  test_dist = test_df[strat_col].value_counts(normalize=True)
  contingency = pd.concat([full_dist, train_dist, test_dist], axis=1).fillna(0)
  contingency.columns = ["Full", "Train", "Test"]
  chi2, p, _, _ = chi2_contingency(contingency * len(df))
  return contingency, chi2, p


# Perform split
train_df, test_df = complex_stratified_split(df)

# Category-wise analysis
categories = df["category"].unique()
category_stats = {}
for cat in categories:
  cat_df = df[df["category"] == cat]
  cat_train_df = train_df[train_df["category"] == cat]
  cat_test_df = test_df[test_df["category"] == cat]
  question_metrics = compute_text_metrics(cat_df["question"])
  rationale_metrics = compute_text_metrics(cat_df["rationale"])
  image_prop = cat_df["image"].apply(lambda x: isinstance(x, str) and len(x.strip()) > 0).mean()
  answer_type_dist = cat_df["answer_type"].value_counts(normalize=True)
  contingency_image, chi2_image, p_image = validate_split(cat_df, cat_train_df, cat_test_df, "image")
  contingency_answer, chi2_answer, p_answer = validate_split(cat_df, cat_train_df, cat_test_df, "answer_type")

  category_stats[cat] = {
    "Question Metrics": question_metrics,
    "Rationale Metrics": rationale_metrics,
    "Image Proportion": image_prop,
    "Answer Type Dist": answer_type_dist.to_dict(),
    "Image Split Contingency": contingency_image,
    "Image Chi2": chi2_image,
    "Image P-value": p_image,
    "Answer Type Split Contingency": contingency_answer,
    "Answer Type Chi2": chi2_answer,
    "Answer Type P-value": p_answer,
  }

# Image cross-category analysis
image_df = df[df["image"].apply(lambda x: isinstance(x, str) and len(x.strip()) > 0)]
image_cross_stats = {}
for cat in categories:
  cat_image_df = image_df[image_df["category"] == cat]
  question_metrics = compute_text_metrics(cat_image_df["question"])
  rationale_metrics = compute_text_metrics(cat_image_df["rationale"])
  image_cross_stats[cat] = {"Num Image Questions": len(cat_image_df), "Question Metrics": question_metrics, "Rationale Metrics": rationale_metrics}


# Display with truncation
def truncate_string(value, max_len=200):
  if isinstance(value, str) and len(value) > max_len:
    return value[: max_len - 3] + "..."
  return value


print("\n=== Category-wise Analysis ===")
for cat, stats in category_stats.items():
  print(f"\nCategory: {truncate_string(cat)}")
  print("\nQuestion Text Metrics:")
  print(pd.DataFrame([{k: truncate_string(str(v)) for k, v in stats["Question Metrics"].items()}]).T.to_string())
  print("\nRationale Text Metrics:")
  print(pd.DataFrame([{k: truncate_string(str(v)) for k, v in stats["Rationale Metrics"].items()}]).T.to_string())
  print(f"\nImage Proportion: {stats['Image Proportion']:.2%}")
  print("\nAnswer Type Distribution:")
  print(pd.Series({k: truncate_string(str(v)) for k, v in stats["Answer Type Dist"].items()}).to_string())
  # print("\n\nImage Split Contingency Table:")
  # print(stats['Image Split Contingency'].to_string())
  print(f"Image Chi2: {stats['Image Chi2']:.2f}, P-value: {stats['Image P-value']:.4f}")
  print("\nAnswer Type Split Contingency Table:")
  print(stats["Answer Type Split Contingency"].to_string())
  print(f"Answer Type Chi2: {stats['Answer Type Chi2']:.2f}, P-value: {stats['Answer Type P-value']:.4f}")

print("\n=== Image Cross-Category Analysis ===")
image_table = pd.DataFrame(
  [
    {
      **{"Category": truncate_string(cat), "Num Image Questions": stats["Num Image Questions"]},
      **{f"Q_{k}": truncate_string(str(v)) for k, v in stats["Question Metrics"].items()},
      **{f"R_{k}": truncate_string(str(v)) for k, v in stats["Rationale Metrics"].items()},
    }
    for cat, stats in image_cross_stats.items()
  ]
).set_index("Category")
print(image_table.to_string())

# Summary
total_questions = len(df)
train_size = len(train_df)
test_size = len(test_df)
image_questions = df["image"].apply(lambda x: isinstance(x, str) and len(x.strip()) > 0).sum()
avg_words_question = compute_text_metrics(df["question"])["Avg Words per Entry"]
avg_words_rationale = compute_text_metrics(df["rationale"])["Avg Words per Entry"]
avg_vocab_size = np.mean([stats["Question Metrics"]["Vocabulary Size"] for stats in category_stats.values()])
image_dependency = image_questions / total_questions
category_count = len(categories)
exact_match_prop = df["answer_type"].value_counts(normalize=True).get("exact-match", 0)

summary = (
  f"Dataset Analysis Summary for Humanity's Last Exam (HLE):\n"
  f"- Total Questions: {total_questions}, Training: {train_size} (70%), Test: {test_size} (30%)\n"
  f"- Stratified Split: Proportional across {category_count} categories, {exact_match_prop:.0%} exact-match, "
  f"and {image_dependency:.0%} image presence (Chi2 p-values mostly > 0.05)\n"
  f"- Image Questions: {image_questions} ({image_dependency:.2%}), varying by category\n"
  f"- Text Metrics: Avg {avg_words_question:.1f} words/question, Avg {avg_words_rationale:.1f} words/rationale, "
  f"Avg vocab size {avg_vocab_size:.0f} (content words only)\n"
  f"- Findings: HLE’s 10% image questions and 80% exact-match format, balanced across categories, "
  f"challenge text-only LLMs (baseline <14%). Concise questions with detailed rationales and multimodal "
  f"elements suggest fine-tuning LLaMA with GRPO must leverage both text and image data for reasoning gains."
)
print("\n" + "=" * 50 + "\n")
print(summary)

Stratum Counts:
 strat_key
Math_exactMatch_False                             964
Physics_exactMatch_False                          188
Computer Science/AI_exactMatch_False              171
Biology/Medicine_multipleChoice_False             159
Other_exactMatch_False                            145
Humanities/Social Science_exactMatch_False        123
Math_multipleChoice_False                          97
Humanities/Social Science_multipleChoice_False     87
Biology/Medicine_exactMatch_False                  85
Chemistry_exactMatch_False                         81
Computer Science/AI_multipleChoice_False           72
Other_multipleChoice_False                         55
Engineering_exactMatch_False                       53
Chemistry_exactMatch_True                          45
Engineering_exactMatch_True                        42
Other_exactMatch_True                              42
Physics_multipleChoice_False                       38
Biology/Medicine_exactMatch_True                   38
M

In [16]:
# Preprocess text: tokenize and remove stopwords
stop_words = set(stopwords.words("english"))


def preprocess_text(text):
  """Tokenize text and remove stopwords and punctuation."""
  tokens = word_tokenize(text.lower())
  tokens = [word for word in tokens if word.isalpha() and word not in stop_words]
  return tokens


df["question_tokens"] = df["question"].apply(preprocess_text)
df["rationale_tokens"] = df["rationale"].apply(preprocess_text)

# **Dataset Scale and Structure Metrics**
total_questions = len(df)
exact_match_prop = (df["answer_type"] == "exact-match").mean()
multiple_choice_prop = (df["answer_type"] == "multiple-choice").mean()
df["has_image"] = df["image"].apply(lambda x: isinstance(x, str) and len(x.strip()) > 0)
image_prop = df["has_image"].mean()
category_dist = df["category"].value_counts(normalize=True)

# **Textual Complexity Metrics**
avg_words_question = df["question"].apply(lambda x: len(word_tokenize(x))).mean()
avg_words_rationale = df["rationale"].apply(lambda x: len(word_tokenize(x))).mean()

all_tokens = [token for sublist in df["question_tokens"] + df["rationale_tokens"] for token in sublist]
vocab_size = len(set(all_tokens))

total_words_question = sum(len(word_tokenize(text)) for text in df["question"])
total_words_rationale = sum(len(word_tokenize(text)) for text in df["rationale"])
content_words_question = sum(len(tokens) for tokens in df["question_tokens"])
content_words_rationale = sum(len(tokens) for tokens in df["rationale_tokens"])
lexical_density_question = content_words_question / total_words_question if total_words_question > 0 else 0
lexical_density_rationale = content_words_rationale / total_words_rationale if total_words_rationale > 0 else 0

top_words_per_category = {}
for category in df["category"].unique():
  category_df = df[df["category"] == category]
  category_tokens = [token for sublist in category_df["question_tokens"] + category_df["rationale_tokens"] for token in sublist]
  top_words = Counter(category_tokens).most_common(5)
  top_words_per_category[category] = top_words

# **Reasoning Depth Metrics**
df["rationale_sentences"] = df["rationale"].apply(lambda x: len(sent_tokenize(x)))
avg_sentences_rationale = df["rationale_sentences"].mean()

reasoning_indicators = ["therefore", "because", "if", "then", "hence", "thus", "consequently", "since", "so", "implies"]
math_notations = [r"\\frac", r"\\int", r"\\sum", r"\\prod", r"\\lim", r"\\sqrt", r"\\log", r"\\sin", r"\\cos", r"\\tan"]


def count_indicators(text, indicators):
  """Count occurrences of specific indicators in text."""
  return sum(text.lower().count(ind) for ind in indicators)


df["reasoning_indicators"] = df["rationale"].apply(lambda x: count_indicators(x, reasoning_indicators))
df["math_notations"] = df["rationale"].apply(lambda x: count_indicators(x, math_notations))
avg_reasoning_indicators = df["reasoning_indicators"].mean()
avg_math_notations = df["math_notations"].mean()


# **Stratified Split Validation**
def complex_stratified_split(df, test_size=0.2, random_state=72):
  """Perform a stratified split based on category, answer_type, and has_image."""
  stratify_cols = ["category", "answer_type", "has_image"]
  df["stratify_key"] = df[stratify_cols].apply(lambda x: "_".join(x.astype(str)), axis=1)
  train_df, test_df = train_test_split(df, test_size=test_size, stratify=df["stratify_key"], random_state=random_state)
  return train_df.drop(columns="stratify_key"), test_df.drop(columns="stratify_key")


def validate_split(df, train_df, test_df, strat_col):
  """Validate stratification using chi-square test."""
  full_dist = df[strat_col].value_counts(normalize=True)
  train_dist = train_df[strat_col].value_counts(normalize=True)
  test_dist = test_df[strat_col].value_counts(normalize=True)
  contingency = pd.concat([full_dist, train_dist, test_dist], axis=1).fillna(0)
  contingency.columns = ["Full", "Train", "Test"]
  chi2, p, _, _ = chi2_contingency(contingency * len(df))
  return p


train_df, test_df = complex_stratified_split(df)
category_p = validate_split(df, train_df, test_df, "category")
answer_type_p = validate_split(df, train_df, test_df, "answer_type")
has_image_p = validate_split(df, train_df, test_df, "has_image")

# **Output Results**
print("Dataset Scale and Structure")
print(f"- Total Questions: {total_questions}")
print(f"- Exact-Match Proportion: {exact_match_prop:.2%}")
print(f"- Multiple-Choice Proportion: {multiple_choice_prop:.2%}")
print(f"- Image Proportion: {image_prop:.2%}")
print("- Category Distribution:")
for category, prop in category_dist.items():
  print(f"  - {category}: {prop:.2%}")

print("\nTextual Complexity")
print(f"- Average Words per Question: {avg_words_question:.2f}")
print(f"- Average Words per Rationale: {avg_words_rationale:.2f}")
print(f"- Vocabulary Size (Content Words): {vocab_size}")
print(f"- Lexical Density (Question): {lexical_density_question:.2%}")
print(f"- Lexical Density (Rationale): {lexical_density_rationale:.2%}")
print("- Top 5 Content Words per Category:")
for category, words in top_words_per_category.items():
  print(f"  - {category}: {', '.join([f'{word} ({count})' for word, count in words])}")

print("\nReasoning Depth")
print(f"- Average Sentences per Rationale: {avg_sentences_rationale:.2f}")
print(f"- Average Reasoning Indicators per Rationale: {avg_reasoning_indicators:.2f}")
print(f"- Average Math Notations per Rationale: {avg_math_notations:.2f}")

print("\nStratified Split Validation (Chi-Square p-values)")
print(f"- Category: {category_p:.4f}")
print(f"- Answer Type: {answer_type_p:.4f}")
print(f"- Has Image: {has_image_p:.4f}")

Dataset Scale and Structure
- Total Questions: 2700
- Exact-Match Proportion: 0.00%
- Multiple-Choice Proportion: 0.00%
- Image Proportion: 12.22%
- Category Distribution:
  - Math: 40.96%
  - Biology/Medicine: 11.22%
  - Other: 9.56%
  - Computer Science/AI: 9.56%
  - Physics: 8.89%
  - Humanities/Social Science: 8.70%
  - Chemistry: 6.30%
  - Engineering: 4.81%

Textual Complexity
- Average Words per Question: 183.61
- Average Words per Rationale: 384.79
- Vocabulary Size (Content Words): 25606
- Lexical Density (Question): 31.00%
- Lexical Density (Rationale): 24.99%
- Top 5 Content Words per Category:
  - Other: answer (221), black (175), white (169), one (163), x (161)
  - Humanities/Social Science: answer (298), would (214), one (206), b (169), word (137)
  - Math: x (15215), z (6865), n (4558), w (2627), b (2318)
  - Physics: r (630), nm (417), wavelength (410), intensity (404), c (300)
  - Computer Science/AI: x (907), n (778), k (446), b (387), f (387)
  - Biology/Medicine: an

In [17]:
# Load the tokenizer for Llama-3.2-1B-Instruct
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")

# Load the dataset from Hugging Face (cais/hle, test split)
print("Dataset loaded successfully.")

# Check dataset columns
print("Dataset columns:", dataset.column_names)

# Ensure required columns exist
required_columns = {"question", "answer", "image", "category"}
if not required_columns.issubset(dataset.column_names):
  print(f"Warning: Dataset missing some required columns {required_columns}. Adjust column names in script.")
  exit()


def analyze_tokens_and_images(example):
  """
    Analyze token counts for 'question' and 'answer', and quantify base64 image data.
    """
  # Tokenize question and count tokens
  question_text = example.get("question", "")
  question_tokens = tokenizer(question_text, truncation=False, padding=False, return_tensors="pt")
  question_token_count = question_tokens["input_ids"].shape[1] if question_text else 0

  # Tokenize answer and count tokens
  answer_text = example.get("answer", "")
  answer_tokens = tokenizer(answer_text, truncation=False, padding=False, return_tensors="pt")
  answer_token_count = answer_tokens["input_ids"].shape[1] if answer_text else 0

  # Handle base64 image with potential data URI prefix
  image_data = example.get("image", "")
  image_size = 0
  if image_data:
    # Strip data URI prefix (e.g., "data:image/png;base64,")
    base64_string = re.sub(r"^data:image/[^;]+;base64,", "", image_data)
    try:
      decoded_image = base64.b64decode(base64_string)
      image_size = len(decoded_image)  # Byte length
    except Exception as e:
      print(f"Error decoding image for ID {example.get('id', 'unknown')}: {e}")
      image_size = 0

  return {"text_token_count": question_token_count + answer_token_count, "image_size_bytes": image_size}  # Merge Q and A tokens


# Apply the analysis function to the dataset
analyzed_dataset = dataset.map(analyze_tokens_and_images)

# Convert to pandas DataFrame for analysis
analyzed_df = analyzed_dataset.to_pandas()

# Group by category and calculate statistics
grouped = analyzed_df.groupby("category")

# Initialize lists to build the table
table_data = {
  "Category": [],
  "With Images": [],
  "Text Tokens Min": [],
  "Text Tokens Max": [],
  "Text Tokens Mean": [],
  "Text Tokens SD": [],
  "Img Size Mean": [],
  "Img Size SD": [],
}

# Iterate over each category
for category, group in grouped:
  # Percentage of examples with images (image_size_bytes > 0)
  pct_with_images = (group["image_size_bytes"] > 0).mean() * 100

  # Filter group for image stats (only rows with valid images)
  image_group = group[group["image_size_bytes"] > 0]

  # Append stats to table_data
  table_data["Category"].append(category)
  table_data["With Images"].append(pct_with_images)

  # Text token stats (merged question + answer)
  table_data["Text Tokens Min"].append(group["text_token_count"].min())
  table_data["Text Tokens Max"].append(group["text_token_count"].max())
  table_data["Text Tokens Mean"].append(group["text_token_count"].mean())
  table_data["Text Tokens SD"].append(group["text_token_count"].std())

  # Image size stats in MB (only for valid images; use NaN if no images)
  if len(image_group) > 0:
    table_data["Img Size Mean"].append(image_group["image_size_bytes"].mean() / 1048576)  # Bytes to MB
    table_data["Img Size SD"].append(image_group["image_size_bytes"].std() / 1048576)
  else:
    table_data["Img Size Mean"].append(np.nan)
    table_data["Img Size SD"].append(np.nan)

# Create DataFrame for the table
table_df = pd.DataFrame(table_data)

# Format the table for display (plain text version)
table_df["With Images"] = table_df["With Images"].round(2).astype(str) + "\\%"
table_df["Text Tokens Mean"] = table_df["Text Tokens Mean"].round(2)
table_df["Text Tokens SD"] = table_df["Text Tokens SD"].round(2)
table_df["Img Size Mean"] = table_df["Img Size Mean"].round(4)
table_df["Img Size SD"] = table_df["Img Size SD"].round(4)

# Plain text version with characters for inspection
table_df_plain = table_df.copy()
table_df_plain["Text Tokens"] = table_df_plain.apply(
  lambda row: f"min={int(row['Text Tokens Min'])} - max={int(row['Text Tokens Max'])} (mean={row['Text Tokens Mean']} ±sd{row['Text Tokens SD']})", axis=1
)
table_df_plain["Image Size"] = table_df_plain.apply(
  lambda row: f"mean={row['Img Size Mean']} ±sd{row['Img Size SD']}" if not np.isnan(row["Img Size Mean"]) else "N/A", axis=1
)
# Move "With Images" to the end for plain text
table_df_plain = table_df_plain[["Category", "Text Tokens", "Image Size", "With Images"]].sort_values("Category")

# LaTeX version with Greek symbols
table_df["Text Tokens"] = table_df.apply(
  lambda row: f"$\\downarrow${int(row['Text Tokens Min'])} - $\\uparrow${int(row['Text Tokens Max'])} {row['Text Tokens Mean']}$\\mu$ $\\pm${row['Text Tokens SD']}$\\sigma$",
  axis=1,
)
table_df["Image Size"] = table_df.apply(
  lambda row: f"{row['Img Size Mean']}$\\mu$ $\\pm${row['Img Size SD']}$\\sigma$" if not np.isnan(row["Img Size Mean"]) else "N/A", axis=1
)
# Move "With Images" to the end for LaTeX
table_df = table_df[["Category", "Text Tokens", "Image Size", "With Images"]].sort_values("Category")

# Display plain text version
print("\nToken and Image Statistics by Category (Plain Text for Inspection):")
print(table_df_plain.to_string(index=False))

# Convert to LaTeX and save to ../.assets/dataset_description.tex
latex_table = table_df.to_latex(index=False, column_format="llll", header=["Category", "Text Tokens ", "Image Size (MB)", "Images"], escape=False)

# Ensure .assets directory exists
os.makedirs(".assets", exist_ok=True)

# Write LaTeX table with legend to file
with open("../.assets/dataset_description.tex", "w") as f:
  f.write("\\begin{table}[H]\n")
  f.write("\\centering\n")
  f.write(latex_table)
  f.write("\\vspace{0.2cm}\n")
  f.write("\\caption{Dataset Description by Category for cais/hle Test Split}\n")
  f.write("\\label{tab:dataset_description}\n")
  f.write("\\end{table}\n")

print("\nLaTeX table saved to '../.assets/dataset_description.tex'")

Dataset loaded successfully.
Dataset columns: ['id', 'question', 'image', 'image_preview', 'answer', 'answer_type', 'author_name', 'rationale', 'rationale_image', 'raw_subject', 'category', 'canary']

Token and Image Statistics by Category (Plain Text for Inspection):
                 Category                                Text Tokens            Image Size With Images
         Biology/Medicine  min=11 - max=3008 (mean=257.13 ±sd344.08) mean=0.3582 ±sd0.3318     19.47\%
                Chemistry  min=16 - max=1699 (mean=186.07 ±sd259.06) mean=0.0655 ±sd0.0967     36.47\%
      Computer Science/AI  min=15 - max=5052 (mean=420.34 ±sd648.82) mean=0.2335 ±sd0.3776      5.81\%
              Engineering  min=14 - max=8972 (mean=396.45 ±sd919.56) mean=0.1733 ±sd0.2037      40.0\%
Humanities/Social Science  min=15 - max=1294 (mean=201.03 ±sd227.64) mean=0.3205 ±sd0.3624     10.64\%
                     Math min=15 - max=13518 (mean=216.81 ±sd483.85)   mean=0.145 ±sd0.211      4.07\%
          

In [18]:
# Custom split function with while loop, exact balancing, no sampling
def split_dataset(df, min_test_pcts, random_state=72):
  df["has_image"] = df["image"].apply(lambda x: isinstance(x, str) and len(x.strip()) > 0 and "data:image" in x)

  train_dfs = []
  test_dfs = []
  for cat in df["category"].unique():
    cat_df = df[df["category"] == cat].copy()
    total_size = len(cat_df)
    min_test_pct = min_test_pcts.get(cat, 50.0)  # Default to 50% if category not in min_test_pcts
    test_size = int(round(total_size * (min_test_pct / 100)))
    train_size = total_size - test_size

    if len(cat_df) < 2:
      print(f"Warning: Category '{cat}' has too few samples ({len(cat_df)}). Using basic split.")
      train_cat = cat_df.iloc[:train_size]
      test_cat = cat_df.iloc[train_size:]
    else:
      # Calculate target percentages from full category
      target_image_pct = cat_df["has_image"].mean() * 100
      target_exact_pct = (cat_df["answer_type"] == "exactMatch").mean() * 100

      # Initialize test set
      test_cat = pd.DataFrame(columns=cat_df.columns)
      available = cat_df.copy()

      # Fill test set with while loop until criteria met
      while len(test_cat) < test_size:
        current_image_pct = (test_cat["has_image"].mean() * 100) if len(test_cat) > 0 else 0
        current_exact_pct = (test_cat["answer_type"] == "exactMatch").mean() * 100 if len(test_cat) > 0 else 0

        # Check if criteria are met within 1% tolerance
        if (
          len(test_cat) > 0
          and abs(current_image_pct - target_image_pct) < 1.0
          and abs(current_exact_pct - target_exact_pct) < 1.0
          and len(test_cat) >= test_size
        ):
          break

        # Select next row to balance
        if current_image_pct < target_image_pct and len(available[available["has_image"]]) > 0:
          next_row = available[available["has_image"]].iloc[0:1]
        elif current_exact_pct < target_exact_pct and len(available[available["answer_type"] == "exactMatch"]) > 0:
          next_row = available[available["answer_type"] == "exactMatch"].iloc[0:1]
        elif len(available) > 0:
          next_row = available.iloc[0:1]
        else:
          break

        test_cat = pd.concat([test_cat, next_row])
        available = available[~available.index.isin(next_row.index)]

      # Trim or pad test set to exact size
      if len(test_cat) > test_size:
        test_cat = test_cat.iloc[:test_size]
      elif len(test_cat) < test_size:
        extra_needed = test_size - len(test_cat)
        extra_rows = available.iloc[:extra_needed]
        test_cat = pd.concat([test_cat, extra_rows])
        available = available[~available.index.isin(extra_rows.index)]

      # Remaining rows go to train
      train_cat = available

    train_dfs.append(train_cat)
    test_dfs.append(test_cat)

  train_df = pd.concat(train_dfs).drop(columns=["has_image"])
  test_df = pd.concat(test_dfs).drop(columns=["has_image"])
  return train_df, test_df


# Minimum test percentages from your table (Feb 20, 2025, 10:58 PM)
min_test_pcts = {
  "Biology/Medicine": 47.26,
  "Chemistry": 61.56,
  "Computer Science/AI": 51.29,
  "Engineering": 67.72,
  "Humanities/Social Science": 53.63,
  "Math": 19.67,
  "Other": 51.29,
  "Physics": 53.10,
}

# Perform the split
train_df, test_df = split_dataset(df, min_test_pcts)

# Calculate stats by category with percentages
categories = df["category"].unique()
strat_stats = {}
for cat in categories:
  cat_df = df[df["category"] == cat]
  cat_train_df = train_df[train_df["category"] == cat]
  cat_test_df = test_df[test_df["category"] == cat]

  num_samples = len(cat_df)
  test_size = len(cat_test_df)
  train_size = num_samples - test_size
  train_pct = (train_size / num_samples) * 100 if num_samples > 0 else 0
  test_pct = (test_size / num_samples) * 100 if num_samples > 0 else 0

  test_pct_image = (
    cat_test_df["image"].apply(lambda x: isinstance(x, str) and len(x.strip()) > 0 and "data:image" in x).mean() * 100 if len(cat_test_df) > 0 else 0
  )

  test_pct_exact = (cat_test_df["answer_type"] == "exactMatch").mean() * 100 if len(cat_test_df) > 0 else 0

  strat_stats[cat] = {"Questions": num_samples, "Train %": train_pct, "Test %": test_pct, "% Image": test_pct_image, "% Exact Match": test_pct_exact}

# Create table DataFrame
table_data = {"Category": [], "Questions": [], "Train %": [], "Test %": [], "% Image": [], "% Exact Match": []}
for cat, stats in strat_stats.items():
  table_data["Category"].append(cat)
  table_data["Questions"].append(stats["Questions"])
  table_data["Train %"].append(stats["Train %"])
  table_data["Test %"].append(stats["Test %"])
  table_data["% Image"].append(stats["% Image"])
  table_data["% Exact Match"].append(stats["% Exact Match"])

table_df = pd.DataFrame(table_data)
table_df = table_df.sort_values("Category")

# Format plain text version
table_df_plain = table_df.copy()
for col in ["Train %", "Test %", "% Image", "% Exact Match"]:
  table_df_plain[col] = table_df_plain[col].apply(lambda x: f"{x:.2f}%")

# Display plain text version
print("\nStratification Statistics by Category (Plain Text):")
print(table_df_plain.to_string(index=False))

# Save LaTeX version with corrected path
os.makedirs("../.assets", exist_ok=True)
latex_table = table_df.to_latex(
  index=False,
  column_format="lrrrrr",
  header=["Category", "Questions", "Train \\%", "Test \\%", "\\% Image", "\\% Exact Match"],
  escape=False,
  float_format="%.2f",
)
with open("../.assets/split_statistics.tex", "w") as f:
  f.write("\\begin{table}[H]\n")
  f.write("\\centering\n")
  f.write(latex_table)
  f.write("\\vspace{0.2cm}\n")
  f.write("\\caption{Dataset Train/Test Split by Category}\n")
  f.write("\\label{tab:split_statistics}\n")
  f.write("\\end{table}\n")

print("\nLaTeX table saved to '../.assets/split_statistics.tex'")


Stratification Statistics by Category (Plain Text):
                 Category  Questions Train % Test % % Image % Exact Match
         Biology/Medicine        303  52.81% 47.19%  19.58%        48.25%
                Chemistry        170  38.24% 61.76%  38.10%        75.24%
      Computer Science/AI        258  48.84% 51.16%   6.82%        69.70%
              Engineering        130  32.31% 67.69%  46.59%        72.73%
Humanities/Social Science        235  46.38% 53.62%  11.11%        59.52%
                     Math       1106  80.29% 19.71%   4.13%        91.28%
                    Other        258  48.84% 51.16%  22.73%        75.00%
                  Physics        240  47.08% 52.92%   6.30%        84.25%

LaTeX table saved to '../.assets/split_statistics.tex'


In [19]:
# Function to calculate minimum test size and percentage
def calculate_min_test_size(total_examples):
  Z = 1.645  # 90% confidence level
  p = 0.5  # Assumed proportion (max variance for conservative estimate)
  E = 0.05  # 5% margin of error
  n = (Z**2 * p * (1 - p)) / (E**2)  # Base sample size
  n_adjusted = n / (1 + (n - 1) / total_examples) if total_examples > 0 else n  # Finite population correction
  min_test_pct = (n_adjusted / total_examples) * 100 if total_examples > 0 else 100  # Convert to percentage
  return n_adjusted, min_test_pct


# Stratification stats by category
categories = df["category"].unique()
strat_stats = {}
for cat in categories:
  cat_df = df[df["category"] == cat]
  total_examples = len(cat_df)
  min_test_size, min_test_pct = calculate_min_test_size(total_examples)

  strat_stats[cat] = {"Questions": total_examples, "Min Test Size": min_test_size, "Min Test %": min_test_pct, "Train %": 100 - min_test_pct}

# Create table DataFrame
table_data = {"Category": [], "Questions": [], "Min Test Size": [], "Min Test %": [], "Train %": []}
for cat, stats in strat_stats.items():
  table_data["Category"].append(cat)
  table_data["Questions"].append(stats["Questions"])
  table_data["Min Test Size"].append(stats["Min Test Size"])
  table_data["Min Test %"].append(stats["Min Test %"])
  table_data["Train %"].append(stats["Train %"])

table_df = pd.DataFrame(table_data)

# Format plain text version
table_df_plain = table_df.copy()
table_df_plain["Min Test Size"] = table_df_plain["Min Test Size"].round(0).astype(int)
table_df_plain["Min Test %"] = table_df_plain["Min Test %"].round(2).astype(str) + "%"
table_df_plain["Train %"] = table_df_plain["Train %"].round(2).astype(str) + "%"
table_df_plain = table_df_plain.sort_values("Category")

# Format LaTeX version
table_df_latex = table_df.copy()
table_df_latex["Min Test Size"] = table_df_latex["Min Test Size"].round(0).astype(int)
table_df_latex["Min Test %"] = table_df_latex["Min Test %"].round(2)
table_df_latex["Train %"] = table_df_latex["Train %"].round(2)
table_df_latex = table_df_latex.sort_values("Category")

# Display plain text version
print("\nStratification Statistics by Category (Plain Text):")
print(table_df_plain.to_string(index=False))

# Save LaTeX version
os.makedirs("../.assets", exist_ok=True)
latex_table = table_df_latex.to_latex(
  index=False, column_format="lrrrr", header=["Category", "Questions", "Min Test Size", "Min Test \\%", "Train \\%"], escape=False
)
with open("../.assets/split_analysis.tex", "w") as f:
  f.write("\\begin{table}[H]\n")
  f.write("\\centering\n")
  f.write(latex_table)
  f.write("\\vspace{0.2cm}\n")
  f.write("\\caption{Statistics by Category for cais/hle Test Split (90\\% CI, 5\\% Error)}\n")
  f.write("\\label{tab:split_analysis}\n")
  f.write("\\end{table}\n")

print("\nLaTeX table saved to '../.assets/split_analysis.tex'")


Stratification Statistics by Category (Plain Text):
                 Category  Questions  Min Test Size Min Test % Train %
         Biology/Medicine        303            143     47.26%  52.74%
                Chemistry        170            105     61.56%  38.44%
      Computer Science/AI        258            132     51.29%  48.71%
              Engineering        130             88     67.72%  32.28%
Humanities/Social Science        235            126     53.63%  46.37%
                     Math       1106            218     19.67%  80.33%
                    Other        258            132     51.29%  48.71%
                  Physics        240            127      53.1%   46.9%

LaTeX table saved to '../.assets/split_analysis.tex'
