In [None]:
import os
import json
import numpy as np
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, models, util
from tqdm import tqdm
from pathlib import Path
import time # To measure experiment time
import types # Used for SimpleNamespace if preferred
import re

# ----------------------------
# Configuration
# ----------------------------
# Base config - specific parameters will be overridden by experiments
config = {
    # --- Data Files ---
    'base_dir': '.',
    'query_list_file': 'test_queries.json',
    'pre_ranking_file': 'shuffled_pre_ranking.json',
    'queries_content_file': 'queries_content_with_features.json',
    'documents_content_file': 'documents_content_with_features.json',
    # 'qrels_file': 'train_gold_mapping.json', # REMOVED - Not needed for prediction generation

    # --- Default Model/Text Settings (can be overridden in experiments) ---
    'reranker_type': 'bi-encoder',
    'bi_encoder_model': 'AI-Growth-Lab/PatentSBERTa',
    'cross_encoder_model': 'cross-encoder/ms-marco-MiniLM-L-6-v2',
    'text_type': 'tac1',
    'max_length': 512,

    # --- Execution Settings ---
    'batch_size': 32,
    'device': None,

    # --- Output Settings ---
    'save_individual_predictions': True, # Keep True to save output for each experiment
    'output_file_prefix': 'prediction_exp', # Prefix for individual prediction files
}

# --- Experiments to Run ---
# Define different configurations to generate predictions for
experiments = [
    # --- Baseline Bi-Encoders ---
    {
        'exp_id': 'BiEnc_PatentSBERTa_tac1',
        'reranker_type': 'bi-encoder',
        'bi_encoder_model': 'AI-Growth-Lab/PatentSBERTa',
        'text_type': 'tac1',
    },
    {
        'exp_id': 'BiEnc_MPNet_tac1',
        'reranker_type': 'bi-encoder',
        'bi_encoder_model': 'all-mpnet-base-v2',
        'text_type': 'tac1',
    },
    # --- Test Claims with Bi-Encoders ---
    {
        'exp_id': 'BiEnc_PatentSBERTa_claims',
        'reranker_type': 'bi-encoder',
        'bi_encoder_model': 'AI-Growth-Lab/PatentSBERTa',
        'text_type': 'claims',
    },
    {
        'exp_id': 'BiEnc_MPNet_claims',
        'reranker_type': 'bi-encoder',
        'bi_encoder_model': 'all-mpnet-base-v2',
        'text_type': 'claims',
    },
    # --- Test QA Bi-Encoder ---
    {
        'exp_id': 'BiEnc_MultiQA_tac1',
        'reranker_type': 'bi-encoder',
        'bi_encoder_model': 'multi-qa-mpnet-base-dot-v1',
        'text_type': 'tac1',
    },
    # --- Baseline Cross-Encoder ---
    {
        'exp_id': 'CrossEnc_L6_tac1',
        'reranker_type': 'cross-encoder',
        'cross_encoder_model': 'cross-encoder/ms-marco-MiniLM-L-6-v2',
        'text_type': 'tac1',
    },
    # --- Test Claims/TA with Cross-Encoder ---
    {
        'exp_id': 'CrossEnc_L6_claims',
        'reranker_type': 'cross-encoder',
        'cross_encoder_model': 'cross-encoder/ms-marco-MiniLM-L-6-v2',
        'text_type': 'claims',
     },
     {
        'exp_id': 'CrossEnc_L6_TA',
        'reranker_type': 'cross-encoder',
        'cross_encoder_model': 'cross-encoder/ms-marco-MiniLM-L-6-v2',
        'text_type': 'TA',
     },
]


# --- Auto-detect device ---
if config['device'] is None:
    config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
elif config['device'] == 'cuda' and not torch.cuda.is_available():
    print("Warning: CUDA requested but not available. Using CPU.")
    config['device'] = 'cpu'

# ----------------------------
# Utility Functions
# ----------------------------
def load_json_file(file_path):
    """Load JSON data from a file"""
    try:
        with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f)
        return data
    except FileNotFoundError: print(f"Error: File not found at {file_path}"); return None
    except json.JSONDecodeError: print(f"Error: Could not decode JSON from {file_path}"); return None
    except Exception as e: print(f"An unexpected error occurred loading {file_path}: {e}"); return None

