# Final Implementation of the neural re-ranking approach for CheckThat Lab Subtask 4b (Scientific Claim Source Retrieval)

# 1) Importing data

In [1]:
import numpy as np
import pandas as pd
PATH_COLLECTION_DATA = 'subtask4b_collection_data.pkl'
df_collection = pd.read_pickle(PATH_COLLECTION_DATA)
PATH_QUERY_TRAIN_DATA = 'subtask4b_query_tweets_train.tsv'
PATH_QUERY_DEV_DATA = 'subtask4b_query_tweets_dev.tsv'
df_query_train = pd.read_csv(PATH_QUERY_TRAIN_DATA, sep = '\t')
df_query_dev = pd.read_csv(PATH_QUERY_DEV_DATA, sep = '\t')

# 2) Exploring optimal parameters for neural re-ranking

In [2]:
import numpy as np
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer, util
from rank_bm25 import BM25Okapi
from tqdm import tqdm

# 1. Load data first
print("Loading data...")
# Update these paths as needed
PATH_COLLECTION_DATA = 'subtask4b_collection_data.pkl'
PATH_QUERY_TRAIN_DATA = 'subtask4b_query_tweets_train.tsv'
PATH_QUERY_DEV_DATA = 'subtask4b_query_tweets_dev.tsv'

df_collection = pd.read_pickle(PATH_COLLECTION_DATA)
df_query_train = pd.read_csv(PATH_QUERY_TRAIN_DATA, sep='\t')
df_query_dev = pd.read_csv(PATH_QUERY_DEV_DATA, sep='\t')

# 2. Create BM25 index
print("Creating BM25 index...")
corpus = df_collection[:][['title', 'abstract']].apply(lambda x: f"{x['title']} {x['abstract']}", axis=1).tolist()
cord_uids = df_collection[:]['cord_uid'].tolist()
tokenized_corpus = [doc.split(' ') for doc in corpus]
bm25 = BM25Okapi(tokenized_corpus)

# 3. BM25 retrieval function
def get_top_cord_uids_extended(query, k=20):
    """Get top-k candidates using BM25 with adjustable k"""
    tokenized_query = query.split(' ')
    doc_scores = bm25.get_scores(tokenized_query)
    indices = np.argsort(-doc_scores)[:k]
    bm25_topk = [cord_uids[x] for x in indices]
    bm25_scores = [doc_scores[x] for x in indices]

    return bm25_topk, bm25_scores

# 4. Evaluation function
def get_performance_mrr(data, col_gold, col_pred, list_k = [1, 5, 10]):
    d_performance = {}
    for k in list_k:
        data["in_topx"] = data.apply(lambda x: (1/([i for i in x[col_pred][:k]].index(x[col_gold]) + 1) if x[col_gold] in [i for i in x[col_pred][:k]] else 0), axis=1)
        d_performance[k] = data["in_topx"].mean()
    return d_performance

# 5. Define the optimized NeuralReranker
class OptimizedReranker:
    def __init__(self, model_name='all-MiniLM-L6-v2'):
        self.model = SentenceTransformer(model_name)
        self.corpus_embeddings = None
        self.corpus_texts = None
        self.paper_ids = None
        
    def index_collection(self, df_collection):
        # Create text representation for each document
        self.corpus_texts = df_collection[:][['title', 'abstract']].apply(
            lambda x: f"{x['title']} {x['abstract']}", axis=1).tolist()
        self.paper_ids = df_collection[:]['cord_uid'].tolist()
        
        # Calculate embeddings for all documents
        print("Calculating document embeddings...")
        self.corpus_embeddings = self.model.encode(
            self.corpus_texts, 
            convert_to_tensor=True,
            show_progress_bar=True
        )
        print(f"Created embeddings for {len(self.corpus_texts)} documents")
    
    def rerank_candidates(self, query, candidate_ids, candidate_scores=None, top_k=5, alpha=0.3):
        """Re-rank the candidate documents with adjustable alpha parameter"""
        # Get query embedding
        query_embedding = self.model.encode(query, convert_to_tensor=True)
        
        # Get embeddings for candidate documents
        candidate_indices = [self.paper_ids.index(cid) for cid in candidate_ids]
        candidate_embeddings = self.corpus_embeddings[candidate_indices]
        
        # Calculate cosine similarity
        cos_scores = util.cos_sim(query_embedding, candidate_embeddings)[0]
        
        # Combine scores with adjustable alpha
        if candidate_scores is not None:
            bm25_scores = torch.tensor(candidate_scores)
            bm25_scores = bm25_scores / bm25_scores.max()
            
            # Alpha controls weight between BM25 and neural scores
            combined_scores = alpha * bm25_scores + (1-alpha) * cos_scores
        else:
            combined_scores = cos_scores
            
        # Sort by score
        top_results = torch.argsort(-combined_scores)[:top_k].tolist()
        
        # Return re-ranked document IDs
        return [candidate_ids[i] for i in top_results]

