In [None]:
from transformers import AutoModel, AutoTokenizer
import torch
import torch.nn.functional as F
from PIL import Image
import requests
from io import BytesIO

def weighted_mean_pooling(hidden, attention_mask):
    attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
    s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
    d = attention_mask_.sum(dim=1, keepdim=True).float()
    reps = s / d
    return reps

@torch.no_grad()
def encode(text_or_image_list):
    
    if (isinstance(text_or_image_list[0], str)):
        inputs = {
            "text": text_or_image_list,
            'image': [None] * len(text_or_image_list),
            'tokenizer': tokenizer
        }
    else:
        inputs = {
            "text": [''] * len(text_or_image_list),
            'image': text_or_image_list,
            'tokenizer': tokenizer
        }
    outputs = model(**inputs)
    attention_mask = outputs.attention_mask
    hidden = outputs.last_hidden_state

    reps = weighted_mean_pooling(hidden, attention_mask)   
    embeddings = F.normalize(reps, p=2, dim=1).detach().cpu().numpy()
    return embeddings

model_name_or_path = "./VisRAG-Ret"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, trust_remote_code=True).cuda()
model.eval()

queries = ["What does a dog look like?"]
INSTRUCTION = "Represent this query for retrieving relevant documents: "
queries = [INSTRUCTION + query for query in queries]

print("Downloading images...")
passages = [
    Image.open(BytesIO(requests.get(
        'https://github.com/OpenBMB/VisRAG/raw/refs/heads/master/scripts/demo/retriever/test_image/cat.jpeg'
    ).content)).convert('RGB'),
    Image.open(BytesIO(requests.get(
        'https://github.com/OpenBMB/VisRAG/raw/refs/heads/master/scripts/demo/retriever/test_image/dog.jpg'
    ).content)).convert('RGB')
]
print("Images downloaded.")

embeddings_query = encode(queries)
embeddings_doc = encode(passages)

scores = (embeddings_query @ embeddings_doc.T)
print(scores.tolist())


In [2]:
from transformers import AutoModel, AutoTokenizer
import torch
import torch.nn.functional as F
from PIL import Image
import requests
from io import BytesIO

def weighted_mean_pooling(hidden, attention_mask):
    attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
    s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
    d = attention_mask_.sum(dim=1, keepdim=True).float()
    reps = s / d
    return reps

@torch.no_grad()
def encode(text_or_image_list):
    if isinstance(text_or_image_list[0], str):
        inputs = {
            "text": text_or_image_list,
            'image': [None] * len(text_or_image_list),
            'tokenizer': tokenizer
        }
    else:
        inputs = {
            "text": [''] * len(text_or_image_list),
            'image': text_or_image_list,
            'tokenizer': tokenizer
        }
    outputs = model(**inputs)
    attention_mask = outputs.attention_mask
    hidden = outputs.last_hidden_state

    reps = weighted_mean_pooling(hidden, attention_mask)
    embeddings = F.normalize(reps, p=2, dim=1).detach().cpu().numpy()
    return embeddings

model_name_or_path = "./VisRAG-Ret"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)  # 使用float32以兼容CPU
# # model.eval()  # 确保模型在CPU上运行

# queries = ["What does a dog look like?"]
# INSTRUCTION = "Represent this query for retrieving relevant documents: "
# queries = [INSTRUCTION + query for query in queries]

# print("Downloading images...")
# passages = [
#     Image.open(BytesIO(requests.get(
#         'https://github.com/OpenBMB/VisRAG/raw/refs/heads/master/scripts/demo/retriever/test_image/cat.jpeg'
#     ).content)).convert('RGB'),
#     Image.open(BytesIO(requests.get(
#         'https://github.com/OpenBMB/VisRAG/raw/refs/heads/master/scripts/demo/retriever/test_image/dog.jpg'
#     ).content)).convert('RGB')
# ]
# print("Images downloaded.")

# embeddings_query = encode(queries)
# embeddings_doc = encode(passages)
# print(encode([]))
# scores = (embeddings_query @ embeddings_doc.T)
# print(scores.tolist())

# # 计算余弦相似度
# cosine_sim = F.cosine_similarity(torch.tensor(embeddings_query), torch.tensor(embeddings_doc), dim=-1)
# print(cosine_sim.tolist())

Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.43s/it]


In [8]:
from PIL import Image, ImageDraw, ImageFont
import requests
from io import BytesIO

# 下载并打开图像
passages = [
    Image.open(BytesIO(requests.get(
        'https://github.com/OpenBMB/VisRAG/raw/refs/heads/master/scripts/demo/retriever/test_image/cat.jpeg'
    ).content)).convert('RGB'),
    Image.open(BytesIO(requests.get(
        'https://github.com/OpenBMB/VisRAG/raw/refs/heads/master/scripts/demo/retriever/test_image/dog.jpg'
    ).content)).convert('RGB')
]

# 对应的summary
summaries = [
    "This is a cat.",
    "This is a dog."
]

# 在图像下方添加文字
def add_text_below_image(image, text):
    # 创建一个新的图像，包含原图像和文字区域
    font = ImageFont.load_default()
    draw = ImageDraw.Draw(image)
    
    # 计算文字区域的高度
    text_bbox = draw.textbbox((0, 0), text, font=font)
    text_height = text_bbox[3] - text_bbox[1]
    padding = 10  # 文字区域的上下边距
    total_text_height = text_height + 2 * padding
    
    new_image = Image.new('RGB', (image.width, image.height + total_text_height), (255, 255, 255))
    
    # 将原图像粘贴到新图像上
    new_image.paste(image, (0, 0))
    
    # 在新图像的下方添加文字
    draw = ImageDraw.Draw(new_image)
    text_position = (10, image.height + padding)  # 文字位置，可以根据需要调整
    draw.text(text_position, text, font=font, fill="black")
    
    return new_image

# 处理每个图像并添加summary
processed_images = []
for image, summary in zip(passages, summaries):
    processed_image = add_text_below_image(image, summary)
    processed_images.append(processed_image)

# # 显示处理后的图像
# for img in processed_images:
#     img.show()
    
# for i, img in enumerate(processed_images):
#     img.save(f"processed_image_{i}.jpg")


queries = ["What does a dog look like?"]
INSTRUCTION = "Represent this query for retrieving relevant documents: "
queries = [INSTRUCTION + query for query in queries]

# print("Downloading images...")
# passages = [
#     Image.open(BytesIO(requests.get(
#         'https://github.com/OpenBMB/VisRAG/raw/refs/heads/master/scripts/demo/retriever/test_image/cat.jpeg'
#     ).content)).convert('RGB'),
#     Image.open(BytesIO(requests.get(
#         'https://github.com/OpenBMB/VisRAG/raw/refs/heads/master/scripts/demo/retriever/test_image/dog.jpg'
#     ).content)).convert('RGB')
# ]
# print("Images downloaded.")

embeddings_query = encode(queries)
embeddings_doc = encode(processed_images)

scores = (embeddings_query @ embeddings_doc.T)
print(scores.tolist())

[[0.2508280873298645, 0.34044021368026733]]


In [3]:
from datasets import load_dataset

# Load datasets
MP_DocVQA_corpus_ds = load_dataset("dataset/VisRAG-Ret-Test-MP-DocVQA", name="corpus", split="train")
MP_DocVQA_queries_ds = load_dataset("dataset/VisRAG-Ret-Test-MP-DocVQA", name="queries", split="train")

ArxivQA_corpus_ds = load_dataset("dataset/VisRAG-Ret-Test-ArxivQA", name="corpus", split="train")
ArxivQA_queries_ds = load_dataset("dataset/VisRAG-Ret-Test-ArxivQA", name="queries", split="train")

ChartQA_corpus_ds = load_dataset("dataset/VisRAG-Ret-Test-ChartQA", name="corpus", split="train")
ChartQA_queries_ds = load_dataset("dataset/VisRAG-Ret-Test-ChartQA", name="queries", split="train")

InfoVQA_corpus_ds = load_dataset("dataset/VisRAG-Ret-Test-InfoVQA", name="corpus", split="train")
InfoVQA_queries_ds = load_dataset("dataset/VisRAG-Ret-Test-InfoVQA", name="queries", split="train")

PlotQA_corpus_ds = load_dataset("dataset/VisRAG-Ret-Test-PlotQA", name="corpus", split="train")
PlotQA_queries_ds = load_dataset("dataset/VisRAG-Ret-Test-PlotQA", name="queries", split="train")

SlideVQA_corpus_ds = load_dataset("dataset/VisRAG-Ret-Test-SlideVQA", name="corpus", split="train")
SlideVQA_queries_ds = load_dataset("dataset/VisRAG-Ret-Test-SlideVQA", name="queries", split="train")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import io

target_query_id = "2012-02-20fy11roadshow-120221022442-phpapp02_95__feb-20-2012-nestl-2011-fullyear-roadshow-presentation-5-1024.jpgquery_number_1"
target_doc_id = "analysisofkoreanwinemarket-20150902-daejeon-150829090424-lva1-app6891_95__analysis-of-korean-wine-market-20150902daejeon-19-1024.jpg"

query_em = None
for query in SlideVQA_queries_ds:
    if query['query-id'] == target_query_id:
        print(query['query'])
        query_em = encode([query['query']])
        break

doc_em = None
for doc in SlideVQA_corpus_ds:
    if doc['corpus-id'] == target_doc_id:
        image = doc['image']
        # 显示图片
        plt.imshow(image)
        plt.axis('off')  # 隐藏坐标轴
        plt.show()
        doc_em = encode([doc['image'].convert('RGB')])
        break


sim = np.dot(query_em, doc_em.T)
print(sim)

In [None]:
import numpy as np
from PIL import Image

def encode_and_save_embeddings(dataset, dataset_name):
    image_embeddings = []
    corpus_ids = []
    total_examples = len(dataset)  

    for i, example in enumerate(dataset):  
        image = example['image'].convert('RGB')
        embedding = encode([image])
        image_embeddings.append(embedding)
        corpus_ids.append(example['corpus-id'])
        
        if (i + 1) % 10 == 0 or (i + 1) == total_examples:  # 每处理10个样本或最后一个样本时输出进度  
            progress = (i + 1) / total_examples * 100  
            print(f"Processing {dataset_name}: {i + 1}/{total_examples} ({progress:.2f}%)")  

    # 将嵌入列表转换为numpy数组
    image_embeddings = np.vstack(image_embeddings)

    # 保存嵌入和corpus-id到文件
    np.save(f"embeddings/{dataset_name}_embeddings.npy", image_embeddings)
    np.save(f"embeddings/{dataset_name}_corpus_ids.npy", np.array(corpus_ids))

