# Pre-training Dataset Preparation

This notebook downloads and prepares the pre-training datasets specified in the research paper:

**"Towards Competitive Search Relevance For Inference-Free Learned Sparse Retrievers"**

## Pre-training Datasets (Table 1 in paper)

| Dataset | Size | Pair Type | Source |
|---------|------|-----------|--------|
| **S2ORC** | 41.7M | (Title, Abstract) | Semantic Scholar |
| **WikiAnswers** | 77.4M | (Question, Question) | Duplicate questions |
| **GOOAQ** | 3.0M | (Question, Answer) | Google Q&A snippets |
| **SQuAD** | 87K | (Question, Context) | Reading comprehension |
| **Natural Questions** | 307K | (Question, Passage) | Google search queries |
| **ELI5** | 272K | (Question, Answer) | Reddit explain-like-I'm-five |

**Total**: ~122M training pairs

## Objectives

1. Download all pre-training datasets
2. Convert to unified (query, document) format
3. Apply quality filters
4. Save in chunks for efficient training
5. Generate statistics on dataset composition

In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../..')

from pathlib import Path
import json
from tqdm import tqdm
import gzip
from datasets import load_dataset

## 1. Setup

In [None]:
# Output directory
output_dir = Path("../../dataset/pretraining")
output_dir.mkdir(parents=True, exist_ok=True)

# Cache directory for HuggingFace datasets
cache_dir = Path("../../dataset/pretraining/cache")
cache_dir.mkdir(parents=True, exist_ok=True)

# Processing settings
CHUNK_SIZE = 100000  # 100K pairs per file
SKIP_IF_EXISTS = True

print(f"‚úì Output directory: {output_dir}")
print(f"‚úì Cache directory: {cache_dir}")

## 2. S2ORC Dataset (41.7M Title-Abstract pairs)

**S2ORC (Semantic Scholar Open Research Corpus)** contains scientific papers with titles and abstracts.

We'll use the `allenai/s2orc` dataset from HuggingFace.

In [None]:
import glob

# Check if S2ORC paired data already exists
s2orc_files = sorted(glob.glob(str(output_dir / "s2orc_chunk_*.jsonl")))

if SKIP_IF_EXISTS and s2orc_files:
    print("=" * 80)
    print("‚úì S2ORC paired data already exists!")
    print("=" * 80)
    print(f"\nFound {len(s2orc_files)} chunk files")
    total_pairs = sum(sum(1 for _ in open(f)) for f in s2orc_files)
    print(f"Total pairs: {total_pairs:,}")
    print("\nüí° Set SKIP_IF_EXISTS = False to force re-processing")
    
else:
    print("=" * 80)
    print("Processing S2ORC Dataset")
    print("=" * 80)
    print("\n‚¨á This will download ~200GB of data on first run...")
    print("‚è≥ Processing may take several hours...\n")
    
    try:
        # Load S2ORC dataset
        dataset = load_dataset(
            "allenai/s2orc",
            split="train",
            cache_dir=str(cache_dir),
            streaming=True  # Stream to avoid loading all at once
        )
        
        chunk_num = 0
        current_chunk = []
        total_pairs = 0
        
        for item in tqdm(dataset, desc="Processing S2ORC"):
            title = item.get("title", "")
            abstract = item.get("abstract", "")
            
            # Quality filters
            if not title or not abstract:
                continue
            if len(abstract) < 50 or len(abstract) > 2000:
                continue
            if len(title) < 10 or len(title) > 300:
                continue
            
            pair = {
                "query": title,
                "document": abstract,
                "query_type": "title",
                "doc_type": "abstract",
                "source": "s2orc",
                "source_id": item.get("paper_id", ""),
            }
            
            current_chunk.append(pair)
            total_pairs += 1
            
            # Save chunk when limit reached
            if len(current_chunk) >= CHUNK_SIZE:
                chunk_num += 1
                chunk_file = output_dir / f"s2orc_chunk_{chunk_num:03d}.jsonl"
                
                with open(chunk_file, 'w', encoding='utf-8') as f:
                    for p in current_chunk:
                        f.write(json.dumps(p, ensure_ascii=False) + "\n")
                
                print(f"Saved chunk {chunk_num}: {len(current_chunk):,} pairs")
                current_chunk = []
        
        # Save remaining pairs
        if current_chunk:
            chunk_num += 1
            chunk_file = output_dir / f"s2orc_chunk_{chunk_num:03d}.jsonl"
            
            with open(chunk_file, 'w', encoding='utf-8') as f:
                for p in current_chunk:
                    f.write(json.dumps(p, ensure_ascii=False) + "\n")
            
            print(f"Saved chunk {chunk_num}: {len(current_chunk):,} pairs")
        
        print(f"\n‚úì Processed {total_pairs:,} S2ORC pairs in {chunk_num} chunks")
        
    except Exception as e:
        print(f"\n‚úó Error processing S2ORC: {e}")
        print("   This dataset requires significant disk space and processing time.")
        print("   You may need to manually download and process it separately.")