# 6. Alpha parameter optimization function
def evaluate_alpha_values(reranker, df_queries, alphas=[0.2, 0.3, 0.4, 0.5], top_k=5):
    """Test different alpha values for combining BM25 and neural scores"""
    results = {}
    
    for alpha in alphas:
        print(f"Testing alpha={alpha}...")
        alpha_results = []
        
        for _, row in tqdm(df_queries.iterrows(), total=len(df_queries)):
            query = row['tweet_text']
            
            # First-stage: Get BM25 candidates
            bm25_candidates, bm25_scores = get_top_cord_uids_extended(query)
            
            # Second-stage: Neural re-ranking with current alpha
            reranked_candidates = reranker.rerank_candidates(
                query, 
                bm25_candidates, 
                bm25_scores, 
                top_k=top_k,
                alpha=alpha
            )
            
            alpha_results.append({
                'post_id': row['post_id'],
                'tweet_text': query,
                'cord_uid': row['cord_uid'],
                'reranked': reranked_candidates
            })
        
        # Evaluate MRR scores
        alpha_df = pd.DataFrame(alpha_results)
        mrr_scores = get_performance_mrr(alpha_df, 'cord_uid', 'reranked')
        results[alpha] = mrr_scores
        print(f"Alpha={alpha} results: {mrr_scores}")
    
    return results

# 7. BM25 candidate pool size optimization function
def evaluate_candidate_pool_sizes(reranker, df_queries, pool_sizes=[20, 30, 50, 100], alpha=0.3, top_k=5):
    """Test different candidate pool sizes for BM25 retrieval"""
    results = {}
    
    for pool_size in pool_sizes:
        print(f"\nTesting candidate pool size={pool_size}...")
        size_results = []
        
        for _, row in tqdm(df_queries.iterrows(), total=len(df_queries)):
            query = row['tweet_text']
            
            # First-stage: Get BM25 candidates with current pool size
            bm25_candidates, bm25_scores = get_top_cord_uids_extended(query, k=pool_size)
            
            # Second-stage: Neural re-ranking
            reranked_candidates = reranker.rerank_candidates(
                query, 
                bm25_candidates, 
                bm25_scores, 
                top_k=top_k,
                alpha=alpha
            )
            
            size_results.append({
                'post_id': row['post_id'],
                'tweet_text': query,
                'cord_uid': row['cord_uid'],
                'reranked': reranked_candidates
            })
        
        # Evaluate MRR scores
        size_df = pd.DataFrame(size_results)
        mrr_scores = get_performance_mrr(size_df, 'cord_uid', 'reranked')
        results[pool_size] = mrr_scores
        print(f"Pool size={pool_size} results: {mrr_scores}")
    
    return results

# 8. Complete optimization pipeline
def run_optimization():
    # Initialize and index with optimized reranker
    print("Initializing optimized reranker...")
    optimized_reranker = OptimizedReranker('all-MiniLM-L6-v2')
    optimized_reranker.index_collection(df_collection)
    
    # Use a subset of dev data for faster testing during optimization
    dev_subset = df_query_dev.sample(min(200, len(df_query_dev)), random_state=42)
    
    # 1. Find optimal alpha
    print("\n=== Alpha Parameter Optimization ===")
    alpha_results = evaluate_alpha_values(
        optimized_reranker,
        dev_subset,
        alphas=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
    )
    best_alpha = max(alpha_results.items(), key=lambda x: x[1][5])[0]
    print(f"\nBest alpha: {best_alpha} with MRR@5: {alpha_results[best_alpha][5]}")
    
    # 2. Find optimal pool size using best alpha
    print("\n=== Pool Size Optimization ===")
    pool_results = evaluate_candidate_pool_sizes(
        optimized_reranker,
        dev_subset,
        pool_sizes=[20, 30, 50, 75, 100],
        alpha=best_alpha
    )
    best_pool_size = max(pool_results.items(), key=lambda x: x[1][5])[0]
    print(f"\nBest pool size: {best_pool_size} with MRR@5: {pool_results[best_pool_size][5]}")
    
    # 3. Generate final predictions with optimal parameters
    print("\n=== Generating Final Predictions ===")
    final_results = []
    
    print(f"Using alpha={best_alpha}, pool_size={best_pool_size}")
    for _, row in tqdm(df_query_dev.iterrows(), total=len(df_query_dev)):
        query = row['tweet_text']
        
        # First-stage: Get BM25 candidates with best pool size
        bm25_candidates, bm25_scores = get_top_cord_uids_extended(query, k=best_pool_size)
        
        # Second-stage: Neural re-ranking with best alpha
        reranked_candidates = optimized_reranker.rerank_candidates(
            query, 
            bm25_candidates, 
            bm25_scores, 
            top_k=5,
            alpha=best_alpha
        )
        
        final_results.append({
            'post_id': row['post_id'],
            'preds': reranked_candidates
        })
    
    # Save final predictions
    final_df = pd.DataFrame(final_results)
    final_df.to_csv('optimized_predictions.tsv', index=None, sep='\t')
    print("Final predictions saved to 'optimized_predictions.tsv'")
    
    # Evaluate final predictions on the full dev set
    final_eval_data = []
    for _, row in df_query_dev.iterrows():
        post_id = row['post_id']
        gold_uid = row['cord_uid']
        pred_row = final_df[final_df['post_id'] == post_id].iloc[0]
        preds = pred_row['preds']
        
        final_eval_data.append({
            'post_id': post_id,
            'cord_uid': gold_uid,
            'reranked': preds
        })
    
    # Calculate MRR scores for final results
    final_eval_df = pd.DataFrame(final_eval_data)
    final_mrr_scores = get_performance_mrr(final_eval_df, 'cord_uid', 'reranked', list_k=[1, 5, 10])
    print(f"Final prediction results: {final_mrr_scores}")
    
    return best_alpha, best_pool_size, final_df, final_mrr_scores

