In [8]:
from transformers import BertTokenizer, BertModel, BertConfig, BertForTokenClassification, BertForMaskedLM
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
import torch
import os
import pandas as pd

## Utils

In [9]:
def get_token_in_sequence_with_most_attention(model, tokenizer, input_sequence):
    """Run an input sequence through the BERT model, collect and average attention scores per token and return token with most average attention."""
    tokenized_input_sequence = tokenizer.tokenize(input_sequence)
    input_ids = torch.tensor(tokenizer.encode(input_sequence, add_special_tokens=False)).unsqueeze(0)
    outputs = model(input_ids)
    last_hidden_states, pooler_outputs, hidden_states, attentions = outputs
    attention_tensor = torch.squeeze(torch.stack(attentions))
    attention_tensor_averaged = torch.mean(attention_tensor, (0, 1))
    attention_average_scores_per_token = torch.sum(attention_tensor_averaged, dim=0)
    attention_scores_dict = dict()
    for token_position in range(len(tokenized_input_sequence)):
        attention_scores_dict[token_position] = attention_average_scores_per_token[token_position].item()
    print('Attention scores dictionary: ', attention_scores_dict)
    return {'token_index': max(attention_scores_dict, key=attention_scores_dict.get),
            'token_str': tokenized_input_sequence[max(attention_scores_dict, key=attention_scores_dict.get)]}

In [10]:
def extract_keywords_from_mlm_results(mlm_results_list, K_kw_explore):
    selected_keywords_list = list()
    for rank_mlm_keyword in range(K_kw_explore):
        selected_keywords_list.append(mlm_results_list[rank_mlm_keyword]['token_str'])
    return selected_keywords_list

## Data and model import 

### is_hired_1mo

In [11]:
#load inference data from random set
data_path = '/home/manuto/Documents/world_bank/bert_twitter_labor/data/inference/convBERT/it0/random'
is_hired_1mo_df = pd.read_pickle(os.path.join(data_path,'is_hired_1mo_ONNX_BERT_ST_merged_random_100m_jun22.pkl')).reset_index(drop=True)
is_hired_1mo_df.head()

Unnamed: 0,tweet_id,first,second,text
0,662755069675327489,0.014041,0.985959,Got hired today!!!
1,459847518688665600,0.014151,0.985849,Just got hired at Google.
2,535195703203880961,0.014181,0.985819,Just got hired at Hobby Lobby!
3,331873004168552448,0.014229,0.985771,Got hired at Chick Fil A!
4,327479239966330881,0.014336,0.985664,Got hired today at the hooters on the beach! 😉...


