In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
from datasets import load_dataset
from qdrant_client import QdrantClient, models
from langchain_qdrant import FastEmbedSparse, QdrantVectorStore, RetrievalMode
from qdrant_client.http.models import Distance, SparseVectorParams, VectorParams
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import re
from uuid import uuid4
import json

import math
import evaluate

2025-04-29 22:38:04.107392: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745966284.118599    3847 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745966284.122362    3847 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-29 22:38:04.134849: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
collection_name = "wikipedia-dot"
client = QdrantClient(url="http://192.168.2.3:6333" , timeout=999)

In [4]:
model = SentenceTransformer('nthakur/contriever-base-msmarco')

In [5]:
from itertools import islice

def batched(iterable, n):
    iterator = iter(iterable)
    while batch := list(islice(iterator, n)):
        yield batch
batch_size = 64

In [6]:
import difflib

def get_overlap(s1, s2):
    s = difflib.SequenceMatcher(None, s1, s2)
    pos_a, pos_b, size = s.find_longest_match(0, len(s1), 0, len(s2)) 
    return s1[pos_a:pos_a+size]

def is_gold_compression(text, supporting_map, threshold=0.5):
    text = text.lower().strip()
    for supporting_text in supporting_map:
        supporting_text = supporting_text.lower().strip()
        if not supporting_text or not text:
            continue
        overlap = get_overlap(supporting_text, text)
        overlap_len = len(overlap)
        if (
            (overlap_len / len(supporting_text) > threshold)
            or
            (overlap_len / len(text) > threshold)
        ):
            return True
    return False

In [7]:
# tqa_dataset = load_dataset("mandarjoshi/trivia_qa" , "rc")
# tqa_dataset_val = tqa_dataset["validation"]


In [8]:
tqa_dataset = load_dataset("mandarjoshi/trivia_qa" , "rc")
tqa_dataset_val = tqa_dataset["validation"]
total_batches = math.ceil(len(tqa_dataset_val) / batch_size)

document_tqa_test = []

for batch in tqdm(batched(tqa_dataset_val, batch_size), total=total_batches, desc="dataset batches"):
    questions = [data["question"] for data in batch]
    answers = [data["answer"] for data in batch]

    # Lowercased answer sets
    answers_grouped = [
        set(s.lower() for s in ans["aliases"] + ans["normalized_aliases"] + [ans["normalized_value"]])
        for ans in answers
    ]

    # Encode questions
    queries_encode = model.encode(questions)

    # Prepare batch search
    search_queries = [
        models.QueryRequest(query=query, with_payload=True, limit=100)
        for query in queries_encode
    ]

    # Run batch query
    batch_point = client.query_batch_points(collection_name=collection_name, requests=search_queries)

    # Process results
    for i, query in enumerate(batch_point):
        ctxs = []
        for point in query.points:
            doc_id = point.payload.get("docid")
            title = point.payload.get("title")
            text = point.payload.get("text")
            score = point.score

            # Match answer aliases in text
            matches = {
                (match.start(), len(match.group()))
                for alias in answers_grouped[i]
                for match in re.finditer(re.escape(alias), text.lower())
            }

            has_answer = len(matches) > 0

            ctxs.append({
                "id": doc_id,
                "title": title,
                "text": text,
                "score": score,
                "has_answer": has_answer,
                "answer_occurrences": [
                    {"start": start, "length": length} for (start, length) in matches
                ]
            })

        document_tqa_test.append({
            "question": questions[i],
            "answers": list(answers_grouped[i]),
            "ctxs": ctxs
        })

with open("dataset/test_tqa.json", "w", encoding="utf-8") as f:
    json.dump(document_tqa_test, f, ensure_ascii=False, indent=2)

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/26 [00:00<?, ?files/s]

Generating train split:   0%|          | 0/138384 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/17944 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/17210 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

dataset batches: 100%|████████████████████████████████████████████████████████████████| 281/281 [06:07<00:00,  1.31s/it]


