Execution steps are from: https://github.com/pytorch/fairseq/blob/main/examples/bart/README.summarization.md
<br/>

Refer to the following files for code:
1. https://github.com/pytorch/fairseq/blob/main/fairseq/data/encoders/gpt2_bpe_utils.py
2. https://github.com/pytorch/fairseq/blob/fcca32258c8e8bcc9f9890bf4714fa2f96b6b3e1/fairseq/binarizer.py#L49
3. https://github.com/pytorch/fairseq/blob/fcca32258c8e8bcc9f9890bf4714fa2f96b6b3e1/fairseq/data/dictionary.py#L304
4. https://github.com/neulab/guided_summarization/blob/ea4bbe91f189cdb51f7f6a827210f9adc5319b3c/bart/fairseq/models/bart/guided_hub_interface.py#L122

In [8]:
import os
import json
import regex as re

In [2]:
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = (
        list(range(ord("!"), ord("~") + 1))
        + list(range(ord("¡"), ord("¬") + 1))
        + list(range(ord("®"), ord("ÿ") + 1))
    )
    cs = bs[:]
    n = 0
    for b in range(2 ** 8):
        if b not in bs:
            bs.append(b)
            cs.append(2 ** 8 + n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

In [3]:
def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

In [81]:
def bpe(token, bpe_ranks, debug=False):
    word = tuple(token)
    pairs = get_pairs(word)

    if not pairs:
        return token

    while True:
        if debug:
            print('pairs', pairs)
        bigram = min(pairs, key=lambda pair: bpe_ranks.get(pair, float("inf")))
        if debug:
            print('bigram', bigram)
        if bigram not in bpe_ranks:
            if debug:
                print('{} not in bpe_ranks'.format(bigram))
            break
        first, second = bigram
        new_word = []
        i = 0
        while i < len(word):
            try:
                j = word.index(first, i)
                new_word.extend(word[i:j])
                i = j
            except:
                new_word.extend(word[i:])
                break
            
            if debug:
                print('new_word', new_word)

            if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
                new_word.append(first + second)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
            
            if debug:
                print('new_word2', new_word)    
            
        new_word = tuple(new_word)
        if debug:
            print('outofwhile', new_word)
            print()
            
        word = new_word
        if len(word) == 1:
            break
        else:
            pairs = get_pairs(word)
    word = " ".join(word)
    return word

In [106]:
'''
wget -N https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe
wget -N https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json
'''

encoder_json_path = 'encoder.json'
vocab_bpe_path = 'vocab.bpe'

with open(encoder_json_path, "r") as f:
    encoder = json.load(f)
decoder = {v: k for k, v in encoder.items()}    

with open(vocab_bpe_path, "r", encoding="utf-8") as f:
    bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]

bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}

In [83]:
# The first key appears different from the one in vocab.bpe when opened in browser. Related to encoding probably.
bpe_merges[:5]

[('Ġ', 't'), ('Ġ', 'a'), ('h', 'e'), ('i', 'n'), ('r', 'e')]

In [98]:
token = ' Shantanu'
benc_token = "".join(byte_encoder[b] for b in token.encode("utf-8"))
print(benc_token, '\n')
bpe(benc_token, bpe_ranks, debug=True)

ĠShantanu 

pairs {('Ġ', 'S'), ('S', 'h'), ('a', 'n'), ('n', 't'), ('n', 'u'), ('h', 'a'), ('t', 'a')}
bigram ('a', 'n')
new_word ['Ġ', 'S', 'h']
new_word2 ['Ġ', 'S', 'h', 'an']
new_word ['Ġ', 'S', 'h', 'an', 't']
new_word2 ['Ġ', 'S', 'h', 'an', 't', 'an']
outofwhile ('Ġ', 'S', 'h', 'an', 't', 'an', 'u')

pairs {('Ġ', 'S'), ('h', 'an'), ('S', 'h'), ('t', 'an'), ('an', 'u'), ('an', 't')}
bigram ('Ġ', 'S')
new_word []
new_word2 ['ĠS']
outofwhile ('ĠS', 'h', 'an', 't', 'an', 'u')

pairs {('ĠS', 'h'), ('h', 'an'), ('t', 'an'), ('an', 'u'), ('an', 't')}
bigram ('an', 't')
new_word ['ĠS', 'h']
new_word2 ['ĠS', 'h', 'ant']
new_word ['ĠS', 'h', 'ant']
new_word2 ['ĠS', 'h', 'ant', 'an']
outofwhile ('ĠS', 'h', 'ant', 'an', 'u')

