In [1]:

"""---------------------------------------------------------------------
Block 1 – Imports & global settings                                     
---------------------------------------------------------------------"""
import os, json, re, string, time, math, multiprocessing as mp, pickle
from pathlib import Path
from functools import lru_cache

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

from rank_bm25 import BM25Okapi
import hnswlib

try:
    from sentence_transformers import SentenceTransformer, util as st_util
    import torch
    DENSE_OK = True
except ImportError:
    DENSE_OK = False

# Re‑seed for reproducibility
SEED = 2025
np.random.seed(SEED)

DATA_DIR = Path("./datasets")
CACHE_DIR = Path(".cache"); CACHE_DIR.mkdir(exist_ok=True)
K_SUBMISSION = 100

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
"""---------------------------------------------------------------------
Block 2 – Data loading helpers                                           
---------------------------------------------------------------------"""

def load_json(path: Path):
    with open(path, "r", encoding="utf‑8") as f:
        return json.load(f)

def load_all():
    content = DATA_DIR / "Content_JSONs"
    citing_train = load_json(content/"Citing_2020_Cleaned_Content_12k/Citing_Train_Test/citing_TRAIN.json")
    citing_test  = load_json(content/"Citing_2020_Cleaned_Content_12k/Citing_Train_Test/citing_TEST.json")
    nonciting    = load_json(content/"Cited_2020_Uncited_2010-2019_Cleaned_Content_22k/CLEANED_CONTENT_DATASET_cited_patents_by_2020_uncited_2010-2019.json")
    mapping      = pd.DataFrame(load_json(DATA_DIR/"Citation_JSONs/Citation_Train.json"))
    return citing_train, citing_test, nonciting, mapping

CITING_TRAIN, CITING_TEST, NONCITING, MAP_DF = load_all()

In [3]:
# %%
"""---------------------------------------------------------------------
Block 3 – Text extraction & basic preprocessing                          
---------------------------------------------------------------------"""

import nltk
nltk.download('stopwords', quiet=True)
nltk.download('punkt', quiet=True) # Needed for word_tokenize
from nltk.corpus import stopwords as nltk_sw
from nltk.stem import PorterStemmer # Added for stemming

STOP_WORDS = set(nltk_sw.words('english')).union({
    'claim', 'claims', 'method', 'apparatus',
    'embodiment', 'wherein', 'figure', 'system', 'device' # Added a few more common patent terms
})

TOKENIZER = re.compile(r"[\w']{2,}")
STEMMER = PorterStemmer() # Initialize stemmer

def normalize(text: str) -> str:
    text = text.lower()
    # Keep hyphens within words, remove other punctuation and digits
    text = text.translate(str.maketrans(string.punctuation.replace('-', ''), ' ' * len(string.punctuation.replace('-', ''))))
    text = text.translate(str.maketrans('', '', string.digits))
    text = re.sub(r'\s+', ' ', text).strip() # Consolidate whitespace
    return text

@lru_cache(maxsize=1<<15) # Increased cache size slightly
def preprocess_and_stem(text: str) -> list[str]:
    # Use nltk tokenizer which handles contractions better
    # Apply stemming after tokenization and stopword removal
    normalized_text = normalize(text)
    tokens = [STEMMER.stem(tok) # Apply stemming
              for tok in nltk.word_tokenize(normalized_text) # Using nltk tokenizer
              if tok.isalnum() and len(tok) > 1 and tok not in STOP_WORDS]
    return tokens

# Keep original preprocess without stemming for dense model if needed
@lru_cache(maxsize=1<<14)
def preprocess_original(text: str) -> list[str]:
    return [tok for tok in TOKENIZER.findall(normalize(text)) if tok not in STOP_WORDS]

TEXT_PARTS = {
    "title": ["title"],
    "abstract": ["pa01"],
    "claim1": ["c-en-0001"],
    "title_abstract_claims": ["title", "pa01"] + [f"c-en-{i:04d}" for i in range(1, 101)]
}

def build_corpus(records: list[dict], text_type: str) -> tuple[list[str], list[str]]:
    ids, texts = [], []
    parts = TEXT_PARTS[text_type]
    for rec in records:
        doc_id = rec.get('Application_Number','') + rec.get('Application_Category','')
        if not doc_id: continue
        content = rec.get('Content', {})
        segments = [content[k] for k in parts if content.get(k)]
        if segments:
            texts.append(' '.join(segments))
            ids.append(doc_id)
    return ids, texts

# Parallel preprocessing (disk cached) --------------------------------

def cached_tokens(name: str, texts: list[str], stem: bool = True):
    path = CACHE_DIR/f"{name}{'_stemmed' if stem else ''}.pkl"
    if path.exists():
        print(f"Loading cached tokens from: {path}")
        return pickle.load(open(path,'rb'))

    print(f"Preprocessing {'and stemming ' if stem else ''}texts for {name}...")
    func = preprocess_and_stem if stem else preprocess_original
    with mp.Pool(mp.cpu_count()) as pool:
        tokens = list(tqdm(pool.imap(func, texts, chunksize=100), total=len(texts), desc=f"Tokenizing {name}"))
    print(f"Caching tokens to: {path}")
    pickle.dump(tokens, open(path,'wb'))
    return tokens

In [4]:
# %%
"""---------------------------------------------------------------------
Block 4 – Sparse model: BM25 + TF‑IDF                                    
---------------------------------------------------------------------"""
class SparseIndexer:
    def __init__(self, text_type="title_abstract_claims", max_features=40_000):
        self.text_type = text_type
        self.max_features = max_features
        self.tfidf = None
        self.bm25 = None
        self.doc_ids = None

    # ---------- TF‑IDF ------------
    def fit_tfidf(self, texts: list[str]):
        vec = TfidfVectorizer(max_features=self.max_features,
                              ngram_range=(1,2),
                              min_df=2,
                              sublinear_tf=True,
                              stop_words='english')
        self.tfidf = vec.fit_transform(texts)
        return vec

    # ---------- BM25 --------------
    def fit_bm25(self, tokenized_texts: list[list[str]]):
        self.bm25 = BM25Okapi(tokenized_texts)
        return self.bm25

    # ---------- Ranking helpers ---
    def rank_tfidf(self, query_vec, top_k):
        sims = cosine_similarity(query_vec, self.tfidf, dense_output=False)
        return np.asarray(sims.toarray()[0])[:], sims

    def rank_bm25(self, query_toks, top_k):
        return np.array(self.bm25.get_scores(query_toks))


# FINETUNE DENSE MODEL

In [5]:
# %% Block: Fine-tuning Setup
import os
import random
import math
from pathlib import Path
import json
from tqdm.auto import tqdm
import pandas as pd

try:
    import torch
    from torch.utils.data import DataLoader
    from sentence_transformers import SentenceTransformer, InputExample, losses, models, util as st_util
    from sentence_transformers.evaluation import InformationRetrievalEvaluator
    FINETUNE_OK = True
except ImportError:
    print("Error: sentence-transformers or torch not installed. Cannot perform fine-tuning.")
    FINETUNE_OK = False

if FINETUNE_OK:
    # 1. Model Choice
    BASE_MODEL_NAME = "intfloat/e5-large-v2"
    FINETUNED_MODEL_PATH = Path("./fine_tuned_patent_model")
    FINETUNED_MODEL_PATH.mkdir(exist_ok=True)

    # 2. Data Configuration
    TEXT_TYPE_FINETUNE = "title_abstract_claims"

    # 3. Training Hyperparameters
    NUM_EPOCHS = 3
    TRAIN_BATCH_SIZE = 8
    EVAL_BATCH_SIZE = 16
    WARMUP_STEPS = 100
    LEARNING_RATE = 2e-5
    EVALUATION_STEPS = 500
    MAX_SEQ_LENGTH = 512
    USE_AMP = True

    # 4. Validation Set Split
    VALIDATION_SPLIT_FRACTION = 0.1
    SEED = 2025 # Make sure SEED is defined globally earlier

    # --- Ensure Core Data is Loaded ---
    try:
        # Check if necessary variables exist
        if 'CITING_TRAIN' not in locals() or 'NONCITING' not in locals() or 'MAP_DF' not in locals():
             print("Reloading data...")
             CITING_TRAIN, _, NONCITING, MAP_DF = load_all() # Assuming load_all() is defined
        if 'build_corpus' not in locals():
             raise NameError("build_corpus function not defined.")
    except NameError as e:
        print(f"Error: Required data or function not found: {e}")
        FINETUNE_OK = False # Prevent proceeding


