In [1]:
from transformers import BertModel, BertConfig
from transformers import BertForMaskedLM
from transformers import BertTokenizer
from transformers import AdamW
from transformers import Trainer, TrainingArguments

import torch
import numpy as np
import spacy
import re

from datetime import datetime

In [2]:
# pos_list = ['ADJ', 'ADP', 'ADV', 'AUX', 'CCONJ', 'DET', 'INTJ', 'NOUN', 'NUM', 'PART', 'PRON', 'PROPN', 'PUNCT', 'SCONJ', 'SYM', 'VERB', 'X']
# ner_list = ['PERSON', 'NORP', 'FAC', 'ORG', 'GPE', 'LOC', 'PRODUCT', 'EVENT', 'WORK_OF_ART', 'LAW', 'LANGUAGE', 'DATE', 'TIME', 'PERCENT', 'MONEY', 'QUANTITY', 'ORDINAL', 'CARDINAL']

In [3]:
sentences = ["When on board H.M.S. 'Beagle,' as naturalist, I was much struck with certain facts in the distribution of the inhabitants of South America, and in the geological relations of the present to the past inhabitants of that continent.",
 'These facts seemed to me to throw some light on the origin of species--that mystery of mysteries, as it has been called by one of our greatest philosophers.',
 'On my return home, it occurred to me, in 1837, that something might perhaps be made out on this question by patiently accumulating and reflecting on all sorts of facts which could possibly have any bearing on it.',
 "After five years' work I allowed myself to speculate on the subject, and drew up some short notes; these I enlarged in 1844 into a sketch of the conclusions, which then seemed to me probable: from that period to the present day I have steadily pursued the same object.",
 'I hope that I may be excused for entering on these personal details, as I give them to show that I have not been hasty in coming to a decision.']

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
nlp = spacy.load("en_core_web_sm")

In [5]:
tokenizer.model_max_length

512

In [12]:
def whole_word_MO_tokenization_and_masking(tokenizer, nlp_model, sequence: str):
        """
        posoi: Part-Of-Speech of interest
        
        Performs whole-word-masking based on selected posoi.
        
        POS possibilities:['ADJ', 'ADP', 'ADV', 'AUX', 'CCONJ', 'DET', 'INTJ', 
                            'NOUN', 'NUM', 'PART', 'PRON', 'PROPN', 'PUNCT', 'SCONJ', 'SYM', 'VERB', 'X']
                             
        TODO: What if no tokens are masked?
        
        """
        print('loading:', datetime.now().time())
        spacy_sentence = nlp_model(sequence, disable=["parser"])
        
        POS_list = ['ADJ', 'ADP', 'ADV', 'AUX', 'CCONJ', 'DET', 'INTJ', 
                            'NOUN', 'NUM', 'PART', 'PRON', 'PROPN', 'PUNCT', 'SCONJ', 'SYM', 'VERB', 'X']
        NER_list = ['PERSON', 'NORP', 'FAC', 'ORG', 'GPE', 'LOC', 'PRODUCT', 'EVENT', 'WORK_OF_ART', 
                    'LAW', 'LANGUAGE', 'DATE', 'TIME', 'PERCENT', 'MONEY', 'QUANTITY', 'ORDINAL', 'CARDINAL']
        NER_pairs = ['']
        
        input_ids = tokenizer.encode(sequence, add_special_tokens=False)
        input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
        print(sequence)
        print(input_tokens)
        sequence_pos_list = [token.pos_ for token in spacy_sentence]
        sequence_pos_frequency = {pos: sequence_pos_list.count(pos) for pos in sequence_pos_list}
        
        modified_input_list = []
        
        #POS-masking
        print('pos-start:', datetime.now().time())
        for posoi in sequence_pos_frequency.keys():
            posoi_vocab = [token.text.lower() for token in spacy_sentence if token.pos_ == posoi]
            
            mask_indices = []
            composite_word_indices = []
            composite_word_tokens = []
            for (i, token) in enumerate(input_tokens):
                if token == "[CLS]" or token == "[SEP]":
                    continue
                elif token.startswith("##"):
                    composite_word_indices.append(i)
                    composite_word_tokens.append(token)
                    print("".join([x.strip("##") for x in composite_word_tokens]))
                    if "".join([x.strip("##") for x in composite_word_tokens]) in posoi_vocab:
                        mask_indices = mask_indices + composite_word_indices

                elif token in posoi_vocab:
                    mask_indices.append(i)
                else:
                    composite_word_indices = [i]
                    composite_word_tokens = [token]

            mask_labels = [1 if i in mask_indices else 0 for i in range(len(input_tokens))]
            masked_tokens = [x if mask_labels[i] == 0 else 103 for i,x in enumerate(input_ids)]
            masked_input = tokenizer.decode(masked_tokens)         
            modified_input_list.append(masked_input)

        #POS-based lemmatization
        replacement_tuples = [(token.text, token.lemma_) for token in spacy_sentence if token.text.lower() != token.lemma_]
        #print(replacement_tuples)
        pos_replaced_sentence = sequence
        for replacement in replacement_tuples:
            pos_replaced_sentence = re.sub(r'\b' + replacement[0] + r'\b', replacement[1], pos_replaced_sentence)

        pos_replaced_sentence = pos_replaced_sentence.replace("  ", " ")
        print('Lemma', pos_replaced_sentence)
        modified_input_list.append(pos_replaced_sentence)
        
        #NER-based swapping of time-place (if present)
        print('ner-start:', datetime.now().time())
        ner_swapped_sentence = spacy_sentence.text
        for ent in spacy_sentence.ents:
            if ent.label_ == 'TIME':
                time_substring = ner_swapped_sentence[ent.start_char:ent.end_char].split(" ")
                time_substring.reverse()
                ner_swapped_sentence = ner_swapped_sentence.replace(ner_swapped_sentence[ent.start_char:ent.end_char], " ".join(time_substring))
        print('NER', ner_swapped_sentence)
        modified_input_list.append(ner_swapped_sentence)
        
        
        #TODO future ideas
        #
        #
        
    
        #actually tokenize input
        inputs = tokenizer(modified_input_list, return_tensors="pt", padding=True)

        inputs['labels'] = tokenizer([sequence for i in range(0,inputs['input_ids'].shape[0])], 
                                     return_attention_mask=False, 
                                     return_token_type_ids=False,
                                     return_tensors='pt', padding=True)['input_ids']
        
        return inputs