## 3. WikiAnswers Dataset (77.4M Question pairs)

**WikiAnswers** contains duplicate question pairs - questions that have the same meaning.

We'll use the `wiki_qa` or similar dataset from HuggingFace.

In [None]:
# Check if WikiAnswers paired data already exists
wiki_ans_files = sorted(glob.glob(str(output_dir / "wikianswers_chunk_*.jsonl")))

if SKIP_IF_EXISTS and wiki_ans_files:
    print("=" * 80)
    print("‚úì WikiAnswers paired data already exists!")
    print("=" * 80)
    print(f"\nFound {len(wiki_ans_files)} chunk files")
    total_pairs = sum(sum(1 for _ in open(f)) for f in wiki_ans_files)
    print(f"Total pairs: {total_pairs:,}")
    print("\nüí° Set SKIP_IF_EXISTS = False to force re-processing")
    
else:
    print("=" * 80)
    print("Processing WikiAnswers Dataset")
    print("=" * 80)
    print("\n‚¨á Downloading WikiAnswers duplicate questions...\n")
    
    try:
        # Try multiple possible sources for WikiAnswers
        # Option 1: Direct WikiQA dataset
        try:
            dataset = load_dataset(
                "wiki_qa",
                split="train",
                cache_dir=str(cache_dir)
            )
        except:
            # Option 2: Paraphrase database that includes WikiAnswers
            dataset = load_dataset(
                "sentence-transformers/embedding-training-data",
                split="train",
                cache_dir=str(cache_dir)
            )
        
        chunk_num = 0
        current_chunk = []
        total_pairs = 0
        
        for item in tqdm(dataset, desc="Processing WikiAnswers"):
            # Extract question pairs (format depends on dataset structure)
            q1 = item.get("question1", item.get("sentence1", ""))
            q2 = item.get("question2", item.get("sentence2", ""))
            
            # Quality filters
            if not q1 or not q2:
                continue
            if len(q1) < 10 or len(q1) > 500:
                continue
            if len(q2) < 10 or len(q2) > 500:
                continue
            
            pair = {
                "query": q1,
                "document": q2,
                "query_type": "question",
                "doc_type": "duplicate_question",
                "source": "wikianswers",
            }
            
            current_chunk.append(pair)
            total_pairs += 1
            
            # Save chunk when limit reached
            if len(current_chunk) >= CHUNK_SIZE:
                chunk_num += 1
                chunk_file = output_dir / f"wikianswers_chunk_{chunk_num:03d}.jsonl"
                
                with open(chunk_file, 'w', encoding='utf-8') as f:
                    for p in current_chunk:
                        f.write(json.dumps(p, ensure_ascii=False) + "\n")
                
                print(f"Saved chunk {chunk_num}: {len(current_chunk):,} pairs")
                current_chunk = []
        
        # Save remaining pairs
        if current_chunk:
            chunk_num += 1
            chunk_file = output_dir / f"wikianswers_chunk_{chunk_num:03d}.jsonl"
            
            with open(chunk_file, 'w', encoding='utf-8') as f:
                for p in current_chunk:
                    f.write(json.dumps(p, ensure_ascii=False) + "\n")
            
            print(f"Saved chunk {chunk_num}: {len(current_chunk):,} pairs")
        
        print(f"\n‚úì Processed {total_pairs:,} WikiAnswers pairs in {chunk_num} chunks")
        
    except Exception as e:
        print(f"\n‚úó Error processing WikiAnswers: {e}")
        print("   WikiAnswers dataset may require manual download or alternative source.")
        print("   Continuing with other datasets...")