encode_and_save_embeddings(MP_DocVQA_corpus_ds, "MP_DocVQA_corpus")
encode_and_save_embeddings(SlideVQA_corpus_ds, "SlideVQA_corpus")

In [None]:
import json
import numpy as np

def encode_and_save_summary_embeddings(summary_path, dataset, dataset_name):
    with open(summary_path, 'r', encoding='utf-8') as f:
        summaries = json.load(f)
    
    image_summary_embeddings = []
    total_examples = len(dataset)
    
    for i, example in enumerate(dataset):
        corpus_id = example['corpus-id']
        summary = summaries.get(corpus_id, "")
        
        embedding = encode([summary])
        image_summary_embeddings.append(embedding)
        
        # 输出进度
        if (i + 1) % 10 == 0 or (i + 1) == total_examples:
            progress = (i + 1) / total_examples * 100
            print(f"Processing {dataset_name}: {i + 1}/{total_examples} ({progress:.2f}%)")
    
    image_summary_embeddings = np.vstack(image_summary_embeddings)
    np.save(f"embeddings/{dataset_name}_summary_embeddings.npy", image_summary_embeddings)
    
# encode_and_save_summary_embeddings("./MP_DocVQA_summary.jsonl", MP_DocVQA_corpus_ds, "MP_DocVQA_corpus")
encode_and_save_summary_embeddings("./SlideVQA_summary.jsonl", SlideVQA_corpus_ds, "SlideVQA_corpus")

In [7]:
import json
import numpy as np
from tqdm import tqdm

def encode_and_save_sparse_ret_embeddings(summary_path, dataset, dataset_name):
    with open(summary_path, 'r', encoding='utf-8') as f:
        summaries = json.load(f)
    
    image_summary_embeddings = []
    # total_examples = len(dataset)
    
    for example in tqdm(dataset):
        corpus_id = example['corpus-id']
        summary = summaries.get(corpus_id, "")
        # title = summary['Title']
        # keywords = ", ".join(summary['Keywords'])
        # keywords = summary['Keywords']
        description = summary['Description']
        # image_type = summary['Image Type']
        
        
        embedding = encode([description])
        # embedding = encode(list(summary.values()))
        # embedding = encode([title, description, image_type] + keywords)
        # sum_array = np.sum(embedding, axis=0)
        image_summary_embeddings.append(embedding)
        
        # 输出进度
        # if (i + 1) % 10 == 0 or (i + 1) == total_examples:
        #     progress = (i + 1) / total_examples * 100
        #     print(f"Processing {dataset_name}: {i + 1}/{total_examples} ({progress:.2f}%)")
    
    image_summary_embeddings = np.vstack(image_summary_embeddings)
    np.save(f"embeddings/{dataset_name}_summary_embeddings.npy", image_summary_embeddings)
    
# encode_and_save_summary_embeddings("./MP_DocVQA_summary.jsonl", MP_DocVQA_corpus_ds, "MP_DocVQA_corpus")
encode_and_save_sparse_ret_embeddings("./ChartQA_image_keywords.json", ChartQA_corpus_ds, "ChartQA_corpus")

100%|██████████| 500/500 [07:45<00:00,  1.07it/s]


In [11]:
import csv  
import pytrec_eval  
import logging  
import numpy as np  
  
logging.basicConfig(  
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",  
    datefmt="%m/%d/%Y %H:%M:%S",  
    level=logging.INFO  
)  
logger = logging.getLogger(__name__)  


  
def eval_mrr(qrel, run, cutoff=None):  
    """  
    Compute MRR@cutoff manually.  
    """  
    mrr = 0.0  
    num_ranked_q = 0  
    results = {}  
    for qid in qrel:  
        if qid not in run:  
            continue  
        num_ranked_q += 1  
        docid_and_score = [(docid, score) for docid, score in run[qid].items()]  
        docid_and_score.sort(key=lambda x: x[1], reverse=True)  
        for i, (docid, _) in enumerate(docid_and_score):  
            rr = 0.0  
            if cutoff is None or i < cutoff:  
                if docid in qrel[qid] and qrel[qid][docid] > 0:  
                    rr = 1.0 / (i + 1)  
                    break  
        results[qid] = rr  
        mrr += rr  
    mrr /= num_ranked_q  
    results["all"] = mrr  
    return results  


def retrieve_and_evaluate(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels):  
    try:  
        run = {}  
        for q_idx, q_emb in enumerate(query_embeddings):  
            qid = query_ids[q_idx]  
            scores = np.dot(corpus_embeddings, q_emb)  
            top_k_indices = np.argsort(scores)[::-1][:10]  # 取前10个  
            run[qid] = {corpus_ids[idx]: float(scores[idx]) for idx in top_k_indices}  

        # 评估  
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {"ndcg_cut.10", "recall.10"})  
        eval_results = evaluator.evaluate(run)  
  
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            logger.info(f"{measure:25s}{'all':8s}{value:.4f}")  
  
        mrr_at_10 = eval_mrr(qrels, run, 10)['all']  
        logger.info(f'MRR@10: {mrr_at_10}')  
    except Exception as e:  
        logger.error(f"Error during retrieval and evaluation: {e}")  
  
def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        logger.error(f"Error loading qrels file: {e}")  
    return qrels 
 
def load_embeddings_and_ids(embeddings_path, ids_path):
    embeddings = np.load(embeddings_path)
    ids = np.load(ids_path).astype(str)
    return embeddings, ids

datasets = [
    # {
    #     "name": "SlideVQA",
    #     "query_embeddings_path": "embeddings/SlideVQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/SlideVQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/SlideVQA_corpus_embeddings.npy",
    #     "corpus_ids_path": "embeddings/SlideVQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-SlideVQA/qrels/slidevqa-eval-qrels.tsv"
    # },
    # {
    #     "name": "MP_DocVQA",
    #     "query_embeddings_path": "embeddings/MP_DocVQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/MP_DocVQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/MP_DocVQA_corpus_embeddings.npy",
    #     "corpus_ids_path": "embeddings/MP_DocVQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-MP-DocVQA/qrels/docvqa_mp-eval-qrels.tsv"
    # },
    # {
    #     "name": "ArxivQA",
    #     "query_embeddings_path": "embeddings/ArxivQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/ArxivQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/ArxivQA_corpus_embeddings.npy",
    #     "corpus_ids_path": "embeddings/ArxivQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-ArxivQA/qrels/arxivqa-eval-qrels.tsv"
    # },
    # {
    #     "name": "ChartQA",
    #     "query_embeddings_path": "embeddings/ChartQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/ChartQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/ChartQA_corpus_summary_embeddings.npy",
    #     "corpus_ids_path": "embeddings/ChartQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-ChartQA/qrels/chartqa-eval-qrels.tsv"
    # },
    # {
    #     "name": "InfoVQA",
    #     "query_embeddings_path": "embeddings/InfoVQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/InfoVQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/InfoVQA_corpus_embeddings.npy",
    #     "corpus_ids_path": "embeddings/InfoVQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-InfoVQA/qrels/infographicsvqa-eval-qrels.tsv"
    # },
    {
        "name": "PlotQA",
        "query_embeddings_path": "embeddings/PlotQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/PlotQA_queries_query_ids.npy",
        # "corpus_embeddings_path": "embeddings/PlotQA_corpus_embeddings.npy",
        # "corpus_embeddings_path": "embeddings/PlotQA_corpus_embeddings_grayscale.npy",
        # "corpus_embeddings_path": "embeddings/PlotQA_corpus_image_plus_summary_embeddings.npy",
        "corpus_embeddings_path": "embeddings/PlotQA_corpus_image_plus_summary_plus_title_embeddings.npy",
        "corpus_ids_path": "embeddings/PlotQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-PlotQA/qrels/plotqa-eval-qrels.tsv"
    }
]

# 循环评估每个数据集
for dataset in datasets:
    logger.info(f"Evaluating {dataset['name']} dataset")
    query_embeddings, query_ids = load_embeddings_and_ids(dataset["query_embeddings_path"], dataset["query_ids_path"])
    corpus_embeddings, corpus_ids = load_embeddings_and_ids(dataset["corpus_embeddings_path"], dataset["corpus_ids_path"])
    qrels = load_beir_qrels(dataset["qrels_path"])
    retrieve_and_evaluate(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels)
    logger.info('')


01/09/2025 22:37:10 - INFO - __main__ -   Evaluating PlotQA dataset
01/09/2025 22:37:34 - INFO - __main__ -   ndcg_cut_10              all     0.4446
01/09/2025 22:37:34 - INFO - __main__ -   recall_10                all     0.6031
01/09/2025 22:37:34 - INFO - __main__ -   MRR@10: 0.39496824554532173
01/09/2025 22:37:34 - INFO - __main__ -   


In [None]:
def encode_and_save_query_embeddings(dataset, dataset_name):  
    query_embeddings = []
    query_ids = []  
    total_examples = len(dataset)  
  
    for i, example in enumerate(dataset):  
        query = example['query']
        embedding = encode([query])  
        query_embeddings.append(embedding)
        query_ids.append(example['query-id'])  
  
        if (i + 1) % 10 == 0 or (i + 1) == total_examples:  # 每处理10个样本或最后一个样本时输出进度  
            progress = (i + 1) / total_examples * 100  
            print(f"Processing {dataset_name}: {i + 1}/{total_examples} ({progress:.2f}%)")  
  
    # 将嵌入列表转换为numpy数组
    query_embeddings = np.vstack(query_embeddings)

    # 保存嵌入和corpus-id到文件
    np.save(f"embeddings/{dataset_name}_embeddings.npy", query_embeddings)
    np.save(f"embeddings/{dataset_name}_query_ids.npy", np.array(query_ids))
  
encode_and_save_query_embeddings(MP_DocVQA_queries_ds, "MP_DocVQA_queries")  
encode_and_save_query_embeddings(SlideVQA_queries_ds, "SlideVQA_queries")  


