In [1]:
import sys
import os

parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.append(parent_dir)

In [2]:
import json
import joblib
import spacy
import numpy as np
import pandas as pd
from sklearn.metrics import ndcg_score
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi

from scripts.processing.keywords import get_keywords_rake, get_keywords_textrank
from model.inference.bert import BertTextClassifier
from scripts.processing.temporal_score import compute_temporal_scores
from scripts import topics

/opt/conda/lib/python3.11/site-packages


In [None]:
docs = json.load(open('../data/experiment/documents.json', 'r'))
queries = json.load(open('../data/experiment/queries.json', 'r'))
rankings = json.load(open('../data/experiment/rankings.json'))
rankings = [[ranking[doc["url"]] / 10 for doc in docs] for ranking in rankings]

In [4]:
feature_sets = [
    ('bm25', 'embedding_similarity'),
    ('bm25', 'embedding_similarity', 'topic'),
    ('bm25', 'embedding_similarity', 'temporal_intent'),
    ('bm25', 'embedding_similarity', 'keywords_rake'),
    ('bm25', 'embedding_similarity', 'keywords_textrank'),
    ('bm25', 'embedding_similarity', 'temporal_intent', 'keywords_rake', 'topic'),
    ('bm25', 'embedding_similarity', 'temporal_intent', 'keywords_textrank', 'topic')
]

In [5]:
nlp = spacy.load("ru_core_news_sm")
bm25 = BM25Okapi([[token.lemma_ for token in nlp(doc["text"])] for doc in docs])
embedding_model = SentenceTransformer('cointegrated/rubert-tiny2')
temporal_model = joblib.load('../model/saved/time_precision_model.joblib')
topic_classification_model = BertTextClassifier(
    model_path='../model/saved/ruBert-large-mlp.pth', 
    base_model_name='ai-forever/ruBert-large', 
    tokenizer_path='ai-forever/ruBert-large'
)

  self.model.load_state_dict(torch.load(model_path))


In [6]:
def jaccard_similarity(query_keyword_list, text_keyword_list):
    intersection = len(set(query_keyword_list) & set(text_keyword_list))
    union = len(set(query_keyword_list) | set(text_keyword_list))
    if union == 0:
        return 0
    return intersection / union

def dice_similarity(query_keyword_list, text_keyword_list):
    intersection = len(set(query_keyword_list) & set(text_keyword_list))
    union = len(set(query_keyword_list) | set(text_keyword_list))
    if union == 0:
        return 0
    return 2 * intersection / (len(query_keyword_list) + len(text_keyword_list))

In [7]:
doc_topics = topic_classification_model.predict([doc["text"] for doc in docs], topics)
doc_embeddings = embedding_model.encode([doc["text"] for doc in docs])
doc_keywords_rake = [get_keywords_rake(doc["text"]) for doc in docs]
doc_keywords_textrank = [get_keywords_textrank(doc["text"]) for doc in docs]
for i, doc in enumerate(docs):
    docs[i]["topic"] = doc_topics[i]
    docs[i]["embedding"] = doc_embeddings[i]
    docs[i]["kw_rake"] = doc_keywords_rake[i]
    docs[i]["kw_textrank"] = doc_keywords_textrank[i]

query_topics = topic_classification_model.predict([query["text"] for query in queries], topics)
query_embeddings = embedding_model.encode([query["text"] for query in queries])
query_keywords_rake = [get_keywords_rake(query["text"]) for query in queries]
query_keywords_textrank = [get_keywords_textrank(query["text"]) for query in queries]
for i, query in enumerate(queries):
    queries[i]["topic"] = query_topics[i]
    queries[i]["embedding"] = query_embeddings[i]
    queries[i]["kw_rake"] = query_keywords_rake[i]
    queries[i]["kw_textrank"] = query_keywords_textrank[i]

Predicting topics: 4it [00:00,  4.27it/s]                       
Predicting topics: 1it [00:00, 111.39it/s]


In [None]:
for feature_set in feature_sets:
    n_features = len(feature_set)
    if "topic" in feature_set:
            n_features -= 1
    total_ndcg = 0
    for j, query in enumerate(queries):
        relevant_docs = docs
        lemmatized_query = [token.lemma_ for token in nlp(query["text"])]
        bm_scores = bm25.get_scores(lemmatized_query)
        if "temporal_intent" in feature_set:
            temporal_scores = compute_temporal_scores([doc["date"] for doc in docs], query["text"], temporal_model)
        for i, doc in enumerate(relevant_docs):
            relevant_docs[i]["score"] = 0
            if "topic" in feature_set and relevant_docs[i]["topic"] != query["topic"]: 
                continue
                
            if "bm25" in feature_set:
                relevant_docs[i]["score"] += bm_scores[i]
            if "embedding_similarity" in feature_set:
                relevant_docs[i]["score"] += 0.5 + 0.5 * np.dot(doc["embedding"], query["embedding"]) / \
                                            (np.linalg.norm(doc["embedding"]) * np.linalg.norm(query["embedding"]))
            if "temporal_intent" in feature_set:
                relevant_docs[i]["score"] += temporal_scores[i]
            if "keywords_rake" in feature_set:
                relevant_docs[i]["score"] += jaccard_similarity([word[0] for word in query["kw_rake"]], [word[0] for word in doc["kw_rake"]])
            if "keywords_textrank" in feature_set:
                relevant_docs[i]["score"] += jaccard_similarity([word[0] for word in query["kw_textrank"]], [word[0] for word in doc["kw_textrank"]])
            relevant_docs[i]["score"] /= n_features
        total_ndcg += ndcg_score([rankings[j]], [[doc["score"] for doc in relevant_docs]])
    print(f"NDCG for features {feature_set}: {total_ndcg / len(queries):.3f}")
        

NDCG for features ('bm25', 'embedding_similarity'): 0.758
NDCG for features ('bm25', 'embedding_similarity', 'topic'): 0.528
NDCG for features ('bm25', 'embedding_similarity', 'temporal_intent'): 0.758
NDCG for features ('bm25', 'embedding_similarity', 'keywords_rake'): 0.759
NDCG for features ('bm25', 'embedding_similarity', 'keywords_textrank'): 0.758
NDCG for features ('bm25', 'embedding_similarity', 'temporal_intent', 'keywords_rake', 'topic'): 0.528
NDCG for features ('bm25', 'embedding_similarity', 'temporal_intent', 'keywords_textrank', 'topic'): 0.528
