In [None]:
from transformers import AutoModel, AutoTokenizer
import torch
import torch.nn.functional as F
from PIL import Image
import requests
from io import BytesIO

def weighted_mean_pooling(hidden, attention_mask):
    attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
    s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
    d = attention_mask_.sum(dim=1, keepdim=True).float()
    reps = s / d
    return reps


def encode(text_or_image_list):
    
    if (isinstance(text_or_image_list[0], str)):
        inputs = {
            "text": text_or_image_list,
            'image': [None] * len(text_or_image_list),
            'tokenizer': tokenizer
        }
    else:
        inputs = {
            "text": [''] * len(text_or_image_list),
            'image': text_or_image_list,
            'tokenizer': tokenizer
        }
    outputs = model(**inputs)
    attention_mask = outputs.attention_mask
    hidden = outputs.last_hidden_state

    reps = weighted_mean_pooling(hidden, attention_mask)   
    embeddings = F.normalize(reps, p=2, dim=1).detach().cpu().numpy()
    return embeddings
model_name_or_path = "./VisRAG-Ret"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, trust_remote_code=True).cuda()
model.eval()

In [None]:
from datasets import load_dataset
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm

ArxivQA_corpus_ds = load_dataset("dataset/VisRAG-Ret-Test-ArxivQA", name="corpus", split="train")
image_embeddings = []
i = 0
for item in tqdm(ArxivQA_corpus_ds):
    image_id = item['corpus-id']
    image = item['image']
    width, height = image.size
    # print(width, height)
    crop_width = width // 2
    crop_height = height // 2
    cropped_images = [
        image.crop((0, 0, crop_width, crop_height)).convert('RGB'),
        image.crop((crop_width, 0, width, crop_height)).convert('RGB'),
        image.crop((0, crop_height, crop_width, height)).convert('RGB'),
        image.crop((crop_width, crop_height, width, height)).convert('RGB')
    ]
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    for i, ax in enumerate(axes.flat):
        ax.imshow(cropped_images[i])
        ax.axis('off')
    plt.show()
    
    # embedding = encode(cropped_images)
    # # embedding2 = [encode([img]) for img in cropped_images]
    # # print(embedding)
    # image_embeddings.append(embedding)
    
    
image_embeddings = np.array(image_embeddings)
    
np.save(f"embeddings/PlotQA_corpus_embeddings_4x4.npy", image_embeddings)
    
    


In [None]:
import numpy as np

saved_embeddings = np.load('embeddings/SlideVQA_corpus_embeddings_4x4.npy')
saved_embeddings.shape


(9593, 16, 2304)

In [1]:
import numpy as np
import csv

local_image_embeddings = np.load('embeddings/ArxivQA_corpus_embeddings_4x4_0.2.npy')
image_embeddings = np.load('embeddings/ArxivQA_corpus_embeddings.npy')
query_embeddings = np.load('embeddings/ArxivQA_queries_with_instruction_embeddings.npy')
corpus_ids = np.load('embeddings/ArxivQA_corpus_corpus_ids.npy')
query_ids = np.load('embeddings/ArxivQA_queries_query_ids.npy')

def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        print(f"Error loading qrels file: {e}")  
    return qrels 

qrels = load_beir_qrels('dataset/VisRAG-Ret-Test-PlotQA/qrels/plotqa-eval-qrels.tsv')

# alpha = []

# for i in image_embeddings.shape[0]:
#     for j in image_embeddings.shape[1]:


In [1]:
def eval_mrr(qrel, run, cutoff=None):  
    """  
    Compute MRR@cutoff manually.  
    """  
    mrr = 0.0  
    num_ranked_q = 0  
    results = {}  
    for qid in qrel:  
        if qid not in run:  
            continue  
        num_ranked_q += 1  
        docid_and_score = [(docid, score) for docid, score in run[qid].items()]  
        docid_and_score.sort(key=lambda x: x[1], reverse=True)  
        for i, (docid, _) in enumerate(docid_and_score):  
            rr = 0.0  
            if cutoff is None or i < cutoff:  
                if docid in qrel[qid] and qrel[qid][docid] > 0:  
                    rr = 1.0 / (i + 1)  
                    break  
        results[qid] = rr  
        mrr += rr  
    mrr /= num_ranked_q  
    results["all"] = mrr  
    return results  

In [7]:
import pytrec_eval

evaluator = pytrec_eval.RelevanceEvaluator(qrels, {"ndcg_cut.10", "recall.10"})  
eval_results = evaluator.evaluate(run)  
  
for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
    value = pytrec_eval.compute_aggregated_measure(  
        measure, [query_measures[measure] for query_measures in eval_results.values()]  
    )  
    print(f"{measure:25s}{'all':8s}{value:.4f}")  
  
mrr_at_10 = eval_mrr(qrels, run, 10)['all']  
print(f'MRR@10: {mrr_at_10}')  

ndcg_cut_10              all     0.4269
recall_10                all     0.5912
MRR@10: 0.3756275435500701


In [18]:
# 探索实验 针对gamma进行调参
import torch
from tqdm import tqdm
import pytrec_eval

# gamma_list = [round(0.1 * i, 1) for i in range(11)]
gamma_list = [round(0.80 + 0.01 * i, 2) for i in range(21)]

# 将numpy数组转换为PyTorch张量并移动到GPU
local_image_embeddings_tensor = torch.tensor(local_image_embeddings).cuda()
image_embeddings_tensor = torch.tensor(image_embeddings).cuda()
query_embeddings_tensor = torch.tensor(query_embeddings).cuda()


results = {}
evaluator = pytrec_eval.RelevanceEvaluator(qrels, {"ndcg_cut.10", "recall.10"})  

for gamma in gamma_list:
    print(f'gamma: {gamma}') 
    run = {} 
    for q_idx, query in enumerate(tqdm(query_embeddings_tensor)):
        qid = query_ids[q_idx]
        
        scores = torch.einsum('ijk,k->ij', local_image_embeddings_tensor, query)
        scores_exp = torch.exp(scores)
        scores_sum = torch.sum(scores_exp, dim=1, keepdim=True)
        alpha = scores_exp / (scores_sum + 1e-8)
        local_agg = torch.einsum('ij,ijk->ik', alpha, local_image_embeddings_tensor)
        

        

        final_fusion = gamma * image_embeddings_tensor + (1 - gamma) * local_agg
        final_score = torch.matmul(final_fusion, query)
        
        top_k_indices = torch.argsort(final_score, descending=True)[:10]  # 取前10个
        run[qid] = {corpus_ids[idx]: float(final_score[idx].cpu().numpy()) for idx in top_k_indices}

    eval_results = evaluator.evaluate(run)  
    
    for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
        value = pytrec_eval.compute_aggregated_measure(  
            measure, [query_measures[measure] for query_measures in eval_results.values()]  
        )  
        print(f"{measure:25s}{'all':8s}{value:.4f}")  
        # results[gamma][measure] = value
    
    mrr_at_10 = eval_mrr(qrels, run, 10)['all']  
    print(f'MRR@10: {mrr_at_10}')  
    
    # results[gamma]['mrr_at_10'] =  mrr_at_10

# import json
# with open('PlotQA_gamma_tuning.json', 'w') as f:
#     json.dump(results, f, indent=2)

gamma: 0.8


100%|██████████| 11307/11307 [00:56<00:00, 200.54it/s]


ndcg_cut_10              all     0.4521
recall_10                all     0.6149
MRR@10: 0.40120304039778676
gamma: 0.81


100%|██████████| 11307/11307 [00:56<00:00, 200.76it/s]


ndcg_cut_10              all     0.4522
recall_10                all     0.6148
MRR@10: 0.401392205418472
gamma: 0.82


100%|██████████| 11307/11307 [00:56<00:00, 200.55it/s]


ndcg_cut_10              all     0.4523
recall_10                all     0.6145
MRR@10: 0.4015224450087805
gamma: 0.83


100%|██████████| 11307/11307 [00:56<00:00, 200.96it/s]


ndcg_cut_10              all     0.4522
recall_10                all     0.6140
MRR@10: 0.4015948471307979
gamma: 0.84


100%|██████████| 11307/11307 [00:56<00:00, 200.91it/s]


ndcg_cut_10              all     0.4522
recall_10                all     0.6132
MRR@10: 0.4017886447642349
gamma: 0.85


100%|██████████| 11307/11307 [00:56<00:00, 200.05it/s]


ndcg_cut_10              all     0.4525
recall_10                all     0.6136
MRR@10: 0.40209460777913936
gamma: 0.86


100%|██████████| 11307/11307 [00:56<00:00, 200.29it/s]


ndcg_cut_10              all     0.4526
recall_10                all     0.6134
MRR@10: 0.40217139684505004
gamma: 0.87


100%|██████████| 11307/11307 [00:56<00:00, 200.86it/s]


ndcg_cut_10              all     0.4525
recall_10                all     0.6134
MRR@10: 0.40214075842889857
gamma: 0.88


100%|██████████| 11307/11307 [00:56<00:00, 200.68it/s]


ndcg_cut_10              all     0.4527
recall_10                all     0.6135
MRR@10: 0.4022689624772404
gamma: 0.89


100%|██████████| 11307/11307 [00:56<00:00, 200.66it/s]


ndcg_cut_10              all     0.4527
recall_10                all     0.6138
MRR@10: 0.40226703222192767
gamma: 0.9


100%|██████████| 11307/11307 [00:56<00:00, 200.73it/s]


ndcg_cut_10              all     0.4525
recall_10                all     0.6137
MRR@10: 0.40208435987820473
gamma: 0.91


100%|██████████| 11307/11307 [00:56<00:00, 199.40it/s]


ndcg_cut_10              all     0.4528
recall_10                all     0.6140
MRR@10: 0.4023343805845798
gamma: 0.92


100%|██████████| 11307/11307 [00:56<00:00, 200.59it/s]


ndcg_cut_10              all     0.4526
recall_10                all     0.6138
MRR@10: 0.4021318090633558
gamma: 0.93


100%|██████████| 11307/11307 [00:56<00:00, 200.57it/s]


ndcg_cut_10              all     0.4529
recall_10                all     0.6144
MRR@10: 0.40232634370336684
gamma: 0.94


100%|██████████| 11307/11307 [00:56<00:00, 200.92it/s]


ndcg_cut_10              all     0.4529
recall_10                all     0.6148
MRR@10: 0.4022827199332904
gamma: 0.95


100%|██████████| 11307/11307 [00:56<00:00, 200.81it/s]


ndcg_cut_10              all     0.4528
recall_10                all     0.6147
MRR@10: 0.4021365118672095
gamma: 0.96


100%|██████████| 11307/11307 [00:56<00:00, 200.72it/s]


ndcg_cut_10              all     0.4528
recall_10                all     0.6145
MRR@10: 0.40227429700101464
gamma: 0.97


100%|██████████| 11307/11307 [00:56<00:00, 199.65it/s]


ndcg_cut_10              all     0.4529
recall_10                all     0.6145
MRR@10: 0.4024000443607763
gamma: 0.98


