In [1]:
from collections import OrderedDict
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import pairwise_distances

from sklearn.feature_extraction.text import TfidfVectorizer

In [2]:
from utils.preprocessor import Stopwords_preprocessor
from utils.markdown import beir_metrics_to_markdown_table
from IPython.display import Markdown

# from rank_bm25 import BM25Okapi as BM25
from transformers import logging, AutoTokenizer
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval

In [3]:
import pathlib
import torch
from torch import nn

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [5]:
# from beir import util
# dataset =  'trec-covid' # "nfcorpus" 
# url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
# data_path = util.download_and_unzip(url, 'data')

In [6]:
corpus_name = 'scifact'
corpus_name = 'trec-covid'
# corpus_name = 'nfcorpus'

corpus, queries, qrels = GenericDataLoader(f'data/{corpus_name}').load(split="test")
corpus_text = [v['text'] for k,v in corpus.items()]

  0%|          | 0/171332 [00:00<?, ?it/s]

In [7]:
def tokenize(x):
    return tokenizer.convert_ids_to_tokens(tokenizer.encode(x, add_special_tokens=False))

vectorizer = TfidfVectorizer(tokenizer=tokenize, vocabulary=tokenizer.vocab)
# %time vectorizer.fit(corpus_text)

In [8]:
# # test
# text_sample = corpus[list(corpus.keys())[0]]['text']
# res = mean_rotary_discrepency(text_sample)
# res.shape

In [9]:
def mean_vector(text):
    ids = tokenizer.encode(text, add_special_tokens=False)
    if len(ids) == 0:
        return np.zeros(word_reprs.shape[1])
    return word_reprs[ids].mean(axis=0)

def sum_vector(text):
    ids = tokenizer.encode(text, add_special_tokens=False)
    return word_reprs[ids].sum(axis=0)


def idf_mean_vector(text):
  ids = tokenizer.encode(text, add_special_tokens=False)
  # return (vectorizer.idf_[ids] @ word_reprs[ids]) / (len(ids) + 1e-8) # 這個比較慢，可能跟 contiguous 有關
  return np.einsum('ld,l', word_reprs[ids], vectorizer.idf_[ids]) / (len(ids) + 1e-8)

# def idf_mean_vector(text):
#   ids = tokenizer.encode(text, add_special_tokens=False)
#   

In [37]:
folder_path = pathlib.Path('data/limanet/')

subpath = '20240422.11:52:35-batch_size_1024' #trec-covid best, batch 10000
# subpath = '20240422.02:12:51-batch_size_1024' #for report
subpath = '20240423.15:38:14-batch_size_1024' #rotator_lima3, dim 96, head 12, depth 3, rotary_denom 2

# subpath = '20240430.16:16:30-batch_size_256' # rotator_lima4_hippo, mse+multiplet, dim128, hdim32, head32, dep3
subpath = '20240501.12:18:59-batch_size_512' # rotator_lima4_multihippo, mse+multiplet, dim128, hdim32, head32, dep3
subpath = '20240502.19:17:09-batch_size_512' # rotator_lima4_noClocks, cos_dist+multiplet, dim128, hdim32, head32, dep3
subpath = '20240502.21:10:49-batch_size_64' # rotator_lima4_noClocks, msmarco, mse+multiplet, dim128, hdim32, head32, dep3
subpath = '20240502.22:32:46-batch_size_512' # rotator_lima4_noTime, mse+multiplet, dim128, hdim32, head32, dep3
# subpath = '20240502.23:36:34-batch_size_512' # rotator_lima4_noTime, cos_dist+multiplet

batch_num = 0

model = torch.load(folder_path/subpath/f'batch_{batch_num}-model.pt', map_location='cpu')
word_reprs_complex = model.predictor.all_word_embeddings()
word_reprs = torch.concat([word_reprs_complex.real, word_reprs_complex.imag], dim=-1).detach().numpy()
word_reprs_complex = word_reprs_complex.detach().numpy()

In [38]:
# for i in range(len(model.limas)):
#     lima_shape = model.limas[i].lima_shape
#     print(lima_shape)
#     print(f'{i}: {lima_shape.min()}, {lima_shape.max()}')

In [39]:
method = idf_mean_vector
method = mean_vector
# method = sum_vector

part = 'text'
# part = 'title'