# Run the optimization
if __name__ == "__main__":
    best_alpha, best_pool_size, predictions, final_results = run_optimization()
    print(f"\nOptimization complete! Best parameters: alpha={best_alpha}, pool_size={best_pool_size}")

  from tqdm.autonotebook import tqdm, trange


Loading data...
Creating BM25 index...
Initializing optimized reranker...




Calculating document embeddings...


Batches: 100%|██████████| 242/242 [11:14<00:00,  2.79s/it]


Created embeddings for 7718 documents

=== Alpha Parameter Optimization ===
Testing alpha=0.0...


100%|██████████| 200/200 [00:20<00:00,  9.74it/s]


Alpha=0.0 results: {1: 0.435, 5: 0.5126666666666666, 10: 0.5126666666666666}
Testing alpha=0.1...


100%|██████████| 200/200 [00:17<00:00, 11.35it/s]


Alpha=0.1 results: {1: 0.515, 5: 0.5676666666666667, 10: 0.5676666666666667}
Testing alpha=0.2...


100%|██████████| 200/200 [00:18<00:00, 10.96it/s]


Alpha=0.2 results: {1: 0.54, 5: 0.5834999999999999, 10: 0.5834999999999999}
Testing alpha=0.3...


100%|██████████| 200/200 [00:17<00:00, 11.41it/s]


Alpha=0.3 results: {1: 0.56, 5: 0.5989166666666667, 10: 0.5989166666666667}
Testing alpha=0.4...


100%|██████████| 200/200 [00:19<00:00, 10.16it/s]


Alpha=0.4 results: {1: 0.545, 5: 0.5943333333333334, 10: 0.5943333333333334}
Testing alpha=0.5...


100%|██████████| 200/200 [00:20<00:00,  9.87it/s]


Alpha=0.5 results: {1: 0.555, 5: 0.5989166666666667, 10: 0.5989166666666667}
Testing alpha=0.6...


100%|██████████| 200/200 [00:20<00:00,  9.69it/s]


Alpha=0.6 results: {1: 0.55, 5: 0.5920833333333333, 10: 0.5920833333333333}

Best alpha: 0.3 with MRR@5: 0.5989166666666667

=== Pool Size Optimization ===

Testing candidate pool size=20...


100%|██████████| 200/200 [00:18<00:00, 10.91it/s]


Pool size=20 results: {1: 0.56, 5: 0.5989166666666667, 10: 0.5989166666666667}

Testing candidate pool size=30...


100%|██████████| 200/200 [00:19<00:00, 10.24it/s]


Pool size=30 results: {1: 0.56, 5: 0.6011666666666666, 10: 0.6011666666666666}

Testing candidate pool size=50...


100%|██████████| 200/200 [00:18<00:00, 11.01it/s]


Pool size=50 results: {1: 0.575, 5: 0.6186666666666667, 10: 0.6186666666666667}

Testing candidate pool size=75...


100%|██████████| 200/200 [00:19<00:00, 10.32it/s]


Pool size=75 results: {1: 0.575, 5: 0.6236666666666667, 10: 0.6236666666666667}

Testing candidate pool size=100...


100%|██████████| 200/200 [00:21<00:00,  9.42it/s]


Pool size=100 results: {1: 0.57, 5: 0.6201666666666666, 10: 0.6201666666666666}

Best pool size: 75 with MRR@5: 0.6236666666666667

=== Generating Final Predictions ===
Using alpha=0.3, pool_size=75


100%|██████████| 1400/1400 [02:32<00:00,  9.19it/s]


Final predictions saved to 'optimized_predictions.tsv'
Final prediction results: {1: 0.5814285714285714, 5: 0.6234047619047619, 10: 0.6234047619047619}

Optimization complete! Best parameters: alpha=0.3, pool_size=75


# 3) Executing neural re-ranking for development phase

In [3]:
import numpy as np
import pandas as pd
import torch
import re
import string
from datetime import datetime
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm

# Try to install required packages
try:
    import pip
    print("Installing required packages...")
    import sys
    from subprocess import call
    call([sys.executable, "-m", "pip", "install", "nltk", "scikit-learn"])
    print("Packages installed")
except:
    print("Could not install packages automatically. Some features may be limited.")

# Try imports with fallbacks
try:
    from sklearn.feature_extraction.text import CountVectorizer
except ImportError:
    CountVectorizer = None
    print("scikit-learn not available. Some features will be limited.")