100%|██████████| 11307/11307 [00:56<00:00, 199.76it/s]


ndcg_cut_10              all     0.4527
recall_10                all     0.6143
MRR@10: 0.4021201924359258
gamma: 0.99


100%|██████████| 11307/11307 [00:56<00:00, 200.86it/s]


ndcg_cut_10              all     0.4527
recall_10                all     0.6143
MRR@10: 0.4021347921852028
gamma: 1.0


100%|██████████| 11307/11307 [00:56<00:00, 200.64it/s]


ndcg_cut_10              all     0.4524
recall_10                all     0.6140
MRR@10: 0.40187680478871707


In [3]:
# 探索实验 针对gamma进行调参
import torch
from tqdm import tqdm
import pytrec_eval

gamma_list = [round(0.1 * i, 1) for i in range(11)]
# gamma_list = [round(0.80 + 0.01 * i, 2) for i in range(21)]

# 将numpy数组转换为PyTorch张量并移动到GPU
local_image_embeddings_tensor = torch.tensor(local_image_embeddings).cuda()
image_embeddings_tensor = torch.tensor(image_embeddings).cuda()
query_embeddings_tensor = torch.tensor(query_embeddings).cuda()

for gamma in gamma_list:
    print(f'gamma: {gamma}') 
    run = {} 
    for q_idx, query in enumerate(tqdm(query_embeddings_tensor)):
        qid = query_ids[q_idx]
        
        scores = torch.einsum('ijk,k->ij', local_image_embeddings_tensor, query)
        alpha = torch.softmax(scores, dim=1)
        local_agg = torch.einsum('ij,ijk->ik', alpha, local_image_embeddings_tensor)
        

        final_fusion = gamma * image_embeddings_tensor + (1 - gamma) * local_agg
        final_score = torch.matmul(final_fusion, query)
        
        top_k_indices = torch.argsort(final_score, descending=True)[:10]  # 取前10个
        run[qid] = {corpus_ids[idx]: float(final_score[idx].cpu().numpy()) for idx in top_k_indices}
    for cutoff in [10, 5, 3, 1]:    
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {f"ndcg_cut.{cutoff}", f"recall.{cutoff}"})  
        eval_results = evaluator.evaluate(run)  
        
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            print(f"{measure:25s}{'all':8s}{value:.4f}")  

        
        mrr = eval_mrr(qrels, run, cutoff)['all']  
        print(f'MRR@{cutoff}: {mrr}')  
    

gamma: 0.0


100%|██████████| 11307/11307 [00:57<00:00, 198.07it/s]


ndcg_cut_10              all     0.1823
recall_10                all     0.3187
MRR@10: 0.1406790076662717
ndcg_cut_5               all     0.1514
recall_5                 all     0.2229
MRR@5: 0.12797529553963652
ndcg_cut_3               all     0.1271
recall_3                 all     0.1638
MRR@3: 0.1145455617464107
ndcg_cut_1               all     0.0780
recall_1                 all     0.0780
MRR@1: 0.07800477580260017
gamma: 0.1


100%|██████████| 11307/11307 [00:57<00:00, 197.92it/s]


ndcg_cut_10              all     0.2740
recall_10                all     0.4396
MRR@10: 0.2229289062401292
ndcg_cut_5               all     0.2402
recall_5                 all     0.3349
MRR@5: 0.20896494796733556
ndcg_cut_3               all     0.2112
recall_3                 all     0.2642
MRR@3: 0.1929041007045773
ndcg_cut_1               all     0.1387
recall_1                 all     0.1387
MRR@1: 0.1386751569824003
gamma: 0.2


100%|██████████| 11307/11307 [00:56<00:00, 199.71it/s]


ndcg_cut_10              all     0.3446
recall_10                all     0.5155
MRR@10: 0.2914593572460393
ndcg_cut_5               all     0.3124
recall_5                 all     0.4158
MRR@5: 0.27819787152501446
ndcg_cut_3               all     0.2818
recall_3                 all     0.3412
MRR@3: 0.2612393502550058
ndcg_cut_1               all     0.1993
recall_1                 all     0.1993
MRR@1: 0.1993455381622004
gamma: 0.3


100%|██████████| 11307/11307 [00:56<00:00, 199.53it/s]


ndcg_cut_10              all     0.3894
recall_10                all     0.5589
MRR@10: 0.336451257192836
ndcg_cut_5               all     0.3584
recall_5                 all     0.4635
MRR@5: 0.32358273635800966
ndcg_cut_3               all     0.3282
recall_3                 all     0.3900
MRR@3: 0.3068895374546769
ndcg_cut_1               all     0.2430
recall_1                 all     0.2430
MRR@1: 0.24303528787476783
gamma: 0.4


100%|██████████| 11307/11307 [00:56<00:00, 199.55it/s]


ndcg_cut_10              all     0.4169
recall_10                all     0.5848
MRR@10: 0.36448337242977896
ndcg_cut_5               all     0.3867
recall_5                 all     0.4920
MRR@5: 0.3519147430795089
ndcg_cut_3               all     0.3555
recall_3                 all     0.4162
MRR@3: 0.3346157247722678
ndcg_cut_1               all     0.2712
recall_1                 all     0.2712
MRR@1: 0.27115945874237196
gamma: 0.5


100%|██████████| 11307/11307 [00:56<00:00, 201.62it/s]


ndcg_cut_10              all     0.4331
recall_10                all     0.5994
MRR@10: 0.38107114429746514
ndcg_cut_5               all     0.4032
recall_5                 all     0.5077
MRR@5: 0.3685696176409898
ndcg_cut_3               all     0.3727
recall_3                 all     0.4336
MRR@3: 0.35169953715987456
ndcg_cut_1               all     0.2880
recall_1                 all     0.2880
MRR@1: 0.287963208631821
gamma: 0.6


100%|██████████| 11307/11307 [00:57<00:00, 198.29it/s]


ndcg_cut_10              all     0.4425
recall_10                all     0.6092
MRR@10: 0.3903633231837004
ndcg_cut_5               all     0.4128
recall_5                 all     0.5180
MRR@5: 0.3779929836974147
ndcg_cut_3               all     0.3826
recall_3                 all     0.4444
MRR@3: 0.36125114236019346
ndcg_cut_1               all     0.2971
recall_1                 all     0.2971
MRR@1: 0.2970726098876802
gamma: 0.7


100%|██████████| 11307/11307 [00:57<00:00, 198.19it/s]


ndcg_cut_10              all     0.4479
recall_10                all     0.6119
MRR@10: 0.39651996024375963
ndcg_cut_5               all     0.4195
recall_5                 all     0.5249
MRR@5: 0.38465404911411766
ndcg_cut_3               all     0.3877
recall_3                 all     0.4479
MRR@3: 0.36701453376964444
ndcg_cut_1               all     0.3048
recall_1                 all     0.3048
MRR@1: 0.30476695852127
gamma: 0.8


100%|██████████| 11307/11307 [00:57<00:00, 198.00it/s]


ndcg_cut_10              all     0.4510
recall_10                all     0.6137
MRR@10: 0.4000495549182205
ndcg_cut_5               all     0.4228
recall_5                 all     0.5274
MRR@5: 0.38829928362960936
ndcg_cut_3               all     0.3906
recall_3                 all     0.4491
MRR@3: 0.3704342442734619
ndcg_cut_1               all     0.3103
recall_1                 all     0.3103
MRR@1: 0.3102502874325639
gamma: 0.9


100%|██████████| 11307/11307 [00:56<00:00, 198.51it/s]


ndcg_cut_10              all     0.4525
recall_10                all     0.6147
MRR@10: 0.401748881504785
ndcg_cut_5               all     0.4246
recall_5                 all     0.5291
MRR@5: 0.3901093717755954
ndcg_cut_3               all     0.3924
recall_3                 all     0.4507
MRR@3: 0.3722620205772261
ndcg_cut_1               all     0.3113
recall_1                 all     0.3113
MRR@1: 0.3113115768992659
gamma: 1.0


100%|██████████| 11307/11307 [00:56<00:00, 200.02it/s]


ndcg_cut_10              all     0.4524
recall_10                all     0.6140
MRR@10: 0.40187680478871707
ndcg_cut_5               all     0.4237
recall_5                 all     0.5262
MRR@5: 0.38985289348780916
ndcg_cut_3               all     0.3930
recall_3                 all     0.4515
MRR@3: 0.37283688570502327
ndcg_cut_1               all     0.3122
recall_1                 all     0.3122
MRR@1: 0.31219598478818433


In [None]:
# PlotQA 2x2 + 4x4

import torch
from tqdm import tqdm
import pytrec_eval
import numpy as np
import csv

def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        print(f"Error loading qrels file: {e}")  
    return qrels 


local_image_embeddings_2x2 = np.load('embeddings/PlotQA_corpus_embeddings_2x2.npy')
local_image_embeddings_4x4 = np.load('embeddings/PlotQA_corpus_embeddings_4x4.npy')
image_embeddings = np.load('embeddings/PlotQA_corpus_embeddings.npy')
query_embeddings = np.load('embeddings/PlotQA_queries_with_instruction_embeddings.npy')
corpus_ids = np.load('embeddings/PlotQA_corpus_corpus_ids.npy')
query_ids = np.load('embeddings/PlotQA_queries_query_ids.npy')
qrels = load_beir_qrels('dataset/VisRAG-Ret-Test-PlotQA/qrels/plotqa-eval-qrels.tsv')

# gamma_list = [round(0.1 * i, 1) for i in range(6, 11)]
for i in range(11):
    for j in range(11 - i):
        k = 10 - i - j
        gamma_list.append((round(i * 0.1, 1), round(j * 0.1, 1), round(k * 0.1, 1)))
gamma_list.reverse()
# gamma_list = [round(0.80 + 0.01 * i, 2) for i in range(21)]

# 将numpy数组转换为PyTorch张量并移动到GPU
local_image_embeddings_tensor_2x2 = torch.tensor(local_image_embeddings_2x2).cuda()
local_image_embeddings_tensor_4x4 = torch.tensor(local_image_embeddings_4x4).cuda()
image_embeddings_tensor = torch.tensor(image_embeddings).cuda()
query_embeddings_tensor = torch.tensor(query_embeddings).cuda()

