In [2]:
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import re
import numpy as np
from transformers import BertTokenizer, BertForMaskedLM
from random import random

# from bow_strategy import get_bow
from datasets.bow_strategy import get_bow

In [3]:
class QuoraBertMaskPredictDataset(Dataset):
    def __init__(self, mode, train_size=5000, val_size=1000, test_size=1000, 
                 text_path='../data/quora_train.txt', pretrained_model_name='bert-base-cased', 
                 topk=50, bow_strategy='simple_sum', indiv_topk=10, indiv_topp=0.01, 
                 only_bow=False, use_origin=False, replace_predict=False, append_bow=True, replace_p=0.15):
        assert mode in ["train", "val", "test"]
        self.mode = mode
        self.train_size = train_size
        self.val_size = val_size
        self.test_size = test_size
        self.topk = topk
        self.bow_strategy = bow_strategy # simple_sum, mask_sum, indiv_topk, indiv_topp, indiv_neighbors
        self.indiv_topk = indiv_topk
        self.indiv_topp = indiv_topp
        self.only_bow = only_bow
        self.use_origin = use_origin
        self.replace_predict = replace_predict
        self.append_bow = append_bow
        self.replace_p = replace_p
        
        self.tokenizer = self.init_tokenizer(pretrained_model_name)
        self.mask_predict_model = BertForMaskedLM.from_pretrained(pretrained_model_name)
        self.sentences = self.read_text(text_path)
        self.init_constants()
        
        self.n_words = len(self.tokenizer)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.mask_predict_model = self.mask_predict_model.to(self.device)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        seq1, seq2 = sentence.split('\t')
        
        tokens1 = self.tokenizer.tokenize(seq1)
        word_pieces1 =  [self.SOS_token] + tokens1 + [self.EOS_token]
        idxes1 = self.tokenizer.convert_tokens_to_ids(word_pieces1)
        
        tokens2 = self.tokenizer.tokenize(seq2)
        word_pieces2 = [self.SOS_token] + tokens2 + [self.EOS_token]
        idxes2 = self.tokenizer.convert_tokens_to_ids(word_pieces2)
        
        seq1_tensor = torch.tensor(idxes1, dtype=torch.long)
        seq2_tensor = torch.tensor(idxes2, dtype=torch.long)
        
        # mask each word from the begining to the end
        # get the probability distribution of the mask tokens
        pred = self.get_mask_pred_probs(seq1_tensor)
        
        # get the predicted bag of words with the probability distribution above
        bow_tokens_tensor = get_bow(self.bow_strategy, self.n_words, pred, self.topk, self.indiv_topk)
        
        ret_seq1_tensor = seq1_tensor
        if self.replace_predict:
            replaced_sentence_tensor = self.get_replaced_sentence(seq1_tensor, pred)
            
            if not self.append_bow:
                return replaced_sentence_tensor, seq2_tensor
            
            ret_seq1_tensor = self.get_concat_replaced_tensor(
                tokens1, bow_tokens_tensor, replaced_sentence_tensor
            )
            return ret_seq1_tensor, seq2_tensor

        if self.only_bow:
            ret_seq1_tensor = bow_tokens_tensor
            if self.use_origin:
                origin_tokens = self.tokenizer.convert_tokens_to_ids(tokens1)
                origin_tensors = torch.tensor(origin_tokens, dtype=torch.long)
                ret_seq1_tensor = torch.sort(torch.cat((origin_tensors, bow_tokens_tensor)))[0]
            
        else:
            ret_seq1_tensor = torch.cat((seq1_tensor, bow_tokens_tensor, torch.tensor([self.EOS_token_id])))

        return ret_seq1_tensor, seq2_tensor

    def __len__(self):
        if self.mode == 'train':
            return self.train_size
        elif self.mode == 'val':
            return self.val_size
        else:
            return self.test_size
        
    # [CLS]  [M]  w2  w3  [SEP]        
    # [CLS]  w1  [M]  w3  [SEP]        
    # [CLS]  w1 w2  [M]  [SEP]        
    # get the probability of [M] for each mask-prediction case
    # mask for (number of words) times and every time get (number of words + 2) probability
    # TODO: get the pred for only once
    def get_mask_pred_probs(self, seq1):
        mask_sentences = []
        
        for i in range(1, len(seq1) - 1):
            mask_seq = seq1.detach().clone()
            mask_seq[i] = self.MASK_token_id
            mask_sentences.append(mask_seq)

        mask_stack = torch.stack(mask_sentences)
        mask_stack = mask_stack.to(self.device)
        
        self.mask_predict_model.eval()

        with torch.no_grad():
            pred = self.mask_predict_model(mask_stack)[0]
        pred = pred.cpu()
        return pred

    def get_replaced_sentence(self, seq1, pred):
        softmax = torch.nn.Softmax(dim=0)
        # add BOS and EOS
        pred_ws = [seq1[0].item()]
        for i in range(pred.shape[0]):
            if random() < self.replace_p:
            # 1. top1, when prob > 0.5
    #         prob, idx = torch.topk(softmax(pred[i][i+1]), 1)
    #         w = idx.item() if prob > 0.5 else seq1[i+1].item() 

    #         # 2. top1
    #         prob, idx = torch.topk(softmax(pred[i][i+1]), 1)
    #         w = idx.item()
    
                # 3. sample
                idx = torch.multinomial(softmax(pred[i][i+1]), 1)[0]
                w = idx.item()

                pred_ws.append(w)
            else:
                pred_ws.append(seq1[i+1])

        pred_ws.append(seq1[-1].item())
        return torch.tensor(pred_ws, dtype=torch.long)
    
    
    def get_concat_replaced_tensor(self, tokens1, bow_tokens_tensor, replaced_sentence_tensor):
        # use original tokens by default
        origin_tokens = self.tokenizer.convert_tokens_to_ids(tokens1)
        origin_tensors = torch.tensor(origin_tokens, dtype=torch.long)
        append_tensors = torch.sort(torch.cat((origin_tensors, bow_tokens_tensor)))[0]
        tensors = torch.cat((replaced_sentence_tensor, append_tensors))

        return tensors
            
        
    def init_tokenizer(self, pretrained_model_name):
        tokenizer = BertTokenizer.from_pretrained(pretrained_model_name)
        return tokenizer
    
    
    def init_constants(self):
        PAD_id,  SOS_id, EOS_id, UNK_id = self.tokenizer.convert_tokens_to_ids(["[PAD]", "[CLS]", "[SEP]", "[UNK]"])
        self.PAD_token_id = PAD_id
        self.SOS_token_id = SOS_id
        self.EOS_token_id = EOS_id
        self.UNK_token_id = UNK_id
        
        self.PAD_token = '[PAD]'
        self.SOS_token = '[CLS]'
        self.EOS_token = '[SEP]'
        self.UNK_token = '[UNK]'
        
        self.MASK_token = '[MASK]'
        self.MASK_token_id = self.tokenizer.convert_tokens_to_ids(["[MASK]"])[0]

        
    def read_text(self, text_path):
        # add words to dictionary
        f = open(text_path, 'r')
        lines = f.readlines()
        # shuffle
        np.random.shuffle(lines)
        if self.mode == "train":
            lines = lines[:self.train_size]
        elif self.mode == 'val':
            lines = lines[self.train_size:self.train_size+self.val_size]
        else:
            lines = lines[self.train_size+self.val_size:self.train_size+self.val_size+self.test_size]
        
        return lines



