In [350]:
import sys
import os
import re
import random
import ujson
import numpy as np
import pickle
import h5py

from typing import List, Dict

from collections import defaultdict, Counter
from tqdm.notebook import tqdm

from nltk.tokenize import word_tokenize

from table_bert.utils import BertTokenizer

from transformers import GPT2Tokenizer, GPT2LMHeadModel

In [125]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [2]:
cmudict_path = '/Users/mac/Desktop/syt/Deep-Learning/Dataset/CMUdict/cmudict-0.7b.txt'

In [19]:
entry_lines = []

with open(cmudict_path, 'r', encoding='latin-1') as f:
    for l in f:
        if len(l.strip()) > 0 and (not l.startswith(';;;')):
            entry_lines.append(l.strip())

len(entry_lines)

133854

In [23]:
entry_lines[5::10000]

['"IN-QUOTES  IH1 N K W OW1 T S',
 'BELDOCK  B EH1 L D AA2 K',
 'CHANNELED  CH AE1 N AH0 L D',
 'DEMILLE(1)  D IH0 M IH1 L',
 'EXTORTIONISTS  EH0 K S T AO1 R SH AH0 N IH0 S T S',
 'GRIPPED  G R IH1 P T',
 'INTERSPERSES  IH2 N T ER0 S P ER1 S AH0 Z',
 'LIMESTONE  L AY1 M S T OW2 N',
 'MONROEVILLE  M AA0 N R OW1 V IH2 L',
 'PENDERGAST  P EH1 N D ER0 G AE2 S T',
 'REPEATS(1)  R IY0 P IY1 T S',
 'SIELAFF  S IY0 L AE1 F',
 'THIBEDEAU  TH IH1 B IH0 D OW0',
 'WHITEAKER(1)  HH W IH1 T AH0 K ER0']

In [34]:
def strip_stress(phone: str) -> str:
    _m = re.match(r'(.*)\d+$', phone)
    if _m is not None:
        phone = _m.group(1)
    return phone

# test 
for l in entry_lines[5::10000]:
    _word, _pron = l.split('  ')
    
    if _word.endswith(')'):
        # variant
        _m = re.match(r'(.*)\((.*)\)$', _word)
        _word = _m.group(1)
        _variant = _m.group(2)
    else:
        # no variant
        _variant = None
    
    _phones = tuple([strip_stress(_phone) for _phone in _pron.split(' ')])
    
    print(_word, _variant, _phones)

"IN-QUOTES None ('IH', 'N', 'K', 'W', 'OW', 'T', 'S')
BELDOCK None ('B', 'EH', 'L', 'D', 'AA', 'K')
CHANNELED None ('CH', 'AE', 'N', 'AH', 'L', 'D')
DEMILLE 1 ('D', 'IH', 'M', 'IH', 'L')
EXTORTIONISTS None ('EH', 'K', 'S', 'T', 'AO', 'R', 'SH', 'AH', 'N', 'IH', 'S', 'T', 'S')
GRIPPED None ('G', 'R', 'IH', 'P', 'T')
INTERSPERSES None ('IH', 'N', 'T', 'ER', 'S', 'P', 'ER', 'S', 'AH', 'Z')
LIMESTONE None ('L', 'AY', 'M', 'S', 'T', 'OW', 'N')
MONROEVILLE None ('M', 'AA', 'N', 'R', 'OW', 'V', 'IH', 'L')
PENDERGAST None ('P', 'EH', 'N', 'D', 'ER', 'G', 'AE', 'S', 'T')
REPEATS 1 ('R', 'IY', 'P', 'IY', 'T', 'S')
SIELAFF None ('S', 'IY', 'L', 'AE', 'F')
THIBEDEAU None ('TH', 'IH', 'B', 'IH', 'D', 'OW')
WHITEAKER 1 ('HH', 'W', 'IH', 'T', 'AH', 'K', 'ER')


In [38]:
word2pron = defaultdict(set) # all possible prons, Dict[str, Set[Tuple[str]]]
pron2word = defaultdict(set) # all possible words, Dict[Tuple[str], Set[str]]

for l in tqdm(entry_lines):
    _word, _pron = l.split('  ')
    
    if _word.endswith(')'):
        # variant
        _m = re.match(r'(.*)\((.*)\)$', _word)
        _word = _m.group(1)
        _variant = _m.group(2)
    else:
        # no variant
        _variant = None
    
    _phones = tuple([strip_stress(_phone) for _phone in _pron.split(' ')])
    
    word2pron[_word].add(_phones)
    pron2word[_phones].add(_word)