if FINETUNE_OK:
    print("Preparing data for fine-tuning with prefixes...")

    # --- Helper functions for prefixes ---
    def add_query_prefix(text):
        # Ensure text is a string
        if not isinstance(text, str):
             text = str(text) # Basic conversion if not string
        return f"query: {text}"

    def add_passage_prefix(text):
        # Ensure text is a string
        if not isinstance(text, str):
             text = str(text) # Basic conversion if not string
        return f"passage: {text}"

    # --- Build Text Maps (ID -> Text) - NO PREFIXES HERE ---
    print("Building text lookup maps...")
    citing_text_map = {}
    # Assuming build_corpus is defined and accessible
    citing_ids_ft, citing_texts_ft = build_corpus(CITING_TRAIN, TEXT_TYPE_FINETUNE)
    for doc_id, text in zip(citing_ids_ft, citing_texts_ft):
        citing_text_map[doc_id] = text

    nonciting_text_map = {}
    nonciting_ids_ft, nonciting_texts_ft = build_corpus(NONCITING, TEXT_TYPE_FINETUNE)
    for doc_id, text in zip(nonciting_ids_ft, nonciting_texts_ft):
        nonciting_text_map[doc_id] = text
    print(f"Built maps: {len(citing_text_map)} citing, {len(nonciting_text_map)} non-citing docs.")


    # --- Create Positive Training Examples & Dev Set ---
    train_examples = []
    dev_queries = {} # {query_id: query_text_with_prefix}
    dev_corpus = {} # {doc_id: doc_text_with_prefix}
    dev_relevant_docs = {} # {query_id: set(relevant_doc_ids)}

    # Split citing patents into train and validation sets
    all_citing_ids = list(citing_text_map.keys())
    random.seed(SEED) # Use the global seed
    random.shuffle(all_citing_ids)
    split_idx = int(len(all_citing_ids) * (1 - VALIDATION_SPLIT_FRACTION))
    train_citing_ids = set(all_citing_ids[:split_idx])
    dev_citing_ids = set(all_citing_ids[split_idx:])

    print(f"Split: {len(train_citing_ids)} train queries, {len(dev_citing_ids)} dev queries.")

    processed_pairs = 0
    missing_texts = 0
    dev_positives_count = 0
    # Use MAP_DF for positive pairs
    for _, row in tqdm(MAP_DF.iterrows(), total=len(MAP_DF), desc="Processing citation pairs"):
        # Make sure column indices/names match your MAP_DF structure
        citing_id = row[0]  # Assuming first column is citing ID
        cited_id = row[2]   # Assuming third column is cited ID

        query_text_orig = citing_text_map.get(citing_id)
        # Positive text can come from citing or non-citing corpus
        positive_text_orig = citing_text_map.get(cited_id)
        if not positive_text_orig:
            positive_text_orig = nonciting_text_map.get(cited_id)

        if query_text_orig and positive_text_orig:
            # <<< --- ADD PREFIXES HERE --- >>>
            prefixed_query = add_query_prefix(query_text_orig)
            prefixed_positive = add_passage_prefix(positive_text_orig)

            if citing_id in train_citing_ids:
                # Use prefixed texts for InputExample
                train_examples.append(InputExample(texts=[prefixed_query, prefixed_positive]))
                processed_pairs += 1
            elif citing_id in dev_citing_ids:
                # Prepare data for InformationRetrievalEvaluator using prefixed texts
                if citing_id not in dev_queries:
                    dev_queries[citing_id] = prefixed_query # Store prefixed query
                # Add positive doc to corpus and relevant docs (use prefixed text for corpus)
                if cited_id not in dev_corpus:
                    dev_corpus[cited_id] = prefixed_positive # Store prefixed passage
                dev_relevant_docs.setdefault(citing_id, set()).add(cited_id)
                dev_positives_count += 1
        else:
            missing_texts += 1
            # Optional: Log which IDs were missing
            # if not query_text_orig: print(f"Missing text for citing_id: {citing_id}")
            # if not positive_text_orig: print(f"Missing text for cited_id: {cited_id}")


    print(f"Created {len(train_examples)} training examples.")
    print(f"Prepared {len(dev_queries)} dev queries with {dev_positives_count} positive relations.")
    if missing_texts > 0:
        print(f"Warning: Skipped {missing_texts} pairs due to missing text for citing or cited patents.")

    # --- Add Negative Examples to Dev Corpus (With Prefix) ---
    # We need some *non-relevant* documents in the dev corpus for the evaluator
    # Sample from NONCITING corpus.
    num_dev_negatives_needed = len(dev_corpus) * 5 # Aim for ~5x more negatives than positives
    dev_negatives_added = 0
    nonciting_ids_list = list(nonciting_text_map.keys())
    random.shuffle(nonciting_ids_list) # Shuffle to get random negatives

    for neg_id in nonciting_ids_list:
        if dev_negatives_added >= num_dev_negatives_needed:
            break
        if neg_id not in dev_corpus: # Avoid adding duplicates or existing positives
            neg_text = nonciting_text_map.get(neg_id)
            if neg_text: # Ensure text exists
                 # <<< --- ADD PREFIX HERE --- >>>
                 dev_corpus[neg_id] = add_passage_prefix(neg_text) # Add passage prefix
                 dev_negatives_added += 1

    print(f"Added {dev_negatives_added} negative documents (passages) to the dev corpus.")
    print(f"Total dev corpus size: {len(dev_corpus)}")

    # --- Sanity Check ---
    if not train_examples:
        print("Error: No training examples were created. Check data loading and matching.")
        FINETUNE_OK = False
    if not dev_queries or not dev_corpus or not dev_relevant_docs:
        print("Error: Validation set data is incomplete. Check validation split and data processing.")
        FINETUNE_OK = False
    # Check if any dev query ID has an empty relevant set (should not happen if processed correctly)
    for qid in dev_queries:
        if not dev_relevant_docs.get(qid):
             print(f"Warning: Dev query {qid} has no relevant documents listed in dev_relevant_docs.")


# %% Block: Model, Loss, and Evaluator Setup

if FINETUNE_OK:
    print(f"Loading base model: {BASE_MODEL_NAME}")
    # Option 1: Use a pre-trained model as is
    model = SentenceTransformer(BASE_MODEL_NAME)

    # Option 2: Add pooling layer if needed (e.g., if base is just transformer)
    # word_embedding_model = models.Transformer(BASE_MODEL_NAME, max_seq_length=MAX_SEQ_LENGTH)
    # pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
    # model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

    # Truncate long patent texts if they exceed model capacity
    model.max_seq_length = MAX_SEQ_LENGTH
    print(f"Model max sequence length set to: {model.max_seq_length}")

    # --- Loss Function ---
    # MultipleNegativesRankingLoss is recommended for training with (anchor, positive) pairs.
    # It uses other examples in the batch as negatives.
    loss = losses.MultipleNegativesRankingLoss(model=model)
    print("Using MultipleNegativesRankingLoss.")

    # --- Dataloader ---
    # NoDuplicatesDataLoader ensures no duplicate texts are in the same batch,
    # useful for MNRL as it prevents trivial negative examples.
    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=TRAIN_BATCH_SIZE)
    print(f"Train dataloader created with batch size {TRAIN_BATCH_SIZE}.")

    # --- Evaluator ---
    # Uses the prepared dev set components
    evaluator = InformationRetrievalEvaluator(
        queries=dev_queries,              # dict: {query_id: query_text}
        corpus=dev_corpus,                # dict: {doc_id: doc_text}
        relevant_docs=dev_relevant_docs,  # dict: {query_id: set(relevant_doc_ids)}
        batch_size=EVAL_BATCH_SIZE,
        main_score_function='cosine',     # How to compare embeddings
        score_functions={'cos_sim': st_util.cos_sim},
        name='patent_dev',
        show_progress_bar=True,
        write_csv=True,                   # Saves detailed eval results
    )
    print("InformationRetrievalEvaluator configured.")


# %% Block: Training Execution

if FINETUNE_OK:
    print("\n--- Starting Fine-Tuning Training ---")

    # Calculate number of steps per epoch if needed for logging/scheduling
    steps_per_epoch = math.ceil(len(train_dataloader))
    print(f"Steps per epoch: {steps_per_epoch}")
    if EVALUATION_STEPS > steps_per_epoch:
        print(f"Warning: EVALUATION_STEPS ({EVALUATION_STEPS}) > steps per epoch ({steps_per_epoch}). Evaluation will happen less than once per epoch.")


    # --- Run Training ---
    model.fit(
        train_objectives=[(train_dataloader, loss)],
        evaluator=evaluator,
        epochs=NUM_EPOCHS,
        evaluation_steps=EVALUATION_STEPS, # Evaluate every N steps
        warmup_steps=WARMUP_STEPS,
        optimizer_params={'lr': LEARNING_RATE},
        output_path=str(FINETUNED_MODEL_PATH), # Save checkpoints here
        save_best_model=True,          # Save the model with the best MAP score on dev
        checkpoint_path=str(FINETUNED_MODEL_PATH / "checkpoints"),
        checkpoint_save_steps=EVALUATION_STEPS * 2, # Save checkpoints less frequently than eval
        checkpoint_save_total_limit=3, # Keep only last 3 checkpoints
        use_amp=USE_AMP,                 # Enable mixed precision
        show_progress_bar=True
    )

    print(f"\n--- Fine-tuning finished ---")
    print(f"Best model saved to: {FINETUNED_MODEL_PATH}")

    # Optional: Load the best model immediately for use
    # print("Loading best fine-tuned model...")
    # fine_tuned_model = SentenceTransformer(str(FINETUNED_MODEL_PATH))
    # Now you could potentially replace the model in your DenseIndexer