for gamma1, gamma2, gamma3 in gamma_list:
    print(f'gamma1: {gamma1}, gamma2: {gamma2}, gamma3: {gamma3}')
    run = {} 
    for q_idx, query in enumerate(tqdm(query_embeddings_tensor)):
        qid = query_ids[q_idx]
        
        scores1 = torch.einsum('ijk,k->ij', local_image_embeddings_tensor_2x2, query)
        scores2 = torch.einsum('ijk,k->ij', local_image_embeddings_tensor_4x4, query)
        
        temperature1 = 100.0
        temperature2 = 10.0
        scaled_scores1 = scores1 * temperature1
        scaled_scores2 = scores2 * temperature2
        alpha1 = torch.softmax(scaled_scores1, dim=1)
        alpha2 = torch.softmax(scaled_scores2, dim=1)
        
        local_agg1 = torch.einsum('ij,ijk->ik', alpha1, local_image_embeddings_tensor_2x2)
        local_agg2 = torch.einsum('ij,ijk->ik', alpha2, local_image_embeddings_tensor_4x4)
        

        final_fusion = gamma1 * image_embeddings_tensor + gamma2 * local_agg1 + gamma3 * local_agg2
        final_score = torch.matmul(final_fusion, query)
        
        top_k_indices = torch.argsort(final_score, descending=True)[:10]  # 取前10个
        run[qid] = {corpus_ids[idx]: float(final_score[idx].cpu().numpy()) for idx in top_k_indices}
    for cutoff in [10]:    
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {f"ndcg_cut.{cutoff}", f"recall.{cutoff}"})  
        eval_results = evaluator.evaluate(run)  
        
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            print(f"{measure:25s}{'all':8s}{value:.4f}")  

        
        mrr = eval_mrr(qrels, run, cutoff)['all']  
        print(f'MRR@{cutoff}: {mrr}')  
    

gamma1: 1.0, gamma2: 0.0, gamma3: 0.0


100%|██████████| 11307/11307 [01:17<00:00, 145.22it/s]


ndcg_cut_10              all     0.4524
recall_10                all     0.6140
MRR@10: 0.40187680478871707
gamma1: 0.9, gamma2: 0.1, gamma3: 0.0


100%|██████████| 11307/11307 [01:17<00:00, 146.54it/s]


ndcg_cut_10              all     0.4518
recall_10                all     0.6143
MRR@10: 0.40101657773454097
gamma1: 0.9, gamma2: 0.0, gamma3: 0.1


100%|██████████| 11307/11307 [01:16<00:00, 147.36it/s]


ndcg_cut_10              all     0.4518
recall_10                all     0.6142
MRR@10: 0.40095624848211797
gamma1: 0.8, gamma2: 0.2, gamma3: 0.0


100%|██████████| 11307/11307 [01:17<00:00, 146.44it/s]


ndcg_cut_10              all     0.4467
recall_10                all     0.6094
MRR@10: 0.3958235241267886
gamma1: 0.8, gamma2: 0.1, gamma3: 0.1


100%|██████████| 11307/11307 [01:16<00:00, 147.31it/s]


ndcg_cut_10              all     0.4495
recall_10                all     0.6121
MRR@10: 0.39868405721417155
gamma1: 0.8, gamma2: 0.0, gamma3: 0.2


100%|██████████| 11307/11307 [01:17<00:00, 146.11it/s]


ndcg_cut_10              all     0.4472
recall_10                all     0.6095
MRR@10: 0.3964506465302445
gamma1: 0.7, gamma2: 0.3, gamma3: 0.0


100%|██████████| 11307/11307 [01:16<00:00, 148.10it/s]


ndcg_cut_10              all     0.4341
recall_10                all     0.5973
MRR@10: 0.38306025484985484
gamma1: 0.7, gamma2: 0.2, gamma3: 0.1


100%|██████████| 11307/11307 [01:16<00:00, 148.45it/s]


ndcg_cut_10              all     0.4403
recall_10                all     0.6037
MRR@10: 0.38916024769036195
gamma1: 0.7, gamma2: 0.1, gamma3: 0.2


100%|██████████| 11307/11307 [01:17<00:00, 145.75it/s]


ndcg_cut_10              all     0.4422
recall_10                all     0.6050
MRR@10: 0.3913182029393233
gamma1: 0.7, gamma2: 0.0, gamma3: 0.3


100%|██████████| 11307/11307 [01:16<00:00, 147.00it/s]


ndcg_cut_10              all     0.4383
recall_10                all     0.5997
MRR@10: 0.3878934737716918
gamma1: 0.6, gamma2: 0.4, gamma3: 0.0


100%|██████████| 11307/11307 [01:16<00:00, 148.68it/s]


ndcg_cut_10              all     0.4162
recall_10                all     0.5779
MRR@10: 0.3656724447982087
gamma1: 0.6, gamma2: 0.3, gamma3: 0.1


100%|██████████| 11307/11307 [01:16<00:00, 147.89it/s]


ndcg_cut_10              all     0.4244
recall_10                all     0.5857
MRR@10: 0.37400816462902003
gamma1: 0.6, gamma2: 0.2, gamma3: 0.2


100%|██████████| 11307/11307 [01:17<00:00, 145.79it/s]


ndcg_cut_10              all     0.4278
recall_10                all     0.5903
MRR@10: 0.37698465341739534
gamma1: 0.6, gamma2: 0.1, gamma3: 0.3


100%|██████████| 11307/11307 [01:16<00:00, 148.21it/s]


ndcg_cut_10              all     0.4277
recall_10                all     0.5892
MRR@10: 0.37718740041637516
gamma1: 0.6, gamma2: 0.0, gamma3: 0.4


100%|██████████| 11307/11307 [01:17<00:00, 146.79it/s]


ndcg_cut_10              all     0.4215
recall_10                all     0.5833
MRR@10: 0.37103279889828167
gamma1: 0.5, gamma2: 0.5, gamma3: 0.0


100%|██████████| 11307/11307 [01:16<00:00, 148.67it/s]


ndcg_cut_10              all     0.3905
recall_10                all     0.5542
MRR@10: 0.33943388770266136
gamma1: 0.5, gamma2: 0.4, gamma3: 0.1


100%|██████████| 11307/11307 [01:16<00:00, 147.76it/s]


ndcg_cut_10              all     0.4019
recall_10                all     0.5633
MRR@10: 0.35154522202147703
gamma1: 0.5, gamma2: 0.3, gamma3: 0.2


100%|██████████| 11307/11307 [01:16<00:00, 147.04it/s]


ndcg_cut_10              all     0.4094
recall_10                all     0.5727
MRR@10: 0.3585382562564855
gamma1: 0.5, gamma2: 0.2, gamma3: 0.3


100%|██████████| 11307/11307 [01:17<00:00, 146.48it/s]


ndcg_cut_10              all     0.4102
recall_10                all     0.5749
MRR@10: 0.3589200607574196
gamma1: 0.5, gamma2: 0.1, gamma3: 0.4


100%|██████████| 11307/11307 [01:14<00:00, 151.53it/s]


ndcg_cut_10              all     0.4056
recall_10                all     0.5689
MRR@10: 0.3546505465781149
gamma1: 0.5, gamma2: 0.0, gamma3: 0.5


100%|██████████| 11307/11307 [01:13<00:00, 154.66it/s]


ndcg_cut_10              all     0.3974
recall_10                all     0.5598
MRR@10: 0.3467069142447236
gamma1: 0.4, gamma2: 0.6, gamma3: 0.0


100%|██████████| 11307/11307 [01:13<00:00, 154.40it/s]


ndcg_cut_10              all     0.3561
recall_10                all     0.5214
MRR@10: 0.3046289979097099
gamma1: 0.4, gamma2: 0.5, gamma3: 0.1


100%|██████████| 11307/11307 [01:13<00:00, 153.57it/s]


ndcg_cut_10              all     0.3707
recall_10                all     0.5344
MRR@10: 0.3197230329294556
gamma1: 0.4, gamma2: 0.4, gamma3: 0.2


100%|██████████| 11307/11307 [01:12<00:00, 156.67it/s]


ndcg_cut_10              all     0.3795
recall_10                all     0.5438
MRR@10: 0.32828568059398683
gamma1: 0.4, gamma2: 0.3, gamma3: 0.3


100%|██████████| 11307/11307 [01:11<00:00, 157.07it/s]


ndcg_cut_10              all     0.3825
recall_10                all     0.5478
MRR@10: 0.33096333076434015
gamma1: 0.4, gamma2: 0.2, gamma3: 0.4


100%|██████████| 11307/11307 [01:13<00:00, 153.22it/s]


ndcg_cut_10              all     0.3830
recall_10                all     0.5482
MRR@10: 0.3316224954059936
gamma1: 0.4, gamma2: 0.1, gamma3: 0.5


100%|██████████| 11307/11307 [01:13<00:00, 153.12it/s]


ndcg_cut_10              all     0.3756
recall_10                all     0.5404
MRR@10: 0.32425127151181987
gamma1: 0.4, gamma2: 0.0, gamma3: 0.6


100%|██████████| 11307/11307 [01:13<00:00, 154.58it/s]


ndcg_cut_10              all     0.3649
recall_10                all     0.5307
MRR@10: 0.3133668074700192
gamma1: 0.3, gamma2: 0.7, gamma3: 0.0


100%|██████████| 11307/11307 [01:13<00:00, 153.13it/s]


ndcg_cut_10              all     0.3169
recall_10                all     0.4815
MRR@10: 0.2658230398081821
gamma1: 0.3, gamma2: 0.6, gamma3: 0.1


100%|██████████| 11307/11307 [01:14<00:00, 151.78it/s]


ndcg_cut_10              all     0.3313
recall_10                all     0.4982
MRR@10: 0.27945994263983126
gamma1: 0.3, gamma2: 0.5, gamma3: 0.2


100%|██████████| 11307/11307 [01:15<00:00, 150.49it/s]


ndcg_cut_10              all     0.3417
recall_10                all     0.5071
MRR@10: 0.29037199178483447
gamma1: 0.3, gamma2: 0.4, gamma3: 0.3


  8%|▊         | 873/11307 [00:06<01:09, 149.57it/s]

In [4]:
# SlideVQA

import torch
from tqdm import tqdm
import pytrec_eval
import numpy as np
import csv

def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        print(f"Error loading qrels file: {e}")  
    return qrels 


local_image_embeddings = np.load('embeddings/SlideVQA_corpus_embeddings_8x8.npy')
image_embeddings = np.load('embeddings/SlideVQA_corpus_embeddings.npy')
query_embeddings = np.load('embeddings/SlideVQA_queries_with_instruction_embeddings.npy')
corpus_ids = np.load('embeddings/SlideVQA_corpus_corpus_ids.npy')
query_ids = np.load('embeddings/SlideVQA_queries_query_ids.npy')
qrels = load_beir_qrels('dataset/VisRAG-Ret-Test-SlideVQA/qrels/slidevqa-eval-qrels.tsv')

gamma_list = [round(0.1 * i, 1) for i in range(6, 11)]
# gamma_list = [round(0.80 + 0.01 * i, 2) for i in range(21)]

# 将numpy数组转换为PyTorch张量并移动到GPU
local_image_embeddings_tensor = torch.tensor(local_image_embeddings).cuda()
image_embeddings_tensor = torch.tensor(image_embeddings).cuda()
query_embeddings_tensor = torch.tensor(query_embeddings).cuda()

