In [1]:
from utils.read_file import read_target_insurance_pdf, read_json

In [8]:
QUESTION_PATH = "競賽資料集/dataset/preliminary/questions_example.json"

INSURANCE_PATH = "競賽資料集/reference/insurance"
FAQ_PATH = "競賽資料集/reference/faq/pid_map_content.json"
ANS_PATH = "競賽資料集/dataset/preliminary/ground_truths_example.json"

In [3]:
question = read_json(QUESTION_PATH)["questions"]

In [4]:
from sentence_transformers import SentenceTransformer

# 加載模型
embbeded_model = SentenceTransformer('thenlper/gte-large-zh')

  from tqdm.autonotebook import tqdm, trange


In [5]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

def get_ans(embedded_model, question_data: dict, documents: list):
    document_ebeddings = embedded_model.encode(documents)

    # 使用者查詢
    user_query = question_data["query"]

    # 查詢文本轉換成嵌入向量
    query_embedding = embbeded_model.encode([user_query])

    # 計算相似度
    similarities = cosine_similarity(query_embedding, document_ebeddings)
    
    return question_data["source"][np.argmax(similarities)]

In [14]:

def generate_faq_retrieve(faq: dict, id: int) -> str:

    text = ""

    for faq_data in faq[str(id)]:

        question = faq_data["question"]
        answers = faq_data["answers"]
        text += f"問題:{question}, 回答:{answers}\n"
    return text

def read_target_faq(faq: dict, indexs: list) -> list:

    all_texts = []

    for i in indexs:
        all_texts.append(generate_faq_retrieve(faq, i))
    
    return all_texts


def validate_faq(embbeded_model, question):
    faq = read_json(FAQ_PATH)
    ans = read_json(ANS_PATH)
    correct = 0
    for i in range(100, 150):
        document = read_target_faq(faq, question[i]["source"])
        predict = get_ans(embbeded_model, question[i], document)
        if (predict == ans["ground_truths"][i]["retrieve"]):
            correct += 1
        else:
            print(f"qid: {i + 1}, predict: {predict}, ans: {ans["ground_truths"][i]["retrieve"]}")
    return correct / 50


### 驗證faq資料集的正確率

In [17]:
validate_faq(embbeded_model, question)

qid: 109, predict: 263, ans: 283
qid: 135, predict: 399, ans: 28


0.96

In [18]:
def validate_insurance(embbeded_model, question):
    ans = read_json(ANS_PATH)
    correct = 0
    for i in range(0, 50):
        src = question[i]["source"]
        document = read_target_insurance_pdf(INSURANCE_PATH, src)
        predict = get_ans(embbeded_model, question[i], document)
        if (predict == ans["ground_truths"][i]["retrieve"]):
            correct += 1
        else:
            print(f"qid: {i + 1}, predict: {predict}, ans: {ans["ground_truths"][i]["retrieve"]}")
    return correct / 50

### 驗證insurance資料集的正確率

In [19]:
validate_insurance(embbeded_model, question)

qid: 1, predict: 51, ans: 392
qid: 2, predict: 578, ans: 428
qid: 5, predict: 179, ans: 162
qid: 6, predict: 338, ans: 116
qid: 8, predict: 244, ans: 78
qid: 9, predict: 148, ans: 62
qid: 11, predict: 357, ans: 7
qid: 14, predict: 149, ans: 526
qid: 15, predict: 337, ans: 536
qid: 16, predict: 325, ans: 54
qid: 17, predict: 353, ans: 606
qid: 18, predict: 439, ans: 184
qid: 19, predict: 172, ans: 315
qid: 20, predict: 578, ans: 292
qid: 21, predict: 392, ans: 36
qid: 22, predict: 439, ans: 614
qid: 24, predict: 625, ans: 359
qid: 25, predict: 591, ans: 4
qid: 26, predict: 32, ans: 147
qid: 27, predict: 414, ans: 171
qid: 28, predict: 258, ans: 298
qid: 29, predict: 527, ans: 578
qid: 30, predict: 483, ans: 327
qid: 31, predict: 486, ans: 10
qid: 32, predict: 351, ans: 434
qid: 36, predict: 63, ans: 148
qid: 37, predict: 414, ans: 148
qid: 38, predict: 256, ans: 353
qid: 39, predict: 578, ans: 29
qid: 41, predict: 14, ans: 165
qid: 42, predict: 365, ans: 325
qid: 43, predict: 322, ans: 

0.26