In [1]:
import os
current_path = os.getcwd()
path = current_path.split("thesis")[0] + "thesis"
print(path)
os.chdir(path)
print(os.getcwd())

import torch
from transformers import BertTokenizer, BertModel
import faiss
import numpy as np
import json
import tqdm

/home/upadro/code/thesis
/home/upadro/code/thesis


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from src.models.vector_db.commons.input_loader import InputLoader

In [3]:
def load_all_input_from_dir(input_data_path):
        files = []
        for (dirpath, dirnames, filenames) in os.walk(input_data_path):
            for filename in filenames:
                if "json" in filename:
                    files.append(os.path.join(dirpath, filename))
        input_loader = InputLoader()
        total_inference_datapoints = []
        for file in files:
            individual_datapoints = input_loader.load_data(data_file=file)
            total_inference_datapoints.extend(individual_datapoints) # type: ignore
        return total_inference_datapoints

# data = load_all_input_from_dir("src/models/single_datapoints/common")
data = load_all_input_from_dir("input/inference_input/english/unique_query_test")

In [4]:
print(data[0].keys())

dict_keys(['query', 'case_name', 'relevant_paragrpahs', 'paragraph_numbers', 'link', 'all_paragraphs', 'id'])


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load pre-trained BERT model and tokenizer, and move the model to GPU if available
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased').to(device)  # Move model to GPU

# Function to encode a list of paragraphs or queries using BERT
def encode_texts(texts):
    # Tokenize and encode, move inputs to the correct device
    inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=512)
    inputs = {key: value.to(device) for key, value in inputs.items()}  # Move inputs to GPU if available
    
    # Get embeddings from BERT
    with torch.no_grad():
        outputs = model(**inputs)
        
    # Take the mean of the last hidden state for each input text and move to CPU
    embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()  # Ensure the output is on CPU
    return embeddings

Using device: cuda




In [6]:
def recall_at_k(actual, predicted, k):
    relevance = [1 if x in actual else 0 for x in predicted]
    r = np.asarray(relevance)[:k]
    return np.sum(r) / len(actual) if len(actual) > 0 else 0

In [7]:
all_embeddings = []
metadata = []

# Set to track already encoded links
encoded_links = set()

# Iterate over each data point to encode all paragraphs
for i, datapoint in enumerate(data):
    case_name = datapoint['case_name']
    link = datapoint['link']
    
    # Check if the link has already been encoded
    if link in encoded_links:
        # print(f"Link {link} already encoded. Skipping...")
        continue  # Skip this data point as it's already encoded
    
    # Combine all paragraphs into one per paragraph set
    paragraphs = ["\n".join(paras) for paras in datapoint['all_paragraphs']]
    
    # Encode all paragraphs for the current data point
    paragraph_embeddings = encode_texts(paragraphs)  # Ensure this returns CPU embeddings
    
    # Add each paragraph's embedding to the FAISS index and keep metadata
    for j, embedding in enumerate(paragraph_embeddings):
        all_embeddings.append(embedding)  # Append the numpy array (CPU-based)
        metadata.append({
            "case_name": case_name,
            "link": link,
            "paragraph_index": j,
            "paragraph_text": paragraphs[j]
        })
    
    # Mark this link as encoded
    encoded_links.add(link)
    # print(f"Encoded and added paragraphs for link: {link}")

# Convert embeddings list to numpy array
all_embeddings_np = np.array(all_embeddings).astype('float32')

In [8]:
embedding_dim = all_embeddings_np.shape[1]
index = faiss.IndexFlatL2(embedding_dim)
index.add(all_embeddings_np)
print(f"Total embeddings added to FAISS index: {index.ntotal}")

Total embeddings added to FAISS index: 88078


In [9]:
recalls_2_percent = []
recalls_5_percent = []
recalls_10_percent = []

# Iterate over each data point to rank paragraphs based on the given query
for i, datapoint in enumerate(data):
    case_name = datapoint['case_name']
    link = datapoint['link']
    
    # Get the paragraph numbers from the dataset (1-indexed)
    actual_paragraph_numbers = datapoint.get('paragraph_numbers', [])
    
    # Use the provided queries in the data point and concatenate them into a single query
    queries = datapoint.get('query', [])
    
    # If there are no queries, skip this data point
    if not queries:
        print(f"No queries found for case: {case_name}")
        continue

    # Combine all queries into a single query string separated by commas
    combined_query = ", ".join(queries)

    # Encode the combined query as a single query
    query_embedding = encode_texts([combined_query])[0]  # Single embedding for the combined query
    
    # Search the FAISS index for this query
    query_embedding = query_embedding.reshape(1, -1).astype('float32')  # Reshape for FAISS
    distances, indices = index.search(query_embedding, len(all_embeddings))  # Get all paragraphs ranked
    
    # Filter results for this data point
    filtered_results = [
        (metadata[idx], distance) for idx, distance in zip(indices[0], distances[0])
        if metadata[idx]["case_name"] == case_name and metadata[idx]["link"] == link
    ]
    
    # Get the top percentages for each filtered result
    top_2_percent = max(1, int(len(filtered_results) * 0.02))  # Top 2% of filtered results
    top_5_percent = max(1, int(len(filtered_results) * 0.05))  # Top 5% of filtered results
    top_10_percent = max(1, int(len(filtered_results) * 0.1))  # Top 10% of filtered results
    
    # Get the paragraph numbers for top percentages
    top_2_percent_paragraphs = [(result_metadata["paragraph_index"] + 1) for result_metadata, _ in filtered_results[:top_2_percent]]
    top_5_percent_paragraphs = [(result_metadata["paragraph_index"] + 1) for result_metadata, _ in filtered_results[:top_5_percent]]
    top_10_percent_paragraphs = [(result_metadata["paragraph_index"] + 1) for result_metadata, _ in filtered_results[:top_10_percent]]
    
    # Calculate individual recalls for each percentage
    recall_2 = recall_at_k(actual_paragraph_numbers, top_2_percent_paragraphs, k=len(actual_paragraph_numbers))
    recall_5 = recall_at_k(actual_paragraph_numbers, top_5_percent_paragraphs, k=len(actual_paragraph_numbers))
    recall_10 = recall_at_k(actual_paragraph_numbers, top_10_percent_paragraphs, k=len(actual_paragraph_numbers))
    
    # Store the individual recall values
    recalls_2_percent.append(recall_2)
    recalls_5_percent.append(recall_5)
    recalls_10_percent.append(recall_10)

# Calculate the mean recall for each percentage
mean_recall_2_percent = np.mean(recalls_2_percent)
mean_recall_5_percent = np.mean(recalls_5_percent)
mean_recall_10_percent = np.mean(recalls_10_percent)

# Print the overall recall scores
print(f"Mean Recall at 2%: {mean_recall_2_percent:.4f}")
print(f"Mean Recall at 5%: {mean_recall_5_percent:.4f}")
print(f"Mean Recall at 10%: {mean_recall_10_percent:.4f}")

Mean Recall at 2%: 0.0794
Mean Recall at 5%: 0.0943
Mean Recall at 10%: 0.1044