for gamma in gamma_list:
    print(f'gamma: {gamma}') 
    run = {} 
    for q_idx, query in enumerate(tqdm(query_embeddings_tensor)):
        qid = query_ids[q_idx]
        
        scores = torch.einsum('ijk,k->ij', local_image_embeddings_tensor, query)
        
        temperature = 100.0
        scaled_scores = scores * temperature
        alpha = torch.softmax(scaled_scores, dim=1)
        local_agg = torch.einsum('ij,ijk->ik', alpha, local_image_embeddings_tensor)
        

        final_fusion = gamma * image_embeddings_tensor + (1 - gamma) * local_agg
        final_score = torch.matmul(final_fusion, query)
        
        top_k_indices = torch.argsort(final_score, descending=True)[:10]  # 取前10个
        run[qid] = {corpus_ids[idx]: float(final_score[idx].cpu().numpy()) for idx in top_k_indices}
    for cutoff in [10]:    
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {f"ndcg_cut.{cutoff}", f"recall.{cutoff}"})  
        eval_results = evaluator.evaluate(run)  
        
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            print(f"{measure:25s}{'all':8s}{value:.4f}")  

        
        mrr = eval_mrr(qrels, run, cutoff)['all']  
        print(f'MRR@{cutoff}: {mrr}')  
    

gamma: 0.6


100%|██████████| 1640/1640 [00:12<00:00, 135.30it/s]


ndcg_cut_10              all     0.9146
recall_10                all     0.9736
MRR@10: 0.916935491676345
gamma: 0.7


100%|██████████| 1640/1640 [00:12<00:00, 135.34it/s]


ndcg_cut_10              all     0.9227
recall_10                all     0.9724
MRR@10: 0.9278639179248934
gamma: 0.8


100%|██████████| 1640/1640 [00:12<00:00, 136.57it/s]


ndcg_cut_10              all     0.9243
recall_10                all     0.9741
MRR@10: 0.92918481416957
gamma: 0.9


100%|██████████| 1640/1640 [00:12<00:00, 136.54it/s]


ndcg_cut_10              all     0.9200
recall_10                all     0.9738
MRR@10: 0.9227731804103754
gamma: 1.0


100%|██████████| 1640/1640 [00:11<00:00, 136.72it/s]

ndcg_cut_10              all     0.9146
recall_10                all     0.9686
MRR@10: 0.9176260646535037





In [9]:
# SlideVQA

import torch
from tqdm import tqdm
import pytrec_eval
import numpy as np
import csv

def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        print(f"Error loading qrels file: {e}")  
    return qrels 


local_image_embeddings = np.load('embeddings/SlideVQA_corpus_embeddings_4x4.npy')
image_embeddings = np.load('embeddings/SlideVQA_corpus_embeddings.npy')
query_embeddings = np.load('embeddings/SlideVQA_queries_with_instruction_embeddings.npy')
corpus_ids = np.load('embeddings/SlideVQA_corpus_corpus_ids.npy')
query_ids = np.load('embeddings/SlideVQA_queries_query_ids.npy')
qrels = load_beir_qrels('dataset/VisRAG-Ret-Test-SlideVQA/qrels/slidevqa-eval-qrels.tsv')

gamma_list = [round(0.1 * i, 1) for i in range(5, 11)]
# gamma_list = [round(0.80 + 0.01 * i, 2) for i in range(21)]

# 将numpy数组转换为PyTorch张量并移动到GPU
local_image_embeddings_tensor = torch.tensor(local_image_embeddings).cuda()
image_embeddings_tensor = torch.tensor(image_embeddings).cuda()
query_embeddings_tensor = torch.tensor(query_embeddings).cuda()

for gamma in gamma_list:
    print(f'gamma: {gamma}') 
    run = {} 
    for q_idx, query in enumerate(tqdm(query_embeddings_tensor)):
        qid = query_ids[q_idx]
        
        scores = torch.einsum('ijk,k->ij', local_image_embeddings_tensor, query)
        
        temperature = 25.0
        scaled_scores = scores * temperature
        alpha = torch.softmax(scaled_scores, dim=1)
        local_agg = torch.einsum('ij,ijk->ik', alpha, local_image_embeddings_tensor)
        

        final_fusion = gamma * image_embeddings_tensor + (1 - gamma) * local_agg
        final_score = torch.matmul(final_fusion, query)
        
        top_k_indices = torch.argsort(final_score, descending=True)[:10]  # 取前10个
        run[qid] = {corpus_ids[idx]: float(final_score[idx].cpu().numpy()) for idx in top_k_indices}
    for cutoff in [10]:    
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {f"ndcg_cut.{cutoff}", f"recall.{cutoff}"})  
        eval_results = evaluator.evaluate(run)  
        
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            print(f"{measure:25s}{'all':8s}{value:.4f}")  

        
        mrr = eval_mrr(qrels, run, cutoff)['all']  
        print(f'MRR@{cutoff}: {mrr}')  

gamma: 0.5


100%|██████████| 1640/1640 [00:02<00:00, 778.17it/s]


ndcg_cut_10              all     0.9171
recall_10                all     0.9760
MRR@10: 0.9190875435540065
gamma: 0.6


100%|██████████| 1640/1640 [00:01<00:00, 838.46it/s]


ndcg_cut_10              all     0.9244
recall_10                all     0.9766
MRR@10: 0.9286147406116917
gamma: 0.7


100%|██████████| 1640/1640 [00:01<00:00, 845.06it/s]


ndcg_cut_10              all     0.9259
recall_10                all     0.9751
MRR@10: 0.9311099012775842
gamma: 0.8


100%|██████████| 1640/1640 [00:01<00:00, 837.84it/s]


ndcg_cut_10              all     0.9243
recall_10                all     0.9753
MRR@10: 0.9286536972512581
gamma: 0.9


100%|██████████| 1640/1640 [00:01<00:00, 845.67it/s]


ndcg_cut_10              all     0.9210
recall_10                all     0.9735
MRR@10: 0.9243304781262098
gamma: 1.0


100%|██████████| 1640/1640 [00:01<00:00, 839.04it/s]

ndcg_cut_10              all     0.9146
recall_10                all     0.9686
MRR@10: 0.9176260646535037





In [4]:
# PlotQA

import math
import torch
from tqdm import tqdm
import pytrec_eval
import numpy as np
import csv

def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        print(f"Error loading qrels file: {e}")  
    return qrels 


local_image_embeddings = np.load('embeddings/PlotQA_corpus_embeddings_4x1_with_summary_above.npy')
image_embeddings = np.load('embeddings/PlotQA_corpus_embeddings.npy')
query_embeddings = np.load('embeddings/PlotQA_queries_with_instruction_embeddings.npy')
corpus_ids = np.load('embeddings/PlotQA_corpus_corpus_ids.npy')
query_ids = np.load('embeddings/PlotQA_queries_query_ids.npy')
qrels = load_beir_qrels('dataset/VisRAG-Ret-Test-PlotQA/qrels/plotqa-eval-qrels.tsv')

gamma_list = [round(0.1 * i, 1) for i in range(8, 11)]
# gamma_list = [round(0.80 + 0.01 * i, 2) for i in range(21)]

# 将numpy数组转换为PyTorch张量并移动到GPU
local_image_embeddings_tensor = torch.tensor(local_image_embeddings).cuda()
image_embeddings_tensor = torch.tensor(image_embeddings).cuda()
query_embeddings_tensor = torch.tensor(query_embeddings).cuda()

# embedding_dim = local_image_embeddings_tensor.size(-1)
# d_k = math.sqrt(embedding_dim)
# batch_size = local_image_embeddings_tensor.size(0)
# d_k = math.sqrt(batch_size)

for gamma in gamma_list:
    print(f'gamma: {gamma}') 
    run = {} 
    for q_idx, query in enumerate(tqdm(query_embeddings_tensor)):
        qid = query_ids[q_idx]
        
        scores = torch.einsum('ijk,k->ij', local_image_embeddings_tensor, query)
        
        temperature = 100.0
        scaled_scores = scores * temperature
        # scaled_scores = scores / d_k
        alpha = torch.softmax(scaled_scores, dim=1)
        local_agg = torch.einsum('ij,ijk->ik', alpha, local_image_embeddings_tensor)
        

        final_fusion = gamma * image_embeddings_tensor + (1 - gamma) * local_agg
        final_score = torch.matmul(final_fusion, query)
        
        top_k_indices = torch.argsort(final_score, descending=True)[:10]  # 取前10个
        run[qid] = {corpus_ids[idx]: float(final_score[idx].cpu().numpy()) for idx in top_k_indices}
    for cutoff in [10]:    
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {f"ndcg_cut.{cutoff}", f"recall.{cutoff}"})  
        eval_results = evaluator.evaluate(run)  
        
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            print(f"{measure:25s}{'all':8s}{value:.4f}")  

        
        mrr = eval_mrr(qrels, run, cutoff)['all']  
        print(f'MRR@{cutoff}: {mrr}')  
    

gamma: 0.8


100%|██████████| 11307/11307 [00:28<00:00, 403.32it/s]


ndcg_cut_10              all     0.4311
recall_10                all     0.5944
MRR@10: 0.37998507737165294
gamma: 0.9


100%|██████████| 11307/11307 [00:27<00:00, 405.26it/s]


ndcg_cut_10              all     0.4486
recall_10                all     0.6088
MRR@10: 0.3985223369144835
gamma: 1.0


100%|██████████| 11307/11307 [00:27<00:00, 405.08it/s]

ndcg_cut_10              all     0.4524
recall_10                all     0.6140
MRR@10: 0.40187680478871707





In [2]:
# PlotQA 使用混合分数

import math
import torch
from tqdm import tqdm
import pytrec_eval
import numpy as np
import csv

def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        print(f"Error loading qrels file: {e}")  
    return qrels 


local_image_embeddings = np.load('embeddings/PlotQA_corpus_embeddings_2x2.npy')
image_embeddings = np.load('embeddings/PlotQA_corpus_embeddings.npy')
query_embeddings = np.load('embeddings/PlotQA_queries_with_instruction_embeddings.npy')
corpus_ids = np.load('embeddings/PlotQA_corpus_corpus_ids.npy')
query_ids = np.load('embeddings/PlotQA_queries_query_ids.npy')
qrels = load_beir_qrels('dataset/VisRAG-Ret-Test-PlotQA/qrels/plotqa-eval-qrels.tsv')

keyword_embeddings = np.load('embeddings/PlotQA_keyword_embeddings.npy')

gamma_list = [round(0.1 * i, 1) for i in range(5, 11)]
gamma_list_2 = [round(0.1 * i, 1) for i in range(5, 11)]

# gamma_list = [0.9]
# gamma_list = [round(0.80 + 0.01 * i, 2) for i in range(21)]

# 将numpy数组转换为PyTorch张量并移动到GPU
local_image_embeddings_tensor = torch.tensor(local_image_embeddings).cuda()
image_embeddings_tensor = torch.tensor(image_embeddings).cuda()
query_embeddings_tensor = torch.tensor(query_embeddings).cuda()
keyword_embeddings_tensor = torch.tensor(keyword_embeddings, dtype=torch.float32).cuda()
keyword_embeddings_tensor = torch.squeeze(keyword_embeddings_tensor, 1)
# print(keyword_embeddings_tensor.shape)

