In [2]:
# !pip install numpy pandas tqdm nltk beir

Collecting nltk
  Using cached nltk-3.9.1-py3-none-any.whl.metadata (2.9 kB)
Collecting beir
  Using cached beir-2.0.0-py3-none-any.whl
Collecting click (from nltk)
  Using cached click-8.1.8-py3-none-any.whl.metadata (2.3 kB)
Collecting joblib (from nltk)
  Using cached joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB)
Collecting regex>=2021.8.3 (from nltk)
  Using cached regex-2024.11.6-cp312-cp312-win_amd64.whl.metadata (41 kB)
Collecting sentence-transformers (from beir)
  Using cached sentence_transformers-3.4.1-py3-none-any.whl.metadata (10 kB)
Collecting pytrec_eval (from beir)
  Using cached pytrec_eval-0.5.tar.gz (15 kB)
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'
Collecting faiss_cpu (from beir)
 

In [None]:
from collections import OrderedDict
import pathlib
import numpy as np
import pandas as pd
from tqdm import tqdm

from utils.preprocessor import Stopwords_preprocessor
from utils.logging import beir_metrics_to_markdown_table
from IPython.display import Markdown

from sklearn.metrics.pairwise import pairwise_distances
from sklearn.feature_extraction.text import TfidfVectorizer

# from rank_bm25 import BM25Okapi as BM25
from transformers import logging, AutoTokenizer
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval

import torch
from torch import nn

In [16]:
device = 'cuda:pick_a_device' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [54]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased')

#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

# mask_input = tokenizer('[MASK]', return_tensors='pt', add_special_tokens=False)
# mask_emb = model.bert.embeddings.word_embeddings(mask_input['input_ids'])
# mask_query_output = model.bert.encoder.layer[0].attention.self.query(mask_emb)
# model.bert.encoder.layer[0].attention.self.query.bias.data = mask_query_output[0, 0].detach()
# model.bert.encoder.layer[0].attention.self.query.weight.data = torch.zeros_like(model.bert.encoder.layer[0].attention.self.query.weight.data)

#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

# model = model.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [55]:
corpus_name = 'scifact'
corpus_name = 'trec-covid'
# corpus_name = 'nfcorpus'

corpus, queries, qrels = GenericDataLoader(f'data/{corpus_name}').load(split="test")
ordered_queries = OrderedDict(queries)
ordered_corpus = OrderedDict(corpus)

texts = [article['text'] for article in ordered_corpus.values()]

  0%|          | 0/171332 [00:00<?, ?it/s]

In [56]:
# def tokenize(x):
#     return tokenizer.convert_ids_to_tokens(tokenizer.encode(x, add_special_tokens=False))

# vectorizer = TfidfVectorizer(tokenizer=tokenize, vocabulary=tokenizer.vocab)
# %time vectorizer.fit(texts)

In [57]:
class Get_bert_raw_output:
  def __init__(self, device):
    self.device = device
    
  def __call__(self, **inputs):
    return model.bert(**{k: v.to(self.device) for k, v in inputs.items()}).last_hidden_state.detach().cpu()

# def get_bert_head_embedding(**inputs):
#   return model.cls.predictions.transform(model.bert(**inputs).last_hidden_state)

class Get_by_input_embedding:
  def __call__(self, **inputs):
    return model.bert.embeddings.word_embeddings(inputs['input_ids']).detach()

class Get_by_input_embedding_and_transform:
  def __call__(self, **inputs):
    with torch.no_grad():
      embs = model.bert.embeddings.word_embeddings(inputs['input_ids'])
    return model.cls.predictions.transform(embs).detach()

In [None]:
class Get_bert_last:
  def __init__(self, device):
    self.device = device
    
  def __call__(self, **inputs):
    return model.bert(**{k: v.to(self.device) for k, v in inputs.items()}).last_hidden_state.detach().cpu()

In [58]:
model_fn = Get_bert_raw_output(device)
# model_fn = Get_bert_head_embedding
# model_fn = Get_by_input_embedding()
model_fn = Get_by_input_embedding_and_transform()

def get_sentence_vectors(text_iterable, inputs2vecs_fn, aggregating_method='mean', batch_size=24, drop_special_token_vec=True):
  
  text_vectors = []
  
  iterator = iter(tqdm(text_iterable))

  stop = False
  while True:
    batch = []
    try:
      for i in range(batch_size):
        batch.append(next(iterator))
    except StopIteration:
      stop = True
      pass
  
    inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True, return_special_tokens_mask=drop_special_token_vec)
    if drop_special_token_vec:
      special_tokens_mask = inputs.pop('special_tokens_mask')
    output = inputs2vecs_fn(**inputs)
    mask = inputs['attention_mask']
    if drop_special_token_vec:
      mask = mask * (1-special_tokens_mask)
    result = (output*mask[:,:,None])
    result = {
      'sum': lambda res: torch.sum(res, dim=1),
      'mean': lambda res: torch.mean(res, dim=1)
    }[aggregating_method](result)
    text_vectors.append(result)
    
    if stop:
      break
  return torch.concat(text_vectors)

