In [1]:
from langchain.document_loaders import PyPDFLoader
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
import tiktoken
token_count = 0
# eb99860b-0909-4446-8ee6-85066f75bfc8
documents = PyPDFLoader("soybean_konw.pdf").load()
for i in range(len(documents)):
    encoder = tiktoken.get_encoding("gpt2") # 或其他适用的编码器
    tokens = encoder.encode(documents[0].page_content)
    token_count += len(tokens)
print(token_count)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=200)
texts = text_splitter.split_documents(documents[:])
for idx, text in enumerate(texts):
    text.metadata["id"] = idx

ValueError: File path soybean_konw.pdf is not a valid file or url

In [3]:
from langchain_chroma import Chroma
from langchain.schema import Document
from langchain_community.embeddings import HuggingFaceBgeEmbeddings

top_k = 20

model_name = '/mnt/workspace/.cache/modelscope/hub/maple77/zpoint_large_embedding_zh'
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': True}
hf = HuggingFaceBgeEmbeddings(
    model_name=model_name,
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs
)
vectorstore = Chroma(persist_directory="soybean_db", embedding_function=hf)
retriever = vectorstore.as_retriever(
    search_type="similarity",
    search_kwargs={"k": top_k}
)

  from tqdm.autonotebook import tqdm, trange


In [98]:
import pandas as pd

df = pd.read_excel('../learn_rag/soybean_question_500.xlsx', sheet_name='Sheet1')

# 打印DataFrame的内容
column_lists = {col: df[col].tolist() for col in df.columns}
print(column_lists.keys())

dict_keys(['id', 'question', 'context'])


In [99]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('/mnt/workspace/.cache/modelscope/hub/Xorbits/bge-reranker-base')
rerank_model = AutoModelForSequenceClassification.from_pretrained('/mnt/workspace/.cache/modelscope/hub/Xorbits/bge-reranker-base')

In [100]:
retriever_result = retriever.batch(column_lists['question'][:2])

In [101]:
import torch
question_rerank_result = []
for query_idx in range(len(column_lists['question'][:2])):
    pairs = []
    for idx in range(len(retriever_result[query_idx])):
        pairs.append([column_lists['question'][:2][query_idx], retriever_result[query_idx][idx].page_content])

    with torch.no_grad():
        inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
        scores = rerank_model(**inputs, return_dict=True).logits.view(-1, ).float()
    combined = sorted(zip(scores, retriever_result[query_idx]), reverse=True)
    scores_rerank_list, retri_rerank_list = zip(*combined)
    # print(scores_rerank_list, retri_rerank_list)
    question_rerank_result.append(retri_rerank_list)

In [None]:
# retriever_result
# question_rerank_result[0]

In [105]:
from retiever_eval_list import get_ht_score, get_mmr_score

def get_retriever_res_list(document_list, top_k = 20):
    id_tmp = []
    for n in range(len(document_list)):
        result = []
        for doc in document_list[n][:top_k]:
            result.append(doc.metadata['id'])
        id_tmp.append(result)
    return id_tmp

print(get_retriever_res_list(question_rerank_result))
print(get_retriever_res_list(retriever_result))

[[0, 0, 66, 106, 5, 57, 94, 19, 19, 83, 34, 30, 14, 14, 78, 115, 97, 110, 91, 114], [105, 13, 13, 93, 115, 0, 0, 73, 4, 4, 97, 82, 65, 110, 109, 63, 101, 74, 75, 94]]
[[115, 0, 0, 94, 57, 83, 78, 106, 34, 19, 19, 14, 14, 66, 30, 91, 97, 114, 110, 5], [115, 73, 65, 63, 4, 4, 82, 105, 94, 97, 74, 75, 110, 101, 0, 0, 93, 109, 13, 13]]


In [104]:
from retiever_eval_list import get_ht_score, get_mmr_score
retiever_evaldict = {}
for i in range(8):
    print('top_', i+1)
    res_retriever_list = get_retriever_res_list(retriever_result, i+1)
    res_rerank_list = get_retriever_res_list(question_rerank_result, i+1)
    ht_score = get_ht_score(column_lists['id'][:2], res_retriever_list)
    mmr_score = get_mmr_score(column_lists['id'][:2], res_retriever_list)
    print('ht:', round(ht_score, 3))
    print('mmr_score:', round(mmr_score, 3))
    print('rerank')
    ht_score = get_ht_score(column_lists['id'][:2], res_rerank_list)
    mmr_score = get_mmr_score(column_lists['id'][:2], res_rerank_list)
    print('ht:', round(ht_score, 3))
    print('mmr_score:', round(mmr_score, 3))

top_ 1
ht: 0.0
mmr_score: 0.0
rerank
ht: 0.5
mmr_score: 0.5
top_ 2
ht: 0.5
mmr_score: 0.25
rerank
ht: 0.5
mmr_score: 0.5
top_ 3
ht: 0.5
mmr_score: 0.25
rerank
ht: 0.5
mmr_score: 0.5
top_ 4
ht: 0.5
mmr_score: 0.25
rerank
ht: 0.5
mmr_score: 0.5
top_ 5
ht: 0.5
mmr_score: 0.25
rerank
ht: 0.5
mmr_score: 0.5
top_ 6
ht: 0.5
mmr_score: 0.25
rerank
ht: 0.5
mmr_score: 0.5
top_ 7
ht: 0.5
mmr_score: 0.25
rerank
ht: 0.5
mmr_score: 0.5
top_ 8
ht: 0.5
mmr_score: 0.25
rerank
ht: 0.5
mmr_score: 0.5