# embedding_dim = local_image_embeddings_tensor.size(-1)
# d_k = math.sqrt(embedding_dim)
# batch_size = local_image_embeddings_tensor.size(0)
# d_k = math.sqrt(batch_size)

for gamma in gamma_list:
    
    for gamma_2 in gamma_list_2:
        print(f'gamma: {gamma}') 
        print(f'image: {gamma_2}')
        run = {} 
        for q_idx, query in enumerate(tqdm(query_embeddings_tensor)):
            qid = query_ids[q_idx]
            
            scores = torch.einsum('ijk,k->ij', local_image_embeddings_tensor, query)
            
            temperature = 100.0
            scaled_scores = scores * temperature
            # scaled_scores = scores / d_k
            alpha = torch.softmax(scaled_scores, dim=1)
            local_agg = torch.einsum('ij,ijk->ik', alpha, local_image_embeddings_tensor)
            

            final_fusion = gamma * image_embeddings_tensor + (1 - gamma) * local_agg
            final_score = torch.matmul(final_fusion, query)
            # print(final_fusion.shape)
            # print('final score', final_score)
            
            keyword_score = torch.matmul(keyword_embeddings_tensor, query)
            # print(keyword_embeddings_tensor.shape)
            # print('keyword score', keyword_score)
            
            final_scores = gamma_2 * final_score + (1 - gamma_2) * keyword_score
            
            top_k_indices = torch.argsort(final_scores, descending=True)[:10]  # 取前10个
            run[qid] = {corpus_ids[idx]: float(final_scores[idx].cpu().numpy()) for idx in top_k_indices}
        for cutoff in [10]:    
            evaluator = pytrec_eval.RelevanceEvaluator(qrels, {f"ndcg_cut.{cutoff}", f"recall.{cutoff}"})  
            eval_results = evaluator.evaluate(run)  
            
            for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
                value = pytrec_eval.compute_aggregated_measure(  
                    measure, [query_measures[measure] for query_measures in eval_results.values()]  
                )  
                print(f"{measure:25s}{'all':8s}{value:.4f}")  

            
            mrr = eval_mrr(qrels, run, cutoff)['all']  
            print(f'MRR@{cutoff}: {mrr}')  
    

gamma: 0.5
image: 0.5


100%|██████████| 11307/11307 [00:30<00:00, 375.06it/s]


ndcg_cut_10              all     0.3381
recall_10                all     0.4937
MRR@10: 0.2899112222938183
gamma: 0.5
image: 0.6


100%|██████████| 11307/11307 [00:30<00:00, 372.40it/s]


ndcg_cut_10              all     0.3642
recall_10                all     0.5241
MRR@10: 0.3145209597650573
gamma: 0.5
image: 0.7


100%|██████████| 11307/11307 [00:29<00:00, 382.40it/s]


ndcg_cut_10              all     0.3814
recall_10                all     0.5404
MRR@10: 0.3318644792311569
gamma: 0.5
image: 0.8


100%|██████████| 11307/11307 [00:29<00:00, 381.38it/s]


ndcg_cut_10              all     0.3903
recall_10                all     0.5505
MRR@10: 0.3403400197377397
gamma: 0.5
image: 0.9


100%|██████████| 11307/11307 [00:29<00:00, 382.22it/s]


ndcg_cut_10              all     0.3918
recall_10                all     0.5524
MRR@10: 0.3416628061560426
gamma: 0.5
image: 1.0


100%|██████████| 11307/11307 [00:29<00:00, 382.01it/s]


ndcg_cut_10              all     0.3905
recall_10                all     0.5542
MRR@10: 0.33943388770266136
gamma: 0.6
image: 0.5


100%|██████████| 11307/11307 [00:29<00:00, 380.73it/s]


ndcg_cut_10              all     0.3561
recall_10                all     0.5163
MRR@10: 0.30647105810279174
gamma: 0.6
image: 0.6


100%|██████████| 11307/11307 [00:30<00:00, 373.67it/s]


ndcg_cut_10              all     0.3856
recall_10                all     0.5468
MRR@10: 0.33540912989705834
gamma: 0.6
image: 0.7


100%|██████████| 11307/11307 [00:30<00:00, 370.63it/s]


ndcg_cut_10              all     0.4041
recall_10                all     0.5641
MRR@10: 0.3542161689415616
gamma: 0.6
image: 0.8


100%|██████████| 11307/11307 [00:30<00:00, 368.54it/s]


ndcg_cut_10              all     0.4131
recall_10                all     0.5731
MRR@10: 0.3631684474149332
gamma: 0.6
image: 0.9


100%|██████████| 11307/11307 [00:31<00:00, 364.08it/s]


ndcg_cut_10              all     0.4166
recall_10                all     0.5776
MRR@10: 0.3663406991876097
gamma: 0.6
image: 1.0


100%|██████████| 11307/11307 [00:31<00:00, 363.83it/s]


ndcg_cut_10              all     0.4162
recall_10                all     0.5779
MRR@10: 0.3656724447982087
gamma: 0.7
image: 0.5


100%|██████████| 11307/11307 [00:30<00:00, 367.33it/s]


ndcg_cut_10              all     0.3721
recall_10                all     0.5319
MRR@10: 0.32251604217643076
gamma: 0.7
image: 0.6


100%|██████████| 11307/11307 [00:30<00:00, 368.07it/s]


ndcg_cut_10              all     0.4027
recall_10                all     0.5651
MRR@10: 0.3520872728089512
gamma: 0.7
image: 0.7


100%|██████████| 11307/11307 [00:30<00:00, 369.87it/s]


ndcg_cut_10              all     0.4214
recall_10                all     0.5813
MRR@10: 0.37148953240091637
gamma: 0.7
image: 0.8


100%|██████████| 11307/11307 [00:30<00:00, 371.62it/s]


ndcg_cut_10              all     0.4318
recall_10                all     0.5929
MRR@10: 0.3815416352561488
gamma: 0.7
image: 0.9


100%|██████████| 11307/11307 [00:30<00:00, 372.28it/s]


ndcg_cut_10              all     0.4351
recall_10                all     0.5971
MRR@10: 0.3844656912911102
gamma: 0.7
image: 1.0


100%|██████████| 11307/11307 [00:30<00:00, 375.03it/s]


ndcg_cut_10              all     0.4341
recall_10                all     0.5973
MRR@10: 0.38306025484985484
gamma: 0.8
image: 0.5


100%|██████████| 11307/11307 [00:31<00:00, 360.88it/s]


ndcg_cut_10              all     0.3844
recall_10                all     0.5465
MRR@10: 0.3340261195129874
gamma: 0.8
image: 0.6


100%|██████████| 11307/11307 [00:31<00:00, 357.43it/s]


ndcg_cut_10              all     0.4148
recall_10                all     0.5764
MRR@10: 0.36428044995304354
gamma: 0.8
image: 0.7


100%|██████████| 11307/11307 [00:30<00:00, 365.85it/s]


ndcg_cut_10              all     0.4336
recall_10                all     0.5939
MRR@10: 0.3834087185771993
gamma: 0.8
image: 0.8


100%|██████████| 11307/11307 [00:30<00:00, 368.93it/s]


ndcg_cut_10              all     0.4434
recall_10                all     0.6044
MRR@10: 0.3931422240191151
gamma: 0.8
image: 0.9


100%|██████████| 11307/11307 [00:30<00:00, 364.77it/s]


ndcg_cut_10              all     0.4469
recall_10                all     0.6102
MRR@10: 0.395898839179551
gamma: 0.8
image: 1.0


100%|██████████| 11307/11307 [00:31<00:00, 363.54it/s]


ndcg_cut_10              all     0.4467
recall_10                all     0.6094
MRR@10: 0.3958235241267886
gamma: 0.9
image: 0.5


100%|██████████| 11307/11307 [00:30<00:00, 366.67it/s]


ndcg_cut_10              all     0.3949
recall_10                all     0.5574
MRR@10: 0.3442398724768071
gamma: 0.9
image: 0.6


100%|██████████| 11307/11307 [00:30<00:00, 367.60it/s]


ndcg_cut_10              all     0.4235
recall_10                all     0.5849
MRR@10: 0.372984322115393
gamma: 0.9
image: 0.7


100%|██████████| 11307/11307 [00:30<00:00, 370.68it/s]


ndcg_cut_10              all     0.4402
recall_10                all     0.6002
MRR@10: 0.3900829799211336
gamma: 0.9
image: 0.8


100%|██████████| 11307/11307 [00:30<00:00, 374.56it/s]


ndcg_cut_10              all     0.4491
recall_10                all     0.6099
MRR@10: 0.39875112481241454
gamma: 0.9
image: 0.9


100%|██████████| 11307/11307 [00:30<00:00, 370.33it/s]


ndcg_cut_10              all     0.4515
recall_10                all     0.6135
MRR@10: 0.40074711409282937
gamma: 0.9
image: 1.0


100%|██████████| 11307/11307 [00:30<00:00, 371.62it/s]


ndcg_cut_10              all     0.4518
recall_10                all     0.6143
MRR@10: 0.40101657773454097
gamma: 1.0
image: 0.5


100%|██████████| 11307/11307 [00:31<00:00, 359.60it/s]


ndcg_cut_10              all     0.4018
recall_10                all     0.5637
MRR@10: 0.35133401699467187
gamma: 1.0
image: 0.6


100%|██████████| 11307/11307 [00:31<00:00, 359.19it/s]


ndcg_cut_10              all     0.4277
recall_10                all     0.5892
MRR@10: 0.37715528798707443
gamma: 1.0
image: 0.7


100%|██████████| 11307/11307 [00:30<00:00, 369.63it/s]


ndcg_cut_10              all     0.4420
recall_10                all     0.6024
MRR@10: 0.39179420389953734
gamma: 1.0
image: 0.8


100%|██████████| 11307/11307 [00:30<00:00, 366.73it/s]


ndcg_cut_10              all     0.4497
recall_10                all     0.6094
MRR@10: 0.3996904572388784
gamma: 1.0
image: 0.9


100%|██████████| 11307/11307 [00:31<00:00, 363.62it/s]


ndcg_cut_10              all     0.4530
recall_10                all     0.6138
MRR@10: 0.4027080078220969
gamma: 1.0
image: 1.0


100%|██████████| 11307/11307 [00:31<00:00, 363.27it/s]

ndcg_cut_10              all     0.4524
recall_10                all     0.6140
MRR@10: 0.40187680478871707





In [None]:
# PlotQA 使用混合编码

import math
import torch
from tqdm import tqdm
import pytrec_eval
import numpy as np
import csv

def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        print(f"Error loading qrels file: {e}")  
    return qrels 


local_image_embeddings = np.load('embeddings/PlotQA_corpus_embeddings_2x2.npy')
image_embeddings = np.load('embeddings/PlotQA_corpus_embeddings.npy')
query_embeddings = np.load('embeddings/PlotQA_queries_with_instruction_embeddings.npy')
corpus_ids = np.load('embeddings/PlotQA_corpus_corpus_ids.npy')
query_ids = np.load('embeddings/PlotQA_queries_query_ids.npy')
qrels = load_beir_qrels('dataset/VisRAG-Ret-Test-PlotQA/qrels/plotqa-eval-qrels.tsv')