Preparing data for fine-tuning with prefixes...
Building text lookup maps...
Built maps: 6831 citing, 16834 non-citing docs.
Split: 6147 train queries, 684 dev queries.


Processing citation pairs: 100%|██████████| 8594/8594 [00:00<00:00, 29104.42it/s]


Created 7726 training examples.
Prepared 684 dev queries with 868 positive relations.
Added 4250 negative documents (passages) to the dev corpus.
Total dev corpus size: 5100
Loading base model: intfloat/e5-large-v2
Model max sequence length set to: 512
Using MultipleNegativesRankingLoss.
Train dataloader created with batch size 8.
InformationRetrievalEvaluator configured.

--- Starting Fine-Tuning Training ---
Steps per epoch: 966


                                                                     

Step,Training Loss,Validation Loss,Patent Dev Cos Sim Accuracy@1,Patent Dev Cos Sim Accuracy@3,Patent Dev Cos Sim Accuracy@5,Patent Dev Cos Sim Accuracy@10,Patent Dev Cos Sim Precision@1,Patent Dev Cos Sim Precision@3,Patent Dev Cos Sim Precision@5,Patent Dev Cos Sim Precision@10,Patent Dev Cos Sim Recall@1,Patent Dev Cos Sim Recall@3,Patent Dev Cos Sim Recall@5,Patent Dev Cos Sim Recall@10,Patent Dev Cos Sim Ndcg@10,Patent Dev Cos Sim Mrr@10,Patent Dev Cos Sim Map@100
500,0.1257,No log,0.425439,0.611111,0.679825,0.776316,0.425439,0.232943,0.159357,0.09269,0.361672,0.561964,0.634089,0.740205,0.570572,0.53699,0.518985
966,0.1257,No log,0.447368,0.666667,0.741228,0.818713,0.447368,0.249513,0.171637,0.098246,0.384722,0.613109,0.693324,0.783528,0.606097,0.568564,0.547834
1000,0.0747,No log,0.438596,0.628655,0.716374,0.80848,0.438596,0.237817,0.166959,0.096784,0.376194,0.577827,0.670492,0.771735,0.592413,0.554257,0.534151
1500,0.0405,No log,0.461988,0.668129,0.73538,0.831871,0.461988,0.252924,0.174269,0.101462,0.398733,0.617154,0.697515,0.803363,0.62402,0.583261,0.565443
1932,0.0405,No log,0.473684,0.682749,0.748538,0.836257,0.473684,0.257797,0.176023,0.101608,0.40982,0.632505,0.708894,0.805312,0.632035,0.594077,0.574289
2000,0.031,No log,0.482456,0.682749,0.748538,0.837719,0.482456,0.257797,0.176316,0.10117,0.419688,0.631603,0.704995,0.804337,0.635349,0.599174,0.580144
2500,0.0198,No log,0.475146,0.675439,0.744152,0.840643,0.475146,0.255361,0.174561,0.101754,0.411769,0.628874,0.70056,0.807505,0.632261,0.594057,0.574908
2898,0.0198,No log,0.475146,0.685673,0.754386,0.840643,0.475146,0.259747,0.176316,0.102047,0.411038,0.639352,0.710307,0.809698,0.63618,0.597831,0.579078


Batches: 100%|██████████| 43/43 [00:05<00:00,  7.85it/s]
Batches: 100%|██████████| 319/319 [00:41<00:00,  7.70it/s]
Corpus Chunks: 100%|██████████| 1/1 [00:42<00:00, 42.35s/it]
Batches: 100%|██████████| 43/43 [00:05<00:00,  7.64it/s]
Batches: 100%|██████████| 319/319 [00:42<00:00,  7.53it/s]
Corpus Chunks: 100%|██████████| 1/1 [00:43<00:00, 43.29s/it]
Batches: 100%|██████████| 43/43 [00:05<00:00,  7.66it/s]
Batches: 100%|██████████| 319/319 [00:42<00:00,  7.54it/s]
Corpus Chunks: 100%|██████████| 1/1 [00:43<00:00, 43.29s/it]
Batches: 100%|██████████| 43/43 [00:05<00:00,  7.66it/s]
Batches: 100%|██████████| 319/319 [00:42<00:00,  7.55it/s]
Corpus Chunks: 100%|██████████| 1/1 [00:43<00:00, 43.22s/it]
Batches: 100%|██████████| 43/43 [00:05<00:00,  7.66it/s]
Batches: 100%|██████████| 319/319 [00:42<00:00,  7.55it/s]
Corpus Chunks: 100%|██████████| 1/1 [00:43<00:00, 43.20s/it]
Batches: 100%|██████████| 43/43 [00:05<00:00,  7.63it/s]
Batches: 100%|██████████| 319/319 [00:42<00:00,  7.42it/s]


--- Fine-tuning finished ---
Best model saved to: fine_tuned_patent_model


# PUSH FINETUNED MODEL TO HF

In [12]:
# os.environ["HF_TOKEN"] = "YOUR_TOKEN"

repo_name = "e5-large-v2-patent"
repo_id = f"petkopetkov/{repo_name}"

model_path = f"{FINETUNED_MODEL_PATH}/checkpoints/checkpoint-2898"

print(f"Loading best model from: {model_path}")
best_model = SentenceTransformer(str(model_path))

best_model.save_to_hub(
    repo_id=repo_id,
)
print(f"✅ Model successfully uploaded to: https://huggingface.co/{repo_id}")

Loading best model from: fine_tuned_patent_model/checkpoints/checkpoint-2898


The `save_to_hub` method is deprecated and will be removed in a future version of SentenceTransformers. Please use `push_to_hub` instead for future model uploads.
model.safetensors: 100%|██████████| 1.34G/1.34G [18:15<00:00, 1.22MB/s] 


✅ Model successfully uploaded to: https://huggingface.co/petkopetkov/e5-large-v2-patent