In [None]:
def encode_and_save_query_with_instruction_embeddings(dataset, dataset_name):  
    INSTRUCTION = "Represent this query for retrieving relevant documents: "
    query_embeddings = []
    # query_ids = []  
    total_examples = len(dataset)  
  
    for i, example in enumerate(dataset):  
        query = INSTRUCTION + example['query']
        embedding = encode([query])  
        query_embeddings.append(embedding)
        # query_ids.append(example['query-id'])  
  
        if (i + 1) % 10 == 0 or (i + 1) == total_examples:  # 每处理10个样本或最后一个样本时输出进度  
            progress = (i + 1) / total_examples * 100  
            print(f"Processing {dataset_name}: {i + 1}/{total_examples} ({progress:.2f}%)")  
  
    # 将嵌入列表转换为numpy数组
    query_embeddings = np.vstack(query_embeddings)

    # 保存嵌入和corpus-id到文件
    np.save(f"embeddings/{dataset_name}_with_instruction_embeddings.npy", query_embeddings)
    # np.save(f"embeddings/{dataset_name}_query_ids.npy", np.array(query_ids))
  
encode_and_save_query_with_instruction_embeddings(MP_DocVQA_queries_ds, "MP_DocVQA_queries")  
encode_and_save_query_with_instruction_embeddings(SlideVQA_queries_ds, "SlideVQA_queries")  


In [9]:
import csv  
import pytrec_eval  
import logging  
import numpy as np  
  
logging.basicConfig(  
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",  
    datefmt="%m/%d/%Y %H:%M:%S",  
    level=logging.INFO  
)  
logger = logging.getLogger(__name__)  

def check_dictionaries(qrels, run):
    # 检查 qrels 字典
    for qid, doc_scores in qrels.items():
        if not isinstance(doc_scores, dict):
            logger.error(f"Qrels for query {qid} is not a dictionary.")
            return False
        for docid, score in doc_scores.items():
            if not isinstance(score, int):
                logger.error(f"Score for doc {docid} in query {qid} is not an integer.")
                return False

    # 检查 run 字典
    for qid, doc_scores in run.items():
        if not isinstance(doc_scores, dict):
            logger.error(f"Run for query {qid} is not a dictionary.")
            return False
        for docid, score in doc_scores.items():
            if not isinstance(score, (int, float)):
                logger.info(f"Query ID: {qid}, Doc ID: {docid}, Score: {score}, Type: {type(score)}")
                logger.error(f"Score for doc {docid} in query {qid} is not a number.")
                return False

    logger.info("Both qrels and run dictionaries are correctly formatted.")
    return True
  
def eval_mrr(qrel, run, cutoff=None):  
    """  
    Compute MRR@cutoff manually.  
    """  
    mrr = 0.0  
    num_ranked_q = 0  
    results = {}  
    for qid in qrel:  
        if qid not in run:  
            continue  
        num_ranked_q += 1  
        docid_and_score = [(docid, score) for docid, score in run[qid].items()]  
        docid_and_score.sort(key=lambda x: x[1], reverse=True)  
        for i, (docid, _) in enumerate(docid_and_score):  
            rr = 0.0  
            if cutoff is None or i < cutoff:  
                if docid in qrel[qid] and qrel[qid][docid] > 0:  
                    rr = 1.0 / (i + 1)  
                    break  
        results[qid] = rr  
        mrr += rr  
    mrr /= num_ranked_q  
    results["all"] = mrr  
    return results  


def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def retrieve_and_evaluate(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels):  
    try:  
        run = {}  
        for q_idx, q_emb in enumerate(query_embeddings):  
            qid = query_ids[q_idx]  
            scores = np.dot(corpus_embeddings, q_emb)  
            # scores = np.array([cosine_similarity(q_emb, c_emb) for c_emb in corpus_embeddings])
            top_k_indices = np.argsort(scores)[::-1][:10]  # 取前10个  
            run[qid] = {corpus_ids[idx]: float(scores[idx]) for idx in top_k_indices}  
            
        if check_dictionaries(qrels, run):
            logger.info("Proceeding with evaluation.")
        else:
            logger.error("Dictionary format error. Aborting evaluation.")
        # 评估  
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {"ndcg_cut.10", "recall.10"})  
        eval_results = evaluator.evaluate(run)  
  
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            logger.info(f"{measure:25s}{'all':8s}{value:.4f}")  
  
        mrr_at_10 = eval_mrr(qrels, run, 10)['all']  
        logger.info(f'MRR@10: {mrr_at_10}')  
    except Exception as e:  
        logger.error(f"Error during retrieval and evaluation: {e}")  
  
def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        logger.error(f"Error loading qrels file: {e}")  
    return qrels 
 
def load_embeddings_and_ids(embeddings_path, ids_path):
    embeddings = np.load(embeddings_path)
    ids = np.load(ids_path).astype(str)
    return embeddings, ids

datasets = [
    # {
    #     "name": "SlideVQA",
    #     "query_embeddings_path": "embeddings/SlideVQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/SlideVQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/SlideVQA_corpus_embeddings.npy",
    #     "corpus_ids_path": "embeddings/SlideVQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-SlideVQA/qrels/slidevqa-eval-qrels.tsv"
    # },
    # {
    #     "name": "MP_DocVQA",
    #     "query_embeddings_path": "embeddings/MP_DocVQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/MP_DocVQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/MP_DocVQA_corpus_embeddings.npy",
    #     "corpus_ids_path": "embeddings/MP_DocVQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-MP-DocVQA/qrels/docvqa_mp-eval-qrels.tsv"
    # },
    # {
    #     "name": "ArxivQA",
    #     "query_embeddings_path": "embeddings/ArxivQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/ArxivQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/ArxivQA_corpus_embeddings.npy",
    #     "corpus_ids_path": "embeddings/ArxivQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-ArxivQA/qrels/arxivqa-eval-qrels.tsv"
    # },
    # {
    #     "name": "ChartQA",
    #     "query_embeddings_path": "embeddings/ChartQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/ChartQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/ChartQA_corpus_embeddings.npy",
    #     "corpus_ids_path": "embeddings/ChartQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-ChartQA/qrels/chartqa-eval-qrels.tsv"
    # },
    # {
    #     "name": "InfoVQA",
    #     "query_embeddings_path": "embeddings/InfoVQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/InfoVQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/InfoVQA_corpus_embeddings.npy",
    #     "corpus_ids_path": "embeddings/InfoVQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-InfoVQA/qrels/infographicsvqa-eval-qrels.tsv"
    # },
    {
        "name": "PlotQA",
        "query_embeddings_path": "embeddings/PlotQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/PlotQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/PlotQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/PlotQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-PlotQA/qrels/plotqa-eval-qrels.tsv"
    }
]

# 循环评估每个数据集
for dataset in datasets:
    logger.info(f"Evaluating {dataset['name']} dataset")
    query_embeddings, query_ids = load_embeddings_and_ids(dataset["query_embeddings_path"], dataset["query_ids_path"])
    corpus_embeddings, corpus_ids = load_embeddings_and_ids(dataset["corpus_embeddings_path"], dataset["corpus_ids_path"])
    qrels = load_beir_qrels(dataset["qrels_path"])
    retrieve_and_evaluate(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels)
    logger.info('')


12/17/2024 15:01:28 - INFO - __main__ -   Evaluating PlotQA dataset
12/17/2024 15:01:31 - INFO - __main__ -   Both qrels and run dictionaries are correctly formatted.
12/17/2024 15:01:31 - INFO - __main__ -   Proceeding with evaluation.
12/17/2024 15:01:31 - INFO - __main__ -   ndcg_cut_10              all     0.4524
12/17/2024 15:01:31 - INFO - __main__ -   recall_10                all     0.6140
12/17/2024 15:01:31 - INFO - __main__ -   MRR@10: 0.4018738567624206
12/17/2024 15:01:31 - INFO - __main__ -   


In [6]:
import csv  
import pytrec_eval  
import logging  
import numpy as np  
import json
  
logging.basicConfig(  
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",  
    datefmt="%m/%d/%Y %H:%M:%S",  
    level=logging.INFO  
)  
logger = logging.getLogger(__name__)  

def check_dictionaries(qrels, run):
    # 检查 qrels 字典
    for qid, doc_scores in qrels.items():
        if not isinstance(doc_scores, dict):
            logger.error(f"Qrels for query {qid} is not a dictionary.")
            return False
        for docid, score in doc_scores.items():
            if not isinstance(score, int):
                logger.error(f"Score for doc {docid} in query {qid} is not an integer.")
                return False

    # 检查 run 字典
    for qid, doc_scores in run.items():
        if not isinstance(doc_scores, dict):
            logger.error(f"Run for query {qid} is not a dictionary.")
            return False
        for docid, score in doc_scores.items():
            if not isinstance(score, (int, float)):
                logger.info(f"Query ID: {qid}, Doc ID: {docid}, Score: {score}, Type: {type(score)}")
                logger.error(f"Score for doc {docid} in query {qid} is not a number.")
                return False

    logger.info("Both qrels and run dictionaries are correctly formatted.")
    return True
  
def eval_mrr(qrel, run, cutoff=None):  
    """  
    Compute MRR@cutoff manually.  
    """  
    mrr = 0.0  
    num_ranked_q = 0  
    results = {}  
    for qid in qrel:  
        if qid not in run:  
            continue  
        num_ranked_q += 1  
        docid_and_score = [(docid, score) for docid, score in run[qid].items()]  
        docid_and_score.sort(key=lambda x: x[1], reverse=True)  
        for i, (docid, _) in enumerate(docid_and_score):  
            rr = 0.0  
            if cutoff is None or i < cutoff:  
                if docid in qrel[qid] and qrel[qid][docid] > 0:  
                    rr = 1.0 / (i + 1)  
                    break  
        results[qid] = rr  
        mrr += rr  
    mrr /= num_ranked_q  
    results["all"] = mrr  
    return results  


def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def save_result_json(run, qrels, dataset_name):
    results = []
    for query, retrieved in run.items():
        result = {
            "query_id": query,
            "true_image_id": list(qrels.get(query, {}).keys()), 
            # "retrive_image_id": list(retrieved.keys())
            "retrive_image_id": retrieved
        }
        results.append(result)
        # for retrive_image_id, score in retrieved.items():
        #     result = {
        #         "query": query,
        #         "true_image_id": qrels.get(query, "unknown"),
        #         "retrive_image_id": retrive_image_id
        #     }
        #     results.append(result)
    
    with open(f"{dataset_name}_results_with_score.json", "w") as f:
        json.dump(results, f, indent=4)

    print(f"Results saved to {dataset_name}_results_with_score.json")

