In [5]:
from langchain_chroma import Chroma
from langchain.schema import Document
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
import pandas as pd
import json
from tqdm import tqdm, trange
from retiever_eval_list import get_result_retrieva, get_retriever_res_list
import os
import shutil
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

df = pd.read_excel('soybean_q_gt_609.xlsx', sheet_name='Sheet1')

# 修改
folder_path = 'Result/a_adcance_rag/'
retriever_filename = "Result/a_adcance_rag/retriever_result.json"
save_info_result_filename = "Result/a_adcance_rag/save_info_result.json"
top_k = 20
s_index = 10

model_name = '/mnt/workspace/.cache/modelscope/hub/maple77/zpoint_large_embedding_zh'
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': True}
hf = HuggingFaceBgeEmbeddings(
    model_name=model_name,
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs
)
vectorstore = Chroma(persist_directory="soybean_db2", embedding_function=hf)
retriever = vectorstore.as_retriever(
    search_type="similarity",
    search_kwargs={"k": top_k}
)

tokenizer = AutoTokenizer.from_pretrained('/mnt/workspace/.cache/modelscope/hub/Xorbits/bge-reranker-base')
rerank_model = AutoModelForSequenceClassification.from_pretrained('/mnt/workspace/.cache/modelscope/hub/Xorbits/bge-reranker-base')



def create_folder_if_not_exists(folder_path):
    # 检查文件夹是否存在
    if os.path.exists(folder_path):
        # 如果存在，则删除原文件夹及其中内容
        shutil.rmtree(folder_path)
        print(f"Folder '{folder_path}' existed and has been removed.")
    # 创建新文件夹
    os.makedirs(folder_path)
    print(f"Folder '{folder_path}' has been created.")

# 打印DataFrame的内容
column_lists = {col: df[col].tolist() for col in df.columns}
print(column_lists.keys())

dict_keys(['id', 'source', 'page', 'question', 'ground_truth', 'context'])


In [2]:
retriever_result = []
for tmp_q in tqdm(range(len(column_lists['question'][:s_index])), desc='Get retriever result'):
    # print(tmp_q)
    retriever_result.append(retriever.invoke(column_lists['question'][tmp_q]))

Get retriever result: 100%|██████████| 10/10 [06:22<00:00, 38.24s/it]


In [4]:
question_rerank_result = []
for query_idx in trange(len(column_lists['question'][:s_index]), desc='Rerank result'):
    pairs = []
    for idx in range(len(retriever_result[query_idx])):
        pairs.append([column_lists['question'][:s_index][query_idx], retriever_result[query_idx][idx].page_content])

    with torch.no_grad():
        inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
        scores = rerank_model(**inputs, return_dict=True).logits.view(-1, ).float()
    combined = sorted(zip(scores, retriever_result[query_idx]), reverse=True)
    scores_rerank_list, retri_rerank_list = zip(*combined)
    # print(scores_rerank_list, retri_rerank_list)
    question_rerank_result.append(retri_rerank_list)

Rerank result: 100%|██████████| 10/10 [04:03<00:00, 24.31s/it]


In [6]:
col_id = column_lists['id'][:s_index]
rerank_result = get_result_retrieva(col_id, question_rerank_result, topk=top_k)
# 用你想要的路径替换'your_folder_path'
create_folder_if_not_exists(folder_path)
# 对结果进行保存
# 指定你想要保存的文件名
# 使用json.dump()将字典保存为json文件
with open(retriever_filename, 'w', encoding='utf-8') as f:
    json.dump(rerank_result, f, ensure_ascii=False, indent=4)
print('完成!')

Folder 'Result/a_adcance_rag/' has been created.
完成!


In [7]:
# dict_keys(['id', 'source', 'page', 'question', 'ground_truth', 'context'])
# column_lists['context'][:s_index]
save_info_result = {}
save_info_result['id'] = column_lists['id'][:s_index]
save_info_result['source'] = column_lists['source'][:s_index]
save_info_result['page'] = column_lists['page'][:s_index]
save_info_result['question'] = column_lists['question'][:s_index]
save_info_result['ground_truth'] = column_lists['ground_truth'][:s_index]
save_info_result['context'] = column_lists['context'][:s_index]
save_info_result['retriever_result_list'] = get_retriever_res_list(retriever_result, top_k)

# 使用json.dump()将字典保存为json文件
with open(save_info_result_filename, 'w', encoding='utf-8') as f:
    json.dump(save_info_result, f, ensure_ascii=False, indent=4)

print('完成!')

完成!