t_vecs = get_sentence_vectors(texts, model_fn).detach()
q_vecs = get_sentence_vectors(queries.values(), model_fn).detach()
# torch.save(total, 'bert_raw_output.pt')


  0%|          | 0/171332 [00:00<?, ?it/s][A
  0%|          | 48/171332 [00:00<06:03, 471.10it/s][A
  0%|          | 96/171332 [00:00<06:01, 474.21it/s][A
  0%|          | 168/171332 [00:00<05:35, 509.46it/s][A
  0%|          | 219/171332 [00:00<05:48, 491.21it/s][A
  0%|          | 268/171332 [00:00<06:03, 470.21it/s][A
  0%|          | 315/171332 [00:00<06:16, 453.82it/s][A
  0%|          | 361/171332 [00:00<06:23, 445.71it/s][A
  0%|          | 408/171332 [00:00<06:27, 440.64it/s][A
  0%|          | 456/171332 [00:01<06:36, 431.35it/s][A
  0%|          | 504/171332 [00:01<06:37, 429.43it/s][A
  0%|          | 552/171332 [00:01<06:33, 434.45it/s][A
  0%|          | 600/171332 [00:01<06:28, 439.07it/s][A
  0%|          | 648/171332 [00:01<06:30, 436.57it/s][A
  0%|          | 696/171332 [00:01<06:24, 443.86it/s][A
  0%|          | 744/171332 [00:01<06:24, 444.19it/s][A
  0%|          | 792/171332 [00:01<06:24, 443.00it/s][A
  1%|          | 864/171332 [00:01<06:06, 4

In [8]:
# def get_sentence_vectors(text_iterable, aggregating_method='mean', batch_size=24):
  
#   text_vectors = []
  
#   iterator = iter(tqdm(text_iterable))

#   stop = False
#   while True:
#     batch = []
#     try:
#       for i in range(batch_size):
#         batch.append(next(iterator))
#     except StopIteration:
#       stop = True
#       pass
  
#     inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True)
#     output = model.bert(**{k: v.to(device) for k, v in inputs.items()}).last_hidden_state.detach().cpu()
#     result = (output*inputs['attention_mask'][:,:,None])
#     result = {
#       'sum': lambda res: torch.sum(res, dim=1),
#       'mean': lambda res: torch.mean(res, dim=1)
#     }[aggregating_method](result)
#     text_vectors.append(result)
    
#     if stop:
#       break
#   return torch.concat(text_vectors)

# t_vecs = get_sentence_vectors(texts)
# q_vecs = get_sentence_vectors(queries.values())
# # torch.save(total, 'bert_raw_output.pt')

In [9]:
# torch.save(t_vecs, 'bert_mlm_mean_without_special_tokens.pt')

In [10]:
# def mean_vector(text):
#     ids = tokenizer.encode(text, add_special_tokens=False)
#     if len(ids) == 0:
#         return np.zeros(word_reprs.shape[1])
#     return word_reprs[ids].mean(axis=0)

# def sum_vector(text):
#     ids = tokenizer.encode(text, add_special_tokens=False)
#     return word_reprs[ids].sum(axis=0)


# def idf_mean_vector(text):
#   ids = tokenizer.encode(text, add_special_tokens=False)
#   # return (vectorizer.idf_[ids] @ word_reprs[ids]) / (len(ids) + 1e-8) # 這個比較慢，可能跟 contiguous 有關
#   return np.einsum('ld,l', word_reprs[ids], vectorizer.idf_[ids]) / (len(ids) + 1e-8)

In [61]:
metric = 'euclidean'
metric = 'cosine'

dists = (1 / pairwise_distances(q_vecs, t_vecs, metric=metric)).tolist()

results = {qid: {tid: score for tid, score in zip(ordered_corpus.keys(), dists[i])} for i, qid in enumerate(ordered_queries.keys())}

In [62]:
metrics = EvaluateRetrieval.evaluate(qrels, results, [1, 3, 5, 10, 100, 1000])

flatten_metrics = {k: v for metric_type in metrics for k, v in metric_type.items()}
metric_names, metric_values = zip(*flatten_metrics.items())
print(*metric_names, sep='\t')
print(*metric_values, sep='\t')
print()

md = beir_metrics_to_markdown_table(*metrics)
Markdown(md)

NDCG@1	NDCG@3	NDCG@5	NDCG@10	NDCG@100	NDCG@1000	MAP@1	MAP@3	MAP@5	MAP@10	MAP@100	MAP@1000	Recall@1	Recall@3	Recall@5	Recall@10	Recall@100	Recall@1000	P@1	P@3	P@5	P@10	P@100	P@1000
0.4	0.3805	0.36009	0.31858	0.19082	0.15347	0.00115	0.00304	0.0041	0.00611	0.02157	0.04207	0.00115	0.00321	0.00452	0.00729	0.0374	0.1428	0.46	0.42667	0.384	0.334	0.189	0.07384



||NDCG|MAP|Recall|P|
|-|-|-|-|-|
|@1|0.4000|0.0011|0.0011|0.4600|
|@3|0.3805|0.0030|0.0032|0.4267|
|@5|0.3601|0.0041|0.0045|0.3840|
|@10|0.3186|0.0061|0.0073|0.3340|
|@100|0.1908|0.0216|0.0374|0.1890|
|@1000|0.1535|0.0421|0.1428|0.0738|