def retrieve_and_evaluate(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels, dataset_name):  
    try:  
        run = {}  
        for q_idx, q_emb in enumerate(query_embeddings):  
            qid = query_ids[q_idx]  
            scores = np.dot(corpus_embeddings, q_emb)  
            # scores = np.array([cosine_similarity(q_emb, c_emb) for c_emb in corpus_embeddings])
            top_k_indices = np.argsort(scores)[::-1][:10]  # 取前10个  
            run[qid] = {corpus_ids[idx]: float(scores[idx]) for idx in top_k_indices}  
            
        if check_dictionaries(qrels, run):
            logger.info("Proceeding with evaluation.")
        else:
            logger.error("Dictionary format error. Aborting evaluation.")
        
        save_result_json(run, qrels, dataset_name)
        
        # 评估  
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {"ndcg_cut.10", "recall.10"})  
        eval_results = evaluator.evaluate(run)  
        
  
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            logger.info(f"{measure:25s}{'all':8s}{value:.4f}")  
  
        mrr_at_10 = eval_mrr(qrels, run, 10)['all']  
        logger.info(f'MRR@10: {mrr_at_10}')  
    except Exception as e:  
        logger.error(f"Error during retrieval and evaluation: {e}")  
  
def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        logger.error(f"Error loading qrels file: {e}")  
    return qrels 
 
def load_embeddings_and_ids(embeddings_path, ids_path):
    embeddings = np.load(embeddings_path)
    ids = np.load(ids_path).astype(str)
    return embeddings, ids

datasets = [
    # {
    #     "name": "SlideVQA",
    #     "query_embeddings_path": "embeddings/SlideVQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/SlideVQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/SlideVQA_corpus_embeddings.npy",
    #     "corpus_ids_path": "embeddings/SlideVQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-SlideVQA/qrels/slidevqa-eval-qrels.tsv"
    # },
    # {
    #     "name": "MP_DocVQA",
    #     "query_embeddings_path": "embeddings/MP_DocVQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/MP_DocVQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/MP_DocVQA_corpus_embeddings.npy",
    #     "corpus_ids_path": "embeddings/MP_DocVQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-MP-DocVQA/qrels/docvqa_mp-eval-qrels.tsv"
    # },
    # {
    #     "name": "ArxivQA",
    #     "query_embeddings_path": "embeddings/ArxivQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/ArxivQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/ArxivQA_corpus_embeddings.npy",
    #     "corpus_ids_path": "embeddings/ArxivQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-ArxivQA/qrels/arxivqa-eval-qrels.tsv"
    # },
    {
        "name": "ChartQA",
        "query_embeddings_path": "embeddings/ChartQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/ChartQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/ChartQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/ChartQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-ChartQA/qrels/chartqa-eval-qrels.tsv"
    },
    # {
    #     "name": "InfoVQA",
    #     "query_embeddings_path": "embeddings/InfoVQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/InfoVQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/InfoVQA_corpus_embeddings.npy",
    #     "corpus_ids_path": "embeddings/InfoVQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-InfoVQA/qrels/infographicsvqa-eval-qrels.tsv"
    # },
    {
        "name": "PlotQA",
        "query_embeddings_path": "embeddings/PlotQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/PlotQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/PlotQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/PlotQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-PlotQA/qrels/plotqa-eval-qrels.tsv"
    }
]

# 循环评估每个数据集
for dataset in datasets:
    logger.info(f"Evaluating {dataset['name']} dataset")
    query_embeddings, query_ids = load_embeddings_and_ids(dataset["query_embeddings_path"], dataset["query_ids_path"])
    corpus_embeddings, corpus_ids = load_embeddings_and_ids(dataset["corpus_embeddings_path"], dataset["corpus_ids_path"])
    qrels = load_beir_qrels(dataset["qrels_path"])
    retrieve_and_evaluate(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels, dataset['name'])
    logger.info('')


12/17/2024 18:07:30 - INFO - __main__ -   Evaluating ChartQA dataset


12/17/2024 18:07:30 - INFO - __main__ -   Both qrels and run dictionaries are correctly formatted.
12/17/2024 18:07:30 - INFO - __main__ -   Proceeding with evaluation.
12/17/2024 18:07:30 - INFO - __main__ -   ndcg_cut_10              all     0.6204
12/17/2024 18:07:30 - INFO - __main__ -   recall_10                all     0.7242
12/17/2024 18:07:30 - INFO - __main__ -   MRR@10: 0.5873922270858202
12/17/2024 18:07:30 - INFO - __main__ -   
12/17/2024 18:07:30 - INFO - __main__ -   Evaluating PlotQA dataset


Results saved to ChartQA_results_with_score.json


12/17/2024 18:07:33 - INFO - __main__ -   Both qrels and run dictionaries are correctly formatted.
12/17/2024 18:07:33 - INFO - __main__ -   Proceeding with evaluation.
12/17/2024 18:07:33 - INFO - __main__ -   ndcg_cut_10              all     0.4524
12/17/2024 18:07:33 - INFO - __main__ -   recall_10                all     0.6140
12/17/2024 18:07:33 - INFO - __main__ -   MRR@10: 0.4018738567624206
12/17/2024 18:07:33 - INFO - __main__ -   


Results saved to PlotQA_results_with_score.json


In [1]:
import csv  
import pytrec_eval  
import logging  
import numpy as np  
  
logging.basicConfig(  
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",  
    datefmt="%m/%d/%Y %H:%M:%S",  
    level=logging.INFO  
)  
logger = logging.getLogger(__name__)  

def check_dictionaries(qrels, run):
    # 检查 qrels 字典
    for qid, doc_scores in qrels.items():
        if not isinstance(doc_scores, dict):
            logger.error(f"Qrels for query {qid} is not a dictionary.")
            return False
        for docid, score in doc_scores.items():
            if not isinstance(score, int):
                logger.error(f"Score for doc {docid} in query {qid} is not an integer.")
                return False

    # 检查 run 字典
    for qid, doc_scores in run.items():
        if not isinstance(doc_scores, dict):
            logger.error(f"Run for query {qid} is not a dictionary.")
            return False
        for docid, score in doc_scores.items():
            if not isinstance(score, (int, float)):
                logger.info(f"Query ID: {qid}, Doc ID: {docid}, Score: {score}, Type: {type(score)}")
                logger.error(f"Score for doc {docid} in query {qid} is not a number.")
                return False

    logger.info("Both qrels and run dictionaries are correctly formatted.")
    return True
  
def eval_mrr(qrel, run, cutoff=None):  
    """  
    Compute MRR@cutoff manually.  
    """  
    mrr = 0.0  
    num_ranked_q = 0  
    results = {}  
    for qid in qrel:  
        if qid not in run:  
            continue  
        num_ranked_q += 1  
        docid_and_score = [(docid, score) for docid, score in run[qid].items()]  
        docid_and_score.sort(key=lambda x: x[1], reverse=True)  
        for i, (docid, _) in enumerate(docid_and_score):  
            rr = 0.0  
            if cutoff is None or i < cutoff:  
                if docid in qrel[qid] and qrel[qid][docid] > 0:  
                    rr = 1.0 / (i + 1)  
                    break  
        results[qid] = rr  
        mrr += rr  
    mrr /= num_ranked_q  
    results["all"] = mrr  
    return results  


def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def retrieve_and_evaluate(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels):  
    try:  
        run = {}  
        for q_idx, q_emb in enumerate(query_embeddings):  
            qid = query_ids[q_idx]  
            scores = np.dot(corpus_embeddings, q_emb)  
            # scores = np.array([cosine_similarity(q_emb, c_emb) for c_emb in corpus_embeddings])
            top_k_indices = np.argsort(scores)[::-1][:10]  # 取前10个  
            run[qid] = {corpus_ids[idx]: float(scores[idx]) for idx in top_k_indices}  
            
        if check_dictionaries(qrels, run):
            logger.info("Proceeding with evaluation.")
        else:
            logger.error("Dictionary format error. Aborting evaluation.")
        # 评估  
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {"ndcg_cut.10", "recall.10"})  
        eval_results = evaluator.evaluate(run)  
  
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            logger.info(f"{measure:25s}{'all':8s}{value:.4f}")  
  
        mrr_at_10 = eval_mrr(qrels, run, 10)['all']  
        logger.info(f'MRR@10: {mrr_at_10}')  
    except Exception as e:  
        logger.error(f"Error during retrieval and evaluation: {e}")  
  
def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        logger.error(f"Error loading qrels file: {e}")  
    return qrels 
 
def load_embeddings_and_ids(embeddings_path, ids_path):
    embeddings = np.load(embeddings_path)
    ids = np.load(ids_path).astype(str)
    return embeddings, ids

datasets = [
    # {
    #     "name": "SlideVQA",
    #     "query_embeddings_path": "embeddings/SlideVQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/SlideVQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/SlideVQA_corpus_image_plus_summary_embeddings_2.npy",
    #     "corpus_ids_path": "embeddings/SlideVQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-SlideVQA/qrels/slidevqa-eval-qrels.tsv"
    # },
    # {
    #     "name": "MP_DocVQA",
    #     "query_embeddings_path": "embeddings/MP_DocVQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/MP_DocVQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/MP_DocVQA_corpus_embeddings.npy",
    #     "corpus_ids_path": "embeddings/MP_DocVQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-MP-DocVQA/qrels/docvqa_mp-eval-qrels.tsv"
    # },
    # {
    #     "name": "ArxivQA",
    #     "query_embeddings_path": "embeddings/ArxivQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/ArxivQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/ArxivQA_corpus_embeddings.npy",
    #     "corpus_ids_path": "embeddings/ArxivQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-ArxivQA/qrels/arxivqa-eval-qrels.tsv"
    # },
    # {
    #     "name": "ChartQA",
    #     "query_embeddings_path": "embeddings/ChartQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/ChartQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/ChartQA_corpus_image_plus_summary_embeddings_1.npy",
    #     "corpus_ids_path": "embeddings/ChartQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-ChartQA/qrels/chartqa-eval-qrels.tsv"
    # },
    # {
    #     "name": "InfoVQA",
    #     "query_embeddings_path": "embeddings/InfoVQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/InfoVQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/InfoVQA_corpus_embeddings.npy",
    #     "corpus_ids_path": "embeddings/InfoVQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-InfoVQA/qrels/infographicsvqa-eval-qrels.tsv"
    # },
    {
        "name": "PlotQA",
        "query_embeddings_path": "embeddings/PlotQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/PlotQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/PlotQA_corpus_image_plus_summary_embeddings.npy",
        "corpus_ids_path": "embeddings/PlotQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-PlotQA/qrels/plotqa-eval-qrels.tsv"
    }
]