try:
    import nltk
    from nltk.tokenize import word_tokenize
    # Attempt to download required NLTK resources
    try:
        nltk.download('punkt', quiet=True)
        nltk.download('stopwords', quiet=True)
        from nltk.corpus import stopwords
        from nltk.stem import PorterStemmer
        NLTK_AVAILABLE = True
        print("NLTK is available with all required resources")
    except:
        NLTK_AVAILABLE = False
        print("NLTK available but couldn't download resources")
except ImportError:
    NLTK_AVAILABLE = False
    print("NLTK not available. Will use basic text processing.")

# Try to import spaCy with fallback
try:
    import spacy
    try:
        nlp = spacy.load("en_core_sci_md")  # Scientific/biomedical model
        SPACY_AVAILABLE = True
        SCIENTIFIC_MODEL = True
        print("Loaded scientific spaCy model")
    except:
        try:
            nlp = spacy.load("en_core_web_sm")  # Fallback to default model
            SPACY_AVAILABLE = True
            SCIENTIFIC_MODEL = False
            print("Loaded default spaCy model")
        except:
            SPACY_AVAILABLE = False
            nlp = None
            print("spaCy models not available. Entity extraction will be limited.")
except ImportError:
    SPACY_AVAILABLE = False
    nlp = None
    print("spaCy not available. Entity extraction will be limited.")

# Dictionary of common scientific abbreviations
SCI_ABBREVIATIONS = {
    'covid': 'coronavirus disease',
    'sars': 'severe acute respiratory syndrome',
    'icu': 'intensive care unit',
    'pcr': 'polymerase chain reaction',
    'mrna': 'messenger rna',
    'ace2': 'angiotensin-converting enzyme 2',
    'rt-pcr': 'reverse transcription polymerase chain reaction',
    'r0': 'basic reproduction number',
    'rct': 'randomized controlled trial',
    # Add more relevant scientific abbreviations as needed
}

# Basic stopwords list if NLTK is not available
BASIC_STOPWORDS = set([
    'a', 'an', 'the', 'and', 'or', 'but', 'if', 'because', 'as', 'what',
    'which', 'this', 'that', 'these', 'those', 'then', 'just', 'so', 'than',
    'such', 'both', 'through', 'about', 'for', 'is', 'of', 'while', 'during',
    'to', 'from', 'in', 'on', 'by', 'at', 'with'
])

# Scientific stopwords to remove (words common in papers but not informative)
SCIENTIFIC_STOPWORDS = set([
    'study', 'studies', 'research', 'researchers', 'paper', 'figure', 'table',
    'et', 'al', 'doi', 'pmid', 'journal', 'abstract', 'conclusion', 'results',
    'methods', 'published', 'authors'
])

# Fallback tokenizer if NLTK is not available
def simple_tokenize(text):
    """Simple tokenization function as fallback"""
    tokens = re.findall(r'\b\w+\b', text.lower())
    return tokens

# Get appropriate tokenize function
if NLTK_AVAILABLE:
    tokenize_func = word_tokenize
else:
    tokenize_func = simple_tokenize

# Get appropriate stopwords
if NLTK_AVAILABLE:
    try:
        STOPWORDS = set(stopwords.words('english')).union(SCIENTIFIC_STOPWORDS)
    except:
        STOPWORDS = BASIC_STOPWORDS.union(SCIENTIFIC_STOPWORDS)
else:
    STOPWORDS = BASIC_STOPWORDS.union(SCIENTIFIC_STOPWORDS)

# Preprocess text with enhanced scientific processing
def preprocess_scientific_text(text, expand_abbreviations=True, remove_punct=True):
    """Enhanced scientific text preprocessing with fallbacks for missing libraries"""
    if not isinstance(text, str):
        return ""
    
    # Convert to lowercase
    text = text.lower()
    
    # Expand common scientific abbreviations
    if expand_abbreviations:
        for abbr, expansion in SCI_ABBREVIATIONS.items():
            # Match whole word only (with word boundaries)
            text = re.sub(r'\b' + abbr + r'\b', expansion, text)
    
    # Remove URLs
    text = re.sub(r'http\S+', '', text)
    
    # Handle hashtags - keep the text without # symbol
    text = re.sub(r'#(\w+)', r'\1', text)
    
    # Handle mentions - remove them
    text = re.sub(r'@\w+', '', text)
    
    # Remove punctuation
    if remove_punct:
        text = text.translate(str.maketrans('', '', string.punctuation))
    
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    
    return text

# Extract entities using spaCy, with fallback
def extract_entities(text):
    """Extract scientific entities from text using spaCy if available, otherwise use simple approach"""
    if SPACY_AVAILABLE and nlp is not None and text:
        doc = nlp(text)
        entities = []
        
        # Extract named entities
        for ent in doc.ents:
            entities.append(ent.text.lower())
        
        # For scientific text, also extract noun chunks as they often represent concepts
        for chunk in doc.noun_chunks:
            entities.append(chunk.text.lower())
        
        return list(set(entities))  # Remove duplicates
    else:
        # Fallback: just use n-grams as proxy for entities
        words = preprocess_scientific_text(text).split()
        # Create bigrams and trigrams as simple entity proxies
        entities = words.copy()  # Start with unigrams
        
        # Add bigrams if we have enough words
        if len(words) >= 2:
            for i in range(len(words) - 1):
                entities.append(f"{words[i]} {words[i+1]}")
        
        # Add trigrams if we have enough words
        if len(words) >= 3:
            for i in range(len(words) - 2):
                entities.append(f"{words[i]} {words[i+1]} {words[i+2]}")
                
        return entities