keyword_embeddings = np.load('embeddings/PlotQA_keyword_embeddings_with_instruction.npy')

# gamma_list = [round(0.1 * i, 1) for i in range(8, 11)]
# gamma_list = [round(0.80 + 0.01 * i, 2) for i in range(21)]
gamma_list = []

for i in range(11):
    for j in range(11 - i):
        k = 10 - i - j
        gamma_list.append((round(i * 0.1, 1), round(j * 0.1, 1), round(k * 0.1, 1)))
gamma_list.reverse()

# print(len(gamma_list)) 
# print(gamma_list)

# 将numpy数组转换为PyTorch张量并移动到GPU
local_image_embeddings_tensor = torch.tensor(local_image_embeddings).cuda()
image_embeddings_tensor = torch.tensor(image_embeddings).cuda()
query_embeddings_tensor = torch.tensor(query_embeddings).cuda()

keyword_embeddings_tensor = torch.tensor(keyword_embeddings, dtype=torch.float32).cuda()
keyword_embeddings_tensor = torch.squeeze(keyword_embeddings_tensor, 1)
# print(keyword_embeddings_tensor.shape)

# embedding_dim = local_image_embeddings_tensor.size(-1)
# d_k = math.sqrt(embedding_dim)
# batch_size = local_image_embeddings_tensor.size(0)
# d_k = math.sqrt(batch_size)

for gamma1, gamma2, gamma3 in gamma_list:
    print(f'gamma1: {gamma1}, gamma2: {gamma2}, gamma3: {gamma3}') 
    run = {} 
    for q_idx, query in enumerate(tqdm(query_embeddings_tensor)):
        qid = query_ids[q_idx]
        
        scores = torch.einsum('ijk,k->ij', local_image_embeddings_tensor, query)
        
        temperature = 100.0
        scaled_scores = scores * temperature
        # scaled_scores = scores / d_k
        alpha = torch.softmax(scaled_scores, dim=1)
        local_agg = torch.einsum('ij,ijk->ik', alpha, local_image_embeddings_tensor)
        

        # final_fusion = gamma * image_embeddings_tensor + (1 - gamma) * local_agg
        
        final_fusion = gamma1 * image_embeddings_tensor + gamma2 * local_agg + gamma3 * keyword_embeddings_tensor
        
        final_score = torch.matmul(final_fusion, query)
        
        top_k_indices = torch.argsort(final_score, descending=True)[:10]  # 取前10个
        run[qid] = {corpus_ids[idx]: float(final_score[idx].cpu().numpy()) for idx in top_k_indices}
    for cutoff in [10]:    
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {f"ndcg_cut.{cutoff}", f"recall.{cutoff}"})  
        eval_results = evaluator.evaluate(run)  
        
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            print(f"{measure:25s}{'all':8s}{value:.4f}")  

        
        mrr = eval_mrr(qrels, run, cutoff)['all']  
        print(f'MRR@{cutoff}: {mrr}')  
    

gamma1: 1.0, gamma2: 0.0, gamma3: 0.0


100%|██████████| 11307/11307 [00:34<00:00, 327.38it/s]


ndcg_cut_10              all     0.4524
recall_10                all     0.6140
MRR@10: 0.40187680478871707
gamma1: 0.9, gamma2: 0.1, gamma3: 0.0


100%|██████████| 11307/11307 [00:34<00:00, 330.23it/s]


ndcg_cut_10              all     0.4518
recall_10                all     0.6143
MRR@10: 0.40101657773454097
gamma1: 0.9, gamma2: 0.0, gamma3: 0.1


100%|██████████| 11307/11307 [00:34<00:00, 330.22it/s]


ndcg_cut_10              all     0.4474
recall_10                all     0.6109
MRR@10: 0.3961653547949654
gamma1: 0.8, gamma2: 0.2, gamma3: 0.0


100%|██████████| 11307/11307 [00:34<00:00, 330.60it/s]


ndcg_cut_10              all     0.4467
recall_10                all     0.6094
MRR@10: 0.3958235241267886
gamma1: 0.8, gamma2: 0.1, gamma3: 0.1


100%|██████████| 11307/11307 [00:34<00:00, 329.33it/s]


ndcg_cut_10              all     0.4457
recall_10                all     0.6102
MRR@10: 0.39415251964999903
gamma1: 0.8, gamma2: 0.0, gamma3: 0.2


100%|██████████| 11307/11307 [00:34<00:00, 324.55it/s]


ndcg_cut_10              all     0.4330
recall_10                all     0.5986
MRR@10: 0.3812878242302508
gamma1: 0.7, gamma2: 0.3, gamma3: 0.0


100%|██████████| 11307/11307 [00:34<00:00, 325.12it/s]


ndcg_cut_10              all     0.4341
recall_10                all     0.5973
MRR@10: 0.38306025484985484
gamma1: 0.7, gamma2: 0.2, gamma3: 0.1


100%|██████████| 11307/11307 [00:34<00:00, 329.05it/s]


ndcg_cut_10              all     0.4393
recall_10                all     0.6044
MRR@10: 0.38775312666265255
gamma1: 0.7, gamma2: 0.1, gamma3: 0.2


100%|██████████| 11307/11307 [00:34<00:00, 326.18it/s]


ndcg_cut_10              all     0.4288
recall_10                all     0.5957
MRR@10: 0.3766225726162054
gamma1: 0.7, gamma2: 0.0, gamma3: 0.3


100%|██████████| 11307/11307 [00:34<00:00, 327.49it/s]


ndcg_cut_10              all     0.3902
recall_10                all     0.5627
MRR@10: 0.33650372504179954
gamma1: 0.6, gamma2: 0.4, gamma3: 0.0


100%|██████████| 11307/11307 [00:34<00:00, 328.16it/s]


ndcg_cut_10              all     0.4162
recall_10                all     0.5779
MRR@10: 0.3656724447982087
gamma1: 0.6, gamma2: 0.3, gamma3: 0.1


100%|██████████| 11307/11307 [00:34<00:00, 331.47it/s]


ndcg_cut_10              all     0.4235
recall_10                all     0.5882
MRR@10: 0.3720072268758934
gamma1: 0.6, gamma2: 0.2, gamma3: 0.2


100%|██████████| 11307/11307 [00:34<00:00, 331.25it/s]


ndcg_cut_10              all     0.4168
recall_10                all     0.5842
MRR@10: 0.3645533178632158
gamma1: 0.6, gamma2: 0.1, gamma3: 0.3


100%|██████████| 11307/11307 [00:33<00:00, 333.61it/s]


ndcg_cut_10              all     0.3805
recall_10                all     0.5534
MRR@10: 0.3266743736497007
gamma1: 0.6, gamma2: 0.0, gamma3: 0.4


100%|██████████| 11307/11307 [00:34<00:00, 332.23it/s]


ndcg_cut_10              all     0.3058
recall_10                all     0.4757
MRR@10: 0.25324391688811954
gamma1: 0.5, gamma2: 0.5, gamma3: 0.0


100%|██████████| 11307/11307 [00:34<00:00, 329.79it/s]


ndcg_cut_10              all     0.3905
recall_10                all     0.5542
MRR@10: 0.33943388770266136
gamma1: 0.5, gamma2: 0.4, gamma3: 0.1


100%|██████████| 11307/11307 [00:34<00:00, 325.06it/s]


ndcg_cut_10              all     0.4007
recall_10                all     0.5655
MRR@10: 0.349293666937606
gamma1: 0.5, gamma2: 0.3, gamma3: 0.2


100%|██████████| 11307/11307 [00:34<00:00, 331.92it/s]


ndcg_cut_10              all     0.3954
recall_10                all     0.5647
MRR@10: 0.34269184281123893
gamma1: 0.5, gamma2: 0.2, gamma3: 0.3


100%|██████████| 11307/11307 [00:34<00:00, 327.98it/s]


ndcg_cut_10              all     0.3617
recall_10                all     0.5365
MRR@10: 0.3073668720458333
gamma1: 0.5, gamma2: 0.1, gamma3: 0.4


 98%|█████████▊| 11030/11307 [00:34<00:00, 329.23it/s]

In [2]:
# ArxivQA

import math
import torch
from tqdm import tqdm
import pytrec_eval
import numpy as np
import csv

def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        print(f"Error loading qrels file: {e}")  
    return qrels 


local_image_embeddings = np.load('embeddings/ArxivQA_corpus_embeddings_4x4.npy')
image_embeddings = np.load('embeddings/ArxivQA_corpus_embeddings.npy')
query_embeddings = np.load('embeddings/ArxivQA_queries_with_instruction_embeddings.npy')
corpus_ids = np.load('embeddings/ArxivQA_corpus_corpus_ids.npy')
query_ids = np.load('embeddings/ArxivQA_queries_query_ids.npy')
qrels = load_beir_qrels('dataset/VisRAG-Ret-Test-ArxivQA/qrels/arxivqa-eval-qrels.tsv')

gamma_list = [round(0.1 * i, 1) for i in range(6, 11)]
# gamma_list = [round(0.80 + 0.01 * i, 2) for i in range(21)]

# 将numpy数组转换为PyTorch张量并移动到GPU
local_image_embeddings_tensor = torch.tensor(local_image_embeddings).cuda()
image_embeddings_tensor = torch.tensor(image_embeddings).cuda()
query_embeddings_tensor = torch.tensor(query_embeddings).cuda()

# embedding_dim = local_image_embeddings_tensor.size(-1)
# d_k = math.sqrt(embedding_dim)
# batch_size = local_image_embeddings_tensor.size(0)
# d_k = math.sqrt(batch_size)

for gamma in gamma_list:
    print(f'gamma: {gamma}') 
    run = {} 
    for q_idx, query in enumerate(tqdm(query_embeddings_tensor)):
        qid = query_ids[q_idx]
        
        scores = torch.einsum('ijk,k->ij', local_image_embeddings_tensor, query)
        
        temperature = 20.0
        scaled_scores = scores * temperature
        # scaled_scores = scores / d_k
        alpha = torch.softmax(scaled_scores, dim=1)
        local_agg = torch.einsum('ij,ijk->ik', alpha, local_image_embeddings_tensor)
        

        final_fusion = gamma * image_embeddings_tensor + (1 - gamma) * local_agg
        final_score = torch.matmul(final_fusion, query)
        
        top_k_indices = torch.argsort(final_score, descending=True)[:10]  # 取前10个
        run[qid] = {corpus_ids[idx]: float(final_score[idx].cpu().numpy()) for idx in top_k_indices}
    for cutoff in [10]:    
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {f"ndcg_cut.{cutoff}", f"recall.{cutoff}"})  
        eval_results = evaluator.evaluate(run)  
        
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            print(f"{measure:25s}{'all':8s}{value:.4f}")  

        
        mrr = eval_mrr(qrels, run, cutoff)['all']  
        print(f'MRR@{cutoff}: {mrr}')  
    