def save_json_file(data, file_path):
    """Save data to a JSON file"""
    print(f"Saving predictions to: {file_path}")
    try:
        output_dir = os.path.dirname(file_path)
        if output_dir: os.makedirs(output_dir, exist_ok=True)
        with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2)
    except Exception as e: print(f"An error occurred saving to {file_path}: {e}")

def load_content_data(file_path):
    """Load content data from a JSON file and create a FAN/AppNum to Content mapping."""
    data = load_json_file(file_path)
    if data is None: return {}
    print(f"Processing content file: {os.path.basename(file_path)}")
    content_dict = {}
    key_options = ['FAN', 'Application_Number']
    # Ensure keys are strings during loading/processing for consistency
    for item in tqdm(data, desc="Loading content", leave=False):
        fan_key = None
        temp_key_val = None # Store the original key value before potential modification
        for key_name in key_options:
            if key_name in item:
                temp_key_val = item[key_name]
                if key_name == 'Application_Number' and 'Application_Category' in item:
                   fan_key = str(temp_key_val) + str(item.get('Application_Category', '')) # Ensure string concat
                else:
                   fan_key = str(temp_key_val) # Ensure key is string
                break
        if fan_key and 'Content' in item:
             content_dict[fan_key] = item['Content'] # Key is now guaranteed string
    return content_dict

def extract_text(content_dict, text_type="TA"):
    """Extract text from patent content based on text_type"""
    if not isinstance(content_dict, dict): return ""
    text_parts = []
    # Standard types
    if text_type in ["TA", "tac1", "full", "title_abstract"]:
        text_parts.append(content_dict.get("title", ""))
        text_parts.append(content_dict.get("pa01", "")) # Abstract
    if text_type in ["claims", "tac1", "full"]:
        claims, first_claim = [], None
        claim_keys = [key for key in content_dict if key.startswith('c-')]
        def get_sort_key(key_string):
            parts = key_string.split('-', 1); return int(parts[1]) if len(parts) == 2 and parts[1].isdigit() else float('inf')
        sorted_keys = sorted(claim_keys, key=get_sort_key)
        for key in sorted_keys:
            claim_text = content_dict.get(key, "")
            if claim_text:
                claims.append(claim_text)
                if first_claim is None and text_type == "tac1": first_claim = claim_text
        if text_type == "claims" or text_type == "full": text_parts.extend(claims)
        elif text_type == "tac1" and first_claim: text_parts.append(first_claim)
    if text_type in ["description", "full"]:
        desc_parts = []
        desc_keys = [key for key in content_dict if key.startswith('p')]
        def get_p_sort_key(key_string):
             parts = key_string.split('-', 1); return int(parts[1]) if len(parts) == 2 and parts[1].isdigit() else float('inf')
        sorted_keys = sorted(desc_keys, key=get_p_sort_key)
        for key in sorted_keys: desc_parts.append(content_dict.get(key,""))
        text_parts.extend(desc_parts)
    if text_type == "features": text_parts.append(content_dict.get("features", ""))
    result = " ".join(filter(None, text_parts)).strip()
    return result

# --- REMOVED Evaluation Functions ---
# load_qrels, calculate_average_precision, calculate_recall_at_k, evaluate_ranking deleted

# --- REMOVED Print Results Table function ---
# print_results_table deleted