# Create enhanced document representation with weighted fields and metadata
def create_enhanced_document_representation(row, title_weight=2.0, use_metadata=True):
    """
    Create enhanced document representation with weighted title and metadata
    Args:
        row: DataFrame row with paper data
        title_weight: Weight to apply to title (higher means more importance)
        use_metadata: Whether to include authors and journal info
    Returns:
        Enhanced document representation string
    """
    # Basic representation with weighted title (repeat title for higher weight)
    components = []
    
    # Add weighted title
    if 'title' in row and pd.notna(row['title']):
        components.extend([row['title']] * int(title_weight))
    
    # Add abstract
    if 'abstract' in row and pd.notna(row['abstract']):
        components.append(row['abstract'])
    
    # Add metadata if requested
    if use_metadata:
        # Add authors (if available)
        if 'authors' in row and pd.notna(row['authors']):
            components.append(row['authors'])
        
        # Add journal (if available)
        if 'journal' in row and pd.notna(row['journal']):
            components.append(row['journal'])
    
    # Join all components
    text = " ".join(components)
    
    # Preprocess
    return preprocess_scientific_text(text)

# Class for advanced scientific paper retrieval
class EnhancedScientificRetriever:
    def __init__(self, model_name='all-MiniLM-L6-v2'):
        self.model = SentenceTransformer(model_name)
        self.corpus_embeddings = None
        self.corpus_texts = None
        self.paper_ids = None
        self.metadata = None
        
    def index_collection(self, df_collection, title_weight=2.0, use_metadata=True):
        """
        Index the collection with enhanced document representation
        Args:
            df_collection: DataFrame with papers
            title_weight: Weight for the title
            use_metadata: Whether to include metadata
        """
        # Store paper metadata for later use in features
        self.metadata = {}
        for _, row in df_collection.iterrows():
            paper_id = row['cord_uid']
            self.metadata[paper_id] = {
                'time': row['time'] if pd.notna(row['time']) else None,
                'journal': row['journal'] if pd.notna(row['journal']) else "",
                'authors': row['authors'] if pd.notna(row['authors']) else "",
                'title': row['title'] if pd.notna(row['title']) else "",
                'abstract': row['abstract'] if pd.notna(row['abstract']) else ""
            }
        
        # Create enhanced text representation for each document
        self.corpus_texts = []
        self.paper_ids = []
        
        print("Creating enhanced document representations...")
        for _, row in df_collection.iterrows():
            enhanced_text = create_enhanced_document_representation(
                row, title_weight=title_weight, use_metadata=use_metadata
            )
            self.corpus_texts.append(enhanced_text)
            self.paper_ids.append(row['cord_uid'])
        
        # Calculate embeddings for all documents
        print("Calculating document embeddings...")
        self.corpus_embeddings = self.model.encode(
            self.corpus_texts, 
            convert_to_tensor=True,
            show_progress_bar=True
        )
        print(f"Created embeddings for {len(self.corpus_texts)} documents")
        
        # Extract and store entities from documents for later matching
        print("Extracting entities from documents...")
        self.doc_entities = {}
        for i, doc_id in enumerate(tqdm(self.paper_ids)):
            text = self.corpus_texts[i]
            self.doc_entities[doc_id] = extract_entities(text)
    
    def preprocess_tweet(self, tweet_text):
        """Enhanced tweet preprocessing"""
        # Apply scientific text preprocessing
        processed_tweet = preprocess_scientific_text(tweet_text)
        return processed_tweet
    
    def calculate_recency_score(self, paper_id, tweet_date=None):
        """
        Calculate recency score to favor more recent papers
        Args:
            paper_id: The paper ID
            tweet_date: Date of the tweet (if known)
        Returns:
            Recency score between 0 and 1 (higher for more recent papers)
        """
        if paper_id not in self.metadata or self.metadata[paper_id]['time'] is None:
            return 0.5  # Default score if no date
        
        paper_date = self.metadata[paper_id]['time']
        
        # If no tweet date, use a recent date
        if tweet_date is None:
            tweet_date = datetime(2020, 12, 31)  # End of 2020 as reference
        
        # Convert to timestamp if needed
        if isinstance(paper_date, str):
            try:
                paper_date = datetime.strptime(paper_date, '%Y-%m-%d')
            except:
                return 0.5
        
        # Calculate days difference
        days_diff = abs((tweet_date - paper_date).days)
        
        # Convert to score (1 for very recent, approaching 0 for very old)
        # Using sigmoid function for smooth transition
        recency_score = 1 / (1 + np.exp(days_diff / 365))  # 365 controls steepness
        
        return recency_score
    
    def calculate_relevance_bonus(self, query_entities, paper_id):
        """
        Calculate bonus based on entity overlap between query and document
        Args:
            query_entities: List of entities in the query
            paper_id: The paper ID
        Returns:
            Relevance bonus score between 0 and 1
        """
        if not query_entities or paper_id not in self.doc_entities:
            return 0
            
        doc_entities = self.doc_entities[paper_id]
        
        # Count overlapping entities
        overlap = sum(1 for e in query_entities if any(e in d for d in doc_entities))
        
        # Calculate score based on overlap ratio
        if len(query_entities) > 0:
            return min(1.0, overlap / len(query_entities))
        return 0
    
    def calculate_journal_score(self, paper_id):
        """
        Calculate journal importance score (proxy for citation impact)
        Args:
            paper_id: The paper ID
        Returns:
            Journal score between 0 and 1
        """
        if paper_id not in self.metadata:
            return 0.5
            
        journal = self.metadata[paper_id]['journal']
        
        # List of high-impact journals (customize based on your domain)
        high_impact_journals = {
            'nature': 1.0,
            'science': 1.0,
            'lancet': 0.9,
            'nejm': 0.9,
            'new england journal': 0.9,
            'jama': 0.8,
            'bmj': 0.8,
            'cell': 0.8,
            'pnas': 0.7,
            'plos': 0.6
        }
        
        # Check for journal match
        if journal:
            journal_lower = journal.lower()
            for j_name, score in high_impact_journals.items():
                if j_name in journal_lower:
                    return score
        
        # Default score
        return 0.5
    
    def rerank_candidates(self, query, candidate_ids, bm25_scores=None, top_k=5, 
                          alpha=0.3, use_context_features=True, context_weight=0.2):
        """
        Re-rank candidate documents with enhanced features
        Args:
            query: The tweet text
            candidate_ids: List of candidate paper IDs
            bm25_scores: BM25 scores for candidates
            top_k: Number of top results to return
            alpha: Weight between BM25 and neural scores
            use_context_features: Whether to use context-aware features
            context_weight: Weight for context features
        Returns:
            List of reranked document IDs
        """
        # Preprocess the query
        processed_query = self.preprocess_tweet(query)
        
        # Get query embedding
        query_embedding = self.model.encode(processed_query, convert_to_tensor=True)
        
        # Get embeddings for candidate documents
        candidate_indices = [self.paper_ids.index(cid) for cid in candidate_ids]
        candidate_embeddings = self.corpus_embeddings[candidate_indices]
        
        # Calculate cosine similarity
        cos_scores = util.cos_sim(query_embedding, candidate_embeddings)[0]
        
        # Combine BM25 and neural scores
        if bm25_scores is not None:
            bm25_tensor = torch.tensor(bm25_scores)
            bm25_norm = bm25_tensor / bm25_tensor.max() if torch.max(bm25_tensor) > 0 else bm25_tensor
            combined_scores = alpha * bm25_norm + (1-alpha) * cos_scores
        else:
            combined_scores = cos_scores
        
        # Apply context-aware features if requested
        if use_context_features:
            # Extract entities from query
            query_entities = extract_entities(processed_query)
            
            # Calculate context scores for each candidate
            context_scores = []
            for i, paper_id in enumerate(candidate_ids):
                # Recency score
                recency_score = self.calculate_recency_score(paper_id)
                
                # Entity relevance score
                relevance_score = self.calculate_relevance_bonus(query_entities, paper_id)
                
                # Journal importance score
                journal_score = self.calculate_journal_score(paper_id)
                
                # Combine context scores (adjust weights as needed)
                context_score = 0.4 * recency_score + 0.4 * relevance_score + 0.2 * journal_score
                context_scores.append(context_score)
            
            # Convert to tensor and combine with other scores
            context_tensor = torch.tensor(context_scores)
            final_scores = (1 - context_weight) * combined_scores + context_weight * context_tensor
        else:
            final_scores = combined_scores
        
        # Sort by score
        top_results = torch.argsort(-final_scores)[:top_k].tolist()
        
        # Return re-ranked document IDs
        return [candidate_ids[i] for i in top_results]