In [5]:
# %%
"""---------------------------------------------------------------------
Block 5 – Dense model & ANN search                                       
---------------------------------------------------------------------"""
class DenseIndexer:
    def __init__(self,
                 model_name="petkopetkov/e5-large-v2-patent", # Default base model
                 finetuned_model_path=None,      # Path to fine-tuned model (optional)
                 batch_size=256,                 # Adjust based on GPU memory
                 ef_construction=200,            # HNSW build parameter
                 M=64,                           # HNSW build parameter
                 ef_search=300):                 # HNSW search parameter (can be tuned)
        if not DENSE_OK:
            raise RuntimeError("sentence-transformers or torch is not installed. Cannot initialize DenseIndexer.")

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"DenseIndexer using device: {self.device}")
        
        print(f"Loading model: ", model_name)
        
        try:
            self.model = SentenceTransformer(model_name, device=self.device)
        except Exception as e:
            print(f"Error loading Sentence Transformer model from '{model_name}': {e}")
            raise # Re-raise error as model loading is critical

        # --- Store parameters ---
        self.bs = batch_size
        self.index = None
        self.doc_emb = None # Optionally store embeddings if needed later
        self.doc_ids = None
        self.ef_construction = ef_construction
        self.M = M
        self.ef_search = ef_search # Default search ef

    # --- Static methods for prefixing (can be called without class instance) ---
    @staticmethod
    def _add_query_prefix(text: str) -> str:
        """Adds 'query: ' prefix."""
        if not isinstance(text, str): text = str(text) # Basic type safety
        return f"query: {text}"

    @staticmethod
    def _add_passage_prefix(text: str) -> str:
        """Adds 'passage: ' prefix."""
        if not isinstance(text, str): text = str(text) # Basic type safety
        return f"passage: {text}"

    def _embed(self, texts_with_prefix: list[str], desc: str ="Embedding") -> np.ndarray:
        """Internal embedding function expecting prefixed texts."""
        if not texts_with_prefix:
            print("Warning: _embed called with empty list of texts.")
            return np.array([], dtype=np.float32).reshape(0, self.model.get_sentence_embedding_dimension())

        print(f"Embedding {len(texts_with_prefix)} texts ({desc})...")
        try:
            embeddings = self.model.encode(
                texts_with_prefix,
                convert_to_numpy=True,
                batch_size=self.bs,
                show_progress_bar=True,
                normalize_embeddings=True,       # E5/GTE models benefit from normalization
                output_value='sentence_embedding',
                device=self.device
            )
            # Ensure output is float32 for hnswlib compatibility if needed
            return embeddings.astype(np.float32)
        except Exception as e:
            print(f"Error during sentence embedding: {e}")
            # Depending on error, might want to return empty or raise
            raise # Re-raise error for clarity

    def fit(self, doc_texts: list[str], ids: list[str]):
        """
        Builds the HNSW index from the provided document texts.
        Automatically adds the 'passage:' prefix.
        """
        if len(doc_texts) != len(ids):
            raise ValueError(f"Number of document texts ({len(doc_texts)}) must match number of IDs ({len(ids)}).")
        if not doc_texts:
            print("Warning: fit called with empty documents list. Index will be empty.")
            self.doc_ids = []
            self.index = None
            self.doc_emb = None
            return # Nothing to index

        self.doc_ids = list(ids) # Store a copy

        # --- Add "passage:" prefix before embedding ---
        print(f"Adding 'passage:' prefix to {len(doc_texts)} candidate documents...")
        prefixed_doc_texts = [self._add_passage_prefix(text) for text in tqdm(doc_texts, desc="Prefixing passages")]

        # --- Embed the prefixed documents ---
        emb = self._embed(prefixed_doc_texts, desc="Embedding Candidates")
        if emb.shape[0] == 0: # Handle case where embedding failed or returned empty
             print("Error: Embeddings could not be generated. Index cannot be built.")
             self.index = None
             self.doc_emb = None
             return

        dim = emb.shape[1]
        num_elements = emb.shape[0]
        print(f"Building HNSW index (dim={dim}, M={self.M}, ef_construction={self.ef_construction}, num_elements={num_elements})...")

        # --- Initialize and build HNSW index ---
        try:
            idx = hnswlib.Index(space='cosine', dim=dim) # Cosine distance is 1 - cosine similarity
            idx.init_index(max_elements=num_elements, ef_construction=self.ef_construction, M=self.M)
            # Add items requires numpy array of indices 0..N-1
            idx.add_items(emb, np.arange(num_elements))
            idx.set_ef(self.ef_search) # Set default search ef
            self.index = idx
            self.doc_emb = emb # Store embeddings if needed
            print("HNSW index built successfully.")
        except Exception as e:
            print(f"Error building HNSW index: {e}")
            self.index = None
            self.doc_emb = None
            raise # Propagate error


    def search(self, query_texts: list[str], top_k: int, current_ef_search: int | None = None) -> tuple[np.ndarray, np.ndarray]:
        """
        Searches the index for the given query texts.
        Automatically adds the 'query:' prefix. Returns labels (indices) and similarities.
        """
        if self.index is None:
            raise RuntimeError("Index not fitted or failed to build. Call fit() first.")
        if not query_texts:
            print("Warning: search called with empty query list.")
            return np.array([]), np.array([]) # Return empty arrays

        # --- Determine and set ef_search for this query batch ---
        search_ef = current_ef_search if current_ef_search is not None else self.ef_search
        if search_ef <= 0:
             print(f"Warning: ef_search value ({search_ef}) is invalid, using default: {self.ef_search}")
             search_ef = self.ef_search
        try:
            self.index.set_ef(search_ef)
        except Exception as e:
             print(f"Warning: Failed to set ef_search to {search_ef}. Using previous value. Error: {e}")
             # Proceed with the existing ef setting in the index object
        print(f"Searching with ef_search={self.index.ef}...") # Print actual ef being used

        # --- Add "query:" prefix before embedding ---
        print(f"Adding 'query:' prefix to {len(query_texts)} search queries...")
        prefixed_query_texts = [self._add_query_prefix(text) for text in tqdm(query_texts, desc="Prefixing queries")]

        # --- Embed the prefixed queries ---
        q_emb = self._embed(prefixed_query_texts, desc="Embedding Queries")
        if q_emb.shape[0] == 0:
             print("Error: Query embeddings could not be generated.")
             return np.array([]), np.array([])

        # --- Perform KNN search ---
        print(f"Performing knn_query for {len(prefixed_query_texts)} queries (top_k={top_k})...")
        try:
             # Ensure k is not larger than the number of items in the index
             actual_k = min(top_k, self.index.get_current_count())
             if actual_k <= 0:
                 print("Warning: top_k or index count is zero, cannot perform search.")
                 return np.array([]), np.array([])

             labels, distances = self.index.knn_query(q_emb, k=actual_k)
             # Convert cosine distances (0=identical, 2=opposite) to similarities (1=identical, -1=opposite)
             similarities = 1 - distances
             print("Dense search complete.")
             return labels, similarities
        except Exception as e:
             print(f"Error during knn_query: {e}")
             # Return empty arrays or re-raise depending on desired behavior
             return np.array([]), np.array([])


In [6]:
# %%
"""---------------------------------------------------------------------
Block 6 – Evaluation metrics & RRF fusion                                
---------------------------------------------------------------------"""
from sklearn.preprocessing import minmax_scale # For score normalization

# --- Evaluation Functions (unchanged) ---
def recall_at_k(true_sets, pred_lists, k=100):
    # Handle cases where true set might be empty
    hits = [(len(t.intersection(p[:k])) / len(t)) if t else (1.0 if not p else 0.0) for t,p in zip(true_sets, pred_lists)]
    return np.mean(hits)

def average_precision(true, pred, k=100):
    if not true:
        return 1.0 if not pred else 0.0 # AP is 1 if nothing needed and nothing predicted
    score, hits = 0.0, 0
    relevant_k = pred[:k] # Consider only top K predictions
    for i, p in enumerate(relevant_k):
        if p in true:
            hits += 1
            score += hits / (i + 1.0)
    # Normalize by the minimum of k or number of true items
    return score / min(len(true), k)

def map_at_k(true_sets, pred_lists, k=100):
    return np.mean([average_precision(t, p, k) for t, p in zip(true_sets, pred_lists)])

# --- Fusion Functions ---

# RRF (unchanged)
def rrf_fuse(rank_lists: list[list[str]], k=60):
    scores = {}
    for rlist in rank_lists:
        for rank, doc in enumerate(rlist):
            # Handle potential empty lists
            if doc:
                scores[doc] = scores.get(doc, 0) + 1 / (k + rank + 1)
    return [d for d, _ in sorted(scores.items(), key=lambda x: -x[1])]

# Weighted Fusion
def weighted_fuse(score_dicts: list[dict[str, float]], weights: list[float], default_score=0.0):
    if len(score_dicts) != len(weights):
        raise ValueError("Number of score dictionaries must match number of weights.")

    fused_scores = {}
    all_docs = set()
    for scores in score_dicts:
        all_docs.update(scores.keys())

    for doc in all_docs:
        weighted_score = 0.0
        for i, scores in enumerate(score_dicts):
            score = scores.get(doc, default_score) # Use default if doc not found by a method
            weighted_score += weights[i] * score
        fused_scores[doc] = weighted_score

    return [d for d, _ in sorted(fused_scores.items(), key=lambda x: -x[1])]

# --- Normalization Function ---
def normalize_scores(score_dict: dict[str, float]) -> dict[str, float]:
    """Normalizes scores within a dictionary to [0, 1] range using min-max."""
    if not score_dict:
        return {}
    scores = np.array(list(score_dict.values())).reshape(-1, 1)
    # Handle case where all scores are the same (avoid division by zero)
    if np.all(scores == scores[0]):
         normalized = np.full_like(scores, 0.5) # Or 1.0, or 0.0 depending on preference
    else:
        normalized = minmax_scale(scores)
    return {doc: float(norm_score) for doc, norm_score in zip(score_dict.keys(), normalized.flatten())}

In [7]:
"""---------------------------------------------------------------------
Block 4 – Sparse model: BM25 with rank_bm25 library
---------------------------------------------------------------------"""
# Note: This replaces the previous SparseIndexer and the custom BM25Score logic

# Parameters (Keep defaults or use ones from your previous tuning if preferred)
BM25_K1 = 1.5
BM25_B = 0.75
TEXT_TYPE_BM25 = "title_abstract_claims" # Text parts to use for BM25

# 1) Build corpora (IDs are the same, texts are used for tokenization)
print("Building corpora...")
CITING_IDS, CITING_TEXTS = build_corpus(CITING_TRAIN, TEXT_TYPE_BM25)
NON_IDS,   NON_TEXTS     = build_corpus(NONCITING,    TEXT_TYPE_BM25)
print(f"{len(CITING_IDS)} citing docs – {len(NON_IDS)} candidate docs")

