Taken from https://huggingface.co/spaces/ml6team/post-processing-summarization

In [1]:
import itertools
import numpy as np

import spacy

from flair.nn import Classifier
from flair.data import Sentence

from sentence_transformers import SentenceTransformer

from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline

def get_transformer_pipeline():
    tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
    model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
    return pipeline("ner", model=model, tokenizer=tokenizer, grouped_entities=True)

sentence_embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
ner_model = get_transformer_pipeline()
nlp = spacy.load("en_core_web_sm")
flair_tagger = Classifier.load('ner')

  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at xlm-roberta-large-finetuned-conll03-english were not used when initializing XLMRobertaForTokenClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing XLMRobertaForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


2023-09-22 14:03:32,981 SequenceTagger predicts: Dictionary with 20 tags: <unk>, O, S-ORG, S-MISC, B-PER, E-PER, S-LOC, B-ORG, E-ORG, I-PER, S-PER, B-MISC, I-MISC, E-MISC, I-ORG, B-LOC, E-LOC, I-LOC, <START>, <STOP>


In [2]:
def get_all_entities_per_sentence(text):
    doc = nlp(text)

    sentences = list(doc.sents)

    entities_all_sentences = []
    for sentence in sentences:
        entities_this_sentence = []

        # SPACY ENTITIES
        for entity in sentence.ents:
            entities_this_sentence.append(str(entity))

        # FLAIR ENTITIES (CURRENTLY NOT USED)
        sentence_entities = Sentence(str(sentence))
        flair_tagger.predict(sentence_entities)
        for entity in sentence_entities.get_spans('ner'):
            entities_this_sentence.append(entity.text)

        # XLM ENTITIES
        entities_xlm = [entity["word"] for entity in ner_model(str(sentence))]
        for entity in entities_xlm:
            entities_this_sentence.append(str(entity))

        entities_all_sentences.append(entities_this_sentence)

    return entities_all_sentences


In [3]:
def get_and_compare_entities(source, summary):
    all_entities_per_sentence = get_all_entities_per_sentence(source)
    entities_source = list(itertools.chain.from_iterable(all_entities_per_sentence))

    # if first_time:
    #     article_content = st.session_state.article_text
    #     all_entities_per_sentence = get_all_entities_per_sentence(article_content)
    #     entities_article = list(itertools.chain.from_iterable(all_entities_per_sentence))
    #     st.session_state.entities_article = entities_article
    # else:
    #     entities_article = st.session_state.entities_article

    # summary_content = st.session_state.summary_output
    all_entities_per_sentence = get_all_entities_per_sentence(summary)
    entities_summary = list(itertools.chain.from_iterable(all_entities_per_sentence))

    matched_entities = []
    unmatched_entities = []
    for entity in entities_summary:
        if any(entity.lower() in substring_entity.lower() for substring_entity in entities_source):
            matched_entities.append(entity)
        elif any(
                np.inner(sentence_embedding_model.encode(entity, show_progress_bar=False),
                         sentence_embedding_model.encode(art_entity, show_progress_bar=False)) > 0.9 for
                art_entity in entities_source):
            matched_entities.append(entity)
        else:
            unmatched_entities.append(entity)

    matched_entities = list(dict.fromkeys(matched_entities))
    unmatched_entities = list(dict.fromkeys(unmatched_entities))

    matched_entities_to_remove = []
    unmatched_entities_to_remove = []

    for entity in matched_entities:
        for substring_entity in matched_entities:
            if entity != substring_entity and entity.lower() in substring_entity.lower():
                matched_entities_to_remove.append(entity)

    for entity in unmatched_entities:
        for substring_entity in unmatched_entities:
            if entity != substring_entity and entity.lower() in substring_entity.lower():
                unmatched_entities_to_remove.append(entity)

    matched_entities_to_remove = list(dict.fromkeys(matched_entities_to_remove))
    unmatched_entities_to_remove = list(dict.fromkeys(unmatched_entities_to_remove))

    for entity in matched_entities_to_remove:
        matched_entities.remove(entity)
    for entity in unmatched_entities_to_remove:
        unmatched_entities.remove(entity)

    return matched_entities, unmatched_entities



In [4]:
article = "Lionel Andrés Messi (born 24 June 1987) is an Argentine professional footballer who plays as a forward and captains both Spanish club Barcelona and the Argentina national team. Often considered as the best player in the world and widely regarded as one of the greatest players of all time, Messi has won a record six Ballon d'Or awards, a record six European Golden Shoes, and in 2020 was named to the Ballon d'Or Dream Team."
summary = "Lionel Andrés Messi (born 24 Aug 1997) is an Spanish professional footballer who plays as a forward and captains both Spanish club Barcelona and the Spanish national team."

In [5]:
matched, unmatched = get_and_compare_entities(article, summary)

In [6]:
matched

['24', 'Spanish', 'Barcelona', 'Lionel Andrés Messi']

In [7]:
unmatched

['1997']

In [None]:
def check_dependency(article: bool):
    if article:
        text = st.session_state.article_text
        all_entities = get_all_entities_per_sentence(text)
    else:
        text = st.session_state.summary_output
        all_entities = get_all_entities_per_sentence(text)
    doc = nlp(text)
    tok_l = doc.to_json()['tokens']
    test_list_dict_output = []

    sentences = list(doc.sents)
    for i, sentence in enumerate(sentences):
        start_id = sentence.start
        end_id = sentence.end
        for t in tok_l:
            if t["id"] < start_id or t["id"] > end_id:
                continue
            head = tok_l[t['head']]
            if t['dep'] == 'amod' or t['dep'] == "pobj":
                object_here = text[t['start']:t['end']]
                object_target = text[head['start']:head['end']]
                if t['dep'] == "pobj" and str.lower(object_target) != "in":
                    continue
                # ONE NEEDS TO BE ENTITY
                if object_here in all_entities[i]:
                    identifier = object_here + t['dep'] + object_target
                    test_list_dict_output.append({"dep": t['dep'], "cur_word_index": (t['id'] - sentence.start),
                                                  "target_word_index": (t['head'] - sentence.start),
                                                  "identifier": identifier, "sentence": str(sentence)})
                elif object_target in all_entities[i]:
                    identifier = object_here + t['dep'] + object_target
                    test_list_dict_output.append({"dep": t['dep'], "cur_word_index": (t['id'] - sentence.start),
                                                  "target_word_index": (t['head'] - sentence.start),
                                                  "identifier": identifier, "sentence": str(sentence)})
                else:
                    continue
    return test_list_dict_output


In [None]:
summary_deps = check_dependency(False)
article_deps = check_dependency(True)
total_unmatched_deps = []
for summ_dep in summary_deps:
    if not any(summ_dep['identifier'] in art_dep['identifier'] for art_dep in article_deps):
        total_unmatched_deps.append(summ_dep)
if total_unmatched_deps:
    for current_drawing_list in total_unmatched_deps:
        render_dependency_parsing(current_drawing_list)