In [2]:
import os
import json
import numpy as np
import torch
from sentence_transformers import SentenceTransformer, models, util
from tqdm import tqdm
# Removed: import argparse
from pathlib import Path
import types # Used for SimpleNamespace if preferred

config = {
    # --- Data Files ---
    'base_dir': '.', # Base directory containing the data files.
    'query_list_file': 'test_queries.json', # Path to the JSON file with query IDs (relative to base_dir). REQUIRED.
    'pre_ranking_file': 'shuffled_pre_ranking.json', # Path to the initial ranking JSON (relative to base_dir).
    'queries_content_file': 'queries_content_with_features.json', # Path to queries content JSON (relative to base_dir).
    'documents_content_file': 'documents_content_with_features.json', # Path to documents content JSON (relative to base_dir).
    'output_file': 'prediction2.json', # Path to save the re-ranked prediction JSON (relative to base_dir).

    # --- Model and Text Settings ---
    'model_name': 'AI-Growth-Lab/PatentSBERTa', # Sentence Transformer model name.
    'pooling': 'mean', # Pooling strategy (Note: may be overridden by model config). Choices: 'mean', 'max', 'cls'
    'text_type': 'TA', # Type of text content. Choices: 'TA', 'claims', 'tac1', 'description', 'full', 'features'
    'max_length': 512, # Max sequence length for the model.

    # --- Execution Settings ---
    'batch_size': 32, # Batch size for encoding document texts.
    'device': None # Device: 'cuda', 'cpu', or None (auto-detect).
}

# --- Auto-detect device if not specified ---
if config['device'] is None:
    config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
elif config['device'] == 'cuda' and not torch.cuda.is_available():
    print("Warning: CUDA requested but not available. Using CPU.")
    config['device'] = 'cpu'

# ----------------------------
# Utility Functions
# ----------------------------

def load_json_file(file_path):
    """Load JSON data from a file"""
    print(f"Loading JSON from: {file_path}")
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        print(f"Successfully loaded {len(data)} items.")
        return data
    except FileNotFoundError:
        print(f"Error: File not found at {file_path}")
        return None
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {file_path}")
        return None
    except Exception as e:
        print(f"An unexpected error occurred loading {file_path}: {e}")
        return None

def save_json_file(data, file_path):
    """Save data to a JSON file"""
    print(f"Saving JSON to: {file_path}")
    try:
        # Ensure the directory exists before saving
        output_dir = os.path.dirname(file_path)
        if output_dir: # Check if dirname returned a non-empty string
             os.makedirs(output_dir, exist_ok=True)
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, indent=2)
        print(f"Successfully saved data to {file_path}")
    except Exception as e:
        print(f"An error occurred saving to {file_path}: {e}")

def load_content_data(file_path):
    """Load content data from a JSON file and create a FAN to Content mapping."""
    data = load_json_file(file_path)
    if data is None:
        return {}

    content_dict = {}
    key_options = ['FAN', 'Application_Number'] # Handle potential key variations

    for item in data:
        fan_key = None
        for key in key_options:
            if key in item:
                # Sometimes Application_Number needs Application_Category appended
                if key == 'Application_Number' and 'Application_Category' in item:
                   fan_key = item[key] + item.get('Application_Category', '') # Safely get category
                else:
                   fan_key = item[key]
                break # Found a key, stop looking

        if fan_key and 'Content' in item:
             content_dict[fan_key] = item['Content']
        # else:
        #     print(f"Warning: Could not find FAN key or Content in item: {item.keys()}")

    print(f"Created content dictionary with {len(content_dict)} entries.")
    return content_dict