In [4]:
# dataset = QuoraBertMaskPredictDataset("train", 1000, 100, text_path='../../data/quora_train.txt')
# dataset = QuoraBertMaskPredictDataset("train", 1000, 100, text_path='../../data/quora_train.txt', bow_strategy='simple_sum')
# dataset = QuoraBertMaskPredictDataset("train", 1000, 100, text_path='../../data/quora_train.txt', bow_strategy='indiv_topk')
# dataset = QuoraBertMaskPredictDataset("train", 124000, 100, text_path='../../data/quora_train.txt', bow_strategy='indiv_neighbors')
# dataset = QuoraBertMaskPredictDataset("train", 1000, 100, text_path='../../data/quora_train.txt', bow_strategy='indiv_topk', only_bow=True, use_origin=True)
# dataset = QuoraBertMaskPredictDataset("train", 1000, 100, text_path='../../data/quora_train.txt', bow_strategy='indiv_topk', replace_predict=True, append_bow=False)



In [23]:
# idxs, idxs2 = dataset[1]
# tokens = dataset.tokenizer.convert_ids_to_tokens(idxs)
# tokens2 = dataset.tokenizer.convert_ids_to_tokens(idxs2)

# print(tokens)
# print(tokens2)

# print(dataset.tokenizer.convert_tokens_to_string(tokens))
# print(dataset.tokenizer.convert_tokens_to_string(tokens2))



