# Inspect `split_batch_preserve_relations`

Use this notebook to load a pre-tokenized batch directory and verify that ``split_batch_preserve_relations`` retains every labeled query–document relation when splitting large batches into smaller pieces.

## How to use
1. Update `PROJECT_ROOT` below if your checkout lives somewhere else.
2. Point `BATCH_DIR` at one of the numbered batch folders containing `queries.parquet`, `documents.parquet`, and `relations.parquet`.
3. Adjust `SPLIT_FACTOR` (and `PRESERVE_RELATIONS` if you want a baseline) then run the validation cells.

In [1]:
from pathlib import Path
import sys

PROJECT_ROOT = Path('/Users/qzeng/codes/ArcticTraining').resolve()
ARCTIC_EMBED_ROOT = PROJECT_ROOT / 'projects' / 'arctic_embed'
SRC_PATH = ARCTIC_EMBED_ROOT / 'src'

if not SRC_PATH.exists():
    raise FileNotFoundError(f'Expected Arctic Embed sources at {SRC_PATH}')

if str(SRC_PATH) not in sys.path:
    sys.path.insert(0, str(SRC_PATH))

print(f'Project root: {PROJECT_ROOT}')
print(f'Using src path: {SRC_PATH}')


Project root: /Users/qzeng/codes/ArcticTraining
Using src path: /Users/qzeng/codes/ArcticTraining/projects/arctic_embed/src


In [4]:
from collections import defaultdict
from typing import Dict, List, Set, Tuple

import pandas as pd
import torch

from arctic_embed.core.pretokenized_batch_loader import (
    ContrastiveLearningBatch,
    read_batch,
    split_batch,
    split_batch_preserve_relations,
)


In [5]:
def summarize_batch(batch: ContrastiveLearningBatch) -> Dict[str, int]:
    relations = batch.relevance_labels.coalesce()
    values = relations.values()
    positives = int((values > 0).sum().item())
    negatives = int((values < 0).sum().item())
    return {
        'queries': batch.query_tokens.size(0),
        'documents': batch.document_tokens.size(0),
        'relations': relations._nnz(),
        'positives': positives,
        'negatives': negatives,
    }

def sparse_relations_to_dict(rel: torch.Tensor) -> Dict[int, Set[int]]:
    rel = rel.coalesce()
    indices = rel.indices()
    mapping: Dict[int, Set[int]] = defaultdict(set)
    for i in range(rel._nnz()):
        mapping[int(indices[0, i])].add(int(indices[1, i]))
    return mapping

def build_token_lookup(tensor: torch.Tensor) -> Dict[Tuple[int, ...], List[int]]:
    lookup: Dict[Tuple[int, ...], List[int]] = defaultdict(list)
    for idx, row in enumerate(tensor.cpu().tolist()):
        key = tuple(int(token) for token in row)
        lookup[key].append(idx)
    return lookup

def map_rows_to_original(tensor: torch.Tensor, lookup: Dict[Tuple[int, ...], List[int]]) -> List[int]:
    original_indices: List[int] = []
    for row in tensor.cpu().tolist():
        key = tuple(int(token) for token in row)
        candidates = lookup.get(key)
        if not candidates:
            raise KeyError(
                'Could not map a row back to the original batch. '
                'Confirm that `BATCH_DIR` points to the same data used to construct the splits.'
            )
        original_indices.append(candidates[0])
    return original_indices

def analyze_splits(original: ContrastiveLearningBatch, splits: List[ContrastiveLearningBatch]):
    query_lookup = build_token_lookup(original.query_tokens)
    doc_lookup = build_token_lookup(original.document_tokens)
    original_rel_map = sparse_relations_to_dict(original.relevance_labels)

    summary_records = []
    per_query_records = []

    for split_idx, split in enumerate(splits):
        local_to_original_query = map_rows_to_original(split.query_tokens, query_lookup)
        local_to_original_doc = map_rows_to_original(split.document_tokens, doc_lookup)
        split_rel_map = sparse_relations_to_dict(split.relevance_labels)

        missing_links = 0
        extra_links = 0

        for local_q_idx, original_q_idx in enumerate(local_to_original_query):
            expected_docs = original_rel_map.get(original_q_idx, set())
            present_docs = {local_to_original_doc[d] for d in split_rel_map.get(local_q_idx, set())}
            missing = sorted(expected_docs - present_docs)
            extra = sorted(present_docs - expected_docs)
            missing_links += len(missing)
            extra_links += len(extra)

            per_query_records.append({
                'split': split_idx,
                'split_query_idx': local_q_idx,
                'original_query_idx': original_q_idx,
                'expected_docs': len(expected_docs),
                'present_docs': len(present_docs),
                'missing_doc_indices': missing,
                'extra_doc_indices': extra,
            })

        rel = split.relevance_labels.coalesce()
        summary_records.append({
            'split': split_idx,
            'num_queries': split.query_tokens.size(0),
            'num_documents': split.document_tokens.size(0),
            'nnz_relations': rel._nnz(),
            'missing_links': missing_links,
            'extra_links': extra_links,
        })

    import pandas as pd
    return pd.DataFrame(summary_records), pd.DataFrame(per_query_records)


In [None]:
# --- Configure the batch you want to inspect ---
BATCH_DIR = Path(ARCTIC_EMBED_ROOT / 'notebooks' / 'batch_data')  # e.g. PROJECT_ROOT / 'data' / 'batch_000001'
SPLIT_FACTOR = 32

if not BATCH_DIR.exists():
    raise FileNotFoundError(f'Batch directory not found: {BATCH_DIR}')


In [9]:
batch = read_batch(str(BATCH_DIR))
print('Original batch summary:', summarize_batch(batch))

Original batch summary: {'queries': 512, 'documents': 5565, 'relations': 5632, 'positives': 512, 'negatives': 5120}


In [41]:
split_batches = split_batch(batch, 32)

print(f'Generated {len(split_batches)} split batches using `split_batch`')
summary_df, per_query_df = analyze_splits(batch, split_batches)
summary_df

Generated 32 split batches using `split_batch`


Unnamed: 0,split,num_queries,num_documents,nnz_relations,missing_links,extra_links
0,0,16,173,173,3,0
1,1,16,173,170,6,0
2,2,16,173,167,9,0
3,3,16,173,164,12,0
4,4,16,173,161,15,0
5,5,16,173,159,17,0
6,6,16,173,158,18,0
7,7,16,173,156,20,0
8,8,16,173,153,23,0
9,9,16,173,152,24,0


In [42]:
preserve_split_batches = split_batch_preserve_relations(batch, SPLIT_FACTOR)

print(f'Generated {len(split_batches)} split batches using `split_batch_preserve_relations`')
preserve_summary_df, per_query_df = analyze_splits(batch, preserve_split_batches)
preserve_summary_df

Generated 32 split batches using `split_batch_preserve_relations`


Unnamed: 0,split,num_queries,num_documents,nnz_relations,missing_links,extra_links
0,0,16,176,176,0,0
1,1,16,176,176,0,0
2,2,16,176,176,0,0
3,3,16,176,176,0,0
4,4,16,176,176,0,0
5,5,16,176,176,0,0
6,6,16,176,176,0,0
7,7,16,175,176,0,0
8,8,16,176,176,0,0
9,9,16,176,176,0,0