len(word2pron), len(pron2word)

HBox(children=(IntProgress(value=0, max=133854), HTML(value='')))




(125074, 113745)

In [133]:
idx2word = list(word2pron.keys())
word2idx = {w : idx for idx, w in enumerate(idx2word)}
word2pieces = {w : bert_tokenizer.tokenize(w) for w in idx2word}
idx2pron = list(pron2word.keys())
pron2idx = {p : idx for idx, p in enumerate(idx2pron)}
len(word2idx), len(word2pieces), len(pron2idx)

(125074, 125074, 113745)

In [39]:
word2pron['RECORD']

{('R', 'AH', 'K', 'AO', 'R', 'D'),
 ('R', 'EH', 'K', 'ER', 'D'),
 ('R', 'IH', 'K', 'AO', 'R', 'D')}

In [42]:
word2idx['RECORD'], idx2word[91994]

(91994, 'RECORD')

In [45]:
pron2idx[('R', 'EH', 'D')], idx2pron[83079]

(83079, ('R', 'EH', 'D'))

In [61]:
# Suffixes
suffixes = [
    ['S'],
    ['IY NG'],
    ['IH NG'],
    ['D']
]

# Similar phone clusters 
clusters = [
    ['Z', 'S'],
    ['AA', 'AO', 'EY', 'UH'],
    ['AXR', 'AX'],
    ['P', 'B', 'F'],
    ['DH', 'CH', 'ZH', 'T', 'SH'],
    ['IY', 'AY', 'OW'],
    ['EH', 'AH', 'IH', 'AW', 'ER', 'UW']
]

phone2cluster = defaultdict(list) # Dict[str, List(str)]

for c in clusters:
    for p in c:
        phone2cluster[p] = c

len(phone2cluster)

25

In [62]:
phone2cluster['IY']

['IY', 'AY', 'OW']

In [257]:
# 1-step confusion 

# confusable pron ids, List[int]; [i] has j: i can be replaced by j 
# including self, i.e. [i] always has i  
pron_confusions = [set([idx]) for idx in range(len(idx2pron))]

remove_consonant_cnt = 0
double_vowel_cnt = 0
add_suffix_cnt = 0
substitute_cnt = 0

for _idx, _pron in tqdm(enumerate(idx2pron), total=len(idx2pron)):
    _phones = list(_pron)
    
    # remove a consonant (outgoing edge only)
    for j in range(len(_phones)):
        if _phones[j][0] in 'AEIOU':
            continue
            
        _confs_pron = tuple(_phones[:j] + _phones[j+1:])
        try:
            _confs_idx = pron2idx[_confs_pron]
            pron_confusions[_idx].add(_confs_idx)
            remove_consonant_cnt += 1
            
            if remove_consonant_cnt <= 5:
                _src_pron = _pron
                _src_w = next(iter(pron2word[_src_pron]))
                _tgt_pron = _confs_pron
                _tgt_w = next(iter(pron2word[_tgt_pron]))
                print(f'Remove consonant: {_src_w}{_src_pron} -> {_tgt_w}{_tgt_pron}')
        except KeyError:
            continue
    
    # remove a doubled vowel (incoming edge only, for doubling vowel)
    for j in range(1, len(_phones)):
        if _phones[j] != _phones[j-1] or (not _phones[j][0] in 'AEIOU'):
            continue
            
        _confs_pron = tuple(_phones[:j] + _phones[j+1:])
        try:
            _confs_idx = pron2idx[_confs_pron]
            pron_confusions[_confs_idx].add(_idx)
            double_vowel_cnt += 1
            
            if double_vowel_cnt <= 5:
                _src_pron = _confs_pron
                _src_w = next(iter(pron2word[_src_pron]))
                _tgt_pron = _pron
                _tgt_w = next(iter(pron2word[_tgt_pron]))
                print(f'Double vowel: {_src_w}{_src_pron} -> {_tgt_w}{_tgt_pron}')
        except KeyError:
            continue
    
    # add a suffix (outgoing edge only)
    for _suffix in suffixes:
        _confs_pron = tuple(_phones + _suffix)
        try:
            _confs_idx = pron2idx[_confs_pron]
            pron_confusions[_idx].add(_confs_idx)
            add_suffix_cnt += 1
            
            if add_suffix_cnt <= 5:
                _src_pron = _pron
                _src_w = next(iter(pron2word[_src_pron]))
                _tgt_pron = _confs_pron
                _tgt_w = next(iter(pron2word[_tgt_pron]))
                print(f'Add suffix: {_src_w}{_src_pron} -> {_tgt_w}{_tgt_pron}')
        except KeyError:
            continue
    
    # substitute a phone (outgoing edge only; it's bidirectional in essense)
    for j in range(len(_phones)):
        _ph = _phones[j]
        _ph_cluster = phone2cluster[_ph]
        if len(_ph_cluster) <= 1:
            continue
        
        # has a cluster, do replace
        for _confs_ph in _ph_cluster:
            if _confs_ph == _ph:
                continue
            
            _confs_pron = tuple(_phones[:j] + [_confs_ph] + _phones[j+1:])
            try:
                _confs_idx = pron2idx[_confs_pron]
                pron_confusions[_idx].add(_confs_idx)
                substitute_cnt += 1

                if substitute_cnt <= 5:
                    _src_pron = _pron
                    _src_w = next(iter(pron2word[_src_pron]))
                    _tgt_pron = _confs_pron
                    _tgt_w = next(iter(pron2word[_tgt_pron]))
                    print(f'Substitute phone: {_src_w}{_src_pron} -> {_tgt_w}{_tgt_pron}')
            except KeyError:
                continue