# 循环评估每个数据集
for dataset in datasets:
    logger.info(f"Evaluating {dataset['name']} dataset")
    query_embeddings, query_ids = load_embeddings_and_ids(dataset["query_embeddings_path"], dataset["query_ids_path"])
    corpus_embeddings, corpus_ids = load_embeddings_and_ids(dataset["corpus_embeddings_path"], dataset["corpus_ids_path"])
    qrels = load_beir_qrels(dataset["qrels_path"])
    retrieve_and_evaluate(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels)
    logger.info('')


12/30/2024 18:52:02 - INFO - __main__ -   Evaluating PlotQA dataset
12/30/2024 18:52:29 - INFO - __main__ -   Both qrels and run dictionaries are correctly formatted.
12/30/2024 18:52:29 - INFO - __main__ -   Proceeding with evaluation.
12/30/2024 18:52:29 - INFO - __main__ -   ndcg_cut_10              all     0.4477
12/30/2024 18:52:29 - INFO - __main__ -   recall_10                all     0.6061
12/30/2024 18:52:30 - INFO - __main__ -   MRR@10: 0.3982066524319113
12/30/2024 18:52:30 - INFO - __main__ -   


In [1]:
import csv  
import pytrec_eval  
import logging  
import numpy as np  
  
logging.basicConfig(  
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",  
    datefmt="%m/%d/%Y %H:%M:%S",  
    level=logging.INFO  
)  
logger = logging.getLogger(__name__)  

def check_dictionaries(qrels, run):
    # 检查 qrels 字典
    for qid, doc_scores in qrels.items():
        if not isinstance(doc_scores, dict):
            logger.error(f"Qrels for query {qid} is not a dictionary.")
            return False
        for docid, score in doc_scores.items():
            if not isinstance(score, int):
                logger.error(f"Score for doc {docid} in query {qid} is not an integer.")
                return False

    # 检查 run 字典
    for qid, doc_scores in run.items():
        if not isinstance(doc_scores, dict):
            logger.error(f"Run for query {qid} is not a dictionary.")
            return False
        for docid, score in doc_scores.items():
            if not isinstance(score, (int, float)):
                logger.info(f"Query ID: {qid}, Doc ID: {docid}, Score: {score}, Type: {type(score)}")
                logger.error(f"Score for doc {docid} in query {qid} is not a number.")
                return False

    logger.info("Both qrels and run dictionaries are correctly formatted.")
    return True
  
def eval_mrr(qrel, run, cutoff=None):  
    """  
    Compute MRR@cutoff manually.  
    """  
    mrr = 0.0  
    num_ranked_q = 0  
    results = {}  
    for qid in qrel:  
        if qid not in run:  
            continue  
        num_ranked_q += 1  
        docid_and_score = [(docid, score) for docid, score in run[qid].items()]  
        docid_and_score.sort(key=lambda x: x[1], reverse=True)  
        for i, (docid, _) in enumerate(docid_and_score):  
            rr = 0.0  
            if cutoff is None or i < cutoff:  
                if docid in qrel[qid] and qrel[qid][docid] > 0:  
                    rr = 1.0 / (i + 1)  
                    break  
        results[qid] = rr  
        mrr += rr  
    mrr /= num_ranked_q  
    results["all"] = mrr  
    return results  


def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def retrieve_and_evaluate(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels):  
    try:  
        run = {}  
        for q_idx, q_emb in enumerate(query_embeddings):  
            qid = query_ids[q_idx]  
            scores = np.dot(corpus_embeddings, q_emb)  
            # scores = np.array([cosine_similarity(q_emb, c_emb) for c_emb in corpus_embeddings])
            top_k_indices = np.argsort(scores)[::-1][:11]  # 取前5个  
            run[qid] = {corpus_ids[idx]: float(scores[idx]) for idx in top_k_indices}  
            
        if check_dictionaries(qrels, run):
            logger.info("Proceeding with evaluation.")
        else:
            logger.error("Dictionary format error. Aborting evaluation.")
        # 评估  
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {"ndcg_cut.11", "recall.11"})  
        eval_results = evaluator.evaluate(run)  
  
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            logger.info(f"{measure:25s}{'all':8s}{value:.4f}")  
  
        mrr_at_11 = eval_mrr(qrels, run, 11)['all']  
        logger.info(f'MRR@11: {mrr_at_11}')  
    except Exception as e:  
        logger.error(f"Error during retrieval and evaluation: {e}")  
  
def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        logger.error(f"Error loading qrels file: {e}")  
    return qrels 
 
def load_embeddings_and_ids(embeddings_path, ids_path):
    embeddings = np.load(embeddings_path)
    ids = np.load(ids_path).astype(str)
    return embeddings, ids

# datasets = [
#     {
#         "name": "SlideVQA",
#         "query_embeddings_path": "embeddings/SlideVQA_queries_with_instruction_embeddings.npy",
#         "query_ids_path": "embeddings/SlideVQA_queries_query_ids.npy",
#         "corpus_embeddings_path": "embeddings/SlideVQA_corpus_embeddings.npy",
#         "corpus_ids_path": "embeddings/SlideVQA_corpus_corpus_ids.npy",
#         "qrels_path": "dataset/VisRAG-Ret-Test-SlideVQA/qrels/slidevqa-eval-qrels.tsv"
#     },
#     {
#         "name": "MP_DocVQA",
#         "query_embeddings_path": "embeddings/MP_DocVQA_queries_with_instruction_embeddings.npy",
#         "query_ids_path": "embeddings/MP_DocVQA_queries_query_ids.npy",
#         "corpus_embeddings_path": "embeddings/MP_DocVQA_corpus_embeddings.npy",
#         "corpus_ids_path": "embeddings/MP_DocVQA_corpus_corpus_ids.npy",
#         "qrels_path": "dataset/VisRAG-Ret-Test-MP-DocVQA/qrels/docvqa_mp-eval-qrels.tsv"
#     }
# ]

datasets = [
    {
        "name": "SlideVQA",
        "query_embeddings_path": "embeddings/SlideVQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/SlideVQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/SlideVQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/SlideVQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-SlideVQA/qrels/slidevqa-eval-qrels.tsv"
    },
    {
        "name": "MP_DocVQA",
        "query_embeddings_path": "embeddings/MP_DocVQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/MP_DocVQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/MP_DocVQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/MP_DocVQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-MP-DocVQA/qrels/docvqa_mp-eval-qrels.tsv"
    },
    {
        "name": "ArxivQA",
        "query_embeddings_path": "embeddings/ArxivQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/ArxivQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/ArxivQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/ArxivQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-ArxivQA/qrels/arxivqa-eval-qrels.tsv"
    },
    {
        "name": "ChartQA",
        "query_embeddings_path": "embeddings/ChartQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/ChartQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/ChartQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/ChartQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-ChartQA/qrels/chartqa-eval-qrels.tsv"
    },
    {
        "name": "InfoVQA",
        "query_embeddings_path": "embeddings/InfoVQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/InfoVQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/InfoVQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/InfoVQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-InfoVQA/qrels/infographicsvqa-eval-qrels.tsv"
    },
    {
        "name": "PlotQA",
        "query_embeddings_path": "embeddings/PlotQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/PlotQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/PlotQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/PlotQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-PlotQA/qrels/plotqa-eval-qrels.tsv"
    }
]

# 循环评估每个数据集
for dataset in datasets:
    logger.info(f"Evaluating {dataset['name']} dataset")
    query_embeddings, query_ids = load_embeddings_and_ids(dataset["query_embeddings_path"], dataset["query_ids_path"])
    corpus_embeddings, corpus_ids = load_embeddings_and_ids(dataset["corpus_embeddings_path"], dataset["corpus_ids_path"])
    qrels = load_beir_qrels(dataset["qrels_path"])
    retrieve_and_evaluate(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels)
    logger.info('')


12/14/2024 15:34:45 - INFO - __main__ -   Evaluating SlideVQA dataset


12/14/2024 15:34:45 - INFO - __main__ -   Both qrels and run dictionaries are correctly formatted.
12/14/2024 15:34:45 - INFO - __main__ -   Proceeding with evaluation.
12/14/2024 15:34:45 - INFO - __main__ -   ndcg_cut_11              all     0.9152
12/14/2024 15:34:45 - INFO - __main__ -   recall_11                all     0.9704
12/14/2024 15:34:45 - INFO - __main__ -   MRR@11: 0.9176814970260092
12/14/2024 15:34:45 - INFO - __main__ -   
12/14/2024 15:34:45 - INFO - __main__ -   Evaluating MP_DocVQA dataset
12/14/2024 15:34:45 - INFO - __main__ -   Both qrels and run dictionaries are correctly formatted.
12/14/2024 15:34:45 - INFO - __main__ -   Proceeding with evaluation.
12/14/2024 15:34:45 - INFO - __main__ -   ndcg_cut_11              all     0.8144
12/14/2024 15:34:45 - INFO - __main__ -   recall_11                all     0.9329
12/14/2024 15:34:45 - INFO - __main__ -   MRR@11: 0.7765095069911457
12/14/2024 15:34:45 - INFO - __main__ -   
12/14/2024 15:34:45 - INFO - __main__ -

In [None]:
import csv  
import pytrec_eval  
import logging  
import numpy as np  
import matplotlib.pyplot as plt
  
logging.basicConfig(  
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",  
    datefmt="%m/%d/%Y %H:%M:%S",  
    level=logging.INFO  
)  
logger = logging.getLogger(__name__)  