In [13]:
print(datetime.now().time())
test_sentence = "Anne went to the Albert Heijn at 5 o'clock to buy some milk for me."
example_sentence_inputs = whole_word_MO_tokenization_and_masking(tokenizer=tokenizer, nlp_model=nlp, sequence=test_sentence)
print(datetime.now().time())

12:49:09.696654
loading: 12:49:09.696654
Anne went to the Albert Heijn at 5 o'clock to buy some milk for me.
['anne', 'went', 'to', 'the', 'albert', 'he', '##ij', '##n', 'at', '5', 'o', "'", 'clock', 'to', 'buy', 'some', 'milk', 'for', 'me', '.']
pos-start: 12:49:09.704615
['anne', 'albert', 'heijn']
heij
heijn
PROPN [MASK] went to the [MASK] [MASK] [MASK] [MASK] at 5 o'clock to buy some milk for me.
['went', 'buy']
heij
heijn
VERB anne [MASK] to the albert heijn at 5 o'clock to [MASK] some milk for me.
['to', 'at', 'for']
heij
heijn
ADP anne went [MASK] the albert heijn [MASK] 5 o'clock [MASK] buy some milk [MASK] me.
['the', 'some']
heij
heijn
DET anne went to [MASK] albert heijn at 5 o'clock to buy [MASK] milk for me.
['5']
heij
heijn
NUM anne went to the albert heijn at [MASK] o'clock to buy some milk for me.
["o'clock", 'milk']
heij
heijn
NOUN anne went to the albert heijn at 5 o'clock to buy some [MASK] for me.
['to']
heij
heijn
PART anne went [MASK] the albert heijn at 5 o'clock

In [11]:
spacy_sentence = nlp("Anne went to the Albert Heijn at 5 o'clock to buy some milk for me.")

for token in spacy_sentence:
    print(token.text, token.pos_)


Anne PROPN
went VERB
to ADP
the DET
Albert PROPN
Heijn PROPN
at ADP
5 NUM
o'clock NOUN
to PART
buy VERB
some DET
milk NOUN
for ADP
me PRON
. PUNCT


In [8]:
example_sentence_inputs

