In [1]:
%load_ext autoreload
%autoreload 2

In [110]:
from nltk.corpus import brown
import pkg_resources
from symspellpy import SymSpell, Verbosity
from collections import Counter
import numpy as np
import torch
from transformers import AutoTokenizer, BertForMaskedLM
from pdb import set_trace

In [84]:
sym_spell = SymSpell(max_dictionary_edit_distance=0, prefix_length=7)
dictionary_path = pkg_resources.resource_filename(
    "symspellpy", "frequency_dictionary_en_82_765.txt")
# term_index is the column of the term and count_index is the
# column of the term frequency
sym_spell.load_dictionary(dictionary_path, term_index=0, count_index=1)


the quick brown fox jumps over the lazy dog, 8, -0.8021201856258288


In [118]:
sym_spell = SymSpell(max_dictionary_edit_distance=0, prefix_length=7)
dictionary_path = pkg_resources.resource_filename(
    "symspellpy", "frequency_dictionary_en_82_765.txt")
# term_index is the column of the term and count_index is the
# column of the term frequency
sym_spell.load_dictionary(dictionary_path, term_index=0, count_index=1)

# a sentence without any spaces
input_term = "updown"
result = sym_spell.word_segmentation(input_term)
print("{}, {}, {}".format(result.corrected_string, result.distance_sum,
                          result.log_prob_sum/len(result.corrected_string)))

up down, 1, -0.9643268255583413


In [78]:
sym_spell = SymSpell(max_dictionary_edit_distance=0, prefix_length=7)
dictionary_path = pkg_resources.resource_filename(
    "symspellpy", "frequency_dictionary_en_82_765.txt")
# term_index is the column of the term and count_index is the
# column of the term frequency
sym_spell.load_dictionary(dictionary_path, term_index=0, count_index=1)

# a sentence without any spaces
input_term = "updown"
result = sym_spell.word_segmentation(input_term)
print("{}, {}, {}".format(result.corrected_string, result.distance_sum,
                          result.log_prob_sum))

up down, 1, -6.750287778908389


In [88]:
result.corrected_string.split(" ")

['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog']

In [125]:
class WordSuggester():
    
    def __init__(self,): 
        print("Initializing the vocabulary set..")
        self.word_set = set(brown.words())
        print("Initializing BERT pipeline..")

        self.tok = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.bert = BertForMaskedLM.from_pretrained("bert-base-uncased")
        self.sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7)
        self.sym_spell_cut = SymSpell(max_dictionary_edit_distance=0, prefix_length=7)
        dictionary_path = pkg_resources.resource_filename(
            "symspellpy", "frequency_dictionary_en_82_765.txt")
        # term_index is the column of the term and count_index is the
        # column of the term frequency
        self.sym_spell.load_dictionary(dictionary_path, term_index=0, count_index=1)
        self.sym_spell_cut.load_dictionary(dictionary_path, term_index=0, count_index=1)

    def validate_word(self,word,word_counts,min_counts=2):
        """
        A word is considered valid if it occures many times or 
        """
        tot = sum(word_counts.values())
        return word_counts[word] >= min_counts    

    def get_word_suggestions(self,word,word_counts):
        """
        Return the suggestions for the word passed in parameter. If the 
        word passed in parameter is valid, return a list of len 1 with the
        word inside.
        
        Args:
            word (str): the word to find suggestions for
            word_counts (dict): value counts of word for a given emoji (context)
        """
        # if the word appears many times we keep it
        if self.validate_word(word,word_counts):
            return {'status':'present','words':[word]}
        
        # if the word is part of the vocabulary we keep it
        if word in self.word_set:
            return {'status':'exist','words':[word]}
        
        # if it is a combinaison of many words
        result = self.sym_spell_cut.word_segmentation(word)
        log_confidence = result.log_prob_sum/len(result.corrected_string)
        if log_confidence > -1:
            suggestions = result.corrected_string.split(" ")
            if len(suggestions) == 1:
                return {'status':'exist','words':suggestions}
            return {'status':'cut','words':suggestions}
        
        # otherwise we correct it
        suggestions = self.sym_spell.lookup(word, Verbosity.CLOSEST,
                                       max_edit_distance=2)
        # display suggestion term, term frequency, and edit distance
        if len(suggestions) == 0:
            print(f"Word {word} not found!")
            return {'status':'notfound','words':[word]}
        return {'status':'corrected','words':[sugg.term for sugg in suggestions]}

    def get_context_suggestions(self,word_list):
        """
        Applies get_word_suggestions for every word of an emoji's vocabulary (context)
        
        Args:
            word_list (list of str): words to describe the emoji
        
        Returns:
            [list of list of str]: list of suggestions: each word receives suggestions (list of str)
        """
        word_counts = Counter(word_list)
        context_suggestions = [self.get_word_suggestions(word,word_counts) for word in word_list]
        return context_suggestions
    

    def find_best_word(self,context,suggestions):
        """
        Find the most appropriate word in suggestions given the context
        
        Args:
            context (list of str): words defining the context
            suggestions (list of str): suggestions for the word to find
        
        Returns:
            [str]: the word of suggestions that matches the best the context
            according to BERT output
        """
        # We place the word of interest in the middle of the context
        n = len(context) // 2
        pre_context = ' '.join(context[:n])
        post_context = ' '.join(context[n:])
        sentence = f"{pre_context} {self.tok.mask_token} {post_context}"

        input_tokens = self.tok.encode(sentence)
        answer_pos = input_tokens.index(self.tok.mask_token_id)

        logits = self.bert(torch.tensor([input_tokens]))[0][0]
        logits = logits[answer_pos]
        suggestions_tokens = [self.tok.encode(word)[1:-1] for word in suggestions]
        scores = [np.mean([logits[i].item() for i in tokens]) for tokens in suggestions_tokens]
        best_sugg_idx = np.argmax(scores)
        return suggestions[best_sugg_idx]
    
    def extract_context_suggestions(self,context_suggestions):
        """
        Extract best words for each suggestions in the context suggestions
        
        Args:
            context_suggestions (list of list of str): list of suggestions
        
        Returns:
            [list of str]: most appropriate words

        """
        # we don't need the status in the current function
        context_suggestions = [sugg['words'] for sugg in context_suggestions]
        ret_words = []
        for suggestions in context_suggestions:
            # single suggestion: the word is not ambiguous
            if len(suggestions) == 1:
                ret_words.append(suggestions[0])
            else:
                # we gather the single words considered as healthy
                context = [word_list[0] for word_list in context_suggestions
                                         if word_list != suggestions and len(word_list) == 1 ]
                word = self.find_best_word(context,suggestions)
                
                ret_words.append(word)
        return ret_words
    
    def process_context(self,context,verbose=False):
        """
        Args:
            context (list of str): words
        
        Returns:
            [list of str]: corrected words
        """
        context_suggestions = self.get_context_suggestions(context)
        corr_words = self.extract_context_suggestions(context_suggestions)
        if verbose:
            for word,suggestions,corr_word in zip(context,context_suggestions,corr_words):
                if suggestions['status'] not in ['present','exist']:
                    status = suggestions['status']
                    print(f"Modified {word} --> {corr_word} ({status})")
        return corr_words