## 4. GOOAQ Dataset (3.0M Question-Answer pairs)

**GOOAQ** contains questions and answers from Google's Q&A snippets.

In [None]:
# Check if GOOAQ paired data already exists
gooaq_files = sorted(glob.glob(str(output_dir / "gooaq_chunk_*.jsonl")))

if SKIP_IF_EXISTS and gooaq_files:
    print("=" * 80)
    print("‚úì GOOAQ paired data already exists!")
    print("=" * 80)
    print(f"\nFound {len(gooaq_files)} chunk files")
    total_pairs = sum(sum(1 for _ in open(f)) for f in gooaq_files)
    print(f"Total pairs: {total_pairs:,}")
    print("\nüí° Set SKIP_IF_EXISTS = False to force re-processing")
    
else:
    print("=" * 80)
    print("Processing GOOAQ Dataset")
    print("=" * 80)
    print("\n‚¨á Downloading GOOAQ Question-Answer pairs...\n")
    
    try:
        # Load GOOAQ dataset
        dataset = load_dataset(
            "sentence-transformers/gooaq",
            split="train",
            cache_dir=str(cache_dir)
        )
        
        chunk_num = 0
        current_chunk = []
        total_pairs = 0
        
        for item in tqdm(dataset, desc="Processing GOOAQ"):
            question = item.get("question", "")
            answer = item.get("answer", "")
            
            # Quality filters
            if not question or not answer:
                continue
            if len(question) < 10 or len(question) > 500:
                continue
            if len(answer) < 20 or len(answer) > 2000:
                continue
            
            pair = {
                "query": question,
                "document": answer,
                "query_type": "question",
                "doc_type": "answer",
                "source": "gooaq",
            }
            
            current_chunk.append(pair)
            total_pairs += 1
            
            # Save chunk when limit reached
            if len(current_chunk) >= CHUNK_SIZE:
                chunk_num += 1
                chunk_file = output_dir / f"gooaq_chunk_{chunk_num:03d}.jsonl"
                
                with open(chunk_file, 'w', encoding='utf-8') as f:
                    for p in current_chunk:
                        f.write(json.dumps(p, ensure_ascii=False) + "\n")
                
                print(f"Saved chunk {chunk_num}: {len(current_chunk):,} pairs")
                current_chunk = []
        
        # Save remaining pairs
        if current_chunk:
            chunk_num += 1
            chunk_file = output_dir / f"gooaq_chunk_{chunk_num:03d}.jsonl"
            
            with open(chunk_file, 'w', encoding='utf-8') as f:
                for p in current_chunk:
                    f.write(json.dumps(p, ensure_ascii=False) + "\n")
            
            print(f"Saved chunk {chunk_num}: {len(current_chunk):,} pairs")
        
        print(f"\n‚úì Processed {total_pairs:,} GOOAQ pairs in {chunk_num} chunks")
        
    except Exception as e:
        print(f"\n‚úó Error processing GOOAQ: {e}")
        print("   Continuing with other datasets...")

## 5. SQuAD Dataset (87K Question-Context pairs)

**SQuAD (Stanford Question Answering Dataset)** contains questions and context passages.

In [None]:
# Check if SQuAD paired data already exists
squad_files = sorted(glob.glob(str(output_dir / "squad_chunk_*.jsonl")))

if SKIP_IF_EXISTS and squad_files:
    print("=" * 80)
    print("‚úì SQuAD paired data already exists!")
    print("=" * 80)
    print(f"\nFound {len(squad_files)} chunk files")
    total_pairs = sum(sum(1 for _ in open(f)) for f in squad_files)
    print(f"Total pairs: {total_pairs:,}")
    