gamma: 0.6


100%|██████████| 8640/8640 [01:53<00:00, 76.04it/s]


ndcg_cut_10              all     0.7055
recall_10                all     0.8096
MRR@10: 0.6723660714285689
gamma: 0.7


100%|██████████| 8640/8640 [01:53<00:00, 76.02it/s]


ndcg_cut_10              all     0.7151
recall_10                all     0.8149
MRR@10: 0.6833274544385634
gamma: 0.8


100%|██████████| 8640/8640 [01:55<00:00, 75.01it/s]


ndcg_cut_10              all     0.7164
recall_10                all     0.8148
MRR@10: 0.6851061875367406
gamma: 0.9


100%|██████████| 8640/8640 [02:07<00:00, 67.92it/s]


ndcg_cut_10              all     0.7114
recall_10                all     0.8115
MRR@10: 0.6796276087595516
gamma: 1.0


100%|██████████| 8640/8640 [02:59<00:00, 48.14it/s]

ndcg_cut_10              all     0.7000
recall_10                all     0.8042
MRR@10: 0.66699133781599





In [24]:
# ArxivQA in domain

import math
import torch
from tqdm import tqdm
import pytrec_eval
import numpy as np
import csv

def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        print(f"Error loading qrels file: {e}")  
    return qrels 


local_image_embeddings = np.load('embeddings/ArxivQA_corpus_embeddings_4x4_in_domain.npy')
image_embeddings = np.load('embeddings/ArxivQA_corpus_embeddings_in_domain.npy').squeeze(1)
query_embeddings = np.load('embeddings/ArxivQA_query_embeddings_in_domain.npy').squeeze(1)
corpus_ids = np.load('embeddings/ArxivQA_corpus_corpus_ids.npy')
query_ids = np.load('embeddings/ArxivQA_queries_query_ids.npy')
qrels = load_beir_qrels('dataset/VisRAG-Ret-Test-ArxivQA/qrels/arxivqa-eval-qrels.tsv')

gamma_list = [round(0.1 * i, 1) for i in range(6, 11)]
# gamma_list = [round(0.80 + 0.01 * i, 2) for i in range(21)]

# 将numpy数组转换为PyTorch张量并移动到GPU
local_image_embeddings_tensor = torch.tensor(local_image_embeddings).cuda()
image_embeddings_tensor = torch.tensor(image_embeddings).cuda()
query_embeddings_tensor = torch.tensor(query_embeddings).cuda()

# embedding_dim = local_image_embeddings_tensor.size(-1)
# d_k = math.sqrt(embedding_dim)
# batch_size = local_image_embeddings_tensor.size(0)
# d_k = math.sqrt(batch_size)

for gamma in gamma_list:
    print(f'gamma: {gamma}') 
    run = {} 
    for q_idx, query in enumerate(tqdm(query_embeddings_tensor)):
        qid = query_ids[q_idx]
        
        scores = torch.einsum('ijk,k->ij', local_image_embeddings_tensor, query)
        
        temperature = 20.0
        scaled_scores = scores * temperature
        # scaled_scores = scores / d_k
        alpha = torch.softmax(scaled_scores, dim=1)
        local_agg = torch.einsum('ij,ijk->ik', alpha, local_image_embeddings_tensor)
        

        final_fusion = gamma * image_embeddings_tensor + (1 - gamma) * local_agg
        final_score = torch.matmul(final_fusion, query)
        
        top_k_indices = torch.argsort(final_score, descending=True)[:10]  # 取前10个
        run[qid] = {corpus_ids[idx]: float(final_score[idx].cpu().numpy()) for idx in top_k_indices}
    for cutoff in [10]:    
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {f"ndcg_cut.{cutoff}", f"recall.{cutoff}"})  
        eval_results = evaluator.evaluate(run)  
        
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            print(f"{measure:25s}{'all':8s}{value:.4f}")  

        
        mrr = eval_mrr(qrels, run, cutoff)['all']  
        print(f'MRR@{cutoff}: {mrr}')  
    

gamma: 0.6


100%|██████████| 8640/8640 [00:39<00:00, 220.59it/s]


ndcg_cut_10              all     0.7298
recall_10                all     0.8346
MRR@10: 0.6964588385508497
gamma: 0.7


100%|██████████| 8640/8640 [00:39<00:00, 216.37it/s]


ndcg_cut_10              all     0.7424
recall_10                all     0.8440
MRR@10: 0.7100931896678389
gamma: 0.8


100%|██████████| 8640/8640 [00:38<00:00, 223.72it/s]


ndcg_cut_10              all     0.7470
recall_10                all     0.8488
MRR@10: 0.7146937922545535
gamma: 0.9


100%|██████████| 8640/8640 [00:39<00:00, 220.05it/s]


ndcg_cut_10              all     0.7444
recall_10                all     0.8462
MRR@10: 0.712152180702526
gamma: 1.0


100%|██████████| 8640/8640 [00:39<00:00, 218.87it/s]

ndcg_cut_10              all     0.7342
recall_10                all     0.8381
MRR@10: 0.7012619415049944





In [24]:
# InfoVQA

import math
import torch
from tqdm import tqdm
import pytrec_eval
import numpy as np
import csv

def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        print(f"Error loading qrels file: {e}")  
    return qrels 


local_image_embeddings = np.load('embeddings/InfoVQA_corpus_embeddings_4x4.npy')
image_embeddings = np.load('embeddings/InfoVQA_corpus_embeddings.npy')
query_embeddings = np.load('embeddings/InfoVQA_queries_with_instruction_embeddings.npy')
corpus_ids = np.load('embeddings/InfoVQA_corpus_corpus_ids.npy')
query_ids = np.load('embeddings/InfoVQA_queries_query_ids.npy')
qrels = load_beir_qrels('dataset/VisRAG-Ret-Test-InfoVQA/qrels/infographicsvqa-eval-qrels.tsv')

gamma_list = [round(0.1 * i, 1) for i in range(5, 11)]
# gamma_list = [round(0.80 + 0.01 * i, 2) for i in range(21)]

# 将numpy数组转换为PyTorch张量并移动到GPU
local_image_embeddings_tensor = torch.tensor(local_image_embeddings).cuda()
image_embeddings_tensor = torch.tensor(image_embeddings).cuda()
query_embeddings_tensor = torch.tensor(query_embeddings).cuda()

# embedding_dim = local_image_embeddings_tensor.size(-1)
# d_k = math.sqrt(embedding_dim)
# batch_size = local_image_embeddings_tensor.size(0)
# d_k = math.sqrt(batch_size)

for gamma in gamma_list:
    print(f'gamma: {gamma}') 
    run = {} 
    for q_idx, query in enumerate(tqdm(query_embeddings_tensor)):
        qid = query_ids[q_idx]
        
        scores = torch.einsum('ijk,k->ij', local_image_embeddings_tensor, query)
        
        temperature = 100.0
        scaled_scores = scores * temperature
        # scaled_scores = scores / d_k
        alpha = torch.softmax(scaled_scores, dim=1)
        local_agg = torch.einsum('ij,ijk->ik', alpha, local_image_embeddings_tensor)
        

        final_fusion = gamma * image_embeddings_tensor + (1 - gamma) * local_agg
        final_score = torch.matmul(final_fusion, query)
        
        top_k_indices = torch.argsort(final_score, descending=True)[:10]  # 取前10个
        run[qid] = {corpus_ids[idx]: float(final_score[idx].cpu().numpy()) for idx in top_k_indices}
    for cutoff in [10]:    
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {f"ndcg_cut.{cutoff}", f"recall.{cutoff}"})  
        eval_results = evaluator.evaluate(run)  
        
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            print(f"{measure:25s}{'all':8s}{value:.4f}")  

        
        mrr = eval_mrr(qrels, run, cutoff)['all']  
        print(f'MRR@{cutoff}: {mrr}')  
    

gamma: 0.5


100%|██████████| 2046/2046 [00:02<00:00, 876.58it/s]


ndcg_cut_10              all     0.8884
recall_10                all     0.9761
MRR@10: 0.8600166410650273
gamma: 0.6


100%|██████████| 2046/2046 [00:02<00:00, 919.39it/s]


ndcg_cut_10              all     0.8901
recall_10                all     0.9775
MRR@10: 0.8617734565315203
gamma: 0.7


100%|██████████| 2046/2046 [00:02<00:00, 913.41it/s]


ndcg_cut_10              all     0.8906
recall_10                all     0.9736
MRR@10: 0.8635772083352723
gamma: 0.8


100%|██████████| 2046/2046 [00:02<00:00, 878.40it/s]


ndcg_cut_10              all     0.8877
recall_10                all     0.9721
MRR@10: 0.8603033018355594
gamma: 0.9


100%|██████████| 2046/2046 [00:02<00:00, 863.04it/s]


ndcg_cut_10              all     0.8805
recall_10                all     0.9663
MRR@10: 0.8524637698024786
gamma: 1.0


100%|██████████| 2046/2046 [00:02<00:00, 866.95it/s]

ndcg_cut_10              all     0.8711
recall_10                all     0.9624
MRR@10: 0.8412190258964445





In [31]:
# ChartQA

import math
import torch
from tqdm import tqdm
import pytrec_eval
import numpy as np
import csv

def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        print(f"Error loading qrels file: {e}")  
    return qrels 


local_image_embeddings = np.load('embeddings/ChartQA_corpus_embeddings_4x4.npy')
image_embeddings = np.load('embeddings/ChartQA_corpus_embeddings.npy')
query_embeddings = np.load('embeddings/ChartQA_queries_with_instruction_embeddings.npy')
corpus_ids = np.load('embeddings/ChartQA_corpus_corpus_ids.npy')
query_ids = np.load('embeddings/ChartQA_queries_query_ids.npy')
qrels = load_beir_qrels('dataset/VisRAG-Ret-Test-ChartQA/qrels/chartqa-eval-qrels.tsv')

gamma_list = [round(0.1 * i, 1) for i in range(5, 11)]
# gamma_list = [round(0.80 + 0.01 * i, 2) for i in range(21)]

# 将numpy数组转换为PyTorch张量并移动到GPU
local_image_embeddings_tensor = torch.tensor(local_image_embeddings).cuda()
image_embeddings_tensor = torch.tensor(image_embeddings).cuda()
query_embeddings_tensor = torch.tensor(query_embeddings).cuda()

# embedding_dim = local_image_embeddings_tensor.size(-1)
# d_k = math.sqrt(embedding_dim)
# batch_size = local_image_embeddings_tensor.size(0)
# d_k = math.sqrt(batch_size)