import multiprocessing as mp
from functools import partial
import math

# ... (Keep previous code in the block: imports, BM25 params, corpus building, tokenization) ...

# 2) Tokenize corpora (Assuming this is already done and cached)
print("Tokenizing candidate documents for BM25...")
tokenized_corpus = cached_tokens("nonciting_corpus", NON_TEXTS, stem=True)
print("Tokenizing query documents for BM25...")
tokenized_queries = cached_tokens("citing_queries", CITING_TEXTS, stem=True)

# 3) Fit BM25Okapi model (Same as before)
print(f"Fitting BM25Okapi (k1={BM25_K1}, b={BM25_B})...")
# Make bm25 global ONLY if using the simple access method below,
# otherwise pass necessary data to the worker function.
# global bm25 # <-- Be cautious with globals in multiprocessing
bm25 = BM25Okapi(tokenized_corpus, k1=BM25_K1, b=BM25_B)
print("BM25Okapi model fitted.")

# --- Worker Function for Parallel Scoring ---
# IMPORTANT: BM25Okapi objects themselves aren't easily pickled for multiprocessing.
# We pass the necessary components OR rely on a global object (less safe).
# A safer approach is to pass the index data if possible, or re-initialize
# lightweight components if fitting isn't the bottleneck (but here it is).
# Let's TRY accessing the fitted 'bm25' object as a global, but be aware
# this might be problematic on some systems/configurations.

# Define the worker function *outside* any class, at the top level of the module/script
def score_query_batch(query_indices, bm25_model, query_tokens_list, non_ids_list, top_k):
    """Scores a batch of queries using the provided BM25 model."""
    results = {}
    for i in query_indices:
        query_toks = query_tokens_list[i]
        try:
             # Calculate scores against the *full* corpus
            doc_scores = bm25_model.get_scores(query_toks)

            # Get top K indices (no need to sort all 17k scores fully)
            # Using argpartition is faster than argsort for finding top K
            # Get indices of the top K scores (might not be sorted among themselves)
            k_th_score_idx = min(top_k, len(doc_scores) -1) # Ensure k is not larger than available docs
            if k_th_score_idx < 0: # Handle empty doc_scores case
                 top_n_indices = []
                 top_n_scores = []
            else:
                # Efficiently find the indices of the top K scores
                top_n_indices = np.argpartition(doc_scores, -top_k)[-top_k:]
                # Sort only the top K indices by score
                top_n_scores = doc_scores[top_n_indices]
                sorted_top_indices = top_n_indices[np.argsort(top_n_scores)[::-1]] # Sort descending
                top_n_scores = doc_scores[sorted_top_indices] # Get sorted scores
                top_n_indices = sorted_top_indices

        except Exception as e:
            print(f"Error scoring query index {i}: {e}")
            top_n_indices = []
            top_n_scores = []

        # Store ranks (doc IDs) and scores
        query_id = CITING_IDS[i] # Assumes CITING_IDS is accessible
        ranks = [non_ids_list[idx] for idx in top_n_indices]
        scores = {non_ids_list[idx]: score for idx, score in zip(top_n_indices, top_n_scores)}
        results[query_id] = (ranks, scores)
    return results


# 4) Get BM25 scores and rankings in parallel
TOP_K_BM25 = 100
N_CORES = mp.cpu_count() - 1 or 1 # Use all but one core, or 1 if only one exists
CHUNK_SIZE = math.ceil(len(tokenized_queries) / N_CORES / 4) # Adjust chunk size for progress updates/memory

print(f"Calculating BM25 scores in parallel using {N_CORES} cores (top {TOP_K_BM25})...")

# Prepare arguments for the worker function
# We pass the bm25 object directly - this might work on Linux/macOS via fork
# but could fail on Windows (spawn) or if the object is too complex.
# Also pass lists needed inside the worker.
worker_args = (bm25, tokenized_queries, NON_IDS, TOP_K_BM25)

# Create batches of query indices
query_indices_all = list(range(len(tokenized_queries)))
index_batches = [query_indices_all[i:i + CHUNK_SIZE] for i in range(0, len(tokenized_queries), CHUNK_SIZE)]

RANKS_BM25_PATH = 'RANKS_BM25.pkl'

if os.path.exists(RANKS_BM25_PATH):
    print("Loading cached BM25 ranks from: ", RANKS_BM25_PATH)
    with open(RANKS_BM25_PATH, 'rb') as f:
        RANKS_BM25 = pickle.load(f)
else:
    print("BM25 ranks not found, will be calculated.")
    RANKS_BM25 = {}
    
SCORES_BM25_PATH = 'SCORES_BM25.pkl'
    
if os.path.exists(SCORES_BM25_PATH):
    print("Loading cached BM25 scores from: ", SCORES_BM25_PATH)
    with open(SCORES_BM25_PATH, 'rb') as f:
        SCORES_BM25 = pickle.load(f)
else:
    print("BM25 scores not found, will be calculated.")
    SCORES_BM25 = {}

if not bool(RANKS_BM25) and not bool(SCORES_BM25):
    # Run in parallel
    # Use imap_unordered for potentially faster processing as results arrive
    # Need to wrap the worker call with arguments using partial
    process_func = partial(score_query_batch, bm25_model=bm25, query_tokens_list=tokenized_queries, non_ids_list=NON_IDS, top_k=TOP_K_BM25)

    with mp.Pool(N_CORES) as pool:
        with tqdm(total=len(tokenized_queries), desc="Parallel BM25 Scoring") as pbar:
            for result_dict in pool.imap_unordered(process_func, index_batches):
                for cid, (ranks, scores) in result_dict.items():
                    RANKS_BM25[cid] = ranks
                    SCORES_BM25[cid] = scores
                pbar.update(len(result_dict)) # Update progress bar by number of queries processed in the batch
                
    with open(RANKS_BM25_PATH, 'wb') as f:
        pickle.dump(RANKS_BM25, f)
        
    with open(SCORES_BM25_PATH, 'wb') as f:
        pickle.dump(SCORES_BM25, f)

    print(f"✅ BM25 pre-ranking done (parallel) – top {TOP_K_BM25} docs stored for {len(RANKS_BM25)} queries")

# --- Make sure GT_SETS is defined after CITING_IDS ---
TRUE_DICT = {}
for _, row in MAP_DF.iterrows():
    TRUE_DICT.setdefault(row[0], []).append(row[2])
GT_SETS = [set(TRUE_DICT.get(cid, [])) for cid in CITING_IDS] # Moved here or ensure CITING_IDS is available earlier

Building corpora...
6831 citing docs – 16837 candidate docs
Tokenizing candidate documents for BM25...
Loading cached tokens from: .cache/nonciting_corpus_stemmed.pkl
Tokenizing query documents for BM25...
Loading cached tokens from: .cache/citing_queries_stemmed.pkl
Fitting BM25Okapi (k1=1.5, b=0.75)...
BM25Okapi model fitted.
Calculating BM25 scores in parallel using 23 cores (top 100)...
Loading cached BM25 ranks from:  RANKS_BM25.pkl
Loading cached BM25 scores from:  SCORES_BM25.pkl


In [8]:
# ------------------------------------------------------------------ #
# 5. Dense model (optional)                                          #
# ------------------------------------------------------------------ #
TOP_K_DENSE = 100 # Retrieve same number as BM25 for fair comparison/fusion
TEXT_TYPE_DENSE = "title_abstract_claims" # Or maybe just title_abstract?

SCORES_DENSE_PATH = 'SCORES_DENSE.pkl'

if os.path.exists(SCORES_DENSE_PATH):
    print("Loading cached dense scores from: ", SCORES_DENSE_PATH)
    with open(SCORES_DENSE_PATH, 'rb') as f:
        SCORES_DENSE = pickle.load(f)
else:
    print("Dense scores not found, will be calculated.")
    SCORES_DENSE = {}
    
RANKS_DENSE_PATH = 'RANKS_DENSE.pkl'

if os.path.exists(RANKS_DENSE_PATH):
    print("Loading cached dense ranks from: ", RANKS_DENSE_PATH)
    with open(RANKS_DENSE_PATH, 'rb') as f:
        SCORES_DENSE = pickle.load(f)
else:
    print("Dense ranks not found, will be calculated.")
    RANKS_DENSE = {}