HBox(children=(IntProgress(value=0, max=113745), HTML(value='')))

Remove consonant: QUOTE('K', 'W', 'OW', 'T') -> COAT('K', 'OW', 'T')
Remove consonant: QUOTE('K', 'W', 'OW', 'T') -> QUO('K', 'W', 'OW')
Add suffix: QUOTE('K', 'W', 'OW', 'T') -> QUOTES('K', 'W', 'OW', 'T', 'S')
Substitute phone: QUOTE('K', 'W', 'OW', 'T') -> QUIETT('K', 'W', 'IY', 'T')
Substitute phone: QUOTE('K', 'W', 'OW', 'T') -> QUITE('K', 'W', 'AY', 'T')
Remove consonant: "UNQUOTE('AH', 'N', 'K', 'W', 'OW', 'T') -> UNCOAT('AH', 'N', 'K', 'OW', 'T')
Remove consonant: PERCENT('P', 'ER', 'S', 'EH', 'N', 'T') -> PERSET('P', 'ER', 'S', 'EH', 'T')
Add suffix: PERCENT('P', 'ER', 'S', 'EH', 'N', 'T') -> PERCENTS('P', 'ER', 'S', 'EH', 'N', 'T', 'S')
Substitute phone: PERCENT('P', 'ER', 'S', 'EH', 'N', 'T') -> PRESENT('P', 'ER', 'Z', 'EH', 'N', 'T')
Substitute phone: 'ALLO('AA', 'L', 'OW') -> OLLY('AA', 'L', 'IY')
Remove consonant: BOUT('B', 'AW', 'T') -> OUT('AW', 'T')
Add suffix: BOUT('B', 'AW', 'T') -> BOUTS('B', 'AW', 'T', 'S')
Substitute phone: BOUT('B', 'AW', 'T') -> POUT('P', 'AW', 

In [258]:
remove_consonant_cnt, double_vowel_cnt, add_suffix_cnt, substitute_cnt

(65833, 86, 7700, 51210)

In [259]:
sorted(Counter([len(x) for x in pron_confusions]).most_common())

[(1, 51250),
 (2, 32982),
 (3, 13536),
 (4, 7222),
 (5, 4249),
 (6, 2400),
 (7, 1154),
 (8, 613),
 (9, 236),
 (10, 76),
 (11, 14),
 (12, 11),
 (13, 2)]

In [260]:
bert_tokenizer.tokenize('RUUD')

['ru', '##ud']

In [261]:
# Word confusion 
word_confusions = defaultdict(set) # confusable words (not ids!), Dict[str, Set[str]]

for _word in tqdm(idx2word):
    _word_pieces = word2pieces[_word]
    
    for _pron in word2pron[_word]:
        _pron_idx = pron2idx[_pron]
        _confs_pron_ids = pron_confusions[_pron_idx]
        for _confs_pron_idx in _confs_pron_ids:
            _confs_pron = idx2pron[_confs_pron_idx]
            _confs_words = pron2word[_confs_pron]
            
            for w in _confs_words:
                if len(word2pieces[w]) != len(_word_pieces): continue
                if w == _word: continue
                word_confusions[_word].add(w)
                

HBox(children=(IntProgress(value=0, max=125074), HTML(value='')))




In [306]:
# dump the word_confusions dict 
# word_confusions_path = '/Users/mac/Desktop/syt/Deep-Learning/Repos/TaBERT/data/word_confusions.pkl'

# with open(word_confusions_path, 'wb') as f:
#     pickle.dump(word_confusions, f)

In [307]:
# load the word_confusions dict 
word_confusions_path = '/Users/mac/Desktop/syt/Deep-Learning/Repos/TaBERT/data/word_confusions.pkl'

with open(word_confusions_path, 'rb') as f:
    word_confusions = pickle.load(f)

In [308]:
sorted(Counter([len(word_confusions[w]) for w in idx2word]).most_common())

[(0, 69467),
 (1, 24437),
 (2, 9617),
 (3, 6079),
 (4, 4126),
 (5, 3007),
 (6, 2207),
 (7, 1615),
 (8, 1096),
 (9, 920),
 (10, 641),
 (11, 447),
 (12, 388),
 (13, 258),
 (14, 222),
 (15, 144),
 (16, 136),
 (17, 75),
 (18, 74),
 (19, 40),
 (20, 26),
 (21, 13),
 (22, 3),
 (23, 19),
 (24, 7),
 (25, 2),
 (27, 3),
 (28, 2),
 (30, 1),
 (31, 1),
 (34, 1)]

In [316]:
for w in idx2word:
    if len(word_confusions[w]) > 20:
        print(word2pron[w], word2pieces[w])
        print([(cw, word2pieces[cw]) for cw in word_confusions[w]])
        print('-'*50)

{('B', 'AY')} ['bae']
[('BI', ['bi']), ('B', ['b']), ('PIE', ['pie']), ('BEE', ['bee']), ('I', ['i']), ('EYE', ['eye']), ('BY', ['by']), ('PHI', ['phi']), ('BYE', ['bye']), ('FI', ['fi']), ('BUY', ['buy']), ('BE', ['be']), ('BOW', ['bow']), ('PI', ['pi']), ('BEA', ['bea']), ('AI', ['ai']), ('BO', ['bo']), ('BEAU', ['beau']), ('FAE', ['fae']), ('AYE', ['aye']), ('BEAUX', ['beaux'])]
--------------------------------------------------
{('B', 'IY', 'K'), ('B', 'EH', 'K')} ['bae', '##k']
[('BUC', ['bu', '##c']), ('BOAKE', ['bo', '##ake']), ('BECKS', ['beck', '##s']), ('BIC', ['bi', '##c']), ('PAEK', ['pa', '##ek']), ('BEEKS', ['bee', '##ks']), ('BOUCK', ['bo', '##uck']), ('BICK', ['bi', '##ck']), ('B.', ['b', '.']), ('EKE', ['ek', '##e']), ('FECK', ['fe', '##ck']), ('BOEKE', ['bo', '##eke']), ('BIRK', ['bi', '##rk']), ('PIQUE', ['pi', '##que']), ('BOOCK', ['boo', '##ck']), ('BERCH', ['be', '##rch']), ('PECH', ['pe', '##ch']), ('BURK', ['bu', '##rk']), ('BEEK', ['bee', '##k']), ('BERKE', ['b

In [310]:
w = 'THREE'
word2pron[w], word2pieces[w]

({('TH', 'R', 'IY')}, ['three'])

In [311]:
[(cw, word2pieces[cw], word2pron[cw]) for cw in word_confusions[w]]

[('THROW', ['throw'], {('TH', 'R', 'OW')}),
 ('RE', ['re'], {('R', 'EY'), ('R', 'IY')})]

In [218]:
# GPT2 lm score 
gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt2_model_lm = GPT2LMHeadModel.from_pretrained('gpt2')

You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
inputs = gpt2_tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = gpt2_model_lm(**inputs, labels=inputs["input_ids"])
loss, logits = outputs[:2]
loss, logits

In [221]:
def GPT2_LM_Loss(sentence: str, gpt2_tokenizer: GPT2Tokenizer, gpt2_model_lm: GPT2LMHeadModel) -> float:
    inputs = gpt2_tokenizer(sentence, return_tensors="pt")
    outputs = gpt2_model_lm(**inputs, labels=inputs["input_ids"])
    loss = outputs[0].item()
    return loss

In [230]:
GPT2_LM_Loss("Hello, my dog is very cute .", gpt2_tokenizer, gpt2_model_lm)

4.128024578094482

In [318]:
class SentenceAcousticConfuser(object):
    def __init__(self,
                 word_confusions_path: str):
                 # fix_subword_lengths: bool = True,  # confusion word have same number of subwords as original word 
                 # bert_tokenizer_name: str = 'bert-base-uncased'
        
        super().__init__()
        
        self.word_confusions_path = word_confusions_path
        # self.fix_subword_lengths = fix_subword_lengths
        # self.bert_tokenizer_name = bert_tokenizer_name
        
        with open(self.word_confusions_path, 'rb') as f:
            self.word_confusions = pickle.load(f)
    
    def sentence_confuse(self, sentence: List[str], p: float) -> List[str]:
        raise NotImplementedError

In [319]:
class SentenceAcousticConfuser_RandomReplace(SentenceAcousticConfuser):
    def __init__(self,
                 word_confusions_path: str):
        
        super().__init__(word_confusions_path)
    
    def sentence_confuse(self, sentence: List[str], p: float) -> List[str]:
        assert 0 <= p <= 1
    
        sen_len = len(sentence)
        confs_cnt = int(p * sen_len)

        confusable_positions = []
        for pos in range(sen_len):
            word = sentence[pos].upper()
            if len(self.word_confusions[word]) > 0:
                confusable_positions.append(pos)

        if len(confusable_positions) <= confs_cnt:
            # not enough positions for confusion
            confs_positions = confusable_positions
        else:
            confs_positions = random.sample(confusable_positions, k=confs_cnt)

        confs_sentence = list(sentence)
        for pos in confs_positions:
            word = sentence[pos].upper()
            confs_word = random.choice(list(self.word_confusions[word])).lower()
            if pos == 0:
                confs_word = confs_word.capitalize()
            confs_sentence[pos] = confs_word

        return confs_sentence

In [324]:
class SentenceAcousticConfuser_GPT2LossReplace(SentenceAcousticConfuser):
    def __init__(self,
                 word_confusions_path: str):
        
        super().__init__(word_confusions_path)
        self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.gpt2_model_lm = GPT2LMHeadModel.from_pretrained('gpt2')
    
    def _gpt2_lm_loss(self, sentence: List[str]) -> float:
        inputs = self.gpt2_tokenizer(' '.join(sentence), return_tensors="pt")
        outputs = self.gpt2_model_lm(**inputs, labels=inputs["input_ids"])
        loss = outputs[0].item()
        return loss
    
    def sentence_confuse(self, sentence: List[str], p: float) -> List[str]:
        assert 0 <= p <= 1
    
        sen_len = len(sentence)
        confs_cnt = int(p * sen_len)

        confusable_positions = []
        for pos in range(sen_len):
            word = sentence[pos].upper()
            if len(word_confusions[word]) > 0:
                confusable_positions.append(pos)

        if len(confusable_positions) <= confs_cnt:
            # not enough positions for confusion
            confs_positions = confusable_positions
        else:
            confs_positions = sorted(random.sample(confusable_positions, k=confs_cnt))

        confs_sentence = list(sentence)
        for pos in confs_positions:
            word = sentence[pos].upper()
            assert len(self.word_confusions[word]) > 0

            if len(self.word_confusions[word]) == 1:
                # No need for LM loss 
                _cw = next(iter(self.word_confusions[word])).lower()
                if pos == 0:
                    _cw = _cw.capitalize()
                confs_sentence[pos] = _cw
                continue

            # Compare different confusion words by their LM losses 
            best_lm_loss = np.inf
            best_confs_word = None
            for _confs_word in self.word_confusions[word]:
                _cw = _confs_word.lower()
                if pos == 0:
                    _cw = _cw.capitalize()

                _confs_sen = list(confs_sentence)
                _confs_sen[pos] = _cw

                _loss = self._gpt2_lm_loss(_confs_sen)
                # print(_confs_sen, _loss)

                if _loss < best_lm_loss:
                    best_lm_loss = _loss
                    best_confs_word = _cw

            assert best_confs_word is not None
            confs_sentence[pos] = best_confs_word

        return confs_sentence
    


In [325]:
word_confusions_path = '/Users/mac/Desktop/syt/Deep-Learning/Repos/TaBERT/data/word_confusions.pkl'

confuser_random = SentenceAcousticConfuser_RandomReplace(word_confusions_path)
confuser_gpt2 = SentenceAcousticConfuser_GPT2LossReplace(word_confusions_path)

You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [326]:
confuser_random, confuser_gpt2

(<__main__.SentenceAcousticConfuser_RandomReplace at 0x147ea3f50>,
 <__main__.SentenceAcousticConfuser_GPT2LossReplace at 0x1a0d841d0>)

In [349]:
sentence = [
    "What",
    "are",
    "the",
    "ids",
    "of",
    "the",
    "TV",
    "channels",
    "that",
    "do",
    "not",
    "have",
    "any",
    "cartoons",
    "directed",
    "by",
    "Ben",
    "Jones",
    "?"
]

print(confuser_random.sentence_confuse(sentence, p=0.15))
print(confuser_gpt2.sentence_confuse(sentence, p=0.15))

['What', 'are', 'the', 'ids', 'of', 'the', 'TV', 'channels', 'that', 'do', 'na', 'have', 'any', 'cartoons', 'erected', 'by', 'Ben', 'Jones', '?']
['What', 'are', 'the', 'ids', 'a', 'to', 'TV', 'channels', 'that', 'do', 'not', 'have', 'any', 'cartoons', 'directed', 'by', 'Ben', 'Jones', '?']


In [None]:
# (trial) Apply on tables.jsonl 
tabert_dataset_path = '/Users/mac/Desktop/syt/Deep-Learning/Dataset/TaBERT_datasets/tables_sample.jsonl'

with open(tabert_dataset_path, 'r') as f:
    l = f.readline()

# print(ujson.dumps(ujson.loads(l), indent=4))
d = ujson.loads(l)
d

In [None]:
p = 0.15

confs_d = d.copy()
for idx, sen in enumerate(d['context_before']):
    sen_tokens = word_tokenize(sen)
    sen_tokens_confused = SentenceAcousticConfusion(sen_tokens, p)
    sen_confused = ' '.join(sen_tokens_confused)
    confs_d['context_before'][idx] = sen_confused
for idx, sen in enumerate(d['context_after']):
    sen_tokens = word_tokenize(sen)
    sen_tokens_confused = SentenceAcousticConfusion(sen_tokens, p)
    sen_confused = ' '.join(sen_tokens_confused)
    confs_d['context_after'][idx] = sen_confused
confs_d

In [411]:
# Check preprocessed sample data 
sample_dir = '/Users/mac/Desktop/syt/Deep-Learning/Dataset/TaBERT_datasets/train_data/vanilla_tabert_sample_ac3/train'
sample_json_path = os.path.join(sample_dir, 'epoch_0.shard0.sample.json')
sample_h5_path = os.path.join(sample_dir, 'epoch_0.shard0.h5')

with open(sample_json_path, 'r') as f:
    sample_json = [ujson.loads(l) for l in f]

with h5py.File(sample_h5_path, 'r') as f:
    sample_h5 = {k : np.array(v) for k, v in f.items()}


In [412]:
len(sample_json), list(sample_json[0].keys())

(74,
 ['tokens',
  'token_ids',
  'segment_a_length',
  'masked_lm_positions',
  'masked_lm_labels',
  'masked_lm_label_ids',
  'info',
  'tokens_ref',
  'token_ref_ids',
  'source'])

In [413]:
list(sample_h5.keys())

['masked_lm_label_ids',
 'masked_lm_offsets',
 'masked_lm_positions',
 'segment_a_lengths',
 'sequence_offsets',
 'sequences',
 'sequences_ref']

In [414]:
sample_h5['sequences'].shape

(168859,)

In [415]:
l0 = list(sample_json[0]['token_ids'])
l = list(sample_h5['sequences'])

for _idx in range(len(l) - len(l0)):
    if l[_idx : _idx + len(l0)] == l0:
        print(_idx)
        idx = _idx

1424


In [416]:
print(l[idx : idx + len(l0)])
print(l0)

[101, 103, 2017, 17816, 1037, 7615, 2008, 2017, 3373, 23640, 2015, 2151, 1997, 1996, 11594, 14801, 2682, 1010, 3531, 11562, 1000, 5210, 15501, 15884, 1012, 103, 4067, 2017, 1012, 7929, 102, 103, 1064, 13319, 1064, 1014, 8466, 102, 2580, 1064, 2613, 1064, 2184, 2086, 1023, 3134, 3283, 102, 2197, 7514, 1064, 2613, 1064, 1050, 1013, 1037, 102]
[101, 103, 2017, 17816, 1037, 7615, 2008, 2017, 3373, 23640, 2015, 2151, 1997, 1996, 11594, 14801, 2682, 1010, 3531, 11562, 1000, 5210, 15501, 15884, 1012, 103, 4067, 2017, 1012, 7929, 102, 103, 1064, 13319, 1064, 1014, 8466, 102, 2580, 1064, 2613, 1064, 2184, 2086, 1023, 3134, 3283, 102, 2197, 7514, 1064, 2613, 1064, 1050, 1013, 1037, 102]


In [417]:
print(list(sample_json[0]['token_ref_ids']))
print(list(sample_h5['sequences_ref'])[idx : idx + len(l0)])

[101, 103, 2017, 2156, 1037, 7615, 2008, 2017, 2903, 23640, 2015, 2151, 1997, 1996, 11594, 14801, 2682, 1010, 3531, 11562, 1000, 5210, 2004, 15884, 1012, 103, 4067, 2017, 1012, 7929, 102, 103, 1064, 13319, 1064, 1014, 8466, 102, 2580, 1064, 2613, 1064, 2184, 2086, 1023, 3134, 3283, 102, 2197, 7514, 1064, 2613, 1064, 1050, 1013, 1037, 102]
[101, 103, 2017, 2156, 1037, 7615, 2008, 2017, 2903, 23640, 2015, 2151, 1997, 1996, 11594, 14801, 2682, 1010, 3531, 11562, 1000, 5210, 2004, 15884, 1012, 103, 4067, 2017, 1012, 7929, 102, 103, 1064, 13319, 1064, 1014, 8466, 102, 2580, 1064, 2613, 1064, 2184, 2086, 1023, 3134, 3283, 102, 2197, 7514, 1064, 2613, 1064, 1050, 1013, 1037, 102]


In [418]:
l0 = list(sample_json[0]['masked_lm_label_ids'])
l = list(sample_h5['masked_lm_label_ids'])

for _idx in range(len(l) - len(l0)):
    if l[_idx : _idx + len(l0)] == l0:
        print(_idx)
        idx = _idx

print(l[idx : idx + len(l0)])
print(l0)

302
[2065, 2156, 2903, 2004, 15884, 1000, 14054, 2613]
[2065, 2156, 2903, 2004, 15884, 1000, 14054, 2613]


In [419]:
print(list(sample_json[0]['masked_lm_positions']))
print(list(sample_h5['masked_lm_positions'])[idx : idx + len(l0)])

[1, 3, 8, 22, 23, 25, 31, 33]
[1, 3, 8, 22, 23, 25, 31, 33]


### Temp

In [425]:
_sentence = "I'm a Mr. 'hyusyahuti', with value 10.0 to test tokenizerxs. Don't mind if you've you're!"
_sentence_tok = bert_tokenizer.tokenize(_sentence)
print(_sentence_tok)

['i', "'", 'm', 'a', 'mr', '.', "'", 'h', '##yu', '##sy', '##ahu', '##ti', "'", ',', 'with', 'value', '10', '.', '0', 'to', 'test', 'token', '##izer', '##x', '##s', '.', 'don', "'", 't', 'mind', 'if', 'you', "'", 've', 'you', "'", 're', '!']


In [430]:
_sentence_tok = ['images', 'reginald', 'barclay', 'reginald', 'barclay', 'the', 'next', 'generation', 'title', 'role', 'disc', 'no', 'episode', 'number', 'air', '##date', 'star', '##date', 'season', 'year', 'rating', 'hollow', 'pursuits', 'lieutenant', 'reginald', 'barclay', '3', '.', '5', '68', '30', 'apr', '1990', '43', '##80', '##7', '.', '4', '3', '236', '##6', 'the', 'n', "'", 'th', 'degree', 'lieutenant', 'reginald', 'barclay', '4', '.', '4', '92', '1', 'apr', '1991', '44', '##70', '##4', '.', '2', '4', '236', '##7', 'realm', 'of', 'fear', 'lieutenant', 'reginald', 'barclay', '6', '.', '1', '127', '28', 'sep', '1992', '460', '##41', '.', '1', '6', '236', '##9', 'ship', 'in', 'a', 'bottle', 'lieutenant', 'reginald', 'barclay', '6', '.', '3', '137', '25', 'jan', '1993', '46', '##42', '##4', '.', '1', '6', '236', '##9']
print(_sentence_tok, len(_sentence_tok))

['images', 'reginald', 'barclay', 'reginald', 'barclay', 'the', 'next', 'generation', 'title', 'role', 'disc', 'no', 'episode', 'number', 'air', '##date', 'star', '##date', 'season', 'year', 'rating', 'hollow', 'pursuits', 'lieutenant', 'reginald', 'barclay', '3', '.', '5', '68', '30', 'apr', '1990', '43', '##80', '##7', '.', '4', '3', '236', '##6', 'the', 'n', "'", 'th', 'degree', 'lieutenant', 'reginald', 'barclay', '4', '.', '4', '92', '1', 'apr', '1991', '44', '##70', '##4', '.', '2', '4', '236', '##7', 'realm', 'of', 'fear', 'lieutenant', 'reginald', 'barclay', '6', '.', '1', '127', '28', 'sep', '1992', '460', '##41', '.', '1', '6', '236', '##9', 'ship', 'in', 'a', 'bottle', 'lieutenant', 'reginald', 'barclay', '6', '.', '3', '137', '25', 'jan', '1993', '46', '##42', '##4', '.', '1', '6', '236', '##9'] 106


In [431]:
def detokenize_BertTokenizer(sentence: List[str]) -> str:
    return (
        " ".join(sentence)
        .replace(" ##", "")
#         .replace(" .", ".")
#         .replace(" ?", "?")
#         .replace(" !", "!")
#         .replace(" ,", ",")
#         .replace(" ' ", "'")
#         .replace(" n't", "n't")
#         .replace(" 'm", "'m")
#         .replace(" 's", "'s")
#         .replace(" 've", "'ve")
#         .replace(" 're", "'re")
    )

In [432]:
_sentence_detok = detokenize_BertTokenizer(_sentence_tok)
_sentence_detok

"images reginald barclay reginald barclay the next generation title role disc no episode number airdate stardate season year rating hollow pursuits lieutenant reginald barclay 3 . 5 68 30 apr 1990 43807 . 4 3 2366 the n ' th degree lieutenant reginald barclay 4 . 4 92 1 apr 1991 44704 . 2 4 2367 realm of fear lieutenant reginald barclay 6 . 1 127 28 sep 1992 46041 . 1 6 2369 ship in a bottle lieutenant reginald barclay 6 . 3 137 25 jan 1993 46424 . 1 6 2369"

In [433]:
print(bert_tokenizer.tokenize(_sentence_detok))

['images', 'reginald', 'barclay', 'reginald', 'barclay', 'the', 'next', 'generation', 'title', 'role', 'disc', 'no', 'episode', 'number', 'air', '##date', 'star', '##date', 'season', 'year', 'rating', 'hollow', 'pursuits', 'lieutenant', 'reginald', 'barclay', '3', '.', '5', '68', '30', 'apr', '1990', '43', '##80', '##7', '.', '4', '3', '236', '##6', 'the', 'n', "'", 'th', 'degree', 'lieutenant', 'reginald', 'barclay', '4', '.', '4', '92', '1', 'apr', '1991', '44', '##70', '##4', '.', '2', '4', '236', '##7', 'realm', 'of', 'fear', 'lieutenant', 'reginald', 'barclay', '6', '.', '1', '127', '28', 'sep', '1992', '460', '##41', '.', '1', '6', '236', '##9', 'ship', 'in', 'a', 'bottle', 'lieutenant', 'reginald', 'barclay', '6', '.', '3', '137', '25', 'jan', '1993', '46', '##42', '##4', '.', '1', '6', '236', '##9']


In [434]:
bert_tokenizer.tokenize(_sentence_detok) == _sentence_tok

True