In [9]:
# hotpot_qa = load_dataset("hotpotqa/hotpot_qa" , "fullwiki")
# hotpot_qa_dataset_val = hotpot_qa["validation"]
# hotpot_qa_dataset_val = hotpot_qa_dataset_val.filter(
#     lambda example: all(title in example['context']['title'] for title in set(example['supporting_facts']['title']))
# )
# total_batches = math.ceil(len(hotpot_qa_dataset_val) / batch_size)

# document_hotpotqa_test = []

# for batch in tqdm(batched(hotpot_qa_dataset_val, batch_size), total=total_batches, desc="dataset batches"):
#     questions = [data["question"] for data in batch]
#     answers = [data["answer"] for data in batch]

#     # Lowercased answer sets
#     answers_grouped = [
#         set([ans.lower()])
#         for ans in answers
#     ]

#     # Encode questions
#     queries_encode = model.encode(questions)

#     # Prepare batch search
#     search_queries = [
#         models.QueryRequest(query=query, with_payload=True, limit=100)
#         for query in queries_encode
#     ]

#     # Run batch query
#     batch_point = client.query_batch_points(collection_name=collection_name, requests=search_queries)

#     # Process results
#     for i, query in enumerate(batch_point):

        
#         support_titles = batch[i]["supporting_facts"]["title"]
#         support_sent_ids = batch[i]["supporting_facts"]["sent_id"]
#         context_titles = batch[i]["context"]["title"]
#         context_sentences = batch[i]["context"]["sentences"]
        
#         # Build supporting sentence map per title
#         supporting_map = []
#         for title, sent_id in zip(support_titles, support_sent_ids):
#             try:
#                 title_idx = context_titles.index(title)
#                 sentence = context_sentences[title_idx][sent_id]
#                 supporting_map.append(sentence.strip())
#             except (ValueError, IndexError):
#                 continue

#         ctxs = []
#         for point in query.points:
#             doc_id = point.payload.get("docid")
#             title = point.payload.get("title")
#             text = point.payload.get("text")
#             score = point.score

#             # Match answer aliases in text
#             matches = {
#                 (match.start(), len(match.group()))
#                 for alias in answers_grouped[i]
#                 for match in re.finditer(re.escape(alias), text.lower())
#             }

#             has_answer = len(matches) > 0

#             gold_document = is_gold_compression(text, supporting_map)

#             ctxs.append({
#                 "id": doc_id,
#                 "title": title,
#                 "text": text,
#                 "score": score,
#                 "has_answer": has_answer,
#                 "answer_occurrences": [
#                     {"start": start, "length": length} for (start, length) in matches
#                 ],
#                 "gold_document" : gold_document
#             })

#         document_hotpotqa_test.append({
#             "question": questions[i],
#             "answers": list(answers_grouped[i]),
#             "ctxs": ctxs
#         })

# with open("dataset/test_hotpotqa.json", "w", encoding="utf-8") as f:
#     json.dump(document_hotpotqa_test, f, ensure_ascii=False, indent=2)

In [10]:
# def format_dataset(sample):
#     question = sample['question']['text']
#     context = sample['document']['tokens']['token']
#     is_html = sample['document']['tokens']['is_html']
#     long_answers = sample['annotations']['long_answer']
#     short_answers = sample['annotations']['short_answers']
    
#     context_string =  " ".join([context[i] for i in range(len(context)) if not is_html[i]])
    
#     # 0 - No ; 1 - Yes
#     for answer in sample['annotations']['yes_no_answer']:
#         if answer == 0 or answer == 1:
#           return {"question": question, "context": context_string, "short": [], "long": [], "category": "no" if answer == 0 else "yes" , "answer" : ["no"] if answer == 0 else ["yes"]}
    
#     short_targets = []
#     for s in short_answers:
#         short_targets.extend(s['text'])
#     short_targets = list(set(short_targets))
    
#     long_targets = []
#     for s in long_answers:
#         if s['start_token'] == -1:
#             continue
#         answer = context[s['start_token']: s['end_token']]
#         html = is_html[s['start_token']: s['end_token']]
#         new_answer = " ".join([answer[i] for i in range(len(answer)) if not html[i]])
#         if new_answer not in long_targets:
#             long_targets.append(new_answer)
    