def extract_text(content_dict, text_type="TA"):
    """Extract text from patent content based on text_type"""
    if not isinstance(content_dict, dict):
        # print(f"Warning: Invalid content_dict provided (type: {type(content_dict)}), expected dict.")
        return ""

    text_parts = []

    # Note: The original argparse choices included 'TAC', but the function uses 'tac1'.
    # Adjust config['text_type'] if 'tac1' was intended instead of 'TAC'.
    if text_type in ["TA", "tac1", "full", "title_abstract"]:
        text_parts.append(content_dict.get("title", ""))
        text_parts.append(content_dict.get("pa01", "")) # Abstract

    if text_type in ["claims", "tac1", "full"]:
        claims = []
        first_claim = None
        # Sort keys to approximate claim order, although keys aren't guaranteed sequential
        sorted_keys = sorted([key for key in content_dict if key.startswith('c-')])
        for key in sorted_keys:
            claim_text = content_dict.get(key, "")
            if claim_text:
                claims.append(claim_text)
                if first_claim is None and text_type == "tac1":
                    first_claim = claim_text

        if text_type == "claims" or text_type == "full":
            text_parts.extend(claims)
        elif text_type == "tac1" and first_claim:
            text_parts.append(first_claim)

    if text_type in ["description", "full"]:
        # Add description paragraphs (keys starting with 'p')
        desc_parts = []
        # Sort keys to approximate paragraph order
        sorted_keys = sorted([key for key in content_dict if key.startswith('p')])
        for key in sorted_keys:
             desc_parts.append(content_dict.get(key,""))
        text_parts.extend(desc_parts)

    if text_type == "features":
        # Extract LLM features if present
        text_parts.append(content_dict.get("features", ""))

    # Join non-empty parts with a space
    return " ".join(filter(None, text_parts)).strip()


# ----------------------------
# Main Re-ranking Logic
# ----------------------------