if DENSE_OK and not bool(RANKS_DENSE) and not bool(SCORES_DENSE):
    # Ensure NON_IDS and NON_TEXTS are loaded if running blocks separately
    if 'NON_IDS' not in locals():
        _, NON_TEXTS = build_corpus(NONCITING, TEXT_TYPE_DENSE)
        NON_IDS, _   = build_corpus(NONCITING, "title") # Need IDs consistent with BM25

    # Ensure CITING_IDS and CITING_TEXTS are loaded
    if 'CITING_IDS' not in locals():
         CITING_IDS, CITING_TEXTS = build_corpus(CITING_TRAIN, TEXT_TYPE_DENSE)

    # Initialize and Fit Dense Indexer
    de = DenseIndexer(ef_search=300) # Try a higher ef_search
    de.fit(NON_TEXTS, NON_IDS)

    # Search
    print(f"Retrieving dense neighbours (top {TOP_K_DENSE})...")
    # Use the *same* CITING_TEXTS as used for BM25 queries if TEXT_TYPE_DENSE is the same
    # Or rebuild CITING_TEXTS if using a different TEXT_TYPE_DENSE
    labels, sims = de.search(CITING_TEXTS, top_k=TOP_K_DENSE)

    RANKS_DENSE = {}
    SCORES_DENSE = {} # Store scores for weighted fusion
    for i, cid in enumerate(CITING_IDS):
        doc_indices = labels[i]
        doc_similarities = sims[i]
        valid_mask = doc_indices < len(NON_IDS) # Ensure indices are valid
        RANKS_DENSE[cid] = [NON_IDS[idx] for idx in doc_indices[valid_mask]]
        SCORES_DENSE[cid] = {NON_IDS[idx]: score for idx, score in zip(doc_indices[valid_mask], doc_similarities[valid_mask])}
        
    with open(RANKS_DENSE_PATH, 'wb') as f:
        pickle.dump(RANKS_DENSE, f)
        
    with open(SCORES_DENSE_PATH, 'wb') as f:
        pickle.dump(SCORES_DENSE, f)
        
    print(f"✅ Dense retrieval done – top {TOP_K_DENSE} docs stored for {len(RANKS_DENSE)} queries")

Dense scores not found, will be calculated.
Dense ranks not found, will be calculated.
DenseIndexer using device: cuda
Loading model:  petkopetkov/e5-large-v2-patent
Adding 'passage:' prefix to 16837 candidate documents...


Prefixing passages: 100%|██████████| 16837/16837 [00:00<00:00, 303731.93it/s]

Embedding 16837 texts (Embedding Candidates)...



Batches: 100%|██████████| 66/66 [09:52<00:00,  8.98s/it]


Building HNSW index (dim=1024, M=64, ef_construction=200, num_elements=16837)...
HNSW index built successfully.
Retrieving dense neighbours (top 100)...
Searching with ef_search=300...
Adding 'query:' prefix to 6831 search queries...


Prefixing queries: 100%|██████████| 6831/6831 [00:00<00:00, 651773.03it/s]

Embedding 6831 texts (Embedding Queries)...



Batches: 100%|██████████| 27/27 [03:38<00:00,  8.10s/it]


Performing knn_query for 6831 queries (top_k=100)...
Dense search complete.
✅ Dense retrieval done – top 100 docs stored for 6831 queries


In [9]:
# ------------------------------------------------------------------ #
# 6. Fusion & metric grid                                            #
# ------------------------------------------------------------------ #
K_EVAL = 100 # Evaluate at K=100
BEST_MAP, BEST_CFG, RESULTS = -1, {}, {}

# Ensure GT_SETS is available
if 'GT_SETS' not in locals():
     GT_SETS = [set(TRUE_DICT.get(cid, [])) for cid in CITING_IDS]

# --- RRF Evaluation ---
print("\n--- Evaluating RRF Fusion ---")
RANK_POOLS_RRF = [r for r in (RANKS_BM25, RANKS_DENSE) if r] # Use rank lists
if len(RANK_POOLS_RRF) > 0:
    for k_rrf in [10, 60, 120]:
        fused_rrf = {}
        print(f"Running RRF fusion with k={k_rrf}...")
        for cid in tqdm(CITING_IDS, desc=f"RRF (k={k_rrf})"):
            pools = [p.get(cid, []) for p in RANK_POOLS_RRF]
            fused_rrf[cid] = rrf_fuse(pools, k=k_rrf)

        preds = [fused_rrf.get(cid, [])[:K_EVAL] for cid in CITING_IDS]
        rec_k = recall_at_k(GT_SETS, preds, K_EVAL)
        map_k = map_at_k(GT_SETS, preds, K_EVAL)
        RESULTS[f"RRF_k{k_rrf}"] = dict(rec=rec_k, map=map_k)
        print(f"RRF k={k_rrf}: Recall@{K_EVAL}={rec_k:.4f}, MAP@{K_EVAL}={map_k:.4f}")
        if map_k > BEST_MAP:
            BEST_MAP, BEST_CFG = map_k, dict(method="rrf", k=k_rrf)
else:
    print("Skipping RRF evaluation as only one retrieval method seems available.")


# --- Weighted Fusion Evaluation ---
print("\n--- Evaluating Weighted Fusion ---")
SCORE_POOLS_WEIGHTED = []
if SCORES_BM25: SCORE_POOLS_WEIGHTED.append(SCORES_BM25)
if SCORES_DENSE: SCORE_POOLS_WEIGHTED.append(SCORES_DENSE)

if len(SCORE_POOLS_WEIGHTED) == 2: # Only makes sense for 2 methods currently
    # Normalize scores per query *before* fusion
    print("Normalizing scores for weighted fusion...")
    NORM_SCORES_BM25 = {cid: normalize_scores(SCORES_BM25.get(cid, {})) for cid in tqdm(CITING_IDS, desc="Normalizing BM25")}
    NORM_SCORES_DENSE = {cid: normalize_scores(SCORES_DENSE.get(cid, {})) for cid in tqdm(CITING_IDS, desc="Normalizing Dense")}

    for alpha in [0.1, 0.3, 0.5, 0.7, 0.9]: # Alpha = weight for BM25
        weights = [alpha, 1 - alpha]
        fused_weighted = {}
        print(f"Running Weighted fusion with alpha={alpha:.2f}...")
        for cid in tqdm(CITING_IDS, desc=f"Weighted (a={alpha:.2f})"):
            scores_to_fuse = [
                NORM_SCORES_BM25.get(cid, {}),
                NORM_SCORES_DENSE.get(cid, {})
            ]
            fused_weighted[cid] = weighted_fuse(scores_to_fuse, weights, default_score=0.0) # Use 0 as default for missing docs

        preds = [fused_weighted.get(cid, [])[:K_EVAL] for cid in CITING_IDS]
        rec_k = recall_at_k(GT_SETS, preds, K_EVAL)
        map_k = map_at_k(GT_SETS, preds, K_EVAL)
        result_key = f"Weighted_a{alpha:.1f}"
        RESULTS[result_key] = dict(rec=rec_k, map=map_k)
        print(f"Weighted alpha={alpha:.1f}: Recall@{K_EVAL}={rec_k:.4f}, MAP@{K_EVAL}={map_k:.4f}")
        if map_k > BEST_MAP:
            BEST_MAP, BEST_CFG = map_k, dict(method="weighted", alpha=alpha)
elif len(SCORE_POOLS_WEIGHTED) == 1:
     print("Skipping Weighted fusion evaluation as only one retrieval method has scores.")
     # Evaluate the single method
     method_name = "BM25" if SCORES_BM25 else "Dense"
     single_method_ranks = RANKS_BM25 if SCORES_BM25 else RANKS_DENSE
     preds = [single_method_ranks.get(cid, [])[:K_EVAL] for cid in CITING_IDS]
     rec_k = recall_at_k(GT_SETS, preds, K_EVAL)
     map_k = map_at_k(GT_SETS, preds, K_EVAL)
     RESULTS[method_name] = dict(rec=rec_k, map=map_k)
     print(f"Single Method ({method_name}): Recall@{K_EVAL}={rec_k:.4f}, MAP@{K_EVAL}={map_k:.4f}")
     if map_k > BEST_MAP:
        BEST_MAP, BEST_CFG = map_k, dict(method=method_name)
else:
    print("Skipping Weighted fusion evaluation as no scores were generated.")


# --- Final Results ---
print("\n--- Evaluation Summary ---")
print("Candidate results:", json.dumps(RESULTS, indent=2))
if BEST_CFG:
    print(f"Best config: {BEST_CFG}   MAP@{K_EVAL} = {BEST_MAP:.4f}")
else:
    print("No valid configurations were evaluated.")



--- Evaluating RRF Fusion ---
Running RRF fusion with k=10...


RRF (k=10):   0%|          | 0/6831 [00:00<?, ?it/s]

RRF (k=10): 100%|██████████| 6831/6831 [00:01<00:00, 4426.43it/s]