# ----------------------------
# Core Re-ranking Function
# ----------------------------
def perform_reranking(exp_config, query_ids, pre_ranking_data, queries_content, documents_content):
    """Performs the re-ranking for a specific experiment configuration."""
    # This function remains largely the same, but no longer needs to worry about qrels/metrics

    device = torch.device(exp_config['device'])
    reranker_type = exp_config['reranker_type']
    text_type = exp_config['text_type']
    model = None
    model_name = ""

    # --- Load Model ---
    try:
        if reranker_type == 'bi-encoder':
            model_name = exp_config.get('bi_encoder_model')
            if not model_name: raise ValueError("bi_encoder_model must be specified")
            print(f"Loading Bi-Encoder model: {model_name}...")
            model = SentenceTransformer(model_name, device=device)
            model.max_seq_length = exp_config['max_length']

        elif reranker_type == 'cross-encoder':
            model_name = exp_config.get('cross_encoder_model')
            if not model_name: raise ValueError("cross_encoder_model must be specified")
            print(f"Loading Cross-Encoder model: {model_name}...")
            model = CrossEncoder(model_name, device=device, max_length=exp_config['max_length'])
        else:
             raise ValueError(f"Invalid reranker_type: {reranker_type}")
        print(f"Model '{model_name}' loaded.")

    except Exception as e:
        print(f"\nError loading model '{model_name}' for experiment '{exp_config.get('exp_id', 'N/A')}': {e}")
        return None

    # --- Re-ranking Process ---
    results = {}
    pbar = tqdm(query_ids, desc=f"Re-ranking ({exp_config.get('exp_id', 'N/A')})", leave=False)
    # query_ids are strings here
    for query_id in pbar:
        candidate_doc_ids = pre_ranking_data.get(query_id, []) # query_id is string, doc_ids are strings
        query_content_dict = queries_content.get(query_id) # query_id is string

        if not candidate_doc_ids: results[query_id] = []; continue
        if not query_content_dict: results[query_id] = candidate_doc_ids; continue

        query_text = extract_text(query_content_dict, text_type)
        if not query_text: results[query_id] = candidate_doc_ids; continue

        valid_docs_texts = {}
        for doc_id in candidate_doc_ids: # doc_id is string
            doc_content_dict = documents_content.get(doc_id) # doc_id is string
            if doc_content_dict:
                doc_text = extract_text(doc_content_dict, text_type)
                if doc_text: valid_docs_texts[doc_id] = doc_text # Key doc_id is string

        valid_doc_ids = list(valid_docs_texts.keys()) # These are strings
        if not valid_doc_ids: results[query_id] = candidate_doc_ids; continue

        # --- Calculate Scores ---
        doc_scores_calculated = {}
        try:
            doc_texts_for_scoring = [valid_docs_texts[doc_id] for doc_id in valid_doc_ids]

            if reranker_type == 'bi-encoder':
                query_embedding = model.encode(query_text, convert_to_tensor=True, show_progress_bar=False).to(device)
                doc_embeddings = model.encode(doc_texts_for_scoring, convert_to_tensor=True, show_progress_bar=False, batch_size=exp_config['batch_size']).to(device)
                if query_embedding is None or doc_embeddings is None or len(doc_embeddings) == 0: raise RuntimeError("Embedding generation failed")
                if query_embedding.shape[0] == 0 or doc_embeddings.shape[0] == 0: raise RuntimeError("Embedding tensor is empty")
                cosine_scores = util.cos_sim(query_embedding, doc_embeddings)[0].cpu().numpy()
                doc_scores_calculated = dict(zip(valid_doc_ids, cosine_scores))

            elif reranker_type == 'cross-encoder':
                sentence_pairs = [[query_text, doc_text] for doc_text in doc_texts_for_scoring]
                cross_scores = model.predict(sentence_pairs, show_progress_bar=False, batch_size=exp_config['batch_size'], convert_to_numpy=True)
                doc_scores_calculated = dict(zip(valid_doc_ids, cross_scores))

        except Exception as e:
            print(f"\nError during scoring query {query_id} in exp {exp_config.get('exp_id', 'N/A')}: {type(e).__name__} - {e}")
            results[query_id] = candidate_doc_ids; continue

        # --- Rank Documents ---
        min_score = min(doc_scores_calculated.values()) if doc_scores_calculated else 0
        fallback_score = min_score - 1 if min_score > -float('inf') else -float('inf')
        scored_doc_list = []
        processed_docs = set()
        for doc_id, score in doc_scores_calculated.items(): # doc_id is string
            scored_doc_list.append((doc_id, float(score)))
            processed_docs.add(doc_id)
        for doc_id in candidate_doc_ids: # doc_id is string
            if doc_id not in processed_docs: scored_doc_list.append((doc_id, fallback_score))
        scored_doc_list.sort(key=lambda x: x[1], reverse=True)
        final_ranked_list = [doc_id for doc_id, score in scored_doc_list] # List of strings
        results[query_id] = final_ranked_list[:len(candidate_doc_ids)] # Keys (query_id) are strings

    del model
    if torch.cuda.is_available(): torch.cuda.empty_cache()
    return results # Keys are strings, values are lists of strings