#     category = "long_short" if len(short_targets + long_targets) > 0 else "null"
    
#     return {"question": question, "context": context_string, "short": short_targets, "long": long_targets, "category": category , "answer" : short_targets + long_targets}

In [11]:
# natural_qa = load_dataset("google-research-datasets/natural_questions")
# natural_qa_dataset_val = natural_qa["validation"]
# natural_qa_dataset_val = natural_qa_dataset_val.map(format_dataset).remove_columns(["annotations", "document", "id"])
# natural_qa_dataset_val = natural_qa_dataset_val.filter(lambda x: x["category"] != "null")
# total_batches = math.ceil(len(natural_qa_dataset_val) / batch_size)

# document_natural_qa_test = []

# for batch in tqdm(batched(natural_qa_dataset_val, batch_size), total=total_batches, desc="dataset batches"):
#     questions = [data["question"] for data in batch]
#     answers = [data["answer"] for data in batch]

#     # Lowercased answer sets
#     answers_grouped = [
#         set(s.lower() for s in ans)
#         for ans in answers
#     ]

#     # Encode questions
#     queries_encode = model.encode(questions)

#     # Prepare batch search
#     search_queries = [
#         models.QueryRequest(query=query, with_payload=True, limit=100)
#         for query in queries_encode
#     ]

#     # Run batch query
#     batch_point = client.query_batch_points(collection_name=collection_name, requests=search_queries)

#     # Process results
#     for i, query in enumerate(batch_point):
#         ctxs = []
#         for point in query.points:
#             doc_id = point.payload.get("docid")
#             title = point.payload.get("title")
#             text = point.payload.get("text")
#             score = point.score

#             # Match answer aliases in text
#             matches = {
#                 (match.start(), len(match.group()))
#                 for alias in answers_grouped[i]
#                 for match in re.finditer(re.escape(alias), text.lower())
#             }

#             has_answer = len(matches) > 0

#             ctxs.append({
#                 "id": doc_id,
#                 "title": title,
#                 "text": text,
#                 "score": score,
#                 "has_answer": has_answer,
#                 "answer_occurrences": [
#                     {"start": start, "length": length} for (start, length) in matches
#                 ]
#             })

#         document_natural_qa_test.append({
#             "question": questions[i],
#             "answers": list(answers_grouped[i]),
#             "ctxs": ctxs
#         })

# with open("dataset/test_nqa.json", "w", encoding="utf-8") as f:
#     json.dump(document_natural_qa_test, f, ensure_ascii=False, indent=2)

In [12]:
# musique_qa = load_dataset("dgslibisey/MuSiQue")
# musique_qa_dataset_val = musique_qa["validation"]
# musique_qa_dataset_val = musique_qa_dataset_val.filter(lambda x : x["answerable"] == True)
# total_batches = math.ceil(len(musique_qa_dataset_val) / batch_size)

# document_musique_qa_test = []

# for batch in tqdm(batched(musique_qa_dataset_val, batch_size), total=total_batches, desc="dataset batches"):
#     questions = [data["question"] for data in batch]
#     answers = [[data["answer"]] + data["answer_aliases"] for data in batch]


#     # Lowercased answer sets
#     answers_grouped = [
#         set(s.lower() for s in ans)
#         for ans in answers
#     ]

#     # Encode questions
#     queries_encode = model.encode(questions)

#     # Prepare batch search
#     search_queries = [
#         models.QueryRequest(query=query, with_payload=True, limit=100)
#         for query in queries_encode
#     ]

#     # Run batch query
#     batch_point = client.query_batch_points(collection_name=collection_name, requests=search_queries)

#     # Process results
#     for i, query in enumerate(batch_point):

#         supporting_map = [doc["paragraph_text"] for doc in batch[i]["paragraphs"] if doc["is_supporting"] == True]
        
#         ctxs = []
#         for point in query.points:
#             doc_id = point.payload.get("docid")
#             title = point.payload.get("title")
#             text = point.payload.get("text")
#             score = point.score

#             # Match answer aliases in text
#             matches = {
#                 (match.start(), len(match.group()))
#                 for alias in answers_grouped[i]
#                 for match in re.finditer(re.escape(alias), text.lower())
#             }