RRF k=10: Recall@100=0.9795, MAP@100=0.5226
Running RRF fusion with k=60...


RRF (k=60): 100%|██████████| 6831/6831 [00:01<00:00, 4454.92it/s]


RRF k=60: Recall@100=0.9799, MAP@100=0.5025
Running RRF fusion with k=120...


RRF (k=120): 100%|██████████| 6831/6831 [00:01<00:00, 4273.85it/s]


RRF k=120: Recall@100=0.9799, MAP@100=0.4982

--- Evaluating Weighted Fusion ---
Normalizing scores for weighted fusion...


Normalizing BM25: 100%|██████████| 6831/6831 [00:03<00:00, 1764.10it/s]
Normalizing Dense: 100%|██████████| 6831/6831 [00:02<00:00, 2349.12it/s]


Running Weighted fusion with alpha=0.10...


Weighted (a=0.10): 100%|██████████| 6831/6831 [00:02<00:00, 2300.74it/s]


Weighted alpha=0.1: Recall@100=0.9833, MAP@100=0.5418
Running Weighted fusion with alpha=0.30...


Weighted (a=0.30): 100%|██████████| 6831/6831 [00:02<00:00, 2296.35it/s]


Weighted alpha=0.3: Recall@100=0.9804, MAP@100=0.5599
Running Weighted fusion with alpha=0.50...


Weighted (a=0.50): 100%|██████████| 6831/6831 [00:02<00:00, 2329.96it/s]


Weighted alpha=0.5: Recall@100=0.9699, MAP@100=0.5427
Running Weighted fusion with alpha=0.70...


Weighted (a=0.70): 100%|██████████| 6831/6831 [00:02<00:00, 2304.91it/s]


Weighted alpha=0.7: Recall@100=0.9291, MAP@100=0.4829
Running Weighted fusion with alpha=0.90...


Weighted (a=0.90): 100%|██████████| 6831/6831 [00:02<00:00, 2355.39it/s]


Weighted alpha=0.9: Recall@100=0.8394, MAP@100=0.4156

--- Evaluation Summary ---
Candidate results: {
  "RRF_k10": {
    "rec": 0.9794710388913287,
    "map": 0.5226457757100408
  },
  "RRF_k60": {
    "rec": 0.9799346118186697,
    "map": 0.5025211588576178
  },
  "RRF_k120": {
    "rec": 0.9799346118186697,
    "map": 0.49824767724659696
  },
  "Weighted_a0.1": {
    "rec": 0.9832894158981116,
    "map": 0.541825415789975
  },
  "Weighted_a0.3": {
    "rec": 0.9803859854584492,
    "map": 0.5599194601612584
  },
  "Weighted_a0.5": {
    "rec": 0.9698701995803445,
    "map": 0.5426628786945676
  },
  "Weighted_a0.7": {
    "rec": 0.9290831015468697,
    "map": 0.4829415898503858
  },
  "Weighted_a0.9": {
    "rec": 0.8393597813887668,
    "map": 0.4156319436905769
  }
}
Best config: {'method': 'weighted', 'alpha': 0.3}   MAP@100 = 0.5599


In [11]:
TEST_RANKS_BM25 = 'TEST_RANKS_BM25.pkl'
    
with open(TEST_RANKS_BM25, 'wb') as f:
    pickle.dump(TEST_RANKS_BM25, f)
    
TEST_RANKS_DENSE = 'TEST_RANKS_DENSE.pkl'

with open(TEST_RANKS_DENSE, 'wb') as f:
    pickle.dump(TEST_RANKS_DENSE, f)

In [10]:
print("\n--- Generating final predictions for test set ---")

# 1. Load Test Data & Define Text Representation
# Use a consistent text representation, potentially guided by BEST_CFG or fixed
# TEXT_TYPE_PREDICT = BEST_CFG.get('text_type', 'title_abstract_claims') # Example: Get from config
TEXT_TYPE_PREDICT = "title_abstract_claims" # Or fix it: "title_abstract_claims"
print(f"Using text_type for prediction: {TEXT_TYPE_PREDICT}")
try:
    # Assuming build_corpus is available
    TEST_IDS, TEST_TEXTS = build_corpus(CITING_TEST, TEXT_TYPE_PREDICT)
    if not TEST_IDS:
        raise ValueError("Failed to extract Test IDs and Texts. Check CITING_TEST data and build_corpus.")
    print(f"Loaded {len(TEST_IDS)} test queries.")
except Exception as e:
    print(f"Error loading test data: {e}")
    # Exit or handle gracefully if test data is crucial
    TEST_IDS, TEST_TEXTS = [], [] # Ensure variables exist but are empty


# 2. BM25 Ranking for Test Set (No prefix changes needed here)

TEST_RANKS_BM25_PATH = 'TEST_RANKS_BM25.pkl'

if os.path.exists(TEST_RANKS_BM25_PATH):
    print("Loading cached test BM25 ranks from: ", SCORES_DENSE_PATH)
    with open(TEST_RANKS_BM25_PATH, 'rb') as f:
        TEST_RANKS_BM25 = pickle.load(f)
else:
    print("Dense scores not found, will be calculated.")
    TEST_RANKS_BM25 = {}
    
if not bool(TEST_RANKS_BM25):
    BM25_RETRIEVAL_DEPTH = 100 # How many candidates to retrieve with BM25
    print(f"Performing BM25 ranking for test queries (retrieving top {BM25_RETRIEVAL_DEPTH})...")
    try:
        if 'bm25' not in globals() or bm25 is None:
            print("BM25 model ('bm25') not found or not fitted. Skipping BM25 ranking.")
        elif not TEST_IDS:
            print("No test queries loaded. Skipping BM25 ranking.")
        else:
            # Tokenize test queries using the stemmed preprocessor used for BM25 training
            # Assuming preprocess_and_stem and cached_tokens are available
            print("Tokenizing test queries for BM25...")
            tokenized_test_queries = cached_tokens("citing_test_queries_stemmed", TEST_TEXTS, stem=True)

            if len(tokenized_test_queries) != len(TEST_IDS):
                print("Warning: Mismatch between number of tokenized queries and test IDs.")
                # Decide how to handle: skip, error out, or try to proceed

            print("Calculating BM25 scores for test set...")
            for i, query_toks in enumerate(tqdm(tokenized_test_queries, desc="BM25 Scoring (Test)")):
                if i >= len(TEST_IDS): break # Safety break if lists mismatch
                cid = TEST_IDS[i]
                try:
                    # Get scores for *all* documents from the fitted bm25 model
                    doc_scores = bm25.get_scores(query_toks)
                    # Get top N indices (more efficient than sorting all)
                    num_candidates = len(doc_scores)
                    actual_depth = min(BM25_RETRIEVAL_DEPTH, num_candidates)
                    if actual_depth > 0:
                        # Efficiently get indices of top N scores
                        top_n_indices = np.argpartition(doc_scores, -actual_depth)[-actual_depth:]
                        # Sort only the top N indices by score
                        top_n_scores = doc_scores[top_n_indices]
                        sorted_top_indices = top_n_indices[np.argsort(top_n_scores)[::-1]] # Descending scores
                        # Store ranked doc IDs from NON_IDS list
                        TEST_RANKS_BM25[cid] = [NON_IDS[idx] for idx in sorted_top_indices if idx < len(NON_IDS)]
                    else:
                        TEST_RANKS_BM25[cid] = []
                except Exception as e:
                    print(f"Error scoring BM25 for test query ID {cid}: {e}")
                    TEST_RANKS_BM25[cid] = [] # Assign empty list on error for this query
            print(f"BM25 ranking complete for {len(TEST_RANKS_BM25)} test queries.")

    except NameError as e:
        print(f"Error during BM25 test ranking setup (required function/variable missing?): {e}")
    except Exception as e:
        print(f"An unexpected error occurred during BM25 test ranking: {e}")

TEST_RANKS_DENSE_PATH = 'TEST_RANKS_DENSE.pkl'

if os.path.exists(TEST_RANKS_DENSE_PATH):
    print("Loading cached test dense ranks from: ", SCORES_DENSE_PATH)
    with open(TEST_RANKS_DENSE_PATH, 'rb') as f:
        TEST_RANKS_DENSE = pickle.load(f)
else:
    print("Dense scores not found, will be calculated.")
    TEST_RANKS_DENSE = {}

DENSE_RETRIEVAL_DEPTH = 100 # How many candidates to retrieve with Dense model