{'input_ids': tensor([[  101,  4776,  2253,  2000,  1996,  4789,  2002, 28418,  2078,  2012,
          1019,  1051,  1005,  5119,  2000,  4965,  2070,  6501,  2005,  2033,
          1012,   102],
        [  101,  4776,   103,  2000,  1996,  4789,  2002, 28418,  2078,  2012,
          1019,  1051,  1005,  5119,  2000,   103,  2070,  6501,  2005,  2033,
          1012,   102],
        [  101,  4776,  2253,   103,  1996,  4789,  2002, 28418,  2078,   103,
          1019,  1051,  1005,  5119,   103,  4965,  2070,  6501,   103,  2033,
          1012,   102],
        [  101,  4776,  2253,  2000,   103,  4789,  2002, 28418,  2078,  2012,
          1019,  1051,  1005,  5119,  2000,  4965,   103,  6501,  2005,  2033,
          1012,   102],
        [  101,  4776,  2253,  2000,  1996,  4789,  2002, 28418,  2078,  2012,
           103,  1051,  1005,  5119,  2000,  4965,  2070,  6501,  2005,  2033,
          1012,   102],
        [  101,  4776,  2253,  2000,  1996,  4789,  2002, 28418,  2078,  201

In [9]:
example_sentence_inputs['input_ids']

tensor([[  101,  4776,  2253,  2000,  1996,  4789,  2002, 28418,  2078,  2012,
          1019,  1051,  1005,  5119,  2000,  4965,  2070,  6501,  2005,  2033,
          1012,   102],
        [  101,  4776,   103,  2000,  1996,  4789,  2002, 28418,  2078,  2012,
          1019,  1051,  1005,  5119,  2000,   103,  2070,  6501,  2005,  2033,
          1012,   102],
        [  101,  4776,  2253,   103,  1996,  4789,  2002, 28418,  2078,   103,
          1019,  1051,  1005,  5119,   103,  4965,  2070,  6501,   103,  2033,
          1012,   102],
        [  101,  4776,  2253,  2000,   103,  4789,  2002, 28418,  2078,  2012,
          1019,  1051,  1005,  5119,  2000,  4965,   103,  6501,  2005,  2033,
          1012,   102],
        [  101,  4776,  2253,  2000,  1996,  4789,  2002, 28418,  2078,  2012,
           103,  1051,  1005,  5119,  2000,  4965,  2070,  6501,  2005,  2033,
          1012,   102],
        [  101,  4776,  2253,  2000,  1996,  4789,  2002, 28418,  2078,  2012,
          1

In [10]:
text= 'On their very first meeting, Gilbert had not been pleasantly impressed with Hardy. But he soon saw that the man had a certain rugged strength, and there was no doubt he had suffered from the depredations of Mexico\'s casual visitors, and was ready to protect not only his own interests but those of any newcomers. He seemed to have the spirit of fair-mindedness; and he believed firmly in the possibilities of this magic land, particularly for young men. "It\'s God\'s country," he told Gilbert on more than one occasion. "Get into the soil all you can. Dig--and dig deep."'

In [11]:
whole_word_MO_tokenization_and_masking(tokenizer=tokenizer, nlp_model=nlp, sequence=text)

loading: 18:18:05.687520
On their very first meeting, Gilbert had not been pleasantly impressed with Hardy. But he soon saw that the man had a certain rugged strength, and there was no doubt he had suffered from the depredations of Mexico's casual visitors, and was ready to protect not only his own interests but those of any newcomers. He seemed to have the spirit of fair-mindedness; and he believed firmly in the possibilities of this magic land, particularly for young men. "It's God's country," he told Gilbert on more than one occasion. "Get into the soil all you can. Dig--and dig deep."
['on', 'their', 'very', 'first', 'meeting', ',', 'gilbert', 'had', 'not', 'been', 'pleasantly', 'impressed', 'with', 'hardy', '.', 'but', 'he', 'soon', 'saw', 'that', 'the', 'man', 'had', 'a', 'certain', 'rugged', 'strength', ',', 'and', 'there', 'was', 'no', 'doubt', 'he', 'had', 'suffered', 'from', 'the', 'de', '##pre', '##dation', '##s', 'of', 'mexico', "'", 's', 'casual', 'visitors', ',', 'and', '

{'input_ids': tensor([[ 101,  103, 2037,  ..., 1012, 1000,  102],
        [ 101, 2006,  103,  ..., 1012, 1000,  102],
        [ 101, 2006, 2037,  ..., 1012, 1000,  102],
        ...,
        [ 101, 2006, 2037,  ..., 1012, 1000,  102],
        [ 101, 2006, 2037,  ...,    0,    0,    0],
        [ 101, 2006, 2037,  ..., 1012, 1000,  102]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1]]), 'labels': tensor([[ 101, 2006, 2037,  ..., 1012, 1000,  102],
        [ 101, 2006, 2037,  ..., 1012, 1000,  102],
        [ 101, 2006, 2037,  ..., 1012, 1000,  102],
        ...,
      

In [12]:
text2 = '"Sturgis telegraphed me that there was a big possibility of a new vein of oil down on the border," Pell was telling her. "Some important men want to talk things over with me at Bisbee.I want to get started in a day or two.Don\'t take your maid.It\'s a rough country, but you\'ll be all right.Just old clothes.You can ride a lot, so bring your habit.I\'ll be busy most of the time; but I think you\'ll like the trip.Never been down that way, have you?"'
whole_word_MO_tokenization_and_masking(tokenizer=tokenizer, nlp_model=nlp, sequence=text2)

loading: 18:18:05.795231
"Sturgis telegraphed me that there was a big possibility of a new vein of oil down on the border," Pell was telling her. "Some important men want to talk things over with me at Bisbee.I want to get started in a day or two.Don't take your maid.It's a rough country, but you'll be all right.Just old clothes.You can ride a lot, so bring your habit.I'll be busy most of the time; but I think you'll like the trip.Never been down that way, have you?"
['"', 'stu', '##rg', '##is', 'telegraph', '##ed', 'me', 'that', 'there', 'was', 'a', 'big', 'possibility', 'of', 'a', 'new', 'vein', 'of', 'oil', 'down', 'on', 'the', 'border', ',', '"', 'pe', '##ll', 'was', 'telling', 'her', '.', '"', 'some', 'important', 'men', 'want', 'to', 'talk', 'things', 'over', 'with', 'me', 'at', 'bis', '##bee', '.', 'i', 'want', 'to', 'get', 'started', 'in', 'a', 'day', 'or', 'two', '.', 'don', "'", 't', 'take', 'your', 'maid', '.', 'it', "'", 's', 'a', 'rough', 'country', ',', 'but', 'you', "'",

{'input_ids': tensor([[  101,   103, 24646,  ...,   103,   103,   102],
        [  101,  1000, 24646,  ...,  1029,  1000,   102],
        [  101,  1000, 24646,  ...,  1029,  1000,   102],
        ...,
        [  101,  1000, 24646,  ...,  1029,  1000,   102],
        [  101,  1000, 24646,  ...,  1000,   102,     0],
        [  101,  1000, 24646,  ...,  1029,  1000,   102]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 0],
        [1, 1, 1,  ..., 1, 1, 1]]), 'labels': tensor([[  101,  1000, 24646,  ...,  1029,  1000,   102],
        [  101,  1000, 24646,  ...,  1029,  1000,   102],
        [  101,  1000, 

In [13]:
bert_tiny_config = {"hidden_size": 128, 
                    "hidden_act": "gelu", 
                    "initializer_range": 0.02, 
                    "vocab_size": 30522, 
                    "hidden_dropout_prob": 0.1, 
                    "num_attention_heads": 2, 
                    "type_vocab_size": 2, 
                    "max_position_embeddings": 512, 
                    "num_hidden_layers": 2, 
                    "intermediate_size": 512, 
                    "attention_probs_dropout_prob": 0.1}


model = BertForMaskedLM(config=BertConfig(**bert_tiny_config))
model.train()
optimizer = AdamW(model.parameters(), lr=1e-5)

In [14]:
outputs = model(**example_sentence_inputs, return_dict=True)

In [15]:
outputs

MaskedLMOutput(loss=tensor(10.3635, grad_fn=<NllLossBackward>), logits=tensor([[[ 0.4080,  0.2334, -0.1340,  ..., -0.0911,  0.3205, -0.1181],
         [ 0.1753,  0.1798, -0.0103,  ...,  0.0868,  0.2404, -0.0679],
         [ 0.3360, -0.2163,  0.2287,  ...,  0.1449,  0.2034, -0.3322],
         ...,
         [ 0.2280, -0.2229, -0.1290,  ..., -0.0234, -0.0660, -0.0227],
         [ 0.5470, -0.1972,  0.1615,  ...,  0.0868, -0.0703,  0.2246],
         [ 0.3038,  0.0603, -0.1612,  ..., -0.1186,  0.2610,  0.3539]],

        [[ 0.3188,  0.1850, -0.0963,  ..., -0.0993,  0.2571, -0.3354],
         [ 0.2630,  0.1995,  0.1957,  ...,  0.0708,  0.2714,  0.0073],
         [ 0.2004, -0.1573,  0.3025,  ...,  0.1859,  0.2579, -0.1575],
         ...,
         [ 0.3260, -0.3047, -0.0999,  ...,  0.0763,  0.0275, -0.0704],
         [ 0.5253, -0.3192,  0.1256,  ...,  0.0444, -0.0655,  0.4057],
         [ 0.3130, -0.0074, -0.1666,  ...,  0.0190,  0.3194,  0.3130]],

        [[ 0.3328,  0.2426, -0.1400,  ..., -0

In [16]:
loss = outputs.loss

In [17]:
loss.backward()

In [18]:
optimizer.step()

In [19]:
model(**example_sentence_inputs, return_dict=True)

MaskedLMOutput(loss=tensor(10.3671, grad_fn=<NllLossBackward>), logits=tensor([[[ 0.2816,  0.2122, -0.0551,  ..., -0.2080,  0.1901, -0.2705],
         [-0.0213,  0.0675,  0.1182,  ...,  0.0587,  0.4813,  0.0793],
         [ 0.3258, -0.1267,  0.1164,  ..., -0.1771,  0.3251, -0.2187],
         ...,
         [ 0.2462, -0.1275, -0.2061,  ...,  0.0215,  0.0755, -0.1286],
         [ 0.5200, -0.2259,  0.2210,  ..., -0.0152, -0.0605,  0.2946],
         [ 0.3501, -0.0105, -0.1718,  ..., -0.0210,  0.2615,  0.3812]],

        [[ 0.2275,  0.1725, -0.0553,  ..., -0.1587,  0.2556, -0.1112],
         [ 0.1502,  0.0833,  0.0923,  ...,  0.1065,  0.3873,  0.0937],
         [ 0.2265, -0.1045,  0.4209,  ..., -0.0535,  0.1077, -0.3624],
         ...,
         [ 0.3259, -0.1775, -0.1235,  ..., -0.1407,  0.0281, -0.0293],
         [ 0.3989, -0.2696,  0.3612,  ..., -0.0359, -0.1130,  0.1344],
         [ 0.3135,  0.1012, -0.0567,  ..., -0.2063,  0.1898,  0.3347]],

        [[ 0.4600,  0.2001, -0.1444,  ..., -0

In [20]:
class MODataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = {key: val for key, val in encodings.items() if key != 'labels'}
        self.labels = encodings['labels']

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = MODataset(example_sentence_inputs)
train_dataset

<__main__.MODataset at 0x220e4b25fd0>

In [21]:
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=3,              # total # of training epochs
    per_device_train_batch_size=256,  # batch size per device during training
    per_device_eval_batch_size=256,   # batch size for evaluation
    learning_rate=1e-5,     
    logging_dir='./logs',            # directory for storing logs
)

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=None            # evaluation dataset
)

In [22]:
trainer.train()

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1.0, style=ProgressStyle(description_widt…





TrainOutput(global_step=3, training_loss=10.356541315714518)

Custom tokenizer
=====================

In [91]:
class StrategizedTokenizer(object):
    def __init__(self, pos_based_mask=True, lemmatize=True, ner_based_swap=True):
        """
        Constructs the strategized Tokenizer.
        Loads the required spacy model
        
        Processes the sentence based on desired properties
        
        ==Not guaranteed to work on cased vocabularies==
        """
        self.nlp = spacy.load("en_core_web_sm")
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.pos_based_mask = pos_based_mask
        self.lemmatize = lemmatize
        self.ner_based_swap = ner_based_swap
        

    def tokenize(self, text):
        spacy_sentence = nlp(text, disable=['parser'])
        
        processed_text_list = []
        if self.pos_based_mask:
            processed_text_list += self.mask_text_pos_based(text, spacy_sentence)
        if self.lemmatize:
            processed_text_list += self.lemmatize_text(text, spacy_sentence)         
        if self.ner_based_swap:
            processed_text_list += self.ner_swap_text(text, spacy_sentence)
        #TODO add more?
        
        for x in processed_text_list:
            print(x)
        inputs = self.tokenizer(processed_text_list,
                                return_token_type_ids=False #Dont need this because we dont use NSP
                                return_tensors="pt", 
                                padding=True)
        inputs['labels'] = self.tokenizer([text for i in range(0,len(processed_text_list))], 
                                          return_attention_mask=False, 
                                          return_token_type_ids=False,
                                          return_tensors='pt', 
                                          padding=True)['input_ids']
        return inputs
    
    def mask_text_pos_based(self, text, spacy_sentence) -> list:
        input_ids = self.tokenizer.encode(text, add_special_tokens=False)
        input_tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
        
        pos_masks = []
        
        text_pos_list = [token.pos_ for token in spacy_sentence]
        text_pos_frequency = {pos: text_pos_list.count(pos) for pos in text_pos_list}
        
        for posoi in text_pos_frequency.keys():
            posoi_vocab = [token.text.lower() for token in spacy_sentence if token.pos_ == posoi]
            mask_indices = []
            
            composite_word_indices = []
            composite_word_tokens = []
            for (i, token) in enumerate(input_tokens):
                if token == "[CLS]" or token == "[SEP]":
                    continue
                elif token.startswith("##"):
                    composite_word_indices.append(i)
                    composite_word_tokens.append(token)
                    if "".join([x.strip("##") for x in composite_word_tokens]) in posoi_vocab:
                        mask_indices = mask_indices + composite_word_indices

                elif token in posoi_vocab:
                    mask_indices.append(i)
                else:
                    composite_word_indices = [i]
                    composite_word_tokens = [token]

            mask_labels = [1 if i in mask_indices else 0 for i in range(len(input_tokens))]
            masked_tokens = [x if mask_labels[i] == 0 else 103 for i,x in enumerate(input_ids)]
            masked_input = self.tokenizer.decode(masked_tokens)         
            pos_masks.append(masked_input)
            
        return pos_masks
    
    def lemmatize_text(self, text, spacy_sentence) -> list:
        replacement_tuples = [(token.text, token.lemma_) for token in spacy_sentence if token.text.lower() != token.lemma_]
        lemmatized_text = text
        for replacement in replacement_tuples:
            lemmatized_text = re.sub(r'\b' + replacement[0] + r'\b', replacement[1], lemmatized_text)

        lemmatized_text = lemmatized_text.replace("  ", " ")
        return [lemmatized_text]
    
    def ner_swap_text(self, text, spacy_sentence) -> list:
        ner_swapped_text = spacy_sentence.text
        for ent in spacy_sentence.ents:
            if ent.label_ == 'TIME':
                time_substring = ner_swapped_text[ent.start_char:ent.end_char].split(" ")
                time_substring.reverse()
                ner_swapped_text = ner_swapped_text.replace(ner_swapped_text[ent.start_char:ent.end_char], " ".join(time_substring))
            #TODO add other possible ideas
        return [ner_swapped_text]
        
    def convert_ids_to_tokens(self, input_ids):
        return [self.tokenizer.convert_ids_to_tokens(row) for row in input_ids]

In [92]:
ST_tokenizer = StrategizedTokenizer()

In [93]:
inputs = ST_tokenizer.tokenize("Anne went to the Albert Heijn at 5 o'clock to buy some milk for me.")
inputs

[MASK] went to the [MASK] [MASK] [MASK] [MASK] at 5 o'clock to buy some milk for me.
anne [MASK] to the albert heijn at 5 o'clock to [MASK] some milk for me.
anne went [MASK] the albert heijn [MASK] 5 o'clock [MASK] buy some milk [MASK] me.
anne went to [MASK] albert heijn at 5 o'clock to buy [MASK] milk for me.
anne went to the albert heijn at [MASK] o'clock to buy some milk for me.
anne went to the albert heijn at 5 o'clock to buy some [MASK] for me.
anne went [MASK] the albert heijn at 5 o'clock [MASK] buy some milk for me.
anne went to the albert heijn at 5 o'clock to buy some milk for [MASK].
anne went to the albert heijn at 5 o'clock to buy some milk for me [MASK]
Anne go to the Albert Heijn at 5 o'clock to buy some milk for I.
Anne went to the Albert Heijn at o'clock 5 to buy some milk for me.


{'input_ids': tensor([[  101,   103,  2253,  2000,  1996,   103,   103,   103,   103,  2012,
          1019,  1051,  1005,  5119,  2000,  4965,  2070,  6501,  2005,  2033,
          1012,   102],
        [  101,  4776,   103,  2000,  1996,  4789,  2002, 28418,  2078,  2012,
          1019,  1051,  1005,  5119,  2000,   103,  2070,  6501,  2005,  2033,
          1012,   102],
        [  101,  4776,  2253,   103,  1996,  4789,  2002, 28418,  2078,   103,
          1019,  1051,  1005,  5119,   103,  4965,  2070,  6501,   103,  2033,
          1012,   102],
        [  101,  4776,  2253,  2000,   103,  4789,  2002, 28418,  2078,  2012,
          1019,  1051,  1005,  5119,  2000,  4965,   103,  6501,  2005,  2033,
          1012,   102],
        [  101,  4776,  2253,  2000,  1996,  4789,  2002, 28418,  2078,  2012,
           103,  1051,  1005,  5119,  2000,  4965,  2070,  6501,  2005,  2033,
          1012,   102],
        [  101,  4776,  2253,  2000,  1996,  4789,  2002, 28418,  2078,  201

In [72]:
[] + [1,2,3] + [4] + [5]

[1, 2, 3, 4, 5]

Testing some classification model
====================

In [23]:
from transformers import AdamW
# From paper:
# lr: 1e-4
# Beta1 = 0.9 (default)
# Beta2 = 0.999 (default)
# L2 weight decay = 0.01

# Longer sequences are disproportionately expensive
# because attention is quadratic to the sequence
# length. To speed up pretraing in our experiments,
# we pre-train the model with sequence length of
# 128 for 90% of the steps. Then, we train the rest
# 10% of the steps of sequence of 512 to learn the
# positional embeddings.



#Batch size 256 for 1e6 steps


optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

In [24]:
sentences[0]

"When on board H.M.S. 'Beagle,' as naturalist, I was much struck with certain facts in the distribution of the inhabitants of South America, and in the geological relations of the present to the past inhabitants of that continent."

In [25]:
#doc = nlp("That's a lot better. He was finally walking to the beaches. There he had a meeting with his father. Afterwards, he read a book. The fishing rod that he used was really old")
doc = nlp("When on board H.M.S. 'Beagle,' as naturalist, I was much struck with certain facts in the distribution of the inhabitants of South America, and in the geological relations of the present to the past inhabitants of that continent.")

for token in doc:
    print(token.text, token.pos_, token.lemma_)

When ADV when
on ADP on
board NOUN board
H.M.S. PROPN H.M.S.
' PUNCT '
Beagle PROPN Beagle
, PUNCT ,
' PUNCT '
as ADP as
naturalist ADJ naturalist
, PUNCT ,
I PRON I
was AUX be
much ADV much
struck VERB strike
with ADP with
certain ADJ certain
facts NOUN fact
in ADP in
the DET the
distribution NOUN distribution
of ADP of
the DET the
inhabitants NOUN inhabitant
of ADP of
South PROPN South
America PROPN America
, PUNCT ,
and CCONJ and
in ADP in
the DET the
geological ADJ geological
relations NOUN relation
of ADP of
the DET the
present NOUN present
to ADP to
the DET the
past ADJ past
inhabitants NOUN inhabitant
of ADP of
that DET that
continent NOUN continent
. PUNCT .


In [26]:
doc = nlp("San Francisco is a long drive away from here. Ah, I forgot what I was doing. He had to get a new pair of shoes.")

for token in doc:
    print(token.text, token.pos_, token.lemma_)

San PROPN San
Francisco PROPN Francisco
is AUX be
a DET a
long ADJ long
drive NOUN drive
away ADV away
from ADP from
here ADV here
. PUNCT .
Ah INTJ ah
, PUNCT ,
I PRON I
forgot VERB forget
what PRON what
I PRON I
was AUX be
doing VERB do
. PUNCT .
He PRON he
had VERB have
to PART to
get VERB get
a DET a
new ADJ new
pair NOUN pair
of ADP of
shoes NOUN shoe
. PUNCT .