def eval_mrr_with_wrong_cases(qrel, run, cutoff=None):
    """
    计算MRR@cutoff并识别错误案例。
    """
    mrr = 0.0
    num_ranked_q = 0
    results = {}
    wrong_cases = []

    for qid in qrel:
        if qid not in run:
            continue
        num_ranked_q += 1
        docid_and_score = [(docid, score) for docid, score in run[qid].items()]
        docid_and_score.sort(key=lambda x: x[1], reverse=True)
        rr = 0.0
        for i, (docid, _) in enumerate(docid_and_score):
            if cutoff is None or i < cutoff:
                if docid in qrel[qid] and qrel[qid][docid] > 0:
                    rr = 1.0 / (i + 1)
                    break
        if rr == 0.0:
            wrong_cases.append(qid)
        results[qid] = rr
        mrr += rr

    mrr /= num_ranked_q
    results["all"] = mrr

    return results, wrong_cases

def retrieve_and_evaluate_with_wrong_cases(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels):
    try:
        run = {}
        for q_idx, q_emb in enumerate(query_embeddings):
            qid = query_ids[q_idx]
            scores = np.dot(corpus_embeddings, q_emb)
            top_k_indices = np.argsort(scores)[::-1][:10]  # 取前10个
            run[qid] = {corpus_ids[idx]: float(scores[idx]) for idx in top_k_indices}
            
        # if check_dictionaries(qrels, run):
        #     logger.info("Proceeding with evaluation.")
        # else:
        #     logger.error("Dictionary format error. Aborting evaluation.")
        
        # 评估
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {"ndcg_cut.10", "recall.10"})
        eval_results = evaluator.evaluate(run)

        for measure in sorted(eval_results[next(iter(eval_results))].keys()):
            value = pytrec_eval.compute_aggregated_measure(
                measure, [query_measures[measure] for query_measures in eval_results.values()]
            )
            logger.info(f"{measure:25s}{'all':8s}{value:.4f}")

        mrr_at_10_results, wrong_cases = eval_mrr_with_wrong_cases(qrels, run, 10)
        mrr_at_10 = mrr_at_10_results['all']
        
        logger.info(f'MRR@10: {mrr_at_10}')
        
        if wrong_cases:
            # logger.info(f'Queries with the lowest MRR@10 scores (wrong cases): {wrong_cases}')
            for wrong in wrong_cases:
                # logger.info(f'Query ID: {query}')
                # logger.info(f'Expected Documents: {qrels[query]}')
                # logger.info(f'Retrieved Documents: {run[query]}')
                # print(f'Query: {query}')
                for query in SlideVQA_queries_ds:
                    if query['query-id'] == wrong:
                        print(f"Query: {query['query']}")
                        break
                # print(f'Expected Documents: {qrels[wrong]}')
                print("Expected Images:")
                for doc_id in qrels[wrong]:
                    for doc in SlideVQA_corpus_ds:
                        if doc['corpus-id'] == doc_id:
                            image = doc['image']
                            plt.imshow(image)
                            plt.axis('off')  # 隐藏坐标轴
                            plt.title(f'Expected Image ID: {doc_id}')
                            plt.show()
                            break
                # print(f'Retrieved Documents: {run[wrong]}')
                print("Retrieved Images: (with top 3 similarity)")
                # for doc_id in run[wrong]:
                #     for doc in SlideVQA_corpus_ds:
                #         if doc['corpus-id'] == doc_id:
                #             image = doc['image']
                #             plt.imshow(image)
                #             plt.axis('off')  # 隐藏坐标轴
                #             plt.title(f'Retrieved Image ID: {doc_id}')
                #             plt.show()
                #             break
                sorted_retrieved_docs = sorted(run[wrong].items(), key=lambda item: item[1], reverse=True)
                for doc_id, score in sorted_retrieved_docs[:3]:
                    for doc in SlideVQA_corpus_ds:
                        if doc['corpus-id'] == doc_id:
                            image = doc['image']
                            plt.imshow(image)
                            plt.axis('off')  # 隐藏坐标轴
                            plt.title(f'Retrieved Image ID: {doc_id}, Score: {score}')
                            plt.show()
                            break
        
    except Exception as e:
        logger.error(f"Error during retrieval and evaluation: {e}")
       
def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        logger.error(f"Error loading qrels file: {e}")  
    return qrels 
  
def load_embeddings_and_ids(embeddings_path, ids_path):
    embeddings = np.load(embeddings_path)
    ids = np.load(ids_path).astype(str)
    return embeddings, ids

datasets = [
    {
        "name": "SlideVQA",
        "query_embeddings_path": "embeddings/SlideVQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/SlideVQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/SlideVQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/SlideVQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-SlideVQA/qrels/slidevqa-eval-qrels.tsv"
    },
    {
        "name": "MP_DocVQA",
        "query_embeddings_path": "embeddings/MP_DocVQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/MP_DocVQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/MP_DocVQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/MP_DocVQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-MP-DocVQA/qrels/docvqa_mp-eval-qrels.tsv"
    }
]

# 示例使用提供的数据集
for dataset in datasets:
    logger.info(f"Evaluating {dataset['name']} dataset")
    query_embeddings, query_ids = load_embeddings_and_ids(dataset["query_embeddings_path"], dataset["query_ids_path"])
    corpus_embeddings, corpus_ids = load_embeddings_and_ids(dataset["corpus_embeddings_path"], dataset["corpus_ids_path"])
    qrels = load_beir_qrels(dataset["qrels_path"])
    retrieve_and_evaluate_with_wrong_cases(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels)
    logger.info('')


In [None]:
# use corpus summary to retrieve

import csv  
import pytrec_eval  
import logging  
import numpy as np  
  
logging.basicConfig(  
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",  
    datefmt="%m/%d/%Y %H:%M:%S",  
    level=logging.INFO  
)  
logger = logging.getLogger(__name__)  

  
def eval_mrr(qrel, run, cutoff=None):  
    """  
    Compute MRR@cutoff manually.  
    """  
    mrr = 0.0  
    num_ranked_q = 0  
    results = {}  
    for qid in qrel:  
        if qid not in run:  
            continue  
        num_ranked_q += 1  
        docid_and_score = [(docid, score) for docid, score in run[qid].items()]  
        docid_and_score.sort(key=lambda x: x[1], reverse=True)  
        for i, (docid, _) in enumerate(docid_and_score):  
            rr = 0.0  
            if cutoff is None or i < cutoff:  
                if docid in qrel[qid] and qrel[qid][docid] > 0:  
                    rr = 1.0 / (i + 1)  
                    break  
        results[qid] = rr  
        mrr += rr  
    mrr /= num_ranked_q  
    results["all"] = mrr  
    return results  


def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def retrieve_and_evaluate(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels):  
    try:  
        run = {}  
        for q_idx, q_emb in enumerate(query_embeddings):  
            qid = query_ids[q_idx]  
            scores = np.dot(corpus_embeddings, q_emb)  
            # scores = np.array([cosine_similarity(q_emb, c_emb) for c_emb in corpus_embeddings])
            top_k_indices = np.argsort(scores)[::-1][:10]  # 取前10个  
            run[qid] = {corpus_ids[idx]: float(scores[idx]) for idx in top_k_indices}  
            
        # 评估  
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {"ndcg_cut.10", "recall.10"})  
        eval_results = evaluator.evaluate(run)  
  
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            logger.info(f"{measure:25s}{'all':8s}{value:.4f}")  
  
        mrr_at_10 = eval_mrr(qrels, run, 10)['all']  
        logger.info(f'MRR@10: {mrr_at_10}')  
    except Exception as e:  
        logger.error(f"Error during retrieval and evaluation: {e}")  
  
def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        logger.error(f"Error loading qrels file: {e}")  
    return qrels 
 
def load_embeddings_and_ids(embeddings_path, ids_path):
    embeddings = np.load(embeddings_path)
    ids = np.load(ids_path).astype(str)
    return embeddings, ids

datasets = [
    {
        "name": "SlideVQA",
        # "query_embeddings_path": "embeddings/SlideVQA_queries_with_instruction_embeddings.npy", # MRR@10: 0.7713586430507162
        # "query_embeddings_path": "embeddings/SlideVQA_queries_embeddings.npy", # MRR@10: 0.6361631339527697
        "query_embeddings_path": "embeddings/SlideVQA_queries_embeddings_use_baai_llm_embedder.npy", # MRR@10: 0.7897086720867214
        "query_ids_path": "embeddings/SlideVQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/SlideVQA_corpus_summary_embeddings.npy",
        # "corpus_embeddings_path": "embeddings/SlideVQA_corpus_summary_embeddings_use_baai_llm_embedder.npy", 
        "corpus_ids_path": "embeddings/SlideVQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-SlideVQA/qrels/slidevqa-eval-qrels.tsv"
    },
    # {
    #     "name": "MP_DocVQA",
    #     "query_embeddings_path": "embeddings/MP_DocVQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/MP_DocVQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/MP_DocVQA_corpus_summary_embeddings.npy.npy",
    #     "corpus_ids_path": "embeddings/MP_DocVQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-MP-DocVQA/qrels/docvqa_mp-eval-qrels.tsv"
    # },
    {
        "name": "PlotQA",
        "query_embeddings_path": "embeddings/PlotQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/PlotQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/PlotQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/PlotQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-PlotQA/qrels/plotqa-eval-qrels.tsv"
    }
]

# 循环评估每个数据集
for dataset in datasets:
    logger.info(f"Evaluating {dataset['name']} dataset")
    query_embeddings, query_ids = load_embeddings_and_ids(dataset["query_embeddings_path"], dataset["query_ids_path"])
    corpus_embeddings, corpus_ids = load_embeddings_and_ids(dataset["corpus_embeddings_path"], dataset["corpus_ids_path"])
    qrels = load_beir_qrels(dataset["qrels_path"])
    retrieve_and_evaluate(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels)
    logger.info('')


12/15/2024 15:36:35 - INFO - __main__ -   Evaluating SlideVQA dataset
12/15/2024 15:36:36 - INFO - __main__ -   ndcg_cut_10              all     0.6494
12/15/2024 15:36:36 - INFO - __main__ -   recall_10                all     0.7728
12/15/2024 15:36:36 - INFO - __main__ -   MRR@10: 0.6361631339527697
12/15/2024 15:36:36 - INFO - __main__ -   