if DENSE_OK and not bool(TEST_RANKS_DENSE):
    print(f"Performing Dense retrieval for test queries (retrieving top {DENSE_RETRIEVAL_DEPTH})...")
    try:
        if 'de' not in globals() or de is None or de.index is None:
             print("DenseIndexer 'de' not found, not fitted, or index is missing. Skipping Dense ranking.")
             DENSE_OK = False # Ensure Dense fusion is skipped later
        elif not TEST_IDS:
             print("No test queries loaded. Skipping Dense ranking.")
             DENSE_OK = False
        else:
            # --- Add "query:" prefix to TEST_TEXTS before searching ---
            print("Adding 'query:' prefix to test queries for Dense search...")
            # Ensure the prefix function/method is accessible
            try:
                 # Use the static method from DenseIndexer class definition
                 prefixed_test_texts = [DenseIndexer._add_query_prefix(text) for text in tqdm(TEST_TEXTS, desc="Prefixing test queries")]
            except NameError:
                 # Fallback if DenseIndexer class definition isn't in scope (should be unlikely)
                 def _add_query_prefix_local(text): return f"query: {text}"
                 prefixed_test_texts = [_add_query_prefix_local(text) for text in tqdm(TEST_TEXTS, desc="Prefixing test queries (local)")]


            # --- Perform search using the DenseIndexer's search method ---
            # Assuming 'de' is the fitted DenseIndexer instance
            # Pass the *prefixed* queries
            labels, sims = de.search(
                prefixed_test_texts,
                top_k=DENSE_RETRIEVAL_DEPTH,
                current_ef_search=300 # Optional: override default ef_search here if needed
            )

            if len(labels) != len(TEST_IDS):
                 print("Warning: Number of dense results doesn't match number of test queries.")
                 # Handle mismatch if necessary

            # Map indices back to document IDs
            print("Mapping dense results to document IDs...")
            for i, cid in enumerate(TEST_IDS):
                if i < len(labels):
                    doc_indices = labels[i]
                    # Ensure indices are valid before mapping to NON_IDS
                    valid_mask = (doc_indices >= 0) & (doc_indices < len(NON_IDS)) # Check bounds
                    valid_indices = doc_indices[valid_mask]
                    TEST_RANKS_DENSE[cid] = [NON_IDS[idx] for idx in valid_indices]
                else:
                     TEST_RANKS_DENSE[cid] = [] # Assign empty list if results are missing for this query
            print(f"Dense retrieval complete for {len(TEST_RANKS_DENSE)} test queries.")

    except NameError as e:
        print(f"Error during Dense test ranking setup (required object 'de' missing?): {e}")
        DENSE_OK = False # Disable dense fusion if setup fails
    except Exception as e:
        print(f"An error occurred during Dense test retrieval: {e}")
        DENSE_OK = False # Disable dense fusion on error


# 4. Fusion / Selection based on BEST_CFG
final_pred = {}
# Determine fusion method from BEST_CFG, default to BM25 if config missing or invalid
default_method = 'bm25' if TEST_RANKS_BM25 else ('dense' if TEST_RANKS_DENSE else 'none')
fusion_method = BEST_CFG.get('method', default_method) if isinstance(BEST_CFG, dict) else default_method

# Ensure necessary rank lists are available for the chosen method
can_do_rrf = fusion_method == 'rrf' and DENSE_OK and TEST_RANKS_BM25 and TEST_RANKS_DENSE
can_do_dense = fusion_method == 'dense' and DENSE_OK and TEST_RANKS_DENSE
can_do_bm25 = fusion_method == 'bm25' and TEST_RANKS_BM25

print(f"Applying final ranking strategy: {fusion_method}")

if can_do_rrf:
    fuse_k = BEST_CFG.get('k', 60) # Get RRF k from config, default to 60
    print(f"Using RRF fusion with k={fuse_k}")
    # Assuming rrf_fuse function is available
    for cid in tqdm(TEST_IDS, desc="RRF Fusion (Test)"):
        pools = [TEST_RANKS_BM25.get(cid, []), TEST_RANKS_DENSE.get(cid, [])]
        # Ensure rrf_fuse handles empty lists gracefully
        final_pred[cid] = rrf_fuse(pools, k=fuse_k)[:K_SUBMISSION]

elif can_do_dense:
    print("Using Dense rankings only.")
    for cid in tqdm(TEST_IDS, desc="Dense Selection (Test)"):
        final_pred[cid] = TEST_RANKS_DENSE.get(cid, [])[:K_SUBMISSION]

elif can_do_bm25:
     print("Using BM25 rankings only.")
     for cid in tqdm(TEST_IDS, desc="BM25 Selection (Test)"):
        final_pred[cid] = TEST_RANKS_BM25.get(cid, [])[:K_SUBMISSION]

else:
    # Fallback strategy if chosen method failed or prerequisites missing
    print(f"Warning: Best config method '{fusion_method}' could not be applied or prerequisites missing.")
    if TEST_RANKS_BM25:
        print("Falling back to BM25 rankings.")
        fusion_method = 'bm25 (fallback)'
        for cid in tqdm(TEST_IDS, desc="Fallback BM25 (Test)"):
            final_pred[cid] = TEST_RANKS_BM25.get(cid, [])[:K_SUBMISSION]
    elif TEST_RANKS_DENSE and DENSE_OK:
         print("Falling back to Dense rankings.")
         fusion_method = 'dense (fallback)'
         for cid in tqdm(TEST_IDS, desc="Fallback Dense (Test)"):
            final_pred[cid] = TEST_RANKS_DENSE.get(cid, [])[:K_SUBMISSION]
    else:
         print("Error: No rankings available (BM25 or Dense) to generate predictions.")
         fusion_method = 'none'
         for cid in TEST_IDS:
            final_pred[cid] = [] # Assign empty list if no method worked

print(f"Final ranking strategy applied: {fusion_method}")


# 5. Save Predictions
output_filename = "prediction1.json"
try:
    # Ensure all query IDs from TEST_IDS are present in the output
    for cid in TEST_IDS:
        if cid not in final_pred:
            print(f"Warning: Test query ID {cid} missing from final predictions. Adding empty list.")
            final_pred[cid] = []

    # Final check: ensure values are lists
    for cid in final_pred:
        if not isinstance(final_pred[cid], list):
             print(f"Warning: Prediction for {cid} is not a list. Converting.")
             final_pred[cid] = list(final_pred[cid]) # Attempt conversion

    with open(output_filename, "w", encoding="utf-8") as f:
        json.dump(final_pred, f, indent=2) # Use indent for readability
    print(f"✅ Wrote {output_filename} with predictions for {len(final_pred)} queries.")

    # Sanity check: count queries with fewer than K_SUBMISSION results
    short_preds = sum(1 for cid in final_pred if len(final_pred.get(cid, [])) < K_SUBMISSION)
    zero_preds = sum(1 for cid in final_pred if not final_pred.get(cid, []))
    if zero_preds > 0:
         print(f"   ⚠️ Warning: {zero_preds} queries have ZERO predictions.")
    elif short_preds > 0:
        print(f"   ⚠️ Warning: {short_preds} queries have fewer than {K_SUBMISSION} predictions.")
    else:
        print(f"   ✅ All {len(final_pred)} queries have {K_SUBMISSION} predictions (or max available).")

except Exception as e:
    print(f"Error writing prediction file '{output_filename}': {e}")



--- Generating final predictions for test set ---
Using text_type for prediction: title_abstract_claims
Loaded 1000 test queries.
Performing BM25 ranking for test queries (retrieving top 100)...
Tokenizing test queries for BM25...
Loading cached tokens from: .cache/citing_test_queries_stemmed_stemmed.pkl
Calculating BM25 scores for test set...


BM25 Scoring (Test): 100%|██████████| 1000/1000 [26:51<00:00,  1.61s/it] 


BM25 ranking complete for 1000 test queries.
Performing Dense retrieval for test queries (retrieving top 100)...
Adding 'query:' prefix to test queries for Dense search...


Prefixing test queries: 100%|██████████| 1000/1000 [00:00<00:00, 443138.30it/s]


Searching with ef_search=300...
Adding 'query:' prefix to 1000 search queries...


Prefixing queries: 100%|██████████| 1000/1000 [00:00<00:00, 303056.65it/s]


Embedding 1000 texts (Embedding Queries)...


Batches: 100%|██████████| 4/4 [00:27<00:00,  6.96s/it]


Performing knn_query for 1000 queries (top_k=100)...
Dense search complete.
Mapping dense results to document IDs...
Dense retrieval complete for 1000 test queries.
Applying final ranking strategy: weighted
Falling back to BM25 rankings.


Fallback BM25 (Test): 100%|██████████| 1000/1000 [00:00<00:00, 739214.66it/s]

Final ranking strategy applied: bm25 (fallback)
✅ Wrote prediction1.json with predictions for 1000 queries.
   ✅ All 1000 queries have 100 predictions (or max available).





# PLOT