In [1]:
# !pip install gensim scipy==1.12

In [2]:
from collections import OrderedDict
import pathlib
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import pairwise_distances

from sklearn.feature_extraction.text import TfidfVectorizer

In [None]:
from utils.preprocessor import Stopwords_preprocessor
from utils.logging import beir_metrics_to_markdown_table
from IPython.display import Markdown

# from rank_bm25 import BM25Okapi as BM25
from transformers import logging, AutoTokenizer
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [5]:
corpus_name = 'scifact'
# corpus_name = 'trec-covid'
# corpus_name = 'nfcorpus'

corpus, queries, qrels = GenericDataLoader(f'data/{corpus_name}').load(split="test")
corpus_text = [v['text'] for k,v in corpus.items()]

100%|███████████████████████████████████████████████████████████████████████████████████████| 5183/5183 [00:00<00:00, 115001.50it/s]


In [6]:
# def tokenize(x):
#     return tokenizer.convert_ids_to_tokens(tokenizer.encode(x, add_special_tokens=False))

# vectorizer = TfidfVectorizer(tokenizer=tokenize, vocabulary=tokenizer.vocab)
# %time vectorizer.fit(corpus_text)

In [7]:
from gensim import models
from gensim.utils import tokenize

%time model = models.KeyedVectors.load_word2vec_format('data/glove.6B.100d_w2vformat_bash.txt')

CPU times: user 21.7 s, sys: 345 ms, total: 22.1 s
Wall time: 22 s


In [9]:
def mean_vector(text, dim=300):

    def fetch_vec(word):
        try:
            vec = model[word.lower()]
        except KeyError:
            vec = np.zeros([dim])
        return vec
    
    word_vecs = [fetch_vec(word) for word in tokenize(text)]
    try:
        word_vecs = np.stack(word_vecs)
    except: 
        print(text)
        print(np.unique([len(v) for v in word_vecs]))
    return np.mean(word_vecs, axis=0)

def idf_mean_vector(text):
    return np.einsum('ld,l', mean_vector(text), vectorizer.idf_[ids])

In [10]:
dim = 100
method = idf_mean_vector
method = mean_vector

%time text_vec_dict = OrderedDict({k: method(v['text'], dim) for k, v in corpus.items()})
%time query_vec_dict = OrderedDict({k: method(v, dim) for k, v in queries.items()})
text_vecs = np.stack(list(text_vec_dict.values()))

CPU times: user 3.09 s, sys: 85.6 ms, total: 3.18 s
Wall time: 3.18 s
CPU times: user 18.9 ms, sys: 0 ns, total: 18.9 ms
Wall time: 18.9 ms


In [11]:
metric = 'euclidean'
metric = 'cosine'


def score(query_vector, metric=metric):
    return (1/pairwise_distances(query_vector[None, :], text_vecs, metric=metric))[0]


%time results = {qid: dict(zip(text_vec_dict.keys(), score(query_vector).tolist())) \
            for qid, query_vector in query_vec_dict.items()}

metrics = EvaluateRetrieval.evaluate(qrels, results, [1, 3, 5, 10, 100, 1000])

flatten_metrics = {k: v for metric_type in metrics for k, v in metric_type.items()}
metric_names, metric_values = zip(*flatten_metrics.items())
print(*metric_names, sep='\t')
print(*metric_values, sep='\t')

md = beir_metrics_to_markdown_table(*metrics)
print(md)
Markdown(md)

CPU times: user 21.4 s, sys: 1min 16s, total: 1min 37s
Wall time: 1.84 s
NDCG@1	NDCG@3	NDCG@5	NDCG@10	NDCG@100	NDCG@1000	MAP@1	MAP@3	MAP@5	MAP@10	MAP@100	MAP@1000	Recall@1	Recall@3	Recall@5	Recall@10	Recall@100	Recall@1000	P@1	P@3	P@5	P@10	P@100	P@1000
0.12	0.15543	0.17567	0.19194	0.23819	0.2819	0.11667	0.14431	0.15589	0.16219	0.17033	0.17172	0.11667	0.18167	0.23	0.28056	0.50778	0.86072	0.12	0.06444	0.04867	0.03	0.00563	0.00097
||NDCG|MAP|Recall|P|
|-|-|-|-|-|
|@1|0.1200|0.1167|0.1167|0.1200|
|@3|0.1554|0.1443|0.1817|0.0644|
|@5|0.1757|0.1559|0.2300|0.0487|
|@10|0.1919|0.1622|0.2806|0.0300|
|@100|0.2382|0.1703|0.5078|0.0056|
|@1000|0.2819|0.1717|0.8607|0.0010|


||NDCG|MAP|Recall|P|
|-|-|-|-|-|
|@1|0.1200|0.1167|0.1167|0.1200|
|@3|0.1554|0.1443|0.1817|0.0644|
|@5|0.1757|0.1559|0.2300|0.0487|
|@10|0.1919|0.1622|0.2806|0.0300|
|@100|0.2382|0.1703|0.5078|0.0056|
|@1000|0.2819|0.1717|0.8607|0.0010|