In [12]:
#load model
PATH_MODEL_FOLDER = '/home/manuto/Downloads/best_model'
config = BertConfig.from_pretrained(PATH_MODEL_FOLDER, output_hidden_states=True, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained(PATH_MODEL_FOLDER)
model = BertModel.from_pretrained(PATH_MODEL_FOLDER, config=config)

We identify the token with most attention in each of the top 40 tweets. 

For each of these high-attention tokens, we identify 5 most similar tokens through masked language modeling (MLM).

In [20]:
mlm_pipeline_bert = pipeline('fill-mask', model='bert-base-uncased', tokenizer='bert-base-uncased',
                        config='bert-base-uncased', topk=5)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=230.0, style=ProgressStyle(description_…




In [21]:
mlm_pipeline_custom = pipeline('fill-mask', model=PATH_MODEL_FOLDER, tokenizer=PATH_MODEL_FOLDER,
                        config=PATH_MODEL_FOLDER, topk=5)

Model name '/home/manuto/Downloads/best_model' was not found in model name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base-japanese, bert-base-japanese-whole-word-masking, bert-base-japanese-char, bert-base-japanese-char-whole-word-masking, bert-base-finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased, openai-gpt, transfo-xl-wt103, gpt2, gpt2-medium, gpt2-large, gpt2-xl, distilgpt2, ctrl, xlnet-base-cased, xlnet-large-cased, xlm-mlm-en-2048, xlm-mlm-ende-1024, xlm-mlm-enfr-1024, xlm-mlm-enro-1024, xlm-mlm-tlm-xnli15-1024, xlm-mlm-xnli15-1024, xlm-clm-

#### MLM with Vanilla BERT

In [22]:
for tweet_index in range(40):
    print('***********Tweet #{}***********'.format(str(tweet_index)))
    input_sequence = is_hired_1mo_df['text'][tweet_index]
    tokenized_tweet = tokenizer.tokenize(input_sequence)
    print('Tweet: {}'.format(input_sequence))
    #identify high attention token
    attention_results_dict = get_token_in_sequence_with_most_attention(model, tokenizer, input_sequence)
    attention_token_str = attention_results_dict['token_str']
    attention_token_index = attention_results_dict['token_index']
    print('Token with most attention on average: {}'.format(attention_token_str))
    print('Token index with most attention on average: {}'.format(attention_token_index))
    #do MLM
    ## replace high-attention token by a [MASK] token
    tokenized_tweet[attention_token_index] = '[MASK]'
    mlm_results_list = mlm_pipeline_bert(' '.join(tokenized_tweet))
    print('MLM results for token "{}": '.format(attention_results_dict['token_str']))
    for i in range(len(mlm_results_list)):
        print(mlm_results_list[i])

***********Tweet #0***********
Tweet: Got hired today!!!
Attention scores dictionary:  {0: 1.5126816034317017, 1: 1.4664392471313477, 2: 0.8879139423370361, 3: 0.6898782253265381, 4: 0.6860100626945496, 5: 0.7570769786834717}
Token with most attention on average: Got
Token index with most attention on average: 0
MLM results for token "Got": 
{'sequence': '[CLS] " hired today!!! [SEP]', 'score': 0.19336189329624176, 'token': 1000}
{'sequence': '[CLS] get hired today!!! [SEP]', 'score': 0.1401471197605133, 'token': 2131}
{'sequence': '[CLS] was hired today!!! [SEP]', 'score': 0.08166869729757309, 'token': 2001}
{'sequence': '[CLS] i hired today!!! [SEP]', 'score': 0.07634694874286652, 'token': 1045}
{'sequence': '[CLS] got hired today!!! [SEP]', 'score': 0.06446801126003265, 'token': 2288}
***********Tweet #1***********
Tweet: Just got hired at Google.
Attention scores dictionary:  {0: 1.4466649293899536, 1: 0.9768638610839844, 2: 1.231852650642395, 3: 0.6410866975784302, 4: 0.8262073397

Attention scores dictionary:  {0: 1.905438780784607, 1: 1.1128921508789062, 2: 1.3489856719970703, 3: 0.7385013699531555, 4: 0.6312955617904663, 5: 1.176804542541504, 6: 0.9598628878593445, 7: 0.8739089369773865, 8: 0.8465639352798462, 9: 0.9671778678894043, 10: 0.552645742893219, 11: 0.8859224319458008}
Token with most attention on average: Just
Token index with most attention on average: 0
MLM results for token "Just": 
{'sequence': '[CLS] i got hired in the spot. # two # # job # # s [UNK] [SEP]', 'score': 0.541392982006073, 'token': 1045}
{'sequence': '[CLS] he got hired in the spot. # two # # job # # s [UNK] [SEP]', 'score': 0.19962696731090546, 'token': 2002}
{'sequence': '[CLS] and got hired in the spot. # two # # job # # s [UNK] [SEP]', 'score': 0.06618180125951767, 'token': 1998}
{'sequence': '[CLS] she got hired in the spot. # two # # job # # s [UNK] [SEP]', 'score': 0.03615875542163849, 'token': 2016}
{'sequence': '[CLS] you got hired in the spot. # two # # job # # s [UNK] [S

MLM results for token "Got": 
{'sequence': '[CLS] and hired at the gulf. # yes # # y # # es # # y # # es [SEP]', 'score': 0.08975149691104889, 'token': 1998}
{'sequence': '[CLS] " hired at the gulf. # yes # # y # # es # # y # # es [SEP]', 'score': 0.06754028797149658, 'token': 1000}
{'sequence': '[CLS] was hired at the gulf. # yes # # y # # es # # y # # es [SEP]', 'score': 0.059309616684913635, 'token': 2001}
{'sequence': '[CLS] i hired at the gulf. # yes # # y # # es # # y # # es [SEP]', 'score': 0.05118577554821968, 'token': 1045}
{'sequence': '[CLS] he hired at the gulf. # yes # # y # # es # # y # # es [SEP]', 'score': 0.042256467044353485, 'token': 2002}
***********Tweet #18***********
Tweet: Just got hired at Olive Garden ✔️
Attention scores dictionary:  {0: 1.7673919200897217, 1: 0.9163364171981812, 2: 1.2997117042541504, 3: 0.6286437511444092, 4: 0.6396015882492065, 5: 0.6670652627944946, 6: 1.081249713897705}
Token with most attention on average: Just
Token index with most atte

Attention scores dictionary:  {0: 2.066157341003418, 1: 0.9446656107902527, 2: 1.0775426626205444, 3: 1.1994777917861938, 4: 0.8793065547943115, 5: 1.0466992855072021, 6: 1.2806215286254883, 7: 0.8468508720397949, 8: 0.6319315433502197, 9: 1.178275227546692, 10: 0.8223013281822205, 11: 1.1083741188049316, 12: 0.6953713893890381, 13: 0.8366166949272156, 14: 1.001132845878601, 15: 0.6766977310180664, 16: 0.707977294921875}
Token with most attention on average: Got
Token index with most attention on average: 0
MLM results for token "Got": 
{'sequence': '[CLS] got a new job today and started on the spot! doing some real work!! [SEP]', 'score': 0.5299140214920044, 'token': 2288}
{'sequence': '[CLS] found a new job today and started on the spot! doing some real work!! [SEP]', 'score': 0.2625945508480072, 'token': 2179}
{'sequence': '[CLS] started a new job today and started on the spot! doing some real work!! [SEP]', 'score': 0.13163936138153076, 'token': 2318}
{'sequence': '[CLS] landed a n

Attention scores dictionary:  {0: 1.8963210582733154, 1: 1.3329483270645142, 2: 1.022910237312317, 3: 0.9498006701469421, 4: 0.7221141457557678, 5: 1.0568573474884033, 6: 1.1116896867752075, 7: 0.9277344942092896, 8: 0.9968269467353821, 9: 0.8822914958000183, 10: 0.7221673130989075, 11: 0.7866895198822021, 12: 0.8074938058853149, 13: 0.7841547727584839}
Token with most attention on average: Was
Token index with most attention on average: 0
MLM results for token "Was": 
{'sequence': '[CLS] get hired on to a new job today. hope it goes well. [SEP]', 'score': 0.5580984354019165, 'token': 2131}
{'sequence': '[CLS] got hired on to a new job today. hope it goes well. [SEP]', 'score': 0.2617955803871155, 'token': 2288}
{'sequence': '[CLS] getting hired on to a new job today. hope it goes well. [SEP]', 'score': 0.037400003522634506, 'token': 2893}
{'sequence': '[CLS] " hired on to a new job today. hope it goes well. [SEP]', 'score': 0.028088131919503212, 'token': 1000}
{'sequence': '[CLS] be h

Attention scores dictionary:  {0: 2.451568126678467, 1: 2.4593515396118164, 2: 1.011380910873413, 3: 0.8046233057975769, 4: 0.9342045187950134, 5: 1.0449730157852173, 6: 1.0444341897964478, 7: 1.1438586711883545, 8: 0.9918031096458435, 9: 0.742372989654541, 10: 0.7891494035720825, 11: 0.9352538585662842, 12: 0.8444757461547852, 13: 1.083548665046692, 14: 0.858732283115387, 15: 0.905475914478302, 16: 0.7715875506401062, 17: 1.296288251876831, 18: 0.804690957069397, 19: 0.6597017645835876, 20: 0.6982892751693726, 21: 1.0510209798812866, 22: 0.6705811619758606, 23: 0.7652428150177002, 24: 0.6990610957145691, 25: 0.8288096189498901, 26: 0.9379563927650452, 27: 1.0600669384002686, 28: 0.9595183730125427, 29: 1.1375473737716675, 30: 1.0491234064102173, 31: 0.5653067827224731}
Token with most attention on average: hired
Token index with most attention on average: 1
MLM results for token "hired": 
{'sequence': '[CLS] got it at best buy mobile last week & amp ; amp ; today i have an interview w

#### MLM with our fine-tuned ConvBERT model

In [23]:
for tweet_index in range(40):
    print('***********Tweet #{}***********'.format(str(tweet_index)))
    input_sequence = is_hired_1mo_df['text'][tweet_index]
    tokenized_tweet = tokenizer.tokenize(input_sequence)
    print('Tweet: {}'.format(input_sequence))
    #identify high attention token
    attention_results_dict = get_token_in_sequence_with_most_attention(model, tokenizer, input_sequence)
    attention_token_str = attention_results_dict['token_str']
    attention_token_index = attention_results_dict['token_index']
    print('Token with most attention on average: {}'.format(attention_token_str))
    print('Token index with most attention on average: {}'.format(attention_token_index))
    #do MLM
    ## replace high-attention token by a [MASK] token
    tokenized_tweet[attention_token_index] = '[MASK]'
    mlm_results_list = mlm_pipeline_custom(' '.join(tokenized_tweet))
    print('MLM results for token "{}": '.format(attention_results_dict['token_str']))
    for i in range(len(mlm_results_list)):
        print(mlm_results_list[i])

***********Tweet #0***********
Tweet: Got hired today!!!
Attention scores dictionary:  {0: 1.5126816034317017, 1: 1.4664392471313477, 2: 0.8879139423370361, 3: 0.6898782253265381, 4: 0.6860100626945496, 5: 0.7570769786834717}
Token with most attention on average: Got
Token index with most attention on average: 0
MLM results for token "Got": 
{'sequence': '[CLS]All hired today!!! [SEP]', 'score': 0.0008348205592483282, 'token': 27644}
{'sequence': '[CLS] needles hired today!!! [SEP]', 'score': 0.0006449052598327398, 'token': 20625}
{'sequence': '[CLS]blem hired today!!! [SEP]', 'score': 0.0006268700235523283, 'token': 24151}
{'sequence': '[CLS]row hired today!!! [SEP]', 'score': 0.0006059640436433256, 'token': 6607}
{'sequence': '[CLS]uno hired today!!! [SEP]', 'score': 0.0005076837260276079, 'token': 26761}
***********Tweet #1***********
Tweet: Just got hired at Google.
Attention scores dictionary:  {0: 1.4466649293899536, 1: 0.9768638610839844, 2: 1.231852650642395, 3: 0.6410866975784

Attention scores dictionary:  {0: 1.905438780784607, 1: 1.1128921508789062, 2: 1.3489856719970703, 3: 0.7385013699531555, 4: 0.6312955617904663, 5: 1.176804542541504, 6: 0.9598628878593445, 7: 0.8739089369773865, 8: 0.8465639352798462, 9: 0.9671778678894043, 10: 0.552645742893219, 11: 0.8859224319458008}
Token with most attention on average: Just
Token index with most attention on average: 0
MLM results for token "Just": 
{'sequence': '[CLS]All got hired in the spot. # Two # # Job # # s [UNK] [SEP]', 'score': 0.0008871526224538684, 'token': 27644}
{'sequence': '[CLS] needles got hired in the spot. # Two # # Job # # s [UNK] [SEP]', 'score': 0.000580411811824888, 'token': 20625}
{'sequence': '[CLS]blem got hired in the spot. # Two # # Job # # s [UNK] [SEP]', 'score': 0.0005397737841121852, 'token': 24151}
{'sequence': '[CLS]hari got hired in the spot. # Two # # Job # # s [UNK] [SEP]', 'score': 0.0005042810225859284, 'token': 16234}
{'sequence': '[CLS]row got hired in the spot. # Two # # 

MLM results for token "Got": 
{'sequence': '[CLS]All hired at the Gulf. # Yes # # Y # # es # # Y # # es [SEP]', 'score': 0.0006648480775766075, 'token': 27644}
{'sequence': '[CLS] needles hired at the Gulf. # Yes # # Y # # es # # Y # # es [SEP]', 'score': 0.0006135276053100824, 'token': 20625}
{'sequence': '[CLS]hari hired at the Gulf. # Yes # # Y # # es # # Y # # es [SEP]', 'score': 0.0005542254075407982, 'token': 16234}
{'sequence': '[CLS]row hired at the Gulf. # Yes # # Y # # es # # Y # # es [SEP]', 'score': 0.0005023052799515426, 'token': 6607}
{'sequence': '[CLS]blem hired at the Gulf. # Yes # # Y # # es # # Y # # es [SEP]', 'score': 0.00048550040810368955, 'token': 24151}
***********Tweet #18***********
Tweet: Just got hired at Olive Garden ✔️
Attention scores dictionary:  {0: 1.7673919200897217, 1: 0.9163364171981812, 2: 1.2997117042541504, 3: 0.6286437511444092, 4: 0.6396015882492065, 5: 0.6670652627944946, 6: 1.081249713897705}
Token with most attention on average: Just
Token 

MLM results for token "Yes": 
{'sequence': '[CLS]All sir!!! Just got hired for wings # # top!!! God is good! [SEP]', 'score': 0.0007019881159067154, 'token': 27644}
{'sequence': '[CLS] needles sir!!! Just got hired for wings # # top!!! God is good! [SEP]', 'score': 0.000612558564171195, 'token': 20625}
{'sequence': '[CLS]SL sir!!! Just got hired for wings # # top!!! God is good! [SEP]', 'score': 0.0005858248914591968, 'token': 13726}
{'sequence': '[CLS]row sir!!! Just got hired for wings # # top!!! God is good! [SEP]', 'score': 0.0005164553294889629, 'token': 6607}
{'sequence': '[CLS]rid sir!!! Just got hired for wings # # top!!! God is good! [SEP]', 'score': 0.0004862633941229433, 'token': 10132}
***********Tweet #25***********
Tweet: Got a new job today and started on the spot!  Doing some real work!!
Attention scores dictionary:  {0: 2.066157341003418, 1: 0.9446656107902527, 2: 1.0775426626205444, 3: 1.1994777917861938, 4: 0.8793065547943115, 5: 1.0466992855072021, 6: 1.280621528625

Attention scores dictionary:  {0: 1.8963210582733154, 1: 1.3329483270645142, 2: 1.022910237312317, 3: 0.9498006701469421, 4: 0.7221141457557678, 5: 1.0568573474884033, 6: 1.1116896867752075, 7: 0.9277344942092896, 8: 0.9968269467353821, 9: 0.8822914958000183, 10: 0.7221673130989075, 11: 0.7866895198822021, 12: 0.8074938058853149, 13: 0.7841547727584839}
Token with most attention on average: Was
Token index with most attention on average: 0
MLM results for token "Was": 
{'sequence': '[CLS]All hired on to a new job today. Hope it goes well. [SEP]', 'score': 0.0008517606183886528, 'token': 27644}
{'sequence': '[CLS]hari hired on to a new job today. Hope it goes well. [SEP]', 'score': 0.000688530330080539, 'token': 16234}
{'sequence': '[CLS] needles hired on to a new job today. Hope it goes well. [SEP]', 'score': 0.0005947520257905126, 'token': 20625}
{'sequence': '[CLS]blem hired on to a new job today. Hope it goes well. [SEP]', 'score': 0.0005583846941590309, 'token': 24151}
{'sequence':

Attention scores dictionary:  {0: 2.451568126678467, 1: 2.4593515396118164, 2: 1.011380910873413, 3: 0.8046233057975769, 4: 0.9342045187950134, 5: 1.0449730157852173, 6: 1.0444341897964478, 7: 1.1438586711883545, 8: 0.9918031096458435, 9: 0.742372989654541, 10: 0.7891494035720825, 11: 0.9352538585662842, 12: 0.8444757461547852, 13: 1.083548665046692, 14: 0.858732283115387, 15: 0.905475914478302, 16: 0.7715875506401062, 17: 1.296288251876831, 18: 0.804690957069397, 19: 0.6597017645835876, 20: 0.6982892751693726, 21: 1.0510209798812866, 22: 0.6705811619758606, 23: 0.7652428150177002, 24: 0.6990610957145691, 25: 0.8288096189498901, 26: 0.9379563927650452, 27: 1.0600669384002686, 28: 0.9595183730125427, 29: 1.1375473737716675, 30: 1.0491234064102173, 31: 0.5653067827224731}
Token with most attention on average: hired
Token index with most attention on average: 1
MLM results for token "hired": 
{'sequence': '[CLS] Gothaving at best buy mobile last week & amp ; amp ; today I have an intervie