pairs {('ĠS', 'h'), ('ant', 'an'), ('h', 'ant'), ('an', 'u')}
bigram ('ĠS', 'h')
new_word []
new_word2 ['ĠSh']
outofwhile ('ĠSh', 'ant', 'an', 'u')

pairs {('an', 'u'), ('ant', 'an'), ('ĠSh', 'ant')}
bigram ('an', 'u')
new_word ['ĠSh', 'ant']
new_word2 ['ĠSh', 'ant', 'an

'ĠShant anu'

In [99]:
token = ' Shantanu'
benc_token = "".join(byte_encoder[b] for b in token.encode("utf-8"))
enc_token = [encoder[bpe_token] for bpe_token in bpe(benc_token, bpe_ranks).split(" ")]
enc_token

[49892, 42357]

In [107]:
text = "".join([decoder.get(token, token) for token in enc_token])
print(text)
text = bytearray([byte_decoder[c] for c in text]).decode(
    "utf-8", errors='replace'
)
print(text)

ĠShantanu
 Shantanu


In [84]:
pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

In [85]:
sent = "Hi, I am Shantanu. I've been trying understand this code for long now"

tokens = re.findall(pat, sent)
print(tokens)

['Hi', ',', ' I', ' am', ' Shantanu', '.', ' I', "'ve", ' been', ' trying', ' understand', ' this', ' code', ' for', ' long', ' now']


In [100]:
# Space encodede as Gdot. Punctuations remain the same

for token in tokens:
    benc_token = "".join(byte_encoder[b] for b in token.encode("utf-8"))
    print(token, benc_token)

Hi Hi
, ,
 I ĠI
 am Ġam
 Shantanu ĠShantanu
. .
 I ĠI
've 've
 been Ġbeen
 trying Ġtrying
 understand Ġunderstand
 this Ġthis
 code Ġcode
 for Ġfor
 long Ġlong
 now Ġnow


In [102]:
for token in tokens:
    benc_token = "".join(byte_encoder[b] for b in token.encode("utf-8"))
    print(benc_token, bpe(benc_token, bpe_ranks))

Hi Hi
, ,
ĠI ĠI
Ġam Ġam
ĠShantanu ĠShant anu
. .
ĠI ĠI
've 've
Ġbeen Ġbeen
Ġtrying Ġtrying
Ġunderstand Ġunderstand
Ġthis Ġthis
Ġcode Ġcode
Ġfor Ġfor
Ġlong Ġlong
Ġnow Ġnow


In [109]:
enc_tokens = []
for token in tokens:
    benc_token = "".join(byte_encoder[b] for b in token.encode("utf-8"))
    enc_tokens.extend([encoder[bpe_token] for bpe_token in bpe(benc_token, bpe_ranks).split(" ")])
print(enc_tokens)

text = "".join([decoder.get(token, token) for token in enc_tokens])
text = bytearray([byte_decoder[c] for c in text]).decode("utf-8", errors='replace')
print(text)

[17250, 11, 314, 716, 49892, 42357, 13, 314, 1053, 587, 2111, 1833, 428, 2438, 329, 890, 783]
Hi, I am Shantanu. I've been trying understand this code for long now


## Binarization

After we have the byte-pair encoding, we use binarization to convert these ids to a standard one that the encoder is trained on. The mapping is provided by a dict file. We will use this as an example:

https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt

**NOTE**: Refer to this for understanding the flow:
https://github.com/neulab/guided_summarization/blob/ea4bbe91f189cdb51f7f6a827210f9adc5319b3c/bart/fairseq/models/bart/guided_hub_interface.py#L84


In [116]:
'''
wget -N https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt
'''

with open('dict.txt', "r", encoding="utf-8") as f:
    binarization_dict = dict([row.split() for row in f.read().splitlines()])

In [118]:
in_tokens = list(map(str, enc_tokens))
out_tokens = [binarization_dict[token] for token in in_tokens]
print(list(zip(in_tokens, out_tokens)))

[('17250', '38737'), ('11', '800251374'), ('314', '60989470'), ('716', '4365322'), ('49892', '14052'), ('42357', '178076'), ('13', '850314647'), ('314', '60989470'), ('1053', '6231567'), ('587', '34803895'), ('2111', '3590156'), ('1833', '1899229'), ('428', '52404829'), ('2438', '768787'), ('329', '155236946'), ('890', '8325716'), ('783', '15823492')]


In [121]:
print(list(map(int, out_tokens)))

[38737, 800251374, 60989470, 4365322, 14052, 178076, 850314647, 60989470, 6231567, 34803895, 3590156, 1899229, 52404829, 768787, 155236946, 8325716, 15823492]
