In [1]:

!pip3 install torch torchvision
!pip3 install chromadb
!pip3 install sentence-transformers scipy numpy scikit-learn
!pip3 install pypdf


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip

In [None]:
import time
import chromadb
import json
import numpy as np
import torch

from sentence_transformers import SentenceTransformer
from concurrent.futures import ThreadPoolExecutor

  from tqdm.autonotebook import tqdm, trange


In [None]:
from utils import load_data, get_document_chunks, LinearAdapter

## ChromaDB

In [3]:
def setup_chromadb(path="chromadb"):
    client = chromadb.PersistentClient(path=path)
    try:
        client.delete_collection("paper_collection")
    except Exception as e:
        print(f"Collection deletion error: {e}")

    collection = client.create_collection(
        "paper_collection", metadata={"hnsw:space": "cosine"})
    return collection

## Data Processing and Model Prep

In [None]:
def add_chunks_to_collection(collection, chunks):
    def add_chunk(chunk, index):
        collection.add(documents=[chunk], ids=[f"chunk_{index}"])

    start_time = time.time()
    with ThreadPoolExecutor() as executor:
        futures = [executor.submit(add_chunk, chunk, i)
                   for i, chunk in enumerate(chunks)]
    for future in futures:
        future.result()
    print(
        f"Time taken for adding chunks: {time.time() - start_time:.2f} seconds")


def retrieve_documents_embeddings(collection, query_embedding, k=10):
    results = collection.query(
        query_embeddings=[query_embedding.tolist()],
        n_results=k
    )
    return results['documents'][0]

## Metrics and Calculations

In [5]:
def reciprocal_rank(retrieved_docs, ground_truth, k):
    try:
        rank = retrieved_docs.index(ground_truth) + 1
        return 1.0 / rank if rank <= k else 0.0
    except ValueError:
        return 0.0


def hit_rate(retrieved_docs, ground_truth, k):
    return 1.0 if ground_truth in retrieved_docs[:k] else 0.0


def validate_embedding_model(validation_data, base_model, collection, k=10):
    hit_rates = []
    reciprocal_ranks = []
    for data_point in validation_data:
        question = data_point['question']
        ground_truth = data_point['chunk']
        question_embedding = base_model.encode(question)
        retrieved_docs = retrieve_documents_embeddings(
            collection, question_embedding, k)
        hit_rates.append(hit_rate(retrieved_docs, ground_truth, k))
        reciprocal_ranks.append(reciprocal_rank(
            retrieved_docs, ground_truth, k))
    return {
        'average_hit_rate': np.mean(hit_rates),
        'average_reciprocal_rank': np.mean(reciprocal_ranks)
    }

## Adapter and Dataset Classes


In [None]:
def encode_query(query, base_model, adapter):
    device = next(adapter.parameters()).device
    query_emb = base_model.encode(query, convert_to_tensor=True).to(device)
    adapted_query_emb = adapter(query_emb)
    return adapted_query_emb.cpu().detach().numpy()


def evaluate_adapter(validation_data, base_model, adapter, collection, k=50):
    hit_rates, reciprocal_ranks = [], []
    for data_point in validation_data:
        question, ground_truth = data_point['question'], data_point['chunk']
        question_embedding = encode_query(question, base_model, adapter)
        retrieved_docs = retrieve_documents_embeddings(
            collection, question_embedding, k)
        hit_rates.append(hit_rate(retrieved_docs, ground_truth, k))
        reciprocal_ranks.append(reciprocal_rank(
            retrieved_docs, ground_truth, k))
    return {
        'average_hit_rate': np.mean(hit_rates),
        'average_reciprocal_rank': np.mean(reciprocal_ranks)
    }

# Run 

In [7]:
with open('../globals.json') as config_file:
    config = json.load(config_file)
    pdf_path = config.get("main_pdf")

collection = setup_chromadb()
base_model = SentenceTransformer('all-MiniLM-L6-v2')

chunks = get_document_chunks(pdf_path)
add_chunks_to_collection(collection, chunks)

train_path = '../data/train.json'
validation_path = '../data/validation.json'
train_data, validation_data = load_data(train_path, validation_path)

base_results = validate_embedding_model(
    validation_data, base_model, collection)
print("Base Model - Average Hit Rate @10:", base_results['average_hit_rate'])
print("Base Model - Mean Reciprocal Rank @10:",
      base_results['average_reciprocal_rank'])

adapter = LinearAdapter(base_model.get_sentence_embedding_dimension())
adapter.load_state_dict(torch.load(
    '../adapters/linear_adapter_10epochs.pth')['adapter_state_dict'])
adapter_results = evaluate_adapter(
    validation_data, base_model, adapter, collection)
print("Adapter - Average Hit Rate @10:", adapter_results['average_hit_rate'])
print("Adapter - Mean Reciprocal Rank @10:",
      adapter_results['average_reciprocal_rank'])

Ignoring wrong pointing object 46 0 (offset 0)
Ignoring wrong pointing object 48 0 (offset 0)
Ignoring wrong pointing object 50 0 (offset 0)
Ignoring wrong pointing object 53 0 (offset 0)
Ignoring wrong pointing object 193 0 (offset 0)
Ignoring wrong pointing object 195 0 (offset 0)
Ignoring wrong pointing object 197 0 (offset 0)
Ignoring wrong pointing object 200 0 (offset 0)
Ignoring wrong pointing object 217 0 (offset 0)
Ignoring wrong pointing object 219 0 (offset 0)
Ignoring wrong pointing object 221 0 (offset 0)
Ignoring wrong pointing object 224 0 (offset 0)
Ignoring wrong pointing object 298 0 (offset 0)
Ignoring wrong pointing object 300 0 (offset 0)
Ignoring wrong pointing object 302 0 (offset 0)
Ignoring wrong pointing object 304 0 (offset 0)
Ignoring wrong pointing object 308 0 (offset 0)
Ignoring wrong pointing object 310 0 (offset 0)
Ignoring wrong pointing object 312 0 (offset 0)
Ignoring wrong pointing object 314 0 (offset 0)
Ignoring wrong pointing object 354 0 (offset

Time taken for adding chunks: 11.89 seconds
Base Model - Average Hit Rate @10: 0.013157894736842105
Base Model - Mean Reciprocal Rank @10: 0.003380847953216374
Adapter - Average Hit Rate @10: 0.019736842105263157
Adapter - Mean Reciprocal Rank @10: 0.007080200501253133