#             has_answer = len(matches) > 0

#             gold_document = is_gold_compression(text, supporting_map)
#             ctxs.append({
#                 "id": doc_id,
#                 "title": title,
#                 "text": text,
#                 "score": score,
#                 "has_answer": has_answer,
#                 "answer_occurrences": [
#                     {"start": start, "length": length} for (start, length) in matches
#                 ],
#                 "gold_document" : gold_document
#             })

#         document_musique_qa_test.append({
#             "question": questions[i],
#             "answers": list(answers_grouped[i]),
#             "ctxs": ctxs
#         })

# with open("dataset/test_musique.json", "w", encoding="utf-8") as f:
#     json.dump(document_musique_qa_test, f, ensure_ascii=False, indent=2)

In [13]:
# two_wiki_qa = load_dataset("kamelliao/2wikimultihopqa")
# two_wiki_qa_dataset_val = two_wiki_qa["validation"]

In [14]:
# two_wiki_qa_dataset_val[2]

In [15]:
# two_wiki_qa = load_dataset("kamelliao/2wikimultihopqa")
# two_wiki_qa_dataset_val = two_wiki_qa["validation"]
# two_wiki_qa_dataset_val = two_wiki_qa_dataset_val.filter(
#     lambda example: all(title in example['context']['title'] for title in set(example['supporting_facts']['title']))
# )
# total_batches = math.ceil(len(two_wiki_qa_dataset_val) / batch_size)

# document_two_wiki_qa_test = []

# for batch in tqdm(batched(two_wiki_qa_dataset_val, batch_size), total=total_batches, desc="dataset batches"):
#     questions = [data["question"] for data in batch]
#     answers = [data["answer"] for data in batch]


    

#     # Lowercased answer sets
#     answers_grouped = [
#         set([ans.lower()])
#         for ans in answers
#     ]


#     # Encode questions
#     queries_encode = model.encode(questions)

#     # Prepare batch search
#     search_queries = [
#         models.QueryRequest(query=query, with_payload=True, limit=100)
#         for query in queries_encode
#     ]

#     # Run batch query
#     batch_point = client.query_batch_points(collection_name=collection_name, requests=search_queries)

#     # Process results
#     for i, query in enumerate(batch_point):
#         support_titles = batch[i]["supporting_facts"]["title"]
#         support_sent_ids = batch[i]["supporting_facts"]["sent_id"]
#         context_titles = batch[i]["context"]["title"]
#         context_sentences = batch[i]["context"]["sentences"]
        
#         # Build supporting sentence map per title
#         supporting_map = []
#         for title, sent_id in zip(support_titles, support_sent_ids):
#             try:
#                 title_idx = context_titles.index(title)
#                 sentence = context_sentences[title_idx][sent_id]
#                 supporting_map.append(sentence.strip())
#             except (ValueError, IndexError):
#                 continue
#         ctxs = []
#         for point in query.points:
#             doc_id = point.payload.get("docid")
#             title = point.payload.get("title")
#             text = point.payload.get("text")
#             score = point.score

#             # Match answer aliases in text
#             matches = {
#                 (match.start(), len(match.group()))
#                 for alias in answers_grouped[i]
#                 for match in re.finditer(re.escape(alias), text.lower())
#             }

#             gold_document = is_gold_compression(text, supporting_map)

#             has_answer = len(matches) > 0

#             ctxs.append({
#                 "id": doc_id,
#                 "title": title,
#                 "text": text,
#                 "score": score,
#                 "has_answer": has_answer,
#                 "answer_occurrences": [
#                     {"start": start, "length": length} for (start, length) in matches
#                 ],
#                 "gold_document" : gold_document
#             })

#         document_two_wiki_qa_test.append({
#             "question": questions[i],
#             "answers": list(answers_grouped[i]),
#             "ctxs": ctxs
#         })

# with open("dataset/test_2wiki.json", "w", encoding="utf-8") as f:
#     json.dump(document_two_wiki_qa_test, f, ensure_ascii=False, indent=2)

In [16]:
# document_two_wiki_qa_test[1]