In [2]:
# use corpus image + summary to retrieve

import csv  
import pytrec_eval  
import logging  
import numpy as np  
  
logging.basicConfig(  
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",  
    datefmt="%m/%d/%Y %H:%M:%S",  
    level=logging.INFO  
)  
logger = logging.getLogger(__name__)  

def check_dictionaries(qrels, run):
    # 检查 qrels 字典
    for qid, doc_scores in qrels.items():
        if not isinstance(doc_scores, dict):
            logger.error(f"Qrels for query {qid} is not a dictionary.")
            return False
        for docid, score in doc_scores.items():
            if not isinstance(score, int):
                logger.error(f"Score for doc {docid} in query {qid} is not an integer.")
                return False

    # 检查 run 字典
    for qid, doc_scores in run.items():
        if not isinstance(doc_scores, dict):
            logger.error(f"Run for query {qid} is not a dictionary.")
            return False
        for docid, score in doc_scores.items():
            if not isinstance(score, (int, float)):
                logger.info(f"Query ID: {qid}, Doc ID: {docid}, Score: {score}, Type: {type(score)}")
                logger.error(f"Score for doc {docid} in query {qid} is not a number.")
                return False

    logger.info("Both qrels and run dictionaries are correctly formatted.")
    return True
  
def eval_mrr(qrel, run, cutoff=None):  
    """  
    Compute MRR@cutoff manually.  
    """  
    mrr = 0.0  
    num_ranked_q = 0  
    results = {}  
    for qid in qrel:  
        if qid not in run:  
            continue  
        num_ranked_q += 1  
        docid_and_score = [(docid, score) for docid, score in run[qid].items()]  
        docid_and_score.sort(key=lambda x: x[1], reverse=True)  
        for i, (docid, _) in enumerate(docid_and_score):  
            rr = 0.0  
            if cutoff is None or i < cutoff:  
                if docid in qrel[qid] and qrel[qid][docid] > 0:  
                    rr = 1.0 / (i + 1)  
                    break  
        results[qid] = rr  
        mrr += rr  
    mrr /= num_ranked_q  
    results["all"] = mrr  
    return results  


def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def retrieve_and_evaluate(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels):  
    try:  
        run = {}  
        for q_idx, q_emb in enumerate(query_embeddings):  
            qid = query_ids[q_idx]  
            scores = np.dot(corpus_embeddings, q_emb)  
            # scores = np.array([cosine_similarity(q_emb, c_emb) for c_emb in corpus_embeddings])
            top_k_indices = np.argsort(scores)[::-1][:10]  # 取前10个  
            run[qid] = {corpus_ids[idx]: float(scores[idx]) for idx in top_k_indices}  
            
        if check_dictionaries(qrels, run):
            logger.info("Proceeding with evaluation.")
        else:
            logger.error("Dictionary format error. Aborting evaluation.")
        # 评估  
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {"ndcg_cut.10", "recall.10"})  
        eval_results = evaluator.evaluate(run)  
  
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            logger.info(f"{measure:25s}{'all':8s}{value:.4f}")  
  
        mrr_at_10 = eval_mrr(qrels, run, 10)['all']  
        logger.info(f'MRR@10: {mrr_at_10}')  
    except Exception as e:  
        logger.error(f"Error during retrieval and evaluation: {e}")  
  
def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        logger.error(f"Error loading qrels file: {e}")  
    return qrels 
 
def load_embeddings_and_ids(embeddings_path, ids_path):
    embeddings = np.load(embeddings_path)
    ids = np.load(ids_path).astype(str)
    return embeddings, ids

datasets = [
    {
        "name": "SlideVQA",
        "query_embeddings_path": "embeddings/SlideVQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/SlideVQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/SlideVQA_corpus_image_plus_summary_embeddings.npy",
        "corpus_ids_path": "embeddings/SlideVQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-SlideVQA/qrels/slidevqa-eval-qrels.tsv"
    },
    # {
    #     "name": "MP_DocVQA",
    #     "query_embeddings_path": "embeddings/MP_DocVQA_queries_with_instruction_embeddings.npy",
    #     "query_ids_path": "embeddings/MP_DocVQA_queries_query_ids.npy",
    #     "corpus_embeddings_path": "embeddings/MP_DocVQA_corpus_summary_embeddings.npy.npy",
    #     "corpus_ids_path": "embeddings/MP_DocVQA_corpus_corpus_ids.npy",
    #     "qrels_path": "dataset/VisRAG-Ret-Test-MP-DocVQA/qrels/docvqa_mp-eval-qrels.tsv"
    # }
]

# 循环评估每个数据集
for dataset in datasets:
    logger.info(f"Evaluating {dataset['name']} dataset")
    query_embeddings, query_ids = load_embeddings_and_ids(dataset["query_embeddings_path"], dataset["query_ids_path"])
    corpus_embeddings, corpus_ids = load_embeddings_and_ids(dataset["corpus_embeddings_path"], dataset["corpus_ids_path"])
    qrels = load_beir_qrels(dataset["qrels_path"])
    retrieve_and_evaluate(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels)
    logger.info('')


12/12/2024 07:38:28 - INFO - __main__ -   Evaluating SlideVQA dataset
12/12/2024 07:38:28 - INFO - __main__ -   Both qrels and run dictionaries are correctly formatted.
12/12/2024 07:38:28 - INFO - __main__ -   Proceeding with evaluation.
12/12/2024 07:38:28 - INFO - __main__ -   ndcg_cut_10              all     0.9164
12/12/2024 07:38:28 - INFO - __main__ -   recall_10                all     0.9703
12/12/2024 07:38:28 - INFO - __main__ -   MRR@10: 0.9202281746031743
12/12/2024 07:38:28 - INFO - __main__ -   


In [39]:
import csv  
import pytrec_eval  
import logging  
import numpy as np  
  
logging.basicConfig(  
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",  
    datefmt="%m/%d/%Y %H:%M:%S",  
    level=logging.INFO  
)  
logger = logging.getLogger(__name__)  

def check_dictionaries(qrels, run):
    # 检查 qrels 字典
    for qid, doc_scores in qrels.items():
        if not isinstance(doc_scores, dict):
            logger.error(f"Qrels for query {qid} is not a dictionary.")
            return False
        for docid, score in doc_scores.items():
            if not isinstance(score, int):
                logger.error(f"Score for doc {docid} in query {qid} is not an integer.")
                return False

    # 检查 run 字典
    for qid, doc_scores in run.items():
        if not isinstance(doc_scores, dict):
            logger.error(f"Run for query {qid} is not a dictionary.")
            return False
        for docid, score in doc_scores.items():
            if not isinstance(score, (int, float)):
                logger.info(f"Query ID: {qid}, Doc ID: {docid}, Score: {score}, Type: {type(score)}")
                logger.error(f"Score for doc {docid} in query {qid} is not a number.")
                return False

    logger.info("Both qrels and run dictionaries are correctly formatted.")
    return True
  
def eval_mrr(qrel, run, cutoff=None):  
    """  
    Compute MRR@cutoff manually.  
    """  
    mrr = 0.0  
    num_ranked_q = 0  
    results = {}  
    for qid in qrel:  
        if qid not in run:  
            continue  
        num_ranked_q += 1  
        docid_and_score = [(docid, score) for docid, score in run[qid].items()]  
        docid_and_score.sort(key=lambda x: x[1], reverse=True)  
        for i, (docid, _) in enumerate(docid_and_score):  
            rr = 0.0  
            if cutoff is None or i < cutoff:  
                if docid in qrel[qid] and qrel[qid][docid] > 0:  
                    rr = 1.0 / (i + 1)  
                    break  
        results[qid] = rr  
        mrr += rr  
    mrr /= num_ranked_q  
    results["all"] = mrr  
    return results  


def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def retrieve_and_evaluate(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels):  
    try:  
        run = {}  
        for q_idx, q_emb in enumerate(query_embeddings):  
            qid = query_ids[q_idx]  
            scores = np.dot(corpus_embeddings, q_emb)  
            # scores = np.array([cosine_similarity(q_emb, c_emb) for c_emb in corpus_embeddings])
            top_k_indices = np.argsort(scores)[::-1][:1]  # 取前1个  
            run[qid] = {corpus_ids[idx]: float(scores[idx]) for idx in top_k_indices}  
            
        if check_dictionaries(qrels, run):
            logger.info("Proceeding with evaluation.")
        else:
            logger.error("Dictionary format error. Aborting evaluation.")
        # 评估  
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {"ndcg_cut.1", "recall.1"})  
        eval_results = evaluator.evaluate(run)  
  
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            logger.info(f"{measure:25s}{'all':8s}{value:.4f}")  
  
        mrr_at_1 = eval_mrr(qrels, run, 1)['all']  
        logger.info(f'MRR@1: {mrr_at_1}')  
    except Exception as e:  
        logger.error(f"Error during retrieval and evaluation: {e}")  
  
def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        logger.error(f"Error loading qrels file: {e}")  
    return qrels 
 
def load_embeddings_and_ids(embeddings_path, ids_path):
    embeddings = np.load(embeddings_path)
    ids = np.load(ids_path).astype(str)
    return embeddings, ids

datasets = [
    {
        "name": "SlideVQA",
        "query_embeddings_path": "embeddings/SlideVQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/SlideVQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/SlideVQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/SlideVQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-SlideVQA/qrels/slidevqa-eval-qrels.tsv"
    },
    {
        "name": "MP_DocVQA",
        "query_embeddings_path": "embeddings/MP_DocVQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/MP_DocVQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/MP_DocVQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/MP_DocVQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-MP-DocVQA/qrels/docvqa_mp-eval-qrels.tsv"
    }
]

# 循环评估每个数据集
for dataset in datasets:
    logger.info(f"Evaluating {dataset['name']} dataset")
    query_embeddings, query_ids = load_embeddings_and_ids(dataset["query_embeddings_path"], dataset["query_ids_path"])
    corpus_embeddings, corpus_ids = load_embeddings_and_ids(dataset["corpus_embeddings_path"], dataset["corpus_ids_path"])
    qrels = load_beir_qrels(dataset["qrels_path"])
    retrieve_and_evaluate(query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels)
    logger.info('')