# Changed function signature to accept config dictionary
def main(cfg):
    # --- Device Setup ---
    # Use device from the config dictionary
    device = torch.device(cfg['device'])
    print(f"Using device: {device}")

    # --- Construct Full Paths ---
    # Use base_dir from the config dictionary
    def get_full_path(path):
        if os.path.isabs(path):
            return path
        # Use cfg['base_dir'] instead of args.base_dir
        return os.path.join(cfg['base_dir'], path)

    # Use paths from the config dictionary
    query_list_file = get_full_path(cfg['query_list_file'])
    pre_ranking_file = get_full_path(cfg['pre_ranking_file'])
    queries_content_file = get_full_path(cfg['queries_content_file'])
    documents_content_file = get_full_path(cfg['documents_content_file'])
    output_file = get_full_path(cfg['output_file'])

    # --- Load Data ---
    query_ids = load_json_file(query_list_file)
    pre_ranking_data = load_json_file(pre_ranking_file)
    queries_content = load_content_data(queries_content_file)
    documents_content = load_content_data(documents_content_file)

    if not query_ids or not pre_ranking_data or not queries_content or not documents_content:
        print("Error: Failed to load one or more essential data files. Exiting.")
        return

    # --- Load Model ---
    # Use model_name from the config dictionary
    print(f"Loading SentenceTransformer model: {cfg['model_name']}")
    try:
        # Define model architecture if needed (e.g., for specific pooling)
        # word_embedding_model = models.Transformer(cfg['model_name'], max_seq_length=cfg['max_length'])
        # pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=cfg['pooling']) # Use cfg['pooling']
        # model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device)

        # Simpler loading if default pooling (mean) is okay or model config handles it
        model = SentenceTransformer(cfg['model_name'], device=device)
        model.max_seq_length = cfg['max_length'] # Set max length using cfg['max_length']
        print("Model loaded successfully.")
    except Exception as e:
        print(f"Error loading model {cfg['model_name']}: {e}")
        return

    # --- Re-ranking Process ---
    # Use text_type from the config dictionary
    print(f"Starting re-ranking for {len(query_ids)} queries using '{cfg['text_type']}' content...")
    results = {}
    missing_query_content = 0
    missing_pre_ranking = 0
    queries_with_no_valid_docs = 0

    for query_id in tqdm(query_ids, desc="Processing queries"):
        # 1. Get Query Content
        query_content_dict = queries_content.get(query_id)
        if not query_content_dict:
            # print(f"Warning: Content not found for query {query_id}")
            missing_query_content += 1
            results[query_id] = [] # Assign empty list if query content missing
            continue

        # Use text_type from config
        query_text = extract_text(query_content_dict, cfg['text_type'])
        if not query_text:
            # print(f"Warning: Extracted text is empty for query {query_id} with type '{cfg['text_type']}'")
            missing_query_content += 1
            results[query_id] = []
            continue

        # 2. Get Candidate Documents
        candidate_doc_ids = pre_ranking_data.get(query_id)
        if not candidate_doc_ids:
            # print(f"Warning: Pre-ranking not found for query {query_id}")
            missing_pre_ranking += 1
            results[query_id] = []
            continue

        # 3. Get Candidate Document Content
        doc_texts = []
        valid_doc_ids_for_query = []
        missing_docs_count = 0
        for doc_id in candidate_doc_ids:
            doc_content_dict = documents_content.get(doc_id)
            if not doc_content_dict:
                # print(f"Warning: Content not found for document {doc_id} (query {query_id})")
                missing_docs_count += 1
                continue

            # Use text_type from config
            doc_text = extract_text(doc_content_dict, cfg['text_type'])
            if doc_text:
                doc_texts.append(doc_text)
                valid_doc_ids_for_query.append(doc_id)
            else:
                 # print(f"Warning: Extracted text is empty for document {doc_id} with type '{cfg['text_type']}' (query {query_id})")
                 missing_docs_count += 1


        if not valid_doc_ids_for_query:
            # print(f"Warning: No valid document texts found for query {query_id} after checking {len(candidate_doc_ids)} candidates.")
            queries_with_no_valid_docs += 1
            results[query_id] = [] # Assign empty list if no valid docs
            continue

        # 4. Generate Embeddings (On-the-fly)
        try:
            # Use batch_size from config
            query_embedding = model.encode(
                query_text,
                convert_to_tensor=True,
                show_progress_bar=False,
                batch_size=1 # Batch size for query is usually 1
            )
            doc_embeddings = model.encode(
                doc_texts,
                convert_to_tensor=True,
                show_progress_bar=False,
                batch_size=cfg['batch_size'] # Use cfg['batch_size']
            )
        except Exception as e:
            print(f"Error during encoding for query {query_id}: {e}")
            results[query_id] = candidate_doc_ids # Fallback to original order on error
            continue


        # 5. Calculate Similarities
        # Ensure embeddings are on the same device for cosine similarity
        query_embedding = query_embedding.to(device)
        doc_embeddings = doc_embeddings.to(device)

        cosine_scores = util.cos_sim(query_embedding, doc_embeddings)[0] # Get the first row of scores
        cosine_scores = cosine_scores.cpu().numpy() # Move scores to CPU for sorting

        # 6. Rank Documents
        # Combine scores with their original valid doc_ids
        doc_scores = list(zip(valid_doc_ids_for_query, cosine_scores))

        # Sort by score in descending order
        doc_scores.sort(key=lambda x: x[1], reverse=True)

        # Get the sorted list of document IDs
        re_ranked_doc_ids = [doc_id for doc_id, score in doc_scores]

        # If some original docs were missing content, append their IDs at the end
        # (or handle differently if needed - e.g., exclude them)
        original_candidate_set = set(candidate_doc_ids)
        reranked_set = set(re_ranked_doc_ids)
        missing_from_reranked = list(original_candidate_set - reranked_set)
        final_ranked_list = re_ranked_doc_ids + missing_from_reranked

        results[query_id] = final_ranked_list[:len(candidate_doc_ids)] # Ensure max length is original candidate count


    # --- Report Missing Data ---
    print("\n--- Re-ranking Summary ---")
    print(f"Total queries processed: {len(query_ids)}")
    if missing_query_content > 0:
        print(f"Warning: Content missing or empty for {missing_query_content} queries.")
    if missing_pre_ranking > 0:
        print(f"Warning: Pre-ranking data missing for {missing_pre_ranking} queries.")
    if queries_with_no_valid_docs > 0:
        print(f"Warning: {queries_with_no_valid_docs} queries had no valid documents with content.")
    print(f"Number of queries in results: {len(results)}")


    # --- Save Results ---
    # Use output_file from config
    save_json_file(results, output_file)

    print("\nRe-ranking complete.")

main(config)

Using device: cuda
Loading JSON from: ./test_queries.json
Successfully loaded 10 items.
Loading JSON from: ./shuffled_pre_ranking.json
Successfully loaded 30 items.
Loading JSON from: ./queries_content_with_features.json
Successfully loaded 30 items.
Created content dictionary with 30 entries.
Loading JSON from: ./documents_content_with_features.json
Successfully loaded 900 items.
Created content dictionary with 900 entries.
Loading SentenceTransformer model: AI-Growth-Lab/PatentSBERTa
Model loaded successfully.
Starting re-ranking for 10 queries using 'TA' content...


Processing queries: 100%|██████████| 10/10 [00:01<00:00,  5.41it/s]



--- Re-ranking Summary ---
Total queries processed: 10
Number of queries in results: 10
Saving JSON to: ./prediction2.json
Successfully saved data to ./prediction2.json

Re-ranking complete.