# Create enhanced BM25 index with improved preprocessing
def create_enhanced_bm25_index(df_collection, title_weight=2.0):
    """
    Create an enhanced BM25 index with better document representation
    Args:
        df_collection: DataFrame with papers
        title_weight: Weight for title
    Returns:
        BM25 index and paper IDs
    """
    print("Creating enhanced BM25 index...")
    
    # Create enhanced document representations
    corpus = []
    cord_uids = []
    
    for _, row in df_collection.iterrows():
        # Create enhanced text
        enhanced_text = create_enhanced_document_representation(
            row, title_weight=title_weight, use_metadata=True
        )
        corpus.append(enhanced_text)
        cord_uids.append(row['cord_uid'])
    
    # Tokenize
    tokenized_corpus = [doc.split() for doc in corpus]
    
    # Create BM25 index
    bm25 = BM25Okapi(tokenized_corpus)
    
    return bm25, cord_uids

# Enhanced BM25 retrieval with improved query processing
def get_enhanced_top_cord_uids(query, bm25, cord_uids, k=75):
    """
    Get top-k candidates using enhanced BM25
    Args:
        query: Tweet text
        bm25: BM25 index
        cord_uids: List of paper IDs
        k: Number of results to return
    Returns:
        List of paper IDs and scores
    """
    # Preprocess query
    processed_query = preprocess_scientific_text(query)
    tokenized_query = processed_query.split()
    
    # Get scores
    doc_scores = bm25.get_scores(tokenized_query)
    
    # Sort and return top k
    indices = np.argsort(-doc_scores)[:k]
    bm25_topk = [cord_uids[x] for x in indices]
    bm25_scores = [doc_scores[x] for x in indices]
    
    return bm25_topk, bm25_scores