# ----------------------------
# Main Execution Logic
# ----------------------------
def main(base_cfg, experiments_to_run):
    print(f"Base device requested: {base_cfg.get('device', 'None specified')}")
    print(f"Using effective device: {config['device']}") # Show auto-detected device

    # --- Construct Full Paths ---
    def get_full_path(path): return path if os.path.isabs(path) else os.path.join(base_cfg['base_dir'], path)
    query_list_file = get_full_path(base_cfg['query_list_file'])
    pre_ranking_file = get_full_path(base_cfg['pre_ranking_file'])
    queries_content_file = get_full_path(base_cfg['queries_content_file'])
    documents_content_file = get_full_path(base_cfg['documents_content_file'])
    # qrels_file_path REMOVED

    # --- Load Shared Data ---
    print("\nLoading shared data...")
    query_ids_raw = load_json_file(query_list_file)
    pre_ranking_data_raw = load_json_file(pre_ranking_file)
    queries_content_raw = load_content_data(queries_content_file)
    documents_content_raw = load_content_data(documents_content_file)
    # qrels REMOVED

    # --- Ensure IDs/Keys are Strings ---
    if query_ids_raw is None: print("\nError: Failed to load query_list_file. Exiting."); return
    query_ids = [str(qid) for qid in query_ids_raw]
    print(f"Loaded and processed {len(query_ids)} query IDs (as strings).")

    if pre_ranking_data_raw is None: print("\nError: Failed to load pre_ranking_file. Exiting."); return
    pre_ranking_data = {str(k): list(map(str, v)) for k, v in pre_ranking_data_raw.items()}
    print(f"Processed {len(pre_ranking_data)} pre-ranking entries (keys/docs as strings).")

    if queries_content_raw is None or documents_content_raw is None:
         print("\nError: Failed to load content files. Exiting."); return
    queries_content = {str(k): v for k, v in queries_content_raw.items()}
    documents_content = {str(k): v for k, v in documents_content_raw.items()}
    print("Ensured content dictionary keys are strings.")
    # ------------------------------------

    if not all([pre_ranking_data, queries_content, documents_content]):
        print("\nError: Failed to load one or more data files (pre-ranking, content). Exiting.")
        return

    # --- Run Experiments ---
    print(f"\nStarting {len(experiments_to_run)} experiments to generate prediction files...")

    for i, exp_params in enumerate(experiments_to_run):
        exp_id = exp_params.get('exp_id', f'exp_{i+1}')
        print(f"\n--- Running Experiment: {exp_id} ---")
        start_time = time.time()

        run_config = base_cfg.copy()
        run_config.update(exp_params)

        # Perform re-ranking
        predictions = perform_reranking(
            run_config,
            query_ids,
            pre_ranking_data,
            queries_content,
            documents_content
        )

        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Experiment {exp_id} finished in {elapsed_time:.2f} seconds.")

        if predictions is None:
            print(f"Experiment {exp_id} failed during re-ranking. No prediction file generated.")
        else:
            # Save predictions if enabled
            if run_config.get('save_individual_predictions', False):
                 pred_filename = f"{run_config.get('output_file_prefix', 'pred')}_{exp_id}.json"
                 pred_filepath = get_full_path(pred_filename)
                 save_json_file(predictions, pred_filepath) # Save the generated predictions
            else:
                 print(f"Skipping saving prediction file for {exp_id} as 'save_individual_predictions' is False.")

        # Metrics calculation and storage REMOVED
        # Results table printing REMOVED

    print("\nAll experiments complete.")

# --- Run the main function ---
if __name__ == "__main__":
    main(config, experiments)

: 