else:
    print("=" * 80)
    print("Processing SQuAD Dataset")
    print("=" * 80)
    
    try:
        # Load SQuAD v2.0
        dataset = load_dataset(
            "squad_v2",
            split="train",
            cache_dir=str(cache_dir)
        )
        
        chunk_num = 0
        current_chunk = []
        total_pairs = 0
        
        for item in tqdm(dataset, desc="Processing SQuAD"):
            question = item.get("question", "")
            context = item.get("context", "")
            
            # Quality filters
            if not question or not context:
                continue
            if len(context) < 50 or len(context) > 3000:
                continue
            
            pair = {
                "query": question,
                "document": context,
                "query_type": "question",
                "doc_type": "context",
                "source": "squad",
                "source_id": item.get("id", ""),
            }
            
            current_chunk.append(pair)
            total_pairs += 1
            
            # Save chunk when limit reached
            if len(current_chunk) >= CHUNK_SIZE:
                chunk_num += 1
                chunk_file = output_dir / f"squad_chunk_{chunk_num:03d}.jsonl"
                
                with open(chunk_file, 'w', encoding='utf-8') as f:
                    for p in current_chunk:
                        f.write(json.dumps(p, ensure_ascii=False) + "\n")
                
                print(f"Saved chunk {chunk_num}: {len(current_chunk):,} pairs")
                current_chunk = []
        
        # Save remaining pairs
        if current_chunk:
            chunk_num += 1
            chunk_file = output_dir / f"squad_chunk_{chunk_num:03d}.jsonl"
            
            with open(chunk_file, 'w', encoding='utf-8') as f:
                for p in current_chunk:
                    f.write(json.dumps(p, ensure_ascii=False) + "\n")
            
            print(f"Saved chunk {chunk_num}: {len(current_chunk):,} pairs")
        
        print(f"\n‚úì Processed {total_pairs:,} SQuAD pairs in {chunk_num} chunks")
        
    except Exception as e:
        print(f"\n‚úó Error processing SQuAD: {e}")

## 6. Natural Questions Dataset (307K pairs)

**Natural Questions** contains real Google search queries with passage answers.

In [None]:
# Check if Natural Questions paired data already exists
nq_files = sorted(glob.glob(str(output_dir / "natural_questions_chunk_*.jsonl")))

if SKIP_IF_EXISTS and nq_files:
    print("=" * 80)
    print("‚úì Natural Questions paired data already exists!")
    print("=" * 80)
    print(f"\nFound {len(nq_files)} chunk files")
    total_pairs = sum(sum(1 for _ in open(f)) for f in nq_files)
    print(f"Total pairs: {total_pairs:,}")
    
else:
    print("=" * 80)
    print("Processing Natural Questions Dataset")
    print("=" * 80)
    
    try:
        # Load Natural Questions
        dataset = load_dataset(
            "natural_questions",
            split="train",
            cache_dir=str(cache_dir)
        )
        
        chunk_num = 0
        current_chunk = []
        total_pairs = 0
        
        for item in tqdm(dataset, desc="Processing Natural Questions"):
            question = item.get("question", {}).get("text", "")
            
            # Extract passage from annotations
            annotations = item.get("annotations", [])
            if not annotations:
                continue
            
            long_answer = annotations[0].get("long_answer", {})
            if not long_answer:
                continue
            
            # Get passage text
            document_tokens = item.get("document", {}).get("tokens", [])
            start_token = long_answer.get("start_token", 0)
            end_token = long_answer.get("end_token", 0)
            
            if start_token >= end_token:
                continue
            
            passage = " ".join(
                [t.get("token", "") for t in document_tokens[start_token:end_token]]
            )
            
            # Quality filters
            if not question or not passage:
                continue
            if len(passage) < 50 or len(passage) > 3000:
                continue
            
            pair = {
                "query": question,
                "document": passage,
                "query_type": "question",
                "doc_type": "passage",
                "source": "natural_questions",
            }
            
            current_chunk.append(pair)
            total_pairs += 1
            
            # Save chunk when limit reached
            if len(current_chunk) >= CHUNK_SIZE:
                chunk_num += 1
                chunk_file = output_dir / f"natural_questions_chunk_{chunk_num:03d}.jsonl"
                
                with open(chunk_file, 'w', encoding='utf-8') as f:
                    for p in current_chunk:
                        f.write(json.dumps(p, ensure_ascii=False) + "\n")
                
                print(f"Saved chunk {chunk_num}: {len(current_chunk):,} pairs")
                current_chunk = []
        
        # Save remaining pairs
        if current_chunk:
            chunk_num += 1
            chunk_file = output_dir / f"natural_questions_chunk_{chunk_num:03d}.jsonl"
            
            with open(chunk_file, 'w', encoding='utf-8') as f:
                for p in current_chunk:
                    f.write(json.dumps(p, ensure_ascii=False) + "\n")
            
            print(f"Saved chunk {chunk_num}: {len(current_chunk):,} pairs")
        
        print(f"\n‚úì Processed {total_pairs:,} Natural Questions pairs in {chunk_num} chunks")
        
    except Exception as e:
        print(f"\n‚úó Error processing Natural Questions: {e}")