['[CLS]', 'So', 'so', 'I', 'ask', 'questions', 'on', 'Q', '##uo', '##ra', '?', '[SEP]']
['[CLS]', 'How', 'can', 'I', 'ask', 'my', 'question', 'on', 'Q', '##uo', '##ra', '?', '[SEP]']
[CLS] So so I ask questions on Quora ? [SEP]
[CLS] How can I ask my question on Quora ? [SEP]


In [17]:
# def create_mini_batch(samples):
#     seq1_tensors = [s[0] for s in samples]
#     seq2_tensors = [s[1] for s in samples]
# #     bows_tensors = [s[2] for s in samples]

#     # zero pad
#     seq1_tensors = pad_sequence(seq1_tensors,
#                                   batch_first=True)

#     seq2_tensors = pad_sequence(seq2_tensors,
#                                   batch_first=True)    
    
# #     return seq1_tensors, seq2_tensors, torch.stack(bows_tensors)
#     return seq1_tensors, seq2_tensors


# # it takes time to predict masked component
# # to improve -> use gpu and calculate outside the dataset
# data_loader = DataLoader(dataset, batch_size=64, collate_fn=create_mini_batch)

In [527]:
# seq1, seq2 = next(iter(data_loader))

tensor([  101,  1249,  8005,  6360,   119,   146,  1821,  1103, 17212,  4907,
         8456,  1104,   117,   117,   118,  2801,  2940,   119,   119,   119,
         1327,  1674,  1122,  1474,  1164,  1128,   136,   102])
tensor([ 1249,  8005,  6360,   131,   146,  1821,   170, 17212,  4907,  8456,
         3477, 17212,  5907,  1105,  6707,  4703,   119,   119,   119,  1184,
         1674,  1115,  1474,  1164,  1143,   136])
tensor([  117,   119,   118,  1105,  1232,  1104,   131,  3336,  2851,  1103,
          106,  1110,   136,  5907,  1202,  1474,   132, 17212,  2940,   120,
         1821,  1221,  8005,  1122,  1114,  1252,  1573,  1115,  4907,   170,
         1674,  1143,  1262,  1249,   189,   146,  1327,  1986,  6707,  1297,
         1875,  1164,  1928,  1137,  1107,  8456,  3477, 16922,  2746,  1225])