# Implementation of the whole optimization pipeline
def run_enhanced_retrieval_experiment(df_collection, df_train, df_test, 
                                     title_weight=2.0, alpha=0.3, pool_size=75,
                                     use_context=True, context_weight=0.2,
                                     top_k=5, model_name='all-MiniLM-L6-v2'):
    """
    Run a complete retrieval experiment with all enhancements
    Args:
        df_collection: DataFrame with papers
        df_train: Training queries DataFrame
        df_test: Testing queries DataFrame
        title_weight: Weight for title in document representation
        alpha: Weight between BM25 and neural scores
        pool_size: Size of candidate pool from BM25
        use_context: Whether to use context-aware features
        context_weight: Weight for context features
        top_k: Number of results to return
        model_name: Name of the sentence transformer model to use
    """
    print("=== Enhanced Scientific Retrieval Experiment ===")
    
    # 1. Create enhanced BM25 index
    enhanced_bm25, enhanced_cord_uids = create_enhanced_bm25_index(df_collection, title_weight)
    
    # 2. Initialize and index with enhanced retriever
    # Check if the model is available - fallback to a simpler one if not
    try:
        enhanced_retriever = EnhancedScientificRetriever(model_name)
    except Exception as e:
        print(f"Error loading model {model_name}: {e}")
        print("Falling back to default model")
        enhanced_retriever = EnhancedScientificRetriever('all-MiniLM-L6-v2')
    
    enhanced_retriever.index_collection(df_collection, title_weight)
    
    # 3. Generate predictions
    print("\n=== Generating Enhanced Predictions ===")
    print(f"Using alpha={alpha}, pool_size={pool_size}, context_weight={context_weight}")
    
    enhanced_results = []
    
    for _, row in tqdm(df_test.iterrows(), total=len(df_test)):
        query = row['tweet_text']
        
        # First-stage: Get enhanced BM25 candidates
        bm25_candidates, bm25_scores = get_enhanced_top_cord_uids(
            query, enhanced_bm25, enhanced_cord_uids, k=pool_size
        )
        
        # Second-stage: Enhanced neural re-ranking
        reranked_candidates = enhanced_retriever.rerank_candidates(
            query, 
            bm25_candidates, 
            bm25_scores, 
            top_k=top_k,
            alpha=alpha,
            use_context_features=use_context,
            context_weight=context_weight
        )
        
        enhanced_results.append({
            'post_id': row['post_id'],
            'preds': reranked_candidates
        })
    
    # Save enhanced predictions
    enhanced_df = pd.DataFrame(enhanced_results)
    enhanced_df.to_csv('enhanced_predictions.tsv', index=None, sep='\t')
    print("Enhanced predictions saved to 'enhanced_predictions.tsv'")
    
    # Evaluate enhanced predictions
    enhanced_eval_data = []
    for _, row in df_test.iterrows():
        post_id = row['post_id']
        gold_uid = row['cord_uid']
        pred_row = enhanced_df[enhanced_df['post_id'] == post_id].iloc[0]
        preds = pred_row['preds']
        
        enhanced_eval_data.append({
            'post_id': post_id,
            'cord_uid': gold_uid,
            'reranked': preds
        })
    
    # Calculate MRR scores using the same function as in original code
    def get_performance_mrr(data, col_gold, col_pred, list_k = [1, 5, 10]):
        d_performance = {}
        for k in list_k:
            data["in_topx"] = data.apply(lambda x: (1/([i for i in x[col_pred][:k]].index(x[col_gold]) + 1) if x[col_gold] in [i for i in x[col_pred][:k]] else 0), axis=1)
            d_performance[k] = data["in_topx"].mean()
        return d_performance
    
    enhanced_eval_df = pd.DataFrame(enhanced_eval_data)
    enhanced_mrr_scores = get_performance_mrr(enhanced_eval_df, 'cord_uid', 'reranked', list_k=[1, 5, 10])
    print(f"Enhanced prediction results: {enhanced_mrr_scores}")
    
    return enhanced_df, enhanced_mrr_scores

# Add a simple main method with safe defaults - can be called directly
try:
    # Run the enhanced experiment with safer default choices
    if __name__ == "__main__":
        print("Running enhanced scientific paper retrieval experiment...")
        enhanced_predictions, enhanced_results = run_enhanced_retrieval_experiment(
            df_collection, 
            df_query_train, 
            df_query_dev,
            title_weight=3.0,      # Higher weight for title
            alpha=0.3,             # Value from previous optimization
            pool_size=75,          # Value from previous optimization
            use_context=True,      # Enable context-aware features
            context_weight=0.2,    # Weight for context features
            top_k=5,               # Return top 5 results
            model_name='all-MiniLM-L6-v2'  # Stick with the original model for safety
        )
        print(f"\nEnhanced retrieval experiment complete!")

