In [1]:
import warnings
warnings.filterwarnings('ignore')

## 1. 평가를 위한 데이터셋

In [2]:
%%capture
! pip install datasets

In [3]:
from datasets import load_dataset

klue_mrc_test = load_dataset('klue', 'mrc', split='validation')
klue_mrc_test = klue_mrc_test.train_test_split(test_size=1000, seed=42)['test']

README.md:   0%|          | 0.00/22.5k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/8.68M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/17554 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5841 [00:00<?, ? examples/s]

## 2. 여러 함수들

In [4]:
%%capture
! pip install faiss-cpu

In [5]:
import faiss

def make_embedding_index(sentence_model, corpus):
  embeddings = sentence_model.encode(corpus)
  index = faiss.IndexFlatL2(embeddings.shape[1])
  index.add(embeddings)
  return index

def find_embedding_top_k(query, sentence_model, index, k):
  embedding = sentence_model.encode([query])
  distances, indices = index.search(embedding, k)
  return indices

In [6]:
import numpy as np

def make_question_context_pairs(question_idx, indices):
  return [[klue_mrc_test['question'][question_idx], klue_mrc_test['context'][idx]] for idx in indices]

def rerank_top_k(cross_model, question_idx, indices, k):
  input_examples = make_question_context_pairs(question_idx, indices)
  relevance_scores = cross_model.predict(input_examples)

  top_k_indices = np.argsort(relevance_scores)[::-1][:k]
  reranked_indices = indices[top_k_indices]
  return reranked_indices

In [7]:
import time

def evaluate_hit_rate(datasets, embedding_model, index, k=10):
  start_time = time.time()

  questions = datasets['question']
  contexts = datasets['context']

  predictions = []
  for question in questions:
    predictions.append(find_embedding_top_k(question, embedding_model, index, k)[0])

  hit_count = 0
  for idx, prediction in enumerate(predictions):
    for pred in prediction:
      if contexts[pred] == contexts[idx]:
        hit_count += 1
        break

  end_time = time.time()
  total_prediction_count = len(predictions)
  return (hit_count / total_prediction_count), (end_time - start_time)

## 3. Case 1: Embedding Model

In [8]:
from sentence_transformers import SentenceTransformer

base_embedding_model = SentenceTransformer('whatwant/klue-roberta-base-klue-sts')
base_index = make_embedding_index(base_embedding_model, klue_mrc_test['context'])

evaluate_hit_rate(klue_mrc_test, base_embedding_model, base_index, 10)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/205 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/14.8k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/744 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/442M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.31k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/248k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/752k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/971 [00:00<?, ?B/s]

1_Pooling%2Fconfig.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

(0.857, 11.544653415679932)

## 4. Case 2: Fine-Tuned Embedding Model

In [9]:
finetuned_embedding_model = SentenceTransformer('whatwant/klue-roberta-base-klue-sts-mrc')
finetuned_index = make_embedding_index(finetuned_embedding_model, klue_mrc_test['context'])

evaluate_hit_rate(klue_mrc_test, finetuned_embedding_model, finetuned_index, 10)

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/205 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/58.4k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/762 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/442M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.50k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/248k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/752k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/971 [00:00<?, ?B/s]

1_Pooling%2Fconfig.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

(0.936, 11.472103595733643)

## 5. Case 3: Hybrid

In [10]:
from tqdm.auto import tqdm

def evaluate_hit_rate_with_rerank(datasets, embedding_model, cross_model, index, bi_k=30, cross_k=10):
  start_time = time.time()

  questions = datasets['question']
  contexts = datasets['context']

  predictions = []
  for question_idx, question in enumerate(questions):
    indices = find_embedding_top_k(question, embedding_model, index, bi_k)[0]
    predictions.append(rerank_top_k(cross_model, question_idx, indices, k=cross_k))

  hit_count = 0
  for idx, prediction in enumerate(predictions):
    for pred in prediction:
      if contexts[pred] == contexts[idx]:
        hit_count += 1
        break

  end_time = time.time()
  total_prediction_count = len(predictions)
  return (hit_count / total_prediction_count), (end_time - start_time), predictions

In [11]:
from sentence_transformers.cross_encoder import CrossEncoder

cross_model = CrossEncoder('whatwant/klue-roberta-small-cross-encoder')

config.json:   0%|          | 0.00/842 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/272M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.31k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/248k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/752k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/971 [00:00<?, ?B/s]

In [12]:
hit_rate, consumed_time, predictions = evaluate_hit_rate_with_rerank(
    klue_mrc_test, finetuned_embedding_model, cross_model, finetuned_index, bi_k=30, cross_k=10
)

hit_rate, consumed_time

(0.967, 970.0450859069824)