tensor([  101,  1731,  1169,   146,  1129,   170,  1363, 23455,   136,   102])
tensor([ 1731,  1169,   146,  1129,   170,  1363, 25166,   136])
tensor([ 1141,  1115,   136,   170,  

tensor([  101,  1188,  1108,  1103,  1436,  1105,  1211, 13108,  7696,  1195,
          112,  1396,  1518,  1458,  1105,  8756,   119,   102])
tensor([ 1327,  1110,  1103,  1436,   120,  1211, 14113,  1645,  1128,   112,
         1396,  1518,  8527,  1105,  1725,   136])
tensor([  117,   136,  1137,   118,  1211,  1105,   120,  1694,  1125,  1103,
         1436,  1655,   119,  5095,  1396,  4997,   106,  1155,  1541,   112,
         1518,   132,  1106,  1232,   131,   146,  1191,  1128,   173,  1133,
         1110,  1231,  1325,  1562,  1645,  4583,  1108,  1195,  7696, 13108,
         8527,  2094, 10741,  1152,  4459,  1309,  1198,  1164,  1188,  6119])
tensor([  101,   146,  1138,  1579, 15430,  3375,  1154,   144, 14746,   119,
          146,  1274,   112,   189,  2373,  1103,   144, 14746,  3300,   118,
          146,  1208,  1139, 10322,  2618,  1110,  1185,  2039,  1117,   119,
         1327,  1169,   146,  1202,   136,   102])
tensor([  146,  1108,  2840,  9366,  3660,  1228,   

tensor([  101,  1731,  1180,  1103, 22182,  1250,   136,   102])
tensor([1731, 1674, 7860, 8455, 1250,  136])
tensor([  136,  1250,  1674,  1731,  1180,  1225,  1169,  1156,   119,  3053,
         1209,  1103,  1202,  1839,  1494,  1301,  1440,  1631,   106,  1327,
        13117,  4732,  3333,  2009,  6058,  2777,  8455,   132,  1142,  1110,
         1293,  1240,  1431,  1139,  1115,  1538,   170,  1252, 23741,  2815,
         1332,  1262,  2926, 13970,  1117,  8794, 15171,  1123,  1573, 12783])
tensor([ 101, 3100, 1175, 1518, 1129,  170, 1594, 1206, 6469, 1105, 1999, 4423,
        1103, 8166, 1571, 2186,  136, 1327, 1180, 1138, 1157, 2590,  136,  102])
tensor([3100, 1175, 1541, 1129, 1251, 1594, 1206, 1726, 1105, 3658, 1166, 1103,
         158, 2047, 2035,  136, 1327, 1209, 1129, 1157, 3154,  136])
tensor([1107, 1103,  136,  118,  119, 1129, 1222, 1142, 1170, 1518,  170, 1726,
        1105,  117, 1975, 1209, 3658, 1114, 3398,  106, 2612, 6735, 1251, 2733,
        1157, 2855, 1247, 133

tensor([  101,  1967,   146,  1202,  1136, 19863, 26883,  3171,  1292,  6581,
         1137,  1146,  9607,  8307,  1116,   117,  1184,  1132,  1175, 18552,
         1115,  2256,  1209,  1654,  1139,  6581,   136,   102])
tensor([ 1409,   146,  1202,  1136, 19863, 26883,  3171,  7673,  6581,   111,
         1146,  9607, 11409,  3438,   117,  1173,  1132,  1175,  9820,  1115,
         7986,  1336,  3510,  1139,  3300,   136])
tensor([ 1103,  7673,   117,  1115,  1139,  3438,   119,  1173,  1147,  1209,
         1202,  1137, 26883,  1169,  1105,  1152,  1156,  1128,  1180,   146,
         1184,  1106,  1232,  1336,  1674,  1225,  6581, 19863,  3171,  1431,
         1116,  7986,  1292,  1234,  1251,  1725,  1177,  1187,  1133,  3094,
         3300,  1132,   132,  1146,  1412,  1142,  1343,   136,  1165, 11014])
tensor([  101,  1327,  1674,  1103,   154, 11848,  1611, 26910,  1440,  1176,
         1111,  2936,  1104,   154, 11848,  1611, 18390, 19527,   136,   102])
tensor([ 1327,  1674,  1

In [7]:
# from tqdm import tqdm
# max_size = 0
# for seq1, seq2 in tqdm(data_loader):
# #     pass
# #     print(seq1)
#     print(seq1.shape, seq2.shape)


  6%|▋         | 1/16 [00:02<00:37,  2.53s/it]

torch.Size([64, 87]) torch.Size([64, 30])


 12%|█▎        | 2/16 [00:04<00:31,  2.24s/it]

torch.Size([64, 74]) torch.Size([64, 27])


 19%|█▉        | 3/16 [00:05<00:26,  2.00s/it]

torch.Size([64, 82]) torch.Size([64, 24])


 25%|██▌       | 4/16 [00:07<00:22,  1.88s/it]

torch.Size([64, 81]) torch.Size([64, 29])


 31%|███▏      | 5/16 [00:08<00:20,  1.87s/it]

torch.Size([64, 80]) torch.Size([64, 32])


 38%|███▊      | 6/16 [00:10<00:16,  1.68s/it]

torch.Size([64, 84]) torch.Size([64, 37])


 44%|████▍     | 7/16 [00:11<00:14,  1.66s/it]

torch.Size([64, 87]) torch.Size([64, 43])


 50%|█████     | 8/16 [00:13<00:14,  1.79s/it]

torch.Size([64, 81]) torch.Size([64, 37])


 56%|█████▋    | 9/16 [00:15<00:11,  1.61s/it]

torch.Size([64, 83]) torch.Size([64, 34])


 62%|██████▎   | 10/16 [00:16<00:08,  1.47s/it]

torch.Size([64, 77]) torch.Size([64, 27])


 69%|██████▉   | 11/16 [00:17<00:07,  1.45s/it]

torch.Size([64, 95]) torch.Size([64, 44])


 75%|███████▌  | 12/16 [00:18<00:05,  1.38s/it]

torch.Size([64, 87]) torch.Size([64, 42])


 81%|████████▏ | 13/16 [00:20<00:03,  1.30s/it]

torch.Size([64, 85]) torch.Size([64, 37])


 88%|████████▊ | 14/16 [00:21<00:02,  1.29s/it]

torch.Size([64, 78]) torch.Size([64, 36])


 94%|█████████▍| 15/16 [00:22<00:01,  1.25s/it]

torch.Size([64, 88]) torch.Size([64, 30])


100%|██████████| 16/16 [00:23<00:00,  1.46s/it]

torch.Size([40, 88]) torch.Size([40, 38])





In [8]:
# seq1, seq2 = next(iter(data_loader))
# dataset.tokenizer.convert_ids_to_tokens(seq1[10])

['[CLS]',
 'What',
 'are',
 'some',
 'special',
 'cares',
 'for',
 'someone',
 'with',
 'a',
 'nose',
 'that',
 'gets',
 'stuff',
 '##y',
 'during',
 'the',
 'night',
 '?',
 '[SEP]',
 ',',
 'that',
 'with',
 '-',
 'and',
 'in',
 'like',
 'for',
 'the',
 '...',
 'to',
 'about',
 'on',
 'or',
 'of',
 'you',
 'a',
 'who',
 '?',
 '.',
 'but',
 'just',
 '##s',
 'at',
 'up',
 'is',
 'as',
 'it',
 'when',
 'one',
 'from',
 "'",
 'out',
 'all',
 'if',
 'so',
 'over',
 '"',
 'people',
 ':',
 'after',
 'by',
 'not',
 ';',
 'being',
 'really',
 'only',
 'even',
 'do',
 'around',
 '[SEP]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]']

In [5]:
# analyze
dataset = QuoraBertMaskPredictDataset("train", 1000, 100, text_path='../../data/quora_train.txt')

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
# sentence = 'How do you speak in front of large groups of people ?'
sentence = 'How can I speak faster ?'


tokens1 = dataset.tokenizer.tokenize(sentence)
word_pieces1 =  [dataset.SOS_token] + tokens1 + [dataset.EOS_token]
idxes1 = dataset.tokenizer.convert_tokens_to_ids(word_pieces1)
seq1_tensor = torch.tensor(idxes1, dtype=torch.long)

def get_mask_pred_probs(seq1):
        mask_sentences = []
        
        for i in range(1, len(seq1) - 1):
            mask_seq = seq1.detach().clone()
            mask_seq[i] = dataset.MASK_token_id
            mask_sentences.append(mask_seq)

        mask_stack = torch.stack(mask_sentences)
        mask_stack = mask_stack.to(dataset.device)
        
        dataset.mask_predict_model.eval()

        with torch.no_grad():
            pred = dataset.mask_predict_model(mask_stack)[0]
        pred = pred.cpu()
        return pred


pred = get_mask_pred_probs(seq1_tensor)

# get top 5
softmax = torch.nn.Softmax(dim=0)
for i in range(pred.shape[0]):
    prob, idx = torch.topk(softmax(pred[i][i+1]), 20)
    print(dataset.tokenizer.convert_ids_to_tokens([seq1_tensor[i+1]]))
    print(dataset.tokenizer.convert_ids_to_tokens(idx))

['How']
['How', 'But', 'Why', 'Or', 'how', 'When', 'And', '.', 'Where', 'Now', 'Then', 'What', 'Please', 'So', 'Only', '...', 'Nor', 'but', '-', 'However']
['can']
['could', 'can', 'do', 'did', 'would', 'will', 'should', 'dare', 'must', 'may', 'does', 'shall', 'cannot', 'dared', 'might', 'am', 'couldn', 'have', 'had', 'Can']
['I']
['I', 'he', 'you', 'she', 'they', 'we', 'anyone', 'people', 'someone', 'it', 'one', 'everyone', 'men', 'that', 'this', 'women', 'God', 'man', 'me', 'words']
['speak']
['go', 'move', 'run', 'walk', 'be', 'get', 'travel', 'think', 'drive', 'breathe', 'swim', 'see', 'fly', 'heal', 'talk', 'turn', 'work', 'come', 'sprint', 'ride']
['faster']
['it', 'now', 'again', 'English', 'this', 'that', 'here', 'anymore', 'you', 'normally', 'up', 'them', 'aloud', 'properly', 'French', 'so', 'then', 'freely', 'words', 'him']
['?']
['?', '.', '!', ';', '...', ',', '-', 'that', ':', 'when', '"', 'and', 'if', 'after', 'what', "'", 'to', 'with', 'so', 'me']
