In [2]:
from collections import Counter
from typing import Any, List

import pymorphy2


class TextNormalizer:
    def __init__(self):
        self.morph_ru = pymorphy2.MorphAnalyzer(lang="ru")

    def normalize(self, keyphrases: List[str]) -> Counter[Any]:
        filtered_tokens = []
        for text in keyphrases:
            for word in text.split():
                p_ru = self.morph_ru.parse(str(word))[0]
                if ("NOUN" in p_ru.tag or "LATN" in p_ru.tag) and "Name" not in p_ru.tag:
                    filtered_tokens.append(p_ru.normal_form)

        cntr = Counter(filtered_tokens)

        return cntr

In [3]:
import numpy as np
from sklearn.metrics.pairwise import pairwise_distances


class KeywordExtractor:
    def extract(
        self,
        words: list,
        embd: list,
        top_n: int,
        distance_metric: str = "cosine",
    ) -> np.ndarray:
        distances = pairwise_distances(embd, metric=distance_metric).mean(axis=1)
        top_n_indices = (-distances).argsort()[:top_n]
        top_n_keywords = words[top_n_indices]

        return top_n_keywords

In [4]:
import string
from nltk.corpus import stopwords
import re


class Model:
    def __init__(self, summarizer_t, summarizer_m, normalizer, tokenizer, embedder, ranker):
        self.summarizer_t = summarizer_t
        self.summarizer_m = summarizer_m.to("cuda")
        self.normalizer = normalizer
        self.tokenizer = tokenizer
        self.embedder = embedder.to("cuda")
        self.ranker = ranker

        punct = string.punctuation
        self.punct = punct.replace("-", "")

        self.stop_words = set(stopwords.words('russian'))

    def extract(self, text: str):
        input_ids = self.summarizer_t(
            [text],
            add_special_tokens=True,
            max_length=512,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )["input_ids"]

        output_ids = self.summarizer_m.generate(
            input_ids=input_ids.to("cuda"),
            max_length=128,
            no_repeat_ngram_size=3,
            num_beams=10,
#             top_p=0.95
        )[0].detach().cpu().numpy()

        summary = self.summarizer_t.decode(output_ids, skip_special_tokens=True)

        clear_text = []
        summary_split = [text.strip() for text in re.split(f'[{self.punct} ]', summary)]
        for word in summary_split:
            if word not in self.stop_words and word not in self.punct and word != "":
                clear_text.append(word)

        normalized_words = list(self.normalizer.normalize(clear_text).keys())

        tokenized = self.tokenizer(normalized_words, return_tensors="pt", truncation=True, padding=True).to("cuda")
        embeddings = self.embedder(**tokenized, output_hidden_states=True).last_hidden_state[:, 0, :]

        keywords = ext.extract(np.array(normalized_words), embeddings.detach().cpu().numpy(), 5)

        return keywords

In [5]:
from transformers import AutoTokenizer, AutoModel
from transformers import MBartTokenizer, MBartForConditionalGeneration

tokenizer_embd = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
embedder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to("cuda")

model_name = "IlyaGusev/mbart_ru_sum_gazeta"
summarizer_t = MBartTokenizer.from_pretrained(model_name)
summarizer_m = MBartForConditionalGeneration.from_pretrained(model_name).to("cuda")

normalizer = TextNormalizer()
ext = KeywordExtractor()

In [6]:
model = Model(summarizer_t, summarizer_m, normalizer, tokenizer_embd, embedder, ext)

## scoring

In [7]:
import ast
import pandas as pd

data = pd.read_csv("/kaggle/input/habr-articles-tags/data_extractive_habr.csv")

X_test = data["text"]
y_test = data["tag"].str.split().to_list()

In [8]:
from tqdm import tqdm

extracted_keywords = list()

for text in tqdm(X_test):
    keywords = model.extract(text)
    extracted_keywords.append(keywords)

100%|██████████| 3128/3128 [1:04:49<00:00,  1.24s/it]


In [10]:
from typing import Dict, List

import numpy as np
from rouge import Rouge
from tqdm import tqdm


def rouge_score_corpus(true: List[List[str]], pred: List[List[str]]) -> Dict[str, float]:
    rouge = Rouge()
    rec, prec, f1 = list(), list(), list()

    for true_tags, pred_tags in tqdm(zip(true, pred)):
        if len(true_tags) == 0 or len(pred_tags) == 0:
            continue

        true_tags_str = " ".join(true_tags)
        pred_tags_str = " ".join(pred_tags)

        scores = rouge.get_scores(pred_tags_str, true_tags_str)
        rouge_1 = scores[0]["rouge-1"]

        rec.append(rouge_1["r"])
        prec.append(rouge_1["p"])
        f1.append(rouge_1["f"])

    return {
        "recall": np.mean(rec),
        "precision": np.mean(prec),
        "f1": np.mean(f1),
    }

In [11]:
rouge_score_corpus(y_test, extracted_keywords)

3128it [00:00, 9065.90it/s]


{'recall': 0.15428731275118743,
 'precision': 0.12197890025575447,
 'f1': 0.13343315484204998}

In [12]:
y_test[:10]

[['перевод', 'css', 'css'],
 ['sendfile', 'ланит', 'файлообменник', 'artezio'],
 ['фцп', 'грант', 'экспертиза'],
 ['csv', 'excel', 'libreoffice', 'данные'],
 ['javascript', 'kotlin', 'ajax', 'jquery'],
 ['deb', 'пакет', 'репозиторий', 'lintian', 'dpkg-deb'],
 ['css', 'grid', 'subgrid'],
 ['квайн', 'ассемблер', 'windows'],
 ['nfc', 'nfc-метки', 'визитки', 'магия', 'abbyy'],
 ['видеокарты', 'amd', 'gpgpu', 'фермы', 'пароли', 'безопасность', 'opencl']]

In [13]:
extracted_keywords[:10]

[array(['псевдо-класс', 'not', 'рамка', 'задание', 'обзор'], dtype='<U12'),
 array(['платформа', 'artezio', 'sendfile', 'обзор', 'компания'],
       dtype='<U9'),
 array(['статья', 'конкурс', 'экспертиза', 'деньга', 'россия'],
       dtype='<U14'),
 array(['csv-файл', 'office', 'материал', 'open', 'excel'], dtype='<U8'),
 array(['languages', 'kotlin', 'обзор', 'ktor', 'mysql'], dtype='<U12'),
 array(['deb-пакет', 'статья', 'checkinstall', 'сайт', 'система'],
       dtype='<U12'),
 array(['teamwork', 'дизайн', 'таблица', 'данные', 'контакт'], dtype='<U8'),
 array(['помощь', 'ассемблер', 'язык', 'программирование', 'код'],
       dtype='<U16'),
 array(['технология', 'nfc-метка', 'мир', 'телефон', 'роль'], dtype='<U10'),
 array(['технология', 'обзор', 'множество', 'компьютер', 'кластер'],
       dtype='<U10')]