except Exception as e:
    print(f"Error running experiment: {e}")
    print("Please check error messages above and modify the code as needed.")

Installing required packages...
Packages installed
NLTK is available with all required resources
spaCy models not available. Entity extraction will be limited.
Running enhanced scientific paper retrieval experiment...
=== Enhanced Scientific Retrieval Experiment ===
Creating enhanced BM25 index...




Creating enhanced document representations...
Calculating document embeddings...


Batches: 100%|██████████| 242/242 [11:57<00:00,  2.96s/it]


Created embeddings for 7718 documents
Extracting entities from documents...


100%|██████████| 7718/7718 [00:06<00:00, 1223.46it/s]



=== Generating Enhanced Predictions ===
Using alpha=0.3, pool_size=75, context_weight=0.2


100%|██████████| 1400/1400 [11:05<00:00,  2.10it/s]


Enhanced predictions saved to 'enhanced_predictions.tsv'
Enhanced prediction results: {1: 0.6035714285714285, 5: 0.649452380952381, 10: 0.649452380952381}

Enhanced retrieval experiment complete!


# 4) Executing neural re-ranking for test phase

In [4]:
import pandas as pd
from tqdm import tqdm

# Load the test data
PATH_QUERY_TEST_DATA = 'subtask4b_query_tweets_test.tsv'
df_query_test = pd.read_csv(PATH_QUERY_TEST_DATA, sep='\t')

# This function will process the test set using the already trained model
def process_test_set(df_collection, df_test, 
                     enhanced_retriever=None, enhanced_bm25=None, enhanced_cord_uids=None,
                     alpha=0.3, pool_size=75, context_weight=0.2, top_k=5):
    """
    Generate predictions for the test set using the enhanced retrieval model
    """
    print("=== Generating Test Set Predictions ===")
    print(f"Using alpha={alpha}, pool_size={pool_size}, context_weight={context_weight}")
    
    # If the models weren't passed in, we need to recreate them
    if enhanced_retriever is None or enhanced_bm25 is None or enhanced_cord_uids is None:
        print("Creating enhanced BM25 index...")
        enhanced_bm25, enhanced_cord_uids = create_enhanced_bm25_index(df_collection, title_weight=3.0)
        
        print("Initializing enhanced retriever...")
        enhanced_retriever = EnhancedScientificRetriever('all-MiniLM-L6-v2')
        enhanced_retriever.index_collection(df_collection, title_weight=3.0)
    
    # Generate predictions
    test_results = []
    
    for _, row in tqdm(df_test.iterrows(), total=len(df_test)):
        query = row['tweet_text']
        post_id = row['post_id']
        
        # First-stage: Get enhanced BM25 candidates
        bm25_candidates, bm25_scores = get_enhanced_top_cord_uids(
            query, enhanced_bm25, enhanced_cord_uids, k=pool_size
        )
        
        # Second-stage: Enhanced neural re-ranking
        reranked_candidates = enhanced_retriever.rerank_candidates(
            query, 
            bm25_candidates, 
            bm25_scores, 
            top_k=top_k,
            alpha=alpha,
            use_context_features=True,
            context_weight=context_weight
        )
        
        test_results.append({
            'post_id': post_id,
            'preds': reranked_candidates
        })
    
    # Save predictions
    test_df = pd.DataFrame(test_results)
    test_df.to_csv('predictions.tsv', index=None, sep='\t')
    print("Test predictions saved to 'predictions.tsv'")
    
    return test_df

# Process the test set - you can use the existing models if they're in memory,
# otherwise they will be recreated
test_predictions = process_test_set(
    df_collection, 
    df_query_test,
    enhanced_retriever=enhanced_retriever if 'enhanced_retriever' in locals() else None,
    enhanced_bm25=enhanced_bm25 if 'enhanced_bm25' in locals() else None,
    enhanced_cord_uids=enhanced_cord_uids if 'enhanced_cord_uids' in locals() else None,
    alpha=0.3,            # Best value from optimization
    pool_size=75,         # Best value from optimization
    context_weight=0.2,   # Value used in enhanced experiment
    top_k=5               # Return top 5 results
)

print("Test set processing complete!")

=== Generating Test Set Predictions ===
Using alpha=0.3, pool_size=75, context_weight=0.2
Creating enhanced BM25 index...
Creating enhanced BM25 index...
Initializing enhanced retriever...




Creating enhanced document representations...
Calculating document embeddings...


Batches: 100%|██████████| 242/242 [12:26<00:00,  3.09s/it]


Created embeddings for 7718 documents
Extracting entities from documents...


100%|██████████| 7718/7718 [00:06<00:00, 1225.99it/s]
100%|██████████| 1446/1446 [09:46<00:00,  2.46it/s]


Test predictions saved to 'predictions.tsv'
Test set processing complete!