In [126]:
sugg = WordSuggester()

Initializing the vocabulary set..
Initializing BERT pipeline..


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [32]:
voc = sugg.process_context(["applette","applette","fruito","sugary","sweet"],verbose=True)

Corrected fruito --> fruit


In [37]:
from src.analysis.postprocessing import scrap_form_results,generate_production_format
from src.constants import EMOJI_DATASET_DIR
from tqdm import tqdm
tqdm.pandas()

In [96]:
form_dfs = scrap_form_results(EMOJI_DATASET_DIR)

100%|██████████| 133/133 [00:51<00:00,  2.57it/s]


In [97]:
form_df = generate_production_format(form_dfs)

100%|██████████| 133/133 [00:03<00:00, 37.80it/s]


In [101]:
form_df.word.apply(lambda x: " " in x).any()

False

In [106]:
small_form_df = form_df.groupby('emoji')['word'].agg(list)

In [107]:
small_form_df = small_form_df.iloc[:14]

In [108]:
small_form_df

emoji
#️⃣    [number, hashtag, pound, number, pound, pound,...
*️⃣    [asterisk, pound, snowflake, asterisk, asterik...
©️     [copywrite, copyright, copywrite, copyright, c...
®️     [registered, r, letter, copyright, r, rest, ra...
‼️     [exclamation, exclamation, excited, exclamatio...
⁉️     [confused, exclaim, seriously, surprised, ques...
™️     [text, trademark, tm, trademark, tm, trademark...
ℹ️     [information, i, hand, doubt, letter, exclamat...
↔️     [sign, arrow, navigation, leftorright, turn, w...
↕️     [directions, arrows, vertical, perpendicular, ...
↖️     [diagonal, click, upleftarrow, diagonal, angle...
↗️     [right, turn, up, up, arrow, arrow, up, right,...
↘️     [direction, down, corner, down, down, downhear...
↙️     [arrow, down, arrow, arrow, down, arrow, diago...
Name: word, dtype: object

In [127]:
small_form_df.progress_apply(lambda x: sugg.process_context(x,verbose=True))

 29%|██▊       | 4/14 [00:00<00:00, 27.84it/s]

Modified asterik --> asterisk (corrected)
Modified astrik --> astrid (corrected)
Modified coppyright --> copyright (corrected)
Modified cee --> cen (corrected)


 50%|█████     | 7/14 [00:00<00:00,  9.89it/s]

Modified capitalr --> capital (corrected)
Modified icon --> icon (corrected)
Modified circler --> circlet (corrected)
Modified excalmation --> exclamation (corrected)
Modified exclamationquestion --> question (cut)


 79%|███████▊  | 11/14 [00:01<00:00,  8.65it/s]

Modified leftorright --> or (cut)
Modified whichway --> way (cut)
Modified leftright --> left (cut)
Modified upanddown --> and (cut)


 93%|█████████▎| 13/14 [00:01<00:00,  6.47it/s]

Modified upleftarrow --> up (cut)
Modified angled --> angled (corrected)
Modified risingsign --> sign (cut)


100%|██████████| 14/14 [00:01<00:00,  7.75it/s]

Modified arrowbottomleftcorner --> arrow (cut)
Modified thatway --> way (cut)





emoji
#️⃣    [number, hashtag, pound, number, pound, pound,...
*️⃣    [asterisk, pound, snowflake, asterisk, asteris...
©️     [copywrite, copyright, copywrite, copyright, c...
®️     [registered, r, letter, copyright, r, rest, ra...
‼️     [exclamation, exclamation, excited, exclamatio...
⁉️     [confused, exclaim, seriously, surprised, ques...
™️     [text, trademark, tm, trademark, tm, trademark...
ℹ️     [information, i, hand, doubt, letter, exclamat...
↔️     [sign, arrow, navigation, or, turn, wider, dir...
↕️     [directions, arrows, vertical, perpendicular, ...
↖️     [diagonal, click, up, diagonal, angled, arrow,...
↗️     [right, turn, up, up, arrow, arrow, up, right,...
↘️     [direction, down, corner, down, down, downhear...
↙️     [arrow, down, arrow, arrow, down, arrow, diago...
Name: word, dtype: object