%time text_vec_dict = OrderedDict({k: method(v[part]) for k, v in corpus.items()})
%time query_vec_dict = OrderedDict({k: method(v) for k, v in queries.items()})
text_vecs = np.stack(list(text_vec_dict.values()))

CPU times: user 2min 15s, sys: 390 ms, total: 2min 15s
Wall time: 2min 15s
CPU times: user 8.62 ms, sys: 279 µs, total: 8.9 ms
Wall time: 8.93 ms


In [40]:
metric = 'euclidean'
metric = 'cosine'

def score(query_vector, metric=metric):
    return (1/pairwise_distances(query_vector[None, :], text_vecs, metric=metric))[0]


%time results = {qid: dict(zip(text_vec_dict.keys(), score(query_vector).tolist())) \
            for qid, query_vector in query_vec_dict.items()}

metrics = EvaluateRetrieval.evaluate(qrels, results, [1, 3, 5, 10, 100, 1000])

flatten_metrics = {k: v for metric_type in metrics for k, v in metric_type.items()}
metric_names, metric_values = zip(*flatten_metrics.items())
print(*metric_names, sep='\t')
print(*metric_values, sep='\t')
print()

md = beir_metrics_to_markdown_table(*metrics)
Markdown(md)

CPU times: user 1min 28s, sys: 3min 43s, total: 5min 12s
Wall time: 18.9 s
NDCG@1	NDCG@3	NDCG@5	NDCG@10	NDCG@100	NDCG@1000	MAP@1	MAP@3	MAP@5	MAP@10	MAP@100	MAP@1000	Recall@1	Recall@3	Recall@5	Recall@10	Recall@100	Recall@1000	P@1	P@3	P@5	P@10	P@100	P@1000
0.25	0.25592	0.24596	0.21615	0.1243	0.10074	0.00053	0.0014	0.00204	0.00314	0.01046	0.01944	0.00053	0.00173	0.00261	0.00451	0.02378	0.09455	0.28	0.28	0.268	0.22	0.1232	0.0493



||NDCG|MAP|Recall|P|
|-|-|-|-|-|
|@1|0.2500|0.0005|0.0005|0.2800|
|@3|0.2559|0.0014|0.0017|0.2800|
|@5|0.2460|0.0020|0.0026|0.2680|
|@10|0.2162|0.0031|0.0045|0.2200|
|@100|0.1243|0.0105|0.0238|0.1232|
|@1000|0.1007|0.0194|0.0945|0.0493|

In [14]:
model.predictor.rotary_denom

Parameter containing:
tensor(0.1280, requires_grad=True)

In [15]:
# # write first 10 questions and top 10 answer to file

# samples = list(results.items())[:10]
# for q_num, score_dict in samples:
#     with open(f'question_{q_num}.txt', 'w') as f:
#         f.write(f'{queries[q_num]}\n\n')
#         tokens = tokenizer.convert_ids_to_tokens(tokenizer(queries[q_num], add_special_tokens=False)['input_ids'])
#         f.write(f'{tokens}\n\n')
        
#         text_ids, text_scores = zip(*score_dict.items())
#         text_scores = np.array(text_scores)
#         top_10_idx = np.argsort(text_scores)[:-10:-1]

#         for idx in top_10_idx:
#             f.write(f'{corpus[text_ids[idx]]}\n\n')

In [16]:
# #test: 看每個字往時間方向逆向轉一個 t 後，附近的字為何。理論來說會是跟這個字無關的字 (text independent)，因為這個旋轉抵銷了時間旋轉

# inverse_metric_theta = - 1/model.predictor.rotary_denom**(model.predictor.dimension_nums/model.predictor.dim)
# inverse_pos_rotation = torch.complex(inverse_metric_theta.cos(), inverse_metric_theta.sin())
# least_effective_position_of_the_word = model.predictor.all_word_embeddings() * inverse_pos_rotation
# least_effective_position_of_the_word = torch.concat([least_effective_position_of_the_word.real, least_effective_position_of_the_word.imag], dim=-1).detach().numpy()

# least_effective_position_of_the_word.shape

# %time d = pairwise_distances(word_reprs, least_effective_position_of_the_word, metric='euclidean') # metric='cosine'

# %time pair = d.argsort(axis=1)[:, :10]

# for input_id in tokenizer.encode(text_sample):
#     print(f'{tokenizer.convert_ids_to_tokens(input_id)}: {tokenizer.convert_ids_to_tokens(pair[input_id])}')