for gamma in gamma_list:
    print(f'gamma: {gamma}') 
    run = {} 
    for q_idx, query in enumerate(tqdm(query_embeddings_tensor)):
        qid = query_ids[q_idx]
        
        scores = torch.einsum('ijk,k->ij', local_image_embeddings_tensor, query)
        
        temperature = 25.0
        scaled_scores = scores * temperature
        # scaled_scores = scores / d_k
        alpha = torch.softmax(scaled_scores, dim=1)
        local_agg = torch.einsum('ij,ijk->ik', alpha, local_image_embeddings_tensor)
        

        final_fusion = gamma * image_embeddings_tensor + (1 - gamma) * local_agg
        final_score = torch.matmul(final_fusion, query)
        
        top_k_indices = torch.argsort(final_score, descending=True)[:10]  # 取前10个
        run[qid] = {corpus_ids[idx]: float(final_score[idx].cpu().numpy()) for idx in top_k_indices}
    for cutoff in [10]:    
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {f"ndcg_cut.{cutoff}", f"recall.{cutoff}"})  
        eval_results = evaluator.evaluate(run)  
        
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            print(f"{measure:25s}{'all':8s}{value:.4f}")  

        
        mrr = eval_mrr(qrels, run, cutoff)['all']  
        print(f'MRR@{cutoff}: {mrr}')  
    

gamma: 0.5


100%|██████████| 718/718 [00:00<00:00, 719.66it/s]


ndcg_cut_10              all     0.6126
recall_10                all     0.7312
MRR@10: 0.5754929919971701
gamma: 0.6


100%|██████████| 718/718 [00:00<00:00, 747.38it/s]


ndcg_cut_10              all     0.6189
recall_10                all     0.7298
MRR@10: 0.5837179997347127
gamma: 0.7


100%|██████████| 718/718 [00:00<00:00, 726.12it/s]


ndcg_cut_10              all     0.6271
recall_10                all     0.7409
MRR@10: 0.5912239686961132
gamma: 0.8


100%|██████████| 718/718 [00:00<00:00, 747.16it/s]


ndcg_cut_10              all     0.6317
recall_10                all     0.7409
MRR@10: 0.5971144493080425
gamma: 0.9


100%|██████████| 718/718 [00:00<00:00, 718.14it/s]


ndcg_cut_10              all     0.6305
recall_10                all     0.7382
MRR@10: 0.5964943405403011
gamma: 1.0


100%|██████████| 718/718 [00:01<00:00, 715.93it/s]

ndcg_cut_10              all     0.6204
recall_10                all     0.7242
MRR@10: 0.5873922270858202





In [22]:
# ChartQA in domain

import math
import torch
from tqdm import tqdm
import pytrec_eval
import numpy as np
import csv

def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        print(f"Error loading qrels file: {e}")  
    return qrels 


local_image_embeddings = np.load('embeddings/ChartQA_corpus_embeddings_4x4_in_domain.npy')
image_embeddings = np.load('embeddings/ChartQA_corpus_embeddings_in_domain.npy').squeeze(1)
query_embeddings = np.load('embeddings/ChartQA_query_embeddings_in_domain.npy').squeeze(1)
corpus_ids = np.load('embeddings/ChartQA_corpus_corpus_ids.npy')
query_ids = np.load('embeddings/ChartQA_queries_query_ids.npy')
qrels = load_beir_qrels('dataset/VisRAG-Ret-Test-ChartQA/qrels/chartqa-eval-qrels.tsv')

gamma_list = [round(0.1 * i, 1) for i in range(5, 11)]
# gamma_list = [round(0.80 + 0.01 * i, 2) for i in range(21)]

# 将numpy数组转换为PyTorch张量并移动到GPU
local_image_embeddings_tensor = torch.tensor(local_image_embeddings).cuda()
image_embeddings_tensor = torch.tensor(image_embeddings).cuda()
query_embeddings_tensor = torch.tensor(query_embeddings).cuda()

# embedding_dim = local_image_embeddings_tensor.size(-1)
# d_k = math.sqrt(embedding_dim)
# batch_size = local_image_embeddings_tensor.size(0)
# d_k = math.sqrt(batch_size)

for gamma in gamma_list:
    print(f'gamma: {gamma}') 
    run = {} 
    for q_idx, query in enumerate(tqdm(query_embeddings_tensor)):
        qid = query_ids[q_idx]
        
        scores = torch.einsum('ijk,k->ij', local_image_embeddings_tensor, query)
        
        temperature = 100.0
        scaled_scores = scores * temperature
        # scaled_scores = scores / d_k
        alpha = torch.softmax(scaled_scores, dim=1)
        local_agg = torch.einsum('ij,ijk->ik', alpha, local_image_embeddings_tensor)
        

        final_fusion = gamma * image_embeddings_tensor + (1 - gamma) * local_agg
        final_score = torch.matmul(final_fusion, query)
        
        top_k_indices = torch.argsort(final_score, descending=True)[:10]  # 取前10个
        run[qid] = {corpus_ids[idx]: float(final_score[idx].cpu().numpy()) for idx in top_k_indices}
    for cutoff in [10]:    
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {f"ndcg_cut.{cutoff}", f"recall.{cutoff}"})  
        eval_results = evaluator.evaluate(run)  
        
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            print(f"{measure:25s}{'all':8s}{value:.4f}")  

        
        mrr = eval_mrr(qrels, run, cutoff)['all']  
        print(f'MRR@{cutoff}: {mrr}')  
    

gamma: 0.5


100%|██████████| 718/718 [00:00<00:00, 789.19it/s]


ndcg_cut_10              all     0.6299
recall_10                all     0.7521
MRR@10: 0.5912449705973382
gamma: 0.6


100%|██████████| 718/718 [00:00<00:00, 1053.92it/s]


ndcg_cut_10              all     0.6355
recall_10                all     0.7618
MRR@10: 0.5955873900163591
gamma: 0.7


100%|██████████| 718/718 [00:00<00:00, 1206.77it/s]


ndcg_cut_10              all     0.6355
recall_10                all     0.7563
MRR@10: 0.5974466109563604
gamma: 0.8


100%|██████████| 718/718 [00:00<00:00, 1436.63it/s]


ndcg_cut_10              all     0.6326
recall_10                all     0.7577
MRR@10: 0.593229097581465
gamma: 0.9


100%|██████████| 718/718 [00:00<00:00, 1467.36it/s]


ndcg_cut_10              all     0.6225
recall_10                all     0.7409
MRR@10: 0.5848609453066275
gamma: 1.0


100%|██████████| 718/718 [00:00<00:00, 1467.29it/s]

ndcg_cut_10              all     0.6131
recall_10                all     0.7298
MRR@10: 0.5761296812132461





In [41]:
# MP-DocVQA

import math
import torch
from tqdm import tqdm
import pytrec_eval
import numpy as np
import csv

def load_beir_qrels(qrels_file):  
    qrels = {}  
    try:  
        with open(qrels_file) as f:  
            tsvreader = csv.DictReader(f, delimiter="\t")  
            for row in tsvreader:  
                qid = row["query-id"]  
                pid = row["corpus-id"]  
                rel = int(row["score"])  
                if qid in qrels:  
                    qrels[qid][pid] = rel  
                else:  
                    qrels[qid] = {pid: rel}  
    except Exception as e:  
        print(f"Error loading qrels file: {e}")  
    return qrels 


local_image_embeddings = np.load('embeddings/MP_DocVQA_corpus_embeddings_4x4.npy')
image_embeddings = np.load('embeddings/MP_DocVQA_corpus_embeddings.npy')
query_embeddings = np.load('embeddings/MP_DocVQA_queries_with_instruction_embeddings.npy')
corpus_ids = np.load('embeddings/MP_DocVQA_corpus_corpus_ids.npy')
query_ids = np.load('embeddings/MP_DocVQA_queries_query_ids.npy')
qrels = load_beir_qrels('dataset/VisRAG-Ret-Test-MP-DocVQA/qrels/docvqa_mp-eval-qrels.tsv')

gamma_list = [round(0.1 * i, 1) for i in range(5, 11)]
# gamma_list = [round(0.80 + 0.01 * i, 2) for i in range(21)]

# 将numpy数组转换为PyTorch张量并移动到GPU
local_image_embeddings_tensor = torch.tensor(local_image_embeddings).cuda()
image_embeddings_tensor = torch.tensor(image_embeddings).cuda()
query_embeddings_tensor = torch.tensor(query_embeddings).cuda()

# embedding_dim = local_image_embeddings_tensor.size(-1)
# d_k = math.sqrt(embedding_dim)
# batch_size = local_image_embeddings_tensor.size(0)
# d_k = math.sqrt(batch_size)

for gamma in gamma_list:
    print(f'gamma: {gamma}') 
    run = {} 
    for q_idx, query in enumerate(tqdm(query_embeddings_tensor)):
        qid = query_ids[q_idx]
        
        scores = torch.einsum('ijk,k->ij', local_image_embeddings_tensor, query)
        
        temperature = 35.0
        scaled_scores = scores * temperature
        # scaled_scores = scores / d_k
        alpha = torch.softmax(scaled_scores, dim=1)
        local_agg = torch.einsum('ij,ijk->ik', alpha, local_image_embeddings_tensor)
        

        final_fusion = gamma * image_embeddings_tensor + (1 - gamma) * local_agg
        final_score = torch.matmul(final_fusion, query)
        
        top_k_indices = torch.argsort(final_score, descending=True)[:10]  # 取前10个
        run[qid] = {corpus_ids[idx]: float(final_score[idx].cpu().numpy()) for idx in top_k_indices}
    for cutoff in [10]:    
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {f"ndcg_cut.{cutoff}", f"recall.{cutoff}"})  
        eval_results = evaluator.evaluate(run)  
        
        for measure in sorted(eval_results[next(iter(eval_results))].keys()):  
            value = pytrec_eval.compute_aggregated_measure(  
                measure, [query_measures[measure] for query_measures in eval_results.values()]  
            )  
            print(f"{measure:25s}{'all':8s}{value:.4f}")  

        
        mrr = eval_mrr(qrels, run, cutoff)['all']  
        print(f'MRR@{cutoff}: {mrr}')  
    

gamma: 0.5


100%|██████████| 1879/1879 [00:02<00:00, 813.77it/s]


ndcg_cut_10              all     0.8422
recall_10                all     0.9553
MRR@10: 0.8058864897741955
gamma: 0.6


100%|██████████| 1879/1879 [00:02<00:00, 816.43it/s]


ndcg_cut_10              all     0.8452
recall_10                all     0.9558
MRR@10: 0.8098454513968082
gamma: 0.7


100%|██████████| 1879/1879 [00:02<00:00, 818.09it/s]


ndcg_cut_10              all     0.8410
recall_10                all     0.9516
MRR@10: 0.8055971599212681
gamma: 0.8


100%|██████████| 1879/1879 [00:02<00:00, 799.65it/s]


ndcg_cut_10              all     0.8372
recall_10                all     0.9462
MRR@10: 0.8021665948621773
gamma: 0.9


100%|██████████| 1879/1879 [00:02<00:00, 816.66it/s]


ndcg_cut_10              all     0.8311
recall_10                all     0.9404
MRR@10: 0.7963662704748387
gamma: 1.0


100%|██████████| 1879/1879 [00:02<00:00, 827.18it/s]

ndcg_cut_10              all     0.8133
recall_10                all     0.9292
MRR@10: 0.7764369345396483