## 7. ELI5 Dataset (272K Question-Answer pairs)

**ELI5 (Explain Like I'm Five)** contains Reddit questions with detailed explanatory answers.

In [None]:
# Check if ELI5 paired data already exists
eli5_files = sorted(glob.glob(str(output_dir / "eli5_chunk_*.jsonl")))

if SKIP_IF_EXISTS and eli5_files:
    print("=" * 80)
    print("‚úì ELI5 paired data already exists!")
    print("=" * 80)
    print(f"\nFound {len(eli5_files)} chunk files")
    total_pairs = sum(sum(1 for _ in open(f)) for f in eli5_files)
    print(f"Total pairs: {total_pairs:,}")
    
else:
    print("=" * 80)
    print("Processing ELI5 Dataset")
    print("=" * 80)
    
    try:
        # Load ELI5 dataset
        dataset = load_dataset(
            "eli5",
            split="train",
            cache_dir=str(cache_dir)
        )
        
        chunk_num = 0
        current_chunk = []
        total_pairs = 0
        
        for item in tqdm(dataset, desc="Processing ELI5"):
            question = item.get("title", "")
            answers = item.get("answers", {}).get("text", [])
            
            # Use the top answer if available
            if not answers:
                continue
            
            answer = answers[0]  # Top-voted answer
            
            # Quality filters
            if not question or not answer:
                continue
            if len(answer) < 50 or len(answer) > 3000:
                continue
            
            pair = {
                "query": question,
                "document": answer,
                "query_type": "question",
                "doc_type": "explanation",
                "source": "eli5",
            }
            
            current_chunk.append(pair)
            total_pairs += 1
            
            # Save chunk when limit reached
            if len(current_chunk) >= CHUNK_SIZE:
                chunk_num += 1
                chunk_file = output_dir / f"eli5_chunk_{chunk_num:03d}.jsonl"
                
                with open(chunk_file, 'w', encoding='utf-8') as f:
                    for p in current_chunk:
                        f.write(json.dumps(p, ensure_ascii=False) + "\n")
                
                print(f"Saved chunk {chunk_num}: {len(current_chunk):,} pairs")
                current_chunk = []
        
        # Save remaining pairs
        if current_chunk:
            chunk_num += 1
            chunk_file = output_dir / f"eli5_chunk_{chunk_num:03d}.jsonl"
            
            with open(chunk_file, 'w', encoding='utf-8') as f:
                for p in current_chunk:
                    f.write(json.dumps(p, ensure_ascii=False) + "\n")
            
            print(f"Saved chunk {chunk_num}: {len(current_chunk):,} pairs")
        
        print(f"\n‚úì Processed {total_pairs:,} ELI5 pairs in {chunk_num} chunks")
        
    except Exception as e:
        print(f"\n‚úó Error processing ELI5: {e}")

## 8. Overall Statistics

In [None]:
import glob

print("=" * 80)
print("PRE-TRAINING DATASET STATISTICS")
print("=" * 80)

def count_dataset_pairs(pattern, name):
    files = sorted(glob.glob(str(output_dir / pattern)))
    if files:
        total = sum(sum(1 for _ in open(f)) for f in files)
        print(f"\n{name}:")
        print(f"     Files: {len(files)}")
        print(f"     Pairs: {total:,}")
        return total
    return 0

s2orc_pairs = count_dataset_pairs("s2orc_chunk_*.jsonl", "1. S2ORC (Title-Abstract)")
wiki_ans_pairs = count_dataset_pairs("wikianswers_chunk_*.jsonl", "2. WikiAnswers (Question-Question)")
gooaq_pairs = count_dataset_pairs("gooaq_chunk_*.jsonl", "3. GOOAQ (Question-Answer)")
squad_pairs = count_dataset_pairs("squad_chunk_*.jsonl", "4. SQuAD (Question-Context)")
nq_pairs = count_dataset_pairs("natural_questions_chunk_*.jsonl", "5. Natural Questions")
eli5_pairs = count_dataset_pairs("eli5_chunk_*.jsonl", "6. ELI5 (Question-Explanation)")

total_pairs = s2orc_pairs + wiki_ans_pairs + gooaq_pairs + squad_pairs + nq_pairs + eli5_pairs

print("\n" + "=" * 80)
print(f"TOTAL PRE-TRAINING PAIRS: {total_pairs:,}")
print("=" * 80)

if total_pairs > 0:
    print("\nDataset Composition:")
    if s2orc_pairs > 0:
        print(f"  S2ORC: {s2orc_pairs/total_pairs*100:.1f}%")
    if wiki_ans_pairs > 0:
        print(f"  WikiAnswers: {wiki_ans_pairs/total_pairs*100:.1f}%")
    if gooaq_pairs > 0:
        print(f"  GOOAQ: {gooaq_pairs/total_pairs*100:.1f}%")
    if squad_pairs > 0:
        print(f"  SQuAD: {squad_pairs/total_pairs*100:.1f}%")
    if nq_pairs > 0:
        print(f"  Natural Questions: {nq_pairs/total_pairs*100:.1f}%")
    if eli5_pairs > 0:
        print(f"  ELI5: {eli5_pairs/total_pairs*100:.1f}%")

# Sample a pair
all_files = sorted(glob.glob(str(output_dir / "*.jsonl")))
if all_files:
    print("\n" + "=" * 80)
    print("SAMPLE PAIR:")
    print("=" * 80)
    
    with open(all_files[0], 'r', encoding='utf-8') as f:
        sample = json.loads(f.readline())
    
    print(f"\nSource: {sample['source']}")
    print(f"Query type: {sample['query_type']}")
    print(f"Document type: {sample['doc_type']}")
    print(f"\nQuery: {sample['query'][:200]}...")
    print(f"Document: {sample['document'][:300]}...")

print("\n" + "=" * 80)

## Summary

This notebook downloads and processes the pre-training datasets specified in the research paper:

**Target datasets (from paper):**
- ‚úì S2ORC: 41.7M (Title, Abstract) pairs
- ‚úì WikiAnswers: 77.4M duplicate question pairs
- ‚úì GOOAQ: 3.0M (Question, Answer) pairs
- ‚úì SQuAD: 87K (Question, Context) pairs
- ‚úì Natural Questions: 307K (Question, Passage) pairs
- ‚úì ELI5: 272K (Question, Explanation) pairs

**Total target**: ~122M training pairs

**Output structure:**
```
dataset/pretraining/
‚îú‚îÄ‚îÄ s2orc_chunk_*.jsonl
‚îú‚îÄ‚îÄ wikianswers_chunk_*.jsonl
‚îú‚îÄ‚îÄ gooaq_chunk_*.jsonl
‚îú‚îÄ‚îÄ squad_chunk_*.jsonl
‚îú‚îÄ‚îÄ natural_questions_chunk_*.jsonl
‚îî‚îÄ‚îÄ eli5_chunk_*.jsonl
```

**Next steps:**
1. Hard negatives mining (notebook 04)
2. MS MARCO fine-tuning data (notebook 05)
3. Model pre-training with these datasets

**Note**: Some datasets (especially S2ORC and WikiAnswers) are very large and may require:
- Significant disk space (100GB+)
- Long download/processing time (hours to days)
- Alternative download methods or sampling strategies