12/11/2024 08:50:21 - INFO - __main__ -   Evaluating SlideVQA dataset
12/11/2024 08:50:21 - INFO - __main__ -   Both qrels and run dictionaries are correctly formatted.
12/11/2024 08:50:21 - INFO - __main__ -   Proceeding with evaluation.
12/11/2024 08:50:21 - INFO - __main__ -   ndcg_cut_1               all     0.8756
12/11/2024 08:50:21 - INFO - __main__ -   recall_1                 all     0.7376
12/11/2024 08:50:21 - INFO - __main__ -   MRR@1: 0.875609756097561
12/11/2024 08:50:21 - INFO - __main__ -   
12/11/2024 08:50:21 - INFO - __main__ -   Evaluating MP_DocVQA dataset


12/11/2024 08:50:21 - INFO - __main__ -   Both qrels and run dictionaries are correctly formatted.
12/11/2024 08:50:21 - INFO - __main__ -   Proceeding with evaluation.
12/11/2024 08:50:21 - INFO - __main__ -   ndcg_cut_1               all     0.6908
12/11/2024 08:50:21 - INFO - __main__ -   recall_1                 all     0.6908
12/11/2024 08:50:21 - INFO - __main__ -   MRR@1: 0.690792974986695
12/11/2024 08:50:21 - INFO - __main__ -   


In [16]:
# 2012-02-20fy11roadshow-120221022442-phpapp02_95__feb-20-2012-nestl-2011-fullyear-roadshow-presentation-5-1024.jpgquery_number_1

# analysisofkoreanwinemarket-20150902-daejeon-150829090424-lva1-app6891_95__analysis-of-korean-wine-market-20150902daejeon-19-1024.jpg

target_query_id = "2012-02-20fy11roadshow-120221022442-phpapp02_95__feb-20-2012-nestl-2011-fullyear-roadshow-presentation-5-1024.jpgquery_number_1"
q_position = np.where(query_ids == target_query_id)[0]

target_doc_id = "analysisofkoreanwinemarket-20150902-daejeon-150829090424-lva1-app6891_95__analysis-of-korean-wine-market-20150902daejeon-19-1024.jpg"
d_position = np.where(corpus_ids == target_doc_id)[0]

q_position[0], d_position[0]

dot_product = np.dot(query_embeddings[q_position[0]], corpus_embeddings[d_position[0]])

dot_product

0.40375298

In [4]:
import csv  
import pytrec_eval  
import logging  
import numpy as np  
from collections import defaultdict
import json
import os
  
logging.basicConfig(  
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",  
    datefmt="%m/%d/%Y %H:%M:%S",  
    level=logging.INFO  
)  
logger = logging.getLogger(__name__)  

  
def eval_mrr(qrel, run, cutoff=None):  
    """  
    Compute MRR@cutoff manually.  
    """  
    mrr = 0.0  
    num_ranked_q = 0  
    results = {}  
    for qid in qrel:  
        if qid not in run:  
            continue  
        num_ranked_q += 1  
        docid_and_score = [(docid, score) for docid, score in run[qid].items()]  
        docid_and_score.sort(key=lambda x: x[1], reverse=True)  
        for i, (docid, _) in enumerate(docid_and_score):  
            rr = 0.0  
            if cutoff is None or i < cutoff:  
                if docid in qrel[qid] and qrel[qid][docid] > 0:  
                    rr = 1.0 / (i + 1)  
                    break  
        results[qid] = rr  
        mrr += rr  
    mrr /= num_ranked_q  
    results["all"] = mrr  
    return results  

def view_result(dataset_name, run, qrels, corpus_embeddings, corpus_ids):
    # good_results = defaultdict(dict)
    good_results = defaultdict(lambda: {'gt': {}, 'run': {}})
    bad_results = defaultdict(lambda: {'gt': {}, 'run': {}})
    i = 0
    j = 0
    for qid in run:
        # print(f"Query ID: {qid}")
        # print(f"Expected Documents: {qrels[qid]}")
        # print(f"Retrieved Documents: {run[qid]}")
        
        if list(run[qid].keys())[0] not in list(qrels[qid].keys()) and i < 20:
            for image_id in run[qid]:
                for image, emb in zip(corpus_ids, corpus_embeddings):
                    if image == image_id:
                        bad_results[qid]['run'][image] = emb.tolist()
                        break
            for image_id in qrels[qid].keys():
                for image, emb in zip(corpus_ids, corpus_embeddings):
                    if image == image_id:
                        bad_results[qid]['gt'][image] = emb.tolist()
                        break
            # i += 1
        elif j < 20:
            for image_id in run[qid]:
                for image, emb in zip(corpus_ids, corpus_embeddings):
                    if image == image_id:
                        good_results[qid]['run'][image] = emb.tolist()
                        break
            for image_id in qrels[qid].keys():
                for image, emb in zip(corpus_ids, corpus_embeddings):
                    if image == image_id:
                        good_results[qid]['gt'][image] = emb.tolist()
                        break
            # j += 1
        if i >= 20 and j >= 20:
            break
    os.makedirs(os.path.dirname('tmp/view_result/'), exist_ok=True)
    with open(f'tmp/view_result/view_result_{dataset_name}_all.json', 'w') as f:
        json.dump({'good_results': good_results, 'bad_results': bad_results}, f)
    
    # good_resulttt = np.load(f'tmp/view_result/good_results_{dataset_name}.npy')
    # bad_resulttt = np.load(f'tmp/view_result/bad_results_{dataset_name}.npy')
        
    
    

def retrieve_and_evaluate(dataset_name, query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels):  
    try:  
        run = {}  
        for q_idx, q_emb in enumerate(query_embeddings):  
            qid = query_ids[q_idx]  
            scores = np.dot(corpus_embeddings, q_emb)  
            top_k_indices = np.argsort(scores)[::-1][:10] 
            run[qid] = {corpus_ids[idx]: float(scores[idx]) for idx in top_k_indices}  
            
        view_result(dataset_name, run, qrels, corpus_embeddings, corpus_ids)
            
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {"ndcg_cut.10", "recall.10"})  
        eval_results = evaluator.evaluate(run)  
  
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            logger.info(f"{measure:25s}{'all':8s}{value:.4f}")  
  
        mrr_at_10 = eval_mrr(qrels, run, 10)['all']  
        logger.info(f'MRR@10: {mrr_at_10}')  
    except Exception as e:  
        logger.error(f"Error during retrieval and evaluation: {e}")  
  
def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        logger.error(f"Error loading qrels file: {e}")  
    return qrels 
 
def load_embeddings_and_ids(embeddings_path, ids_path):
    embeddings = np.load(embeddings_path)
    ids = np.load(ids_path).astype(str)
    return embeddings, ids

datasets = [
    {
        "name": "SlideVQA",
        "query_embeddings_path": "embeddings/SlideVQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/SlideVQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/SlideVQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/SlideVQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-SlideVQA/qrels/slidevqa-eval-qrels.tsv"
    },
    {
        "name": "MP_DocVQA",
        "query_embeddings_path": "embeddings/MP_DocVQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/MP_DocVQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/MP_DocVQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/MP_DocVQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-MP-DocVQA/qrels/docvqa_mp-eval-qrels.tsv"
    },
    {
        "name": "ArxivQA",
        "query_embeddings_path": "embeddings/ArxivQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/ArxivQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/ArxivQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/ArxivQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-ArxivQA/qrels/arxivqa-eval-qrels.tsv"
    },
    {
        "name": "ChartQA",
        "query_embeddings_path": "embeddings/ChartQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/ChartQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/ChartQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/ChartQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-ChartQA/qrels/chartqa-eval-qrels.tsv"
    },
    {
        "name": "InfoVQA",
        "query_embeddings_path": "embeddings/InfoVQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/InfoVQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/InfoVQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/InfoVQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-InfoVQA/qrels/infographicsvqa-eval-qrels.tsv"
    },
    {
        "name": "PlotQA",
        "query_embeddings_path": "embeddings/PlotQA_queries_with_instruction_embeddings.npy",
        "query_ids_path": "embeddings/PlotQA_queries_query_ids.npy",
        "corpus_embeddings_path": "embeddings/PlotQA_corpus_embeddings.npy",
        "corpus_ids_path": "embeddings/PlotQA_corpus_corpus_ids.npy",
        "qrels_path": "dataset/VisRAG-Ret-Test-PlotQA/qrels/plotqa-eval-qrels.tsv"
    }
]

# 循环评估每个数据集
for dataset in datasets:
    logger.info(f"Evaluating {dataset['name']} dataset")
    query_embeddings, query_ids = load_embeddings_and_ids(dataset["query_embeddings_path"], dataset["query_ids_path"])
    corpus_embeddings, corpus_ids = load_embeddings_and_ids(dataset["corpus_embeddings_path"], dataset["corpus_ids_path"])
    qrels = load_beir_qrels(dataset["qrels_path"])
    retrieve_and_evaluate(dataset['name'], query_embeddings, query_ids, corpus_embeddings, corpus_ids, qrels)
    logger.info('')


01/06/2025 12:29:56 - INFO - __main__ -   Evaluating SlideVQA dataset
01/06/2025 12:30:42 - INFO - __main__ -   ndcg_cut_10              all     0.9146
01/06/2025 12:30:42 - INFO - __main__ -   recall_10                all     0.9686
01/06/2025 12:30:42 - INFO - __main__ -   MRR@10: 0.9176260646535037
01/06/2025 12:30:42 - INFO - __main__ -   
01/06/2025 12:30:42 - INFO - __main__ -   Evaluating MP_DocVQA dataset
01/06/2025 12:31:31 - INFO - __main__ -   ndcg_cut_10              all     0.8133
01/06/2025 12:31:31 - INFO - __main__ -   recall_10                all     0.9292
01/06/2025 12:31:31 - INFO - __main__ -   MRR@10: 0.7767030335284723
01/06/2025 12:31:31 - INFO - __main__ -   
01/06/2025 12:31:31 - INFO - __main__ -   Evaluating ArxivQA dataset
01/06/2025 12:36:37 - INFO - __main__ -   ndcg_cut_10              all     0.7000
01/06/2025 12:36:37 - INFO - __main__ -   recall_10                all     0.8042
01/06/2025 12:36:37 - INFO - __main__ -   MRR@10: 0.6669318599353316
01/06