In [1]:
from utils.read_file import *

In [2]:
QUESTION_PATH = "競賽資料集/dataset/preliminary/questions_example.json"

INSURANCE_PATH = "競賽資料集/reference/insurance"
FAQ_PATH = "競賽資料集/reference/faq/pid_map_content.json"
ANS_PATH = "競賽資料集/dataset/preliminary/ground_truths_example.json"

In [3]:
question = read_json(QUESTION_PATH)["questions"]

In [4]:
# from huggingface_hub import snapshot_download

# # 下載模型到指定路徑
# local_model_path = "./models/bge-m3"
# snapshot_download(repo_id="BAAI/bge-m3", revision="main", local_dir=local_model_path)

In [5]:
from sentence_transformers import SentenceTransformer

# 加載模型
embbeded_model = SentenceTransformer("./models/bge-m3", device='cuda')

  from tqdm.autonotebook import tqdm, trange


In [31]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

def get_ans(embedded_model, question_data: dict, documents: list, k=1):
    document_ebeddings = embedded_model.encode(documents)

    # 使用者查詢
    user_query = question_data["query"]

    # 查詢文本轉換成嵌入向量
    query_embedding = embbeded_model.encode([user_query])

    # 計算相似度
    similarities = cosine_similarity(query_embedding, document_ebeddings)
    k_highest = np.argsort(similarities[0])[-k:][::-1]
    return k_highest, np.sort(similarities[0])[::-1][:k]

In [7]:



def validate_faq(embbeded_model, question):
    faq = read_json(FAQ_PATH)
    ans = read_json(ANS_PATH)
    correct = 0
    for i in range(100, 150):
        document = read_target_faq(faq, question[i]["source"])
        k_highest, _ = get_ans(embbeded_model, question[i], document)
        predict = question[i]["source"][k_highest[0]]
        if (predict == ans["ground_truths"][i]["retrieve"]):
            correct += 1
        else:
            print(f"qid: {i + 1}, predict: {predict}, ans: {ans["ground_truths"][i]["retrieve"]}")
    return correct / 50

### 驗證faq資料集的正確率

In [8]:
validate_faq(embbeded_model, question)

In [34]:
from langchain.text_splitter import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=800,  # 每段的目標長度
    chunk_overlap=400  # 分段間的重疊字符數
)


def vote_answer(all_predict: np.ndarray, similarities: np.ndarray, k=3)->int:
    predict_similarities_total = {}

    for i in range(min(len(all_predict), k)):
        if all_predict[i] in predict_similarities_total:
            predict_similarities_total[all_predict[i]] += similarities[i] ** 2
        else:
            predict_similarities_total[all_predict[i]] = similarities[i] ** 2
    
    max_predict = 0
    max_simlarity = 0

    for predict, similarity in predict_similarities_total.items():
        if similarity > max_simlarity:
            max_predict = predict
            max_simlarity = similarity
    
    return max_predict

    

def validate_insurance(embbeded_model, question):
    ans = read_json(ANS_PATH)
    correct = 0
    half_correct = 0 #答案在前五

    for i in range(0, 50):
        src = question[i]["source"]
        document, order_list = read_target_insurance_pdf(INSURANCE_PATH, src)
        tokens = []
        chunk_real_index = []
        for j in range(len(document)):
            text = document[j]
            chunks = text_splitter.split_text(text)
            tokens.extend(chunks)
            chunk_real_index.extend([j + k * 0 for k in range(len(chunks))])

                
        chunk_real_index=np.array(chunk_real_index)


        token_index, similarities = get_ans(embbeded_model, question[i], tokens, 5)
        real_index = chunk_real_index[token_index]
        all_predict = order_list[real_index]
        predict = all_predict[0]


        if (predict == ans["ground_truths"][i]["retrieve"]):
            correct += 1
        else:
            print(f"qid: {i + 1}, predict: {predict}, ans: {ans["ground_truths"][i]["retrieve"]}")
        if (ans["ground_truths"][i]["retrieve"] in all_predict):
            half_correct += 1

    print(f"acc: {correct / 50 * 100} %")
    print(f"in rank 5: {half_correct / 50 * 100} %")

### 驗證insurance資料集的正確率

In [35]:
validate_insurance(embbeded_model, question)

qid: 2, predict: 258, ans: 428
qid: 3, predict: 476, ans: 83
qid: 4, predict: 179, ans: 186
qid: 6, predict: 242, ans: 116
qid: 15, predict: 526, ans: 536
qid: 23, predict: 228, ans: 224
qid: 28, predict: 78, ans: 298
qid: 29, predict: 256, ans: 578
qid: 35, predict: 224, ans: 319
qid: 42, predict: 445, ans: 325
qid: 49, predict: 357, ans: 482
qid: 50, predict: 555, ans: 78
acc: 76.0 %
in rank 5: 90.0 %
