In [1]:
import json

from langchain.text_splitter import RecursiveCharacterTextSplitter

from utils.read_file import *
from utils.model_tool import get_ans, use_reranker
from utils.split_law import split_law_to_token


In [2]:


INSURANCE_PATH = "競賽資料集/reference/insurance"
FAQ_PATH = "競賽資料集/reference/faq/pid_map_content.json"
ANS_PATH = "競賽資料集/dataset/preliminary/ground_truths_example.json"
FINANCE_PATH = "競賽資料集/reference/finance"
FILE_NUM = 2000

QUESTION_PATH = "競賽資料集/dataset/preliminary/questions_example.json"
QUESTION_PATH = "questions_preliminary.json"

In [3]:
from sentence_transformers import SentenceTransformer
from FlagEmbedding import FlagReranker

# 加載模型
embbeded_model = SentenceTransformer("./models/bge-m3", device='cuda')
#reranker = FlagReranker('BAAI/bge-reranker-v2-m3')

  from tqdm.autonotebook import tqdm, trange


In [4]:

def save_predictions_to_json(predictions, filename="pred_retrieve.json"):#Modified
    output = {"answers": predictions}
    with open(filename, "w", encoding="utf-8") as f:
        json.dump(output, f, indent=4, ensure_ascii=False)
    print(f"Predictions saved to {filename}")

def gernerate_faq_ans(embbeded_model, question, i: int):
    faq = read_json(FAQ_PATH)
    document = read_target_faq(faq, question[i]["source"])
    k_highest, _ = get_ans(embbeded_model, question[i], document)
    predict = question[i]["source"][k_highest[0]]
    
    return predict
        
    
    

def gernerate_insurance_ans(embedded_model, question, i: int, already_read: list[str]):

    src = np.array(question[i]["source"])
    document = read_target_insurance_pdf(INSURANCE_PATH, src, already_read)
    tokens = []
    chunk_real_index = []

    for j in range(len(document)):
        chunks = split_law_to_token(document[j])
        tokens.extend(chunks)
        chunk_real_index.extend([src[j] for k in range(len(chunks))])

    chunk_real_index = np.array(chunk_real_index)
    token_index, _ = get_ans(embedded_model, question[i], tokens, 10)
    all_predict = chunk_real_index[token_index]
    #all_predict_text = [tokens[ti] for ti in token_index]
    predict = all_predict[0]

    return predict





In [5]:
import re
import jieba.analyse
def remove_useless(text):

    pattern = r"(\b\d{1,3}(,\d{3})+\b)|(-?\b\d+(\.\d+)?\b)"
    text = re.sub(pattern, '1', text)
    text = str(text).replace(' ', '')
    return text

def find_most_correct(top10: list[str], query: str, all_predict, similarity: np.array):
    keywords = jieba.analyse.extract_tags(query, topK=10)
    score = list(np.zeros(len(top10)))
    for i in range(len(top10)):
        
        for k in keywords:
            if k.isdigit():
                continue
            if re.search(k, top10[i]):
                score[i] += 1
        score[i] *= similarity[i]
    order = np.argsort(score)[::-1]
    return all_predict[order[0]]

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,  # 每段的目標長度
    chunk_overlap=300  # 分段間的重疊字符數
)
        

In [6]:


def gernerate_finance_ans(embbeded_model, question, i: int, already_read: list[str]):

    query = question[i]["query"]


    src = np.array(question[i]["source"])
    document = read_target_insurance_pdf(FINANCE_PATH, src, already_read)
    tokens = []
    chunk_real_index = []


    for j in range(len(document)):
        text = remove_useless(document[j])
        valid_chunks = text_splitter.split_text(text)
        tokens.extend(valid_chunks)
        chunk_real_index.extend([src[j] for k in range(len(valid_chunks))])

    
    chunk_real_index = np.array(chunk_real_index)

    token_index, similarities = get_ans(embbeded_model, question[i], tokens, 10)
    all_predict = chunk_real_index[token_index]
    all_predict_text = [tokens[p] for p in token_index]
    predict = find_most_correct(all_predict_text, query, all_predict, similarities)

    return predict


In [7]:

already_read_insurance = ["" for _ in range(FILE_NUM)]
already_read_finance = ["" for _ in range(FILE_NUM)]

question = read_json(QUESTION_PATH)["questions"]
results = []

In [10]:
len(results)

553

In [16]:

for i in range(553, len(question)):
    category = question[i]['category']
    predict = 0
    if category == "insurance":
        predict = gernerate_insurance_ans(embbeded_model, question, i, already_read_insurance)
    elif category == "finance":
        predict = gernerate_finance_ans(embbeded_model, question, i, already_read_finance)
    elif category == "faq":
        predict = gernerate_faq_ans(embbeded_model, question, i)
    results.append({"qid": i + 1, "retrieve": int(predict)})

KeyboardInterrupt: 

: 

In [14]:
len(results)

print(results)

[{'qid': 1, 'retrieve': 388}, {'qid': 2, 'retrieve': 600}, {'qid': 3, 'retrieve': 600}, {'qid': 4, 'retrieve': 81}, {'qid': 5, 'retrieve': 30}, {'qid': 6, 'retrieve': 30}, {'qid': 7, 'retrieve': 30}, {'qid': 8, 'retrieve': 641}, {'qid': 9, 'retrieve': 419}, {'qid': 10, 'retrieve': 271}, {'qid': 11, 'retrieve': 259}, {'qid': 12, 'retrieve': 290}, {'qid': 13, 'retrieve': 490}, {'qid': 14, 'retrieve': 251}, {'qid': 15, 'retrieve': 560}, {'qid': 16, 'retrieve': 512}, {'qid': 17, 'retrieve': 328}, {'qid': 18, 'retrieve': 403}, {'qid': 19, 'retrieve': 378}, {'qid': 20, 'retrieve': 271}, {'qid': 21, 'retrieve': 523}, {'qid': 22, 'retrieve': 264}, {'qid': 23, 'retrieve': 42}, {'qid': 24, 'retrieve': 507}, {'qid': 25, 'retrieve': 328}, {'qid': 26, 'retrieve': 250}, {'qid': 27, 'retrieve': 192}, {'qid': 28, 'retrieve': 600}, {'qid': 29, 'retrieve': 380}, {'qid': 30, 'retrieve': 96}, {'qid': 31, 'retrieve': 567}, {'qid': 32, 'retrieve': 118}, {'qid': 33, 'retrieve': 523}, {'qid': 34, 'retrieve': 

In [15]:
save_predictions_to_json(results, "new.json")

Predictions saved to new.json


In [13]:
#check

# ans = read_json(ANS_PATH)['ground_truths']
# pred_retrieve = read_json('pred_retrieve.json')['answers']
# correct = 0
# total = 0
# for i in range(len(ans)):
#     if ans[i]['retrieve'] != pred_retrieve[i]['retrieve']:
#         print(f"qid: {i + 1}, predict: {pred_retrieve[i]['retrieve']}, ans: {ans[i]['retrieve']}")
#     else:
#         correct += 1
#     total += 1
# print(correct / total)

qid: 2, predict: 258, ans: 428
qid: 4, predict: 179, ans: 186
qid: 50, predict: 555, ans: 78
qid: 59, predict: 639, ans: 632
qid: 64, predict: 706, ans: 124
qid: 72, predict: 497, ans: 204
qid: 80, predict: 386, ans: 213
qid: 86, predict: 757, ans: 189
qid: 94, predict: 481, ans: 699
qid: 97, predict: 579, ans: 282
qid: 109, predict: 234, ans: 283
qid: 135, predict: 399, ans: 28
0.92
