In [366]:
import torch
from transformers import *
import numpy as np
import string

from num2words import num2words

from tqdm.notebook import tqdm

In [2]:
model_name = 'roberta-large'

model = RobertaForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1425941629.0, style=ProgressStyle(descr…




RobertaForMaskedLM(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 1024, padding_idx=1)
      (position_embeddings): Embedding(514, 1024, padding_idx=1)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((

In [3]:
file_name = '../data/en_ewt-ud-train.conllu'

with open(file_name, 'r') as f:
    lines = f.readlines()

In [4]:
len(lines)

245245

In [5]:
lines[4].split()

['1',
 'Al',
 'Al',
 'PROPN',
 'NNP',
 'Number=Sing',
 '0',
 'root',
 '0:root',
 'SpaceAfter=No']

In [270]:
number_inflections = {}
tense_inflection = {}
all_words = {}

for line in tqdm(lines):
    if line.strip() == '':
        continue
    if line.startswith('#'):
        continue
#     print(line)
    parts = line.split()
    
    all_words[parts[2]] = parts[2]
    
    lemma = parts[2]
    if parts[3] == 'NOUN':
        morph = parts[5]
        morph_parts = morph.split('|')
        for morph_val in morph_parts:
            if morph_val.startswith('Number=Plur'):
                number_inflections[lemma] = parts[1]
    elif parts[3] == 'VERB':
        morphs = parts[5]
#         print(morphs)
        morphs_parts = morphs.split('|')
        for morph_val in morphs_parts:
#             print(morph_val)
            if morph_val.startswith('Tense=Past'):
                tense_inflection[lemma] = parts[1]

HBox(children=(FloatProgress(value=0.0, max=245245.0), HTML(value='')))




In [242]:
len(tense_inflection)

1154

## Word Prediction By Transformers

In [283]:
def sentences2ids(sentences):
    tokenized_sentences = []
    
    for sentence in sentences:
        prefix, suffix = sentence.split("[MASK]")
        prefix_tokens = tokenizer.tokenize(prefix)
        suffix_tokens = tokenizer.tokenize(suffix)
        tokens = [tokenizer.cls_token] + prefix_tokens + [tokenizer.mask_token] + suffix_tokens + [tokenizer.sep_token]
        input_ids = torch.tensor([tokenizer.convert_tokens_to_ids(tokens)])
        target_idx = len(prefix_tokens) + 1
        tokenized_sentences.append(input_ids)
    
    return torch.cat(tokenized_sentences), target_idx

def get_predictions(sentences, k=10):
    
    tokenized_sentences, target_idx = sentences2ids(sentences)
#     print(tokenized_sentences)
    prediction_scores = model(tokenized_sentences)[0]
#     print(prediction_scores.shape)
    token_scores = prediction_scores[:, target_idx].detach().cpu().numpy()
    best_k = (np.argsort(token_scores, axis=1))[:, -k:]
#     print(best_k)
    sentences_best_k = []
    for top_per_sentence in best_k:
        best_k_tokens = tokenizer.convert_ids_to_tokens(top_per_sentence)
        best_k_tokens = [x.replace('Ġ', '') for x in best_k_tokens]
        sentences_best_k.append(best_k_tokens[::-1])
    return sentences_best_k

In [176]:
query1 = 'Conjugate the word "face" to plural form: [MASK].'
query2 = 'Conjugate the word "kid" to plural form: [MASK].'

## Data Prep

### Filter Data to Single Tokens

In [324]:
def filter_vals(in_dic):

    filter_inflections = {}

    for k, v in tqdm(in_dic.items()):
        tok_v = tokenizer.tokenize(v)
        tok_k = tokenizer.tokenize(k)

        if len(k) == 1:
            continue
        if len(tok_v) != 1 or len(tok_k) != 1:
            continue
        
        if not all([x in (string.ascii_lowercase + string.ascii_uppercase) for x in k]):
#             print(k)
            continue

        filter_inflections[k] = v
        
    return filter_inflections

In [252]:
filter_inflections = filter_vals(number_inflections)

HBox(children=(FloatProgress(value=0.0, max=2100.0), HTML(value='')))




In [211]:
tokenizer.tokenize('"$"')

['"', '$', '"']

In [216]:
len(filter_inflections)

246

In [147]:
len(inflections)

2100

## Conjugate to plural

In [226]:
singular_plural_query = 'Conjugate the word "{}" to plural form: [MASK].'

In [378]:
def eval_query(vals_dic, query, debug=False, bs=50, ignore_special_tokens=False):
    acc = 0.
    total_vals = len(vals_dic)

    batch_queries = []
    batch_answers = []
    for lemma, plural in tqdm(vals_dic.items()):
        if len(batch_queries) < bs:
            batch_queries.append(query.format(lemma))
            batch_answers.append(plural)
            continue

        if debug:
            print(batch_queries)
        try:
            top_k_per_ex = get_predictions(batch_queries)
        except RuntimeError:
            total_vals -= bs
            batch_queries = []
            batch_answers = []
            print('issue')
            continue

        for top_k, y in zip(top_k_per_ex, batch_answers):
            if top_k[0] == y:
                acc += 1
#             else:
#                 print(y, top_k[:2])
            if top_k[0] in tokenizer.special_tokens_map.values():
                i = 1
                while True:
                    if top_k[i] in tokenizer.special_tokens_map.values():
                        continue
                    else:
                        if top_k[i] == y:
                            acc += 1
                        else:
                            break
                    i += 1
        batch_queries = []
        batch_answers = []

    if len(batch_queries) != 0:
        total_vals -= len(batch_queries)
        
#         top_k_per_ex = get_predictions(batch_queries)
#         for top_k, y in zip(top_k_per_ex, batch_answers):
#             if top_k[0] == y:
#                 acc += 1
    
    return acc / total_vals

In [225]:
eval_query(filter_inflections, singular_plural_query)

HBox(children=(FloatProgress(value=0.0, max=246.0), HTML(value='')))




0.5040650406504065

In [228]:
plural_singular_dict = {v: k for k, v in filter_inflections.items()}
plural_singular_query = 'Conjugate the word "{}" to singular form: [MASK].'

In [229]:
eval_query(plural_singular_dict, plural_singular_query)

HBox(children=(FloatProgress(value=0.0, max=245.0), HTML(value='')))




0.05714285714285714

In [253]:
filter_tense_inflections = filter_vals(tense_inflection)

HBox(children=(FloatProgress(value=0.0, max=1154.0), HTML(value='')))




In [254]:
len(filter_tense_inflections)

172

In [263]:
tense_query = 'Conjugate the word "{}" to past tense: [MASK].'
eval_query(filter_tense_inflections, tense_query)

HBox(children=(FloatProgress(value=0.0, max=172.0), HTML(value='')))




0.12209302325581395

In [260]:
past_present_dict = {v: k for k, v in filter_tense_inflections.items()}
plural_singular_query = 'Conjugate the word "{}" to present form: [MASK].'

eval_query(past_present_dict, plural_singular_query)

HBox(children=(FloatProgress(value=0.0, max=171.0), HTML(value='')))




0.09941520467836257

## Sentiment

In [None]:
'Flip the sentiment of the word "joy" to the negative sense: [MASK].'
# need to find a dataset

## Capitalization

In [273]:
filter_all_words

{'Al': 'Al',
 'force': 'force',
 'kill': 'kill',
 'al': 'al',
 'the': 'the',
 'at': 'at',
 'in': 'in',
 'town': 'town',
 'of': 'of',
 'near': 'near',
 'border': 'border',
 'this': 'this',
 'killing': 'killing',
 'respected': 'respected',
 'will': 'will',
 'be': 'be',
 'cause': 'cause',
 'we': 'we',
 'for': 'for',
 'year': 'year',
 'to': 'to',
 'come': 'come',
 'that': 'that',
 'they': 'they',
 'have': 'have',
 'up': 'up',
 'terrorist': 'terrorist',
 'cell': 'cell',
 'two': 'two',
 'run': 'run',
 'by': 'by',
 'official': 'official',
 'Iraq': 'Iraq',
 'US': 'US',
 'FBI': 'FBI',
 'so': 'so',
 'would': 'would',
 'like': 'like',
 'employ': 'employ',
 'high': 'high',
 'level': 'level',
 'member': 'member',
 'back': 'back',
 '1960': '1960',
 'third': 'third',
 'head': 'head',
 'you': 'you',
 'if': 'if',
 'he': 'he',
 'market': 'market',
 'with': 'with',
 'target': 'target',
 'and': 'and',
 'capital': 'capital',
 'although': 'although',
 'probably': 'probably',
 'make': 'make',
 'show': 'show'

In [295]:
capitalized_words_dic = {}
for k in all_words.keys():
    capitalized_words_dic[k.lower()] = k[0].upper() + k[1:]

In [325]:
filter_capitalized_words = filter_vals(capitalized_words_dic)

HBox(children=(FloatProgress(value=0.0, max=13632.0), HTML(value='')))




In [365]:
capitalize_query = 'Capitalize the word " {} ": [MASK].'
eval_query(filter_capitalized_words, capitalize_query, debug=False, bs=10, ignore_special_tokens=False)

HBox(children=(FloatProgress(value=0.0, max=1582.0), HTML(value='')))




0.20120378331900257

## Numerics

In [371]:
numerics_dict = {}

for i in range(101):
    numerics_dict[num2words(i)] = str(i)

In [370]:
numerics_dict

{'zero': 0,
 'one': 1,
 'two': 2,
 'three': 3,
 'four': 4,
 'five': 5,
 'six': 6,
 'seven': 7,
 'eight': 8,
 'nine': 9,
 'ten': 10,
 'eleven': 11,
 'twelve': 12,
 'thirteen': 13,
 'fourteen': 14,
 'fifteen': 15,
 'sixteen': 16,
 'seventeen': 17,
 'eighteen': 18,
 'nineteen': 19,
 'twenty': 20,
 'twenty-one': 21,
 'twenty-two': 22,
 'twenty-three': 23,
 'twenty-four': 24,
 'twenty-five': 25,
 'twenty-six': 26,
 'twenty-seven': 27,
 'twenty-eight': 28,
 'twenty-nine': 29,
 'thirty': 30,
 'thirty-one': 31,
 'thirty-two': 32,
 'thirty-three': 33,
 'thirty-four': 34,
 'thirty-five': 35,
 'thirty-six': 36,
 'thirty-seven': 37,
 'thirty-eight': 38,
 'thirty-nine': 39,
 'forty': 40,
 'forty-one': 41,
 'forty-two': 42,
 'forty-three': 43,
 'forty-four': 44,
 'forty-five': 45,
 'forty-six': 46,
 'forty-seven': 47,
 'forty-eight': 48,
 'forty-nine': 49,
 'fifty': 50,
 'fifty-one': 51,
 'fifty-two': 52,
 'fifty-three': 53,
 'fifty-four': 54,
 'fifty-five': 55,
 'fifty-six': 56,
 'fifty-seven': 57,

In [380]:
numeric_query = 'The numeric version of "{}" is: [MASK].'
eval_query(numerics_dict, numeric_query, debug=False, bs=1, ignore_special_tokens=False)

HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))




0.41

## Truncate

In [387]:
truncate_words_dic = {}
for k in all_words.keys():
    truncate_words_dic[k.lower()] = k[0].lower()
    
#     if len(truncate_words_dic) > 200:
#         break

In [388]:
truncate_query = 'Truncate the word " {} " into a the first letter: [MASK].'
eval_query(truncate_words_dic, truncate_query, debug=False, bs=1, ignore_special_tokens=False)

HBox(children=(FloatProgress(value=0.0, max=13632.0), HTML(value='')))




0.1580105633802817

In [390]:
truncate_query = 'Truncate the word " {} " into a the first letter: [MASK].'
eval_query(truncate_words_dic, truncate_query, debug=False, bs=1, ignore_special_tokens=True)

HBox(children=(FloatProgress(value=0.0, max=13632.0), HTML(value='')))




0.1580105633802817

Extracting second letter didn't succeed. maybe with fine tunings?


## A word that starts with a letter

In [None]:
'A word which Starts with " l " is : [MASK].'
'Generate a word that starts with the letter " l " : [MASK].'