In [200]:
# import
import re
import json
import numpy as np
from tqdm import tqdm
from pprint import pprint
import pickle

import torch
import torch.nn.functional as F
from datasets import load_from_disk
from transformers import (
    AutoTokenizer, TrainingArguments,
    BertModel, BertPreTrainedModel
)

In [51]:
class BertEncoder(BertPreTrainedModel):

    def __init__(self,
        config
    ):
        super(BertEncoder, self).__init__(config)

        self.bert = BertModel(config)
        self.init_weights()


    def forward(self,
            input_ids,
            attention_mask=None,
            token_type_ids=None
        ):

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )

        pooled_output = outputs[1]
        return pooled_output

In [44]:
# 0 ~ 1.0 사이 값으로 scaling
def min_max_scaling(scores):
    min_score = torch.min(scores)
    max_score = torch.max(scores)
    scaled_scores = (scores - min_score) / (max_score - min_score)
    return scaled_scores

In [None]:
# top-k 확인
def get_relevant_doc(
        tokenizer,
        query,
        passage,
        ground_truth,
        p_name,
        q_name,
        topk=1,
        args=None
    ): # 유사도 높은 context 및 ground-truth context와의 유사도 추출
    
    with open("../code/retrieval/dense_encoder/" + p_name,  "rb") as f:
        p_encoder = pickle.load(f)
    with open("../code/retrieval/dense_encoder/" + q_name, "rb") as f:
        q_encoder = pickle.load(f)

    with torch.no_grad():
        p_encoder.eval()
        q_encoder.eval()
        q_seps = tokenizer([query], padding="max_length", truncation=True, return_tensors="pt").to("cuda")
        q_emb = q_encoder(**q_seps).to("cpu") # (num_query, emb_dim)

        ground_truth = tokenizer(ground_truth, padding="max_length", truncation=True, return_tensors="pt").to("cuda")
        g_emb = p_encoder(**ground_truth).to("cpu")

        p_embs = []
        for p in passage:
            p = tokenizer(p, padding="max_length", truncation=True, return_tensors="pt").to("cuda")
            p_emb = p_encoder(**p).to("cpu").numpy()
            p_embs.append(p_emb)
    
    p_embs = torch.Tensor(p_embs).squeeze()  # (num_passage, emb_dim)
    dot_prod_scores = torch.matmul(q_emb, torch.transpose(p_embs, 0, 1)) # question과의 유사도
    rank = torch.argsort(dot_prod_scores, dim=1, descending=True).squeeze()

    # ground_truth와 passages의 유사도
    p_embs = F.normalize(p_embs, dim=-1)
    g_emb = F.normalize(g_emb, dim=-1)
    dot_prod_scores = torch.matmul(
        g_emb, torch.transpose(p_embs, 0, 1)).squeeze() # 각 문서간의 유사도
    # softmax = F.softmax(dot_prod_scores, dim=1).squeeze() # (num_passage,) → ground-truth 문서에 대한 전체 유사도
    dot_prod_scores = min_max_scaling(dot_prod_scores)

    return dot_prod_scores, rank[:topk]

In [187]:
# document search
def search_doc(tokenizer, query, contexts, p_name, q_name, topk): # query가 여럿일 경우 순서대로 검색
    context_score = []

    # for q in query
    for i, q in tqdm(enumerate(query), total=len(query), desc="Searching", unit="query"):
        score, indices = get_relevant_doc(
            tokenizer=tokenizer, query=q, passage=contexts, ground_truth=contexts[i], p_name=p_name, q_name=q_name, topk=topk)

        # print(f"[Search Query] {q}")
        # print(f'score : {scores}')

        arr = [[contexts[idx], score[idx]] for idx in indices] 
        context_score.append(arr) # [context, score]

    return context_score # (num_query, k, 2)  query마다 top-k개의 context와 score로 구성된 리스트 반환

        # top-k 문서 및 유사도 확인
        # for rank, idx in enumerate(indices):
        #     print(f"Top-{rank + 1}th Passage (Index {idx})")
        #     # pprint(retriever.passage['context'][idx])
        #     print(f"유사도 : {score[idx]}")
        #     pprint(contexts[idx])

In [188]:
# wikipedia 불러오기
# with open("../data/wikipedia_documents.json", "r", encoding="utf-8") as f:
#     wiki = json.load(f)
# contexts = list(dict.fromkeys([
#             value["text"] for value in wiki.values()
#         ]))

model_checkpoint = "klue/bert-base"
p_name = "p_korquad_1.bin"
q_name = "q_korquad_1.bin"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

dataset = load_from_disk("../data/train_dataset")
val_dataset = dataset["validation"]["question"]
contexts = dataset["validation"]["context"]

# query와 비교하여 문서 찾기
query = val_dataset

context_score = search_doc(tokenizer, query, contexts, p_name, q_name)

Searching: 100%|██████████| 240/240 [16:24<00:00,  4.10s/query]


In [191]:
context_score = np.array(context_score)

In [194]:
def top_k_see(k):
    correct_ids = []
    wrong_ids = []

    for con_sco in context_score:
        for idx in range(k):
            context, score = con_sco[idx]
            score = re.sub(
                "[A-Za-z()]",
                "",
                score
                )
            if float(score) >= 0.99:
                correct_ids.append(context)
                break
        else:
            wrong_ids.append(context)

    print(f"전체 개수: {len(context_score)}, 정답 개수: {len(correct_ids)}, 오답 개수: {len(wrong_ids)}, 정답률: {len(correct_ids)/len(context_score):.4%}")

In [198]:
top_k_see(40)

전체 개수: 240, 정답 개수: 237, 오답 개수: 3, 정답률: 98.7500%


train_dataset
epoch: 3, lr: 5e-5, batch: 8, weight_decay: 0.01, top-k: 20
전체 개수: 240, 정답 개수: 221, 오답 개수: 19, 정답률: 92.0833%

train_dataset + korquad v1.0 (context 당 question 랜덤 하나만 추출)
epoch: 2, lr: 4e-5, batch: 5, weight_decay: 0.01, top-k: 5
전체 개수: 240, 정답 개수: 211, 오답 개수: 29, 정답률: 87.9167%

epoch: 2, lr: 4e-5, batch: 5, weight_decay: 0.01, top-k: 20
전체 개수: 240, 정답 개수: 231, 오답 개수: 9, 정답률: 96.2500%

epoch: 2, lr: 4e-5, batch: 5, weight_decay: 0.01, top-k: 40
전체 개수: 240, 정답 개수: 236, 오답 개수: 4, 정답률: 98.3333%
