# Experiments

In [None]:
import sys
from pathlib import Path

# Add project root to path so we can import from src/
PROJECT_ROOT = Path.cwd().parent
sys.path.insert(0, str(PROJECT_ROOT))

from src.preprocess import clean_dataset, clean_all_datasets
from src import dense_embeddings_helpers
from src import dense_rankings_helpers
from src import bm25
from src import evaluate_rankings
from src.config import EMBEDDINGS_DIR, RANKINGS_DIR, RESULTS_DIR, ensure_output_dirs

import json

# Ensure output directories exist
ensure_output_dirs()

## Preprocess

In [None]:
# Clean both train and test datasets
clean_all_datasets()

## Embed the Corpus Using the Dense (BERT based) pipelines

In [None]:
# Run all embedding configurations
# Note: PCA projections will be saved in the same folder

for train in [True, False]:
    for stride in [512, 256]:
        for pool in ["cls_only", "tokens"]:
            for model in ["BERT-base", "LegalBERT"]:
                print(f"Embedding: {model} | {pool} | stride={stride} | train={train}")
                dense_embeddings_helpers.embed_dataset(
                    model,
                    pool,
                    stride=stride,
                    train=train,
                    debug=True
                )

## Create Rankings for each case based on the Dense embeddings

In [None]:
# Generate rankings from all embedding files
for file in sorted(EMBEDDINGS_DIR.iterdir()):
    if file.suffix == '.npz':
        print(f'Ranking: {file.name}')
        dense_rankings_helpers.output_embedding_rankings(file, train=True)

## Create Rankings based on the Sparse (BM25) algorithm


In [None]:
# Generate BM25 rankings
bm25.output_bm25_rankings(train=True, debug=True)

## Create Results based on rankings

In [None]:
# Evaluate all rankings and save results
all_results = evaluate_rankings.evaluate_all_rankings(train=True, save_results=True)

# Display summary
print("\n=== Results Summary ===")
for name, metrics in all_results.items():
    print(f"\n{name}:")
    print(f"  MRR:   {metrics['MRR']['mean']:.4f} ± {metrics['MRR']['std']:.4f}")
    print(f"  P@10:  {metrics['P@10']['mean']:.4f} ± {metrics['P@10']['std']:.4f}")
    print(f"  R@10:  {metrics['R@10']['mean']:.4f} ± {metrics['R@10']['std']:.4f}")