## Start

In [1]:
import os
import pickle
import time

import regex as re

from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
from typing import BinaryIO
from collections import defaultdict

In [2]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
TOKEN = tuple[bytes]
PAIR = tuple[bytes, bytes]

## 1 Assignment Overview

## 2 Byte-Pair Encoding (BPE) Tokenizer

### 2.1 The Unicode Standard

In [3]:
[ord('牛'), chr(29275)]

[29275, '牛']

#### Problem (unicode1)

##### a

In [None]:
chr(0)

'\x00'

##### b

In [None]:
print(chr(0))

 


##### c

In [8]:
"this is a test" + chr(0) + "string"

'this is a test\x00string'

In [9]:
print("this is a test" + chr(0) + "string")

this is a test string


### 2.2 Unicode Encodings

In [10]:
test_string = "hello! こんにちは!"
utf8_encoded = test_string.encode("utf-8")
print(utf8_encoded)

b'hello! \xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf!'


In [11]:
print(type(utf8_encoded))

<class 'bytes'>


In [12]:
list(utf8_encoded)

[104,
 101,
 108,
 108,
 111,
 33,
 32,
 227,
 129,
 147,
 227,
 130,
 147,
 227,
 129,
 171,
 227,
 129,
 161,
 227,
 129,
 175,
 33]

In [13]:
[len(test_string),len(utf8_encoded)]

[13, 23]

In [14]:
print(utf8_encoded.decode("utf-8"))

hello! こんにちは!


#### Problem (unicode2)

##### a

##### b

In [15]:
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])
decode_utf8_bytes_to_str_wrong("hello".encode("utf-8"))

'hello'

In [16]:
decode_utf8_bytes_to_str_wrong("hello, 你好".encode("utf-8"))

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe4 in position 0: unexpected end of data

##### c

In [18]:
sentence = "hello, 你好".encode("utf-8")
[bytes([b]) for b in sentence]

[b'h',
 b'e',
 b'l',
 b'l',
 b'o',
 b',',
 b' ',
 b'\xe4',
 b'\xbd',
 b'\xa0',
 b'\xe5',
 b'\xa5',
 b'\xbd']

In [24]:
(b'\xe4\xbd\xa0').decode("utf-8")

'你'

### 2.3 Subword Tokenization

### 2.4 BPE Tokenizer Training

In [27]:
re.findall(PAT, "some text that i'll pre-tokenize")

['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize']

In [28]:
max([("A", "B"), ("A", "C"), ("B", "ZZ"), ("BA", "A")])

('BA', 'A')

In [31]:
test_string = """low low low low low lower lower widest widest widest newest newest newest newest newest newest"""

In [32]:
re.findall(PAT, test_string)

['low',
 ' low',
 ' low',
 ' low',
 ' low',
 ' lower',
 ' lower',
 ' widest',
 ' widest',
 ' widest',
 ' newest',
 ' newest',
 ' newest',
 ' newest',
 ' newest',
 ' newest']

### 2.5 Experimenting with BPE Tokenizer Training

In [4]:
def find_chunk_boundaries(
    file: BinaryIO,
    desired_num_chunks: int,
    split_special_token: bytes,
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring"

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

def update_token_freqs(
    token_freqs: dict[TOKEN, int],
    text_segment: str,
):
    matches = re.finditer(PAT, text_segment)
    for m in matches:
        token = m.group()
        token_bytes = tuple(bytes([b]) for b in token.encode("utf-8"))
        token_freqs[token_bytes] += 1 

def process_chunk(args):
    input_path, start, end, special_tokens, chunk_id = args
    # Read chunk
    with open(input_path, 'rb') as f:
        f.seek(start)
        chunk_bytes = f.read(end - start)
    
    # Decode
    chunk_str = chunk_bytes.decode("utf-8", errors="ignore")
    
    # Logic from original train_bpe
    special_pat = "|".join(re.escape(st) for st in special_tokens)
    segments = re.split(special_pat, chunk_str)
    
    token_freqs = defaultdict(int)
    # Use position=chunk_id+1 so 0 is left for the main bar
    for seg in segments:
        update_token_freqs(seg, token_freqs)
    
    return token_freqs

def get_pairs(
    token: TOKEN
) -> list[PAIR]:
    if len(token) < 2:
        return []
    return [(token[i], token[i+1]) for i in range(len(token)-1)]

def update_pair_freqs(
    pair_freqs: dict[PAIR, int],
    token_freqs: dict[TOKEN, int],
):
    for token, freq in token_freqs.items():
        if len(token) < 2:
            continue
        pairs = get_pairs(token)
        for p in pairs:
            pair_freqs[p] += freq

def update_pair2idx(
    pair_to_idx: dict[PAIR, set[TOKEN]],
    token_freqs: dict[TOKEN, int],
    token_to_idx: dict[TOKEN, int],
):
    for token, freq in token_freqs.items():
        if len(token) < 2:
            continue
        pairs = get_pairs(token)
        for p in pairs:
            pair_to_idx[p].add(token_to_idx[token])

def update_token2pair(
    token_freqs: dict[TOKEN, int],
    token_to_pair: dict[TOKEN, list[PAIR]],
):
    for token, freq in token_freqs.items():
        if len(token) < 2:
            continue
        pairs = get_pairs(token)
        token_to_pair[token] = pairs

def get_most_frequent_pair(
    pair_freqs: dict[PAIR, int],
) -> PAIR:
    return max(pair_freqs.keys(), key=lambda k: (pair_freqs[k], k))

def update_vocab(
    new_id: int,
    best_pair: PAIR,
    vocab: dict[int, bytes],
):
    new_vocab = best_pair[0] + best_pair[1]
    vocab[new_id] = new_vocab

def update_vocab_inverse(
    new_id: int,
    best_pair: PAIR,
    vocab_inverse: dict[bytes, int],
):
    new_vocab = best_pair[0] + best_pair[1]
    vocab_inverse[new_vocab] = new_id

def update_merges(
    best_pair: PAIR,
    merges: list[PAIR],
):
    merges.append(best_pair)

def update_all(
    best_pair: PAIR,
    pair_to_token: dict[PAIR, set[TOKEN]],
    token_to_pair: dict[TOKEN, list[PAIR]],
    token_freqs: dict[TOKEN, int],
    pair_freqs: dict[PAIR, int],
):
    affected_tokens = list(pair_to_token[best_pair])
    merged_bytes = best_pair[0] + best_pair[1]

    for token in affected_tokens:
        # get new token
        i=0
        new_token = []
        while(i<(len(token))):
            if (i < len(token) - 1) and (token[i] == best_pair[0]) and (token[i+1] == best_pair[1]):
                new_token.append(merged_bytes)
                i = i + 2
            else:
                new_token.append(token[i])
                i = i + 1
        new_token = tuple(new_token)


        ## update pair_to_token
        new_pairs = get_pairs(new_token)
        affected_pairs = token_to_pair[token]
        for pair in affected_pairs:
            pair_to_token[pair].discard(token)
        for pair in new_pairs:
            pair_to_token[pair].add(new_token)


        ## update token_to_pair
        token_to_pair.pop(token)
        token_to_pair[new_token] = new_pairs


        ## update pair_freqs
        for pair in affected_pairs:
            pair_freqs[pair] -= token_freqs[token]
        for pair in new_pairs:
            pair_freqs[pair] += token_freqs[token]


        ## update token_freqs
        origin_freq = token_freqs.pop(token)
        token_freqs[new_token] = origin_freq
        

def train_bpe(
    input_path: str,
    vocab_size: int,
    special_tokens: list[str],
    num_processes: int = 4
):
    print("Start Training BPE...")
    start_total = time.time()

    print("Initialize Vocab and Merges...")
    vocab = {i:bytes([i]) for i in range(256)}
    vocab_inverse = {v:k for k,v in vocab.items()}
    # Add special tokens
    for st in special_tokens:
        new_id = len(vocab)
        st_bytes = st.encode("utf-8")
        vocab[new_id] = st_bytes
        vocab_inverse[st_bytes] = new_id
    
    merges = []

    print("Calculating chunk boundaries...")
    t0 = time.time()
    with open(input_path, 'rb') as f:
        # Assuming first special token is the split token
        split_token = special_tokens[0].encode("utf-8") if special_tokens else b"<|endoftext|>"
        boundaries = find_chunk_boundaries(f, num_processes, split_token) # More chunks than processes for load balancing
    print(f"Boundaries calculated in {time.time() - t0:.2f}s")
    
    token_freqs = defaultdict(int)

    ## Initiate token_freqs
    print("Update Token Freq (Parallel)...")
    t0 = time.time()
    
    tasks = []
    for i in range(len(boundaries) - 1):
        start = boundaries[i]
        end = boundaries[i+1]
        tasks.append((input_path, start, end, special_tokens, i))
    
    with ProcessPoolExecutor(max_workers=num_processes) as executor:
        results = list(tqdm(executor.map(process_chunk, tasks), total=len(tasks), desc="Token Freqs (Chunks)", position=0))
        
        for local_freqs in results:
            for token, count in local_freqs.items():
                token_freqs[token] += count
                
    print(f"Update Token Freq took {time.time() - t0:.2f}s")

    pair_freqs = defaultdict(int)
    pair_to_token = defaultdict(set)
    token_to_pair = defaultdict(list)

    ## Initiate pair_freqs, pair_to_token, and token_to_pair
    print("Update Pair Freq...")
    t0 = time.time()
    update_pair_freqs(token_freqs, pair_freqs)
    print(f"Update Pair Freq took {time.time() - t0:.2f}s")

    print("Update Pair to Token...")
    t0 = time.time()
    update_pair2token(token_freqs, pair_to_token)
    print(f"Update Pair to Token took {time.time() - t0:.2f}s")

    print("Update Token to Pair...")
    t0 = time.time()
    update_token2pair(token_freqs, token_to_pair)
    print(f"Update Token to Pair took {time.time() - t0:.2f}s")

    print("Start Merging...")
    t0 = time.time()
    num_merges = vocab_size - 256 - len(special_tokens)
    for i in tqdm(range(num_merges), desc="Merging"):
        ### find most frequent pair
        if not pair_freqs:
            break
        best_pair = get_most_frequent_pair(pair_freqs)

        new_id = len(vocab)

        ## update vocab
        update_vocab(new_id, best_pair, vocab)

        ## update vocab_inverse
        update_vocab_inverse(new_id, best_pair, vocab_inverse)

        ## update merges
        update_merges(best_pair, merges)

        update_all(best_pair, pair_to_token, token_to_pair, token_freqs, pair_freqs)
    print(f"Merging took {time.time() - t0:.2f}s")
    print(f"Total Training took {time.time() - start_total:.2f}s")

    return vocab, merges

In [32]:
test_string = "abere ererea<|endoftext|>When and where is not as important as who and what. Hi, I am the the Ivan.<|endoftext|> aaa"
# test_string = "abere ererea"
input_path = ""
vocab_size = 260
special_tokens = ['<|endoftext|>']

In [33]:
special_pat = "|".join(re.escape(st) for st in special_tokens)
segments = re.split(special_pat, test_string)

text_segment = segments[0]
text_segment

'abere ererea'

In [34]:
vocab = {i:bytes([i]) for i in range(256)}
vocab_inverse = {v:k for k,v in vocab.items()}
# Add special tokens
for st in special_tokens:
    new_id = len(vocab)
    st_bytes = st.encode("utf-8")
    vocab[new_id] = st_bytes
    vocab_inverse[st_bytes] = new_id

merges = []

In [35]:
token_freqs = defaultdict(int)
update_token_freqs(token_freqs,text_segment)
token_freqs

defaultdict(int,
            {(b'a', b'b', b'e', b'r', b'e'): 1,
             (b' ', b'e', b'r', b'e', b'r', b'e', b'a'): 1})

In [36]:
idx_to_token = {i: k for i, k in enumerate(token_freqs.keys())}
token_to_idx = {k: i for i, k in enumerate(token_freqs.keys())}
token_to_idx

{(b'a', b'b', b'e', b'r', b'e'): 0,
 (b' ', b'e', b'r', b'e', b'r', b'e', b'a'): 1}

In [37]:
pair_freqs = defaultdict(int)
update_pair_freqs(pair_freqs,token_freqs)
pair_freqs

defaultdict(int,
            {(b'a', b'b'): 1,
             (b'b', b'e'): 1,
             (b'e', b'r'): 3,
             (b'r', b'e'): 3,
             (b' ', b'e'): 1,
             (b'e', b'a'): 1})

In [38]:
pair_to_idx = defaultdict(set)
update_pair2idx(pair_to_idx, token_freqs, token_to_idx)
pair_to_idx

defaultdict(set,
            {(b'a', b'b'): {0},
             (b'b', b'e'): {0},
             (b'e', b'r'): {0, 1},
             (b'r', b'e'): {0, 1},
             (b' ', b'e'): {1},
             (b'e', b'a'): {1}})

In [39]:
best_pair = get_most_frequent_pair(pair_freqs)
best_pair

(b'r', b'e')

In [40]:
new_id = len(vocab)

## update vocab
update_vocab(new_id, best_pair, vocab)

## update vocab_inverse
update_vocab_inverse(new_id, best_pair, vocab_inverse)

## update merges
update_merges(best_pair, merges)

In [41]:
affected_idxs = list(pair_to_idx[best_pair])
merged_bytes = best_pair[0] + best_pair[1]
for idx in affected_idxs:
    token = idx_to_token[idx]

    i=0
    new_token = []
    while(i<(len(token))):
        if (i < len(token) - 1) and (token[i] == best_pair[0]) and (token[i+1] == best_pair[1]):
            new_token.append(merged_bytes)
            i = i + 2
        else:
            new_token.append(token[i])
            i = i + 1
    new_token = tuple(new_token)

    new_pairs = get_pairs(new_token)
    affected_pairs = get_pairs(token)
    for p in set(new_pairs)-set(affected_pairs):
        pair_to_idx[p].add(idx)
    for p in set(affected_pairs)-set(new_pairs):
        pair_to_idx[p].discard(idx)


    for pair in affected_pairs:
        pair_freqs[pair] -= token_freqs[token]
    for pair in new_pairs:
        pair_freqs[pair] += token_freqs[token]

    origin_freq = token_freqs.pop(token)
    token_freqs[new_token] = origin_freq

    idx_to_token[idx] = new_token

In [43]:
pair_to_idx

defaultdict(set,
            {(b'a', b'b'): {0},
             (b'b', b'e'): {0},
             (b'e', b'r'): set(),
             (b'r', b'e'): set(),
             (b' ', b'e'): {1},
             (b'e', b'a'): set(),
             (b'e', b're'): {0, 1},
             (b're', b'a'): {1},
             (b're', b're'): {1}})

In [44]:
pair_freqs

defaultdict(int,
            {(b'a', b'b'): 1,
             (b'b', b'e'): 1,
             (b'e', b'r'): 0,
             (b'r', b'e'): 0,
             (b' ', b'e'): 1,
             (b'e', b'a'): 0,
             (b'e', b're'): 2,
             (b're', b're'): 1,
             (b're', b'a'): 1})

In [45]:
token_freqs

defaultdict(int,
            {(b'a', b'b', b'e', b're'): 1,
             (b' ', b'e', b're', b're', b'a'): 1})

In [46]:
idx_to_token

{0: (b'a', b'b', b'e', b're'), 1: (b' ', b'e', b're', b're', b'a')}

In [16]:
affected_idxs = list(pair_to_idx[best_pair])
merged_bytes = best_pair[0] + best_pair[1]
idx = affected_idxs[0]
token = idx_to_token[idx]
print(token)

(b'a', b'b', b'e', b'r', b'e')


In [17]:
i=0
new_token = []
while(i<(len(token))):
    if (i < len(token) - 1) and (token[i] == best_pair[0]) and (token[i+1] == best_pair[1]):
        new_token.append(merged_bytes)
        i = i + 2
    else:
        new_token.append(token[i])
        i = i + 1
new_token = tuple(new_token)
new_token

(b'a', b'b', b'e', b're')

In [28]:
new_pairs = get_pairs(new_token)
affected_pairs = get_pairs(token)
for p in set(new_pairs)-set(affected_pairs):
    pair_to_idx[p].add(idx)
for p in set(affected_pairs)-set(new_pairs):
    pair_to_idx[p].discard(idx)
pair_to_idx

defaultdict(set,
            {(b'a', b'b'): {0},
             (b'b', b'e'): {0},
             (b'e', b'r'): {1},
             (b'r', b'e'): {1},
             (b' ', b'e'): {1},
             (b'e', b'a'): {1},
             (b'e', b're'): {0}})

In [30]:
for pair in affected_pairs:
    pair_freqs[pair] -= token_freqs[token]
for pair in new_pairs:
    pair_freqs[pair] += token_freqs[token]
pair_freqs

defaultdict(int,
            {(b'a', b'b'): 1,
             (b'b', b'e'): 1,
             (b'e', b'r'): 2,
             (b'r', b'e'): 2,
             (b' ', b'e'): 1,
             (b'e', b'a'): 1,
             (b'e', b're'): 1})

In [None]:
## update token_freqs
origin_freq = token_freqs.pop(token)
token_freqs[new_token] = origin_freq

{(b'e', b'r'), (b'r', b'e')}

In [None]:
## update idx2token
idx_to_token[idx] = new_token

defaultdict(int,
            {(b'a', b'b', b'e', b'r', b'e'): 1,
             (b' ', b'e', b'r', b'e', b'r', b'e', b'a'): 1})

- [] vocab
- [] vocab_inverse 
- [] merges 
- [] tokens
- [] token_freqs
- [] pair_freqs
- [] pair_to_token
- [] token_to_pair

In [10]:
new_id = 256

In [11]:
special_pat = "|".join(re.escape(st) for st in special_tokens)
segments = re.split(special_pat, test_string)
text_segment = segments[0]
text_segment

'abere ererea'

In [14]:
token_freqs = defaultdict(int)
pair_freqs = defaultdict(int)
pair_to_token = defaultdict(set)
token_to_pair = defaultdict(list)

## Initiate token_freqs
token_freqs = update_token_freqs(text_segment, token_freqs)

# Initiate pair_freqs, pair_to_token, and token_to_pair
pair_freqs = update_pair_freqs(token_freqs, pair_freqs)
pair_to_token = update_pair2token(token_freqs, pair_to_token)
token_to_pair = update_token2pair(token_freqs, token_to_pair)


### find most frequent pair
best_pair = get_most_frequent_pair(pair_freqs)

## update vocab
vocab = update_vocab(new_id, best_pair, vocab)

## update vocab_inverse
vocab_inverse = update_vocab_inverse(new_id, best_pair, vocab_inverse)

## update merges
merges = update_merges(best_pair, merges)


pair_freqs, pair_to_token, token_to_pair, token_freqs = update_all(best_pair, pair_freqs, pair_to_token, token_to_pair, token_freqs)


In [15]:
pair_freqs

defaultdict(int,
            {(b'a', b'b'): 1,
             (b'b', b'e'): 1,
             (b'e', b'r'): 0,
             (b'r', b'e'): 0,
             (b' ', b'e'): 1,
             (b'e', b'a'): 0,
             (b'e', b're'): 2,
             (b're', b're'): 1,
             (b're', b'a'): 1})

In [13]:
pair_to_token

defaultdict(set,
            {(b'a', b'b'): {(b'a', b'b', b'e', b'r', b'e')},
             (b'b', b'e'): {(b'a', b'b', b'e', b'r', b'e')},
             (b'e', b'r'): {(b' ', b'e', b'r', b'e', b'r', b'e', b'a'),
              (b'a', b'b', b'e', b'r', b'e')},
             (b'r', b'e'): {(b' ', b'e', b'r', b'e', b'r', b'e', b'a'),
              (b'a', b'b', b'e', b'r', b'e')},
             (b' ', b'e'): {(b' ', b'e', b'r', b'e', b'r', b'e', b'a')},
             (b'e', b'a'): {(b' ', b'e', b'r', b'e', b'r', b'e', b'a')}})

In [70]:
affected_tokens = list(pair_to_token[best_pair])
merged_bytes = best_pair[0] + best_pair[1]
token = affected_tokens[0]
token

IndexError: list index out of range

In [25]:
i=0
new_token = []
while(i<(len(token))):
    if (i < len(token) - 1) and (token[i] == best_pair[0]) and (token[i+1] == best_pair[1]):
        new_token.append(merged_bytes)
        i = i + 2
    else:
        new_token.append(token[i])
        i = i + 1
new_token = tuple(new_token)


## update pair_to_token
new_pairs = [(new_token[i], new_token[i+1]) for i in range(len(new_token)-1)]
affected_pairs = token_to_pair[token]
for pair in affected_pairs:
    pair_to_token[pair].discard(token)
for pair in new_pairs:
    pair_to_token[pair].add(new_token)


## update token_to_pair
token_to_pair.pop(token)
token_to_pair[new_token] = new_pairs

In [27]:
new_pairs

[(b'a', b'b'), (b'b', b'e'), (b'e', b're')]

In [None]:
affected_tokens = list(pair_to_token[best_pair])
merged_bytes = best_pair[0] + best_pair[1]

for token in list(affected_tokens):
    # get new token
    i=0
    new_token = []
    while(i<(len(token))):
        if (i < len(token) - 1) and (token[i] == best_pair[0]) and (token[i+1] == best_pair[1]):
            new_token.append(merged_bytes)
            i = i + 2
        else:
            new_token.append(token[i])
            i = i + 1
    new_token = tuple(new_token)


    ## update pair_to_token
    new_pairs = [(new_token[i], new_token[i+1]) for i in range(len(new_token)-1)]
    affected_pairs = token_to_pair[token]
    for pair in affected_pairs:
        pair_to_token[pair].discard(token)
    for pair in new_pairs:
        pair_to_token[pair].add(new_token)


    ## update token_to_pair
    token_to_pair.pop(token)
    token_to_pair[new_token] = new_pairs


    ## update pair_freqs
    for pair in affected_pairs:
        pair_freqs[pair] -= token_freqs[token]
    for pair in new_pairs:
        pair_freqs[pair] += token_freqs[token]


    ## update token_freqs
    origin_freq = token_freqs.pop(token)
    token_freqs[new_token] = origin_freq
    

RuntimeError: Set changed size during iteration

In [126]:
token_freqs

defaultdict(int,
            {(b'a', b'b', b'e', b're'): 1,
             (b' ', b'e', b're', b're', b'a'): 1})

In [None]:
## update affected_tokens
affected_tokens = pair_to_token[best_pair]
merged_bytes = best_pair[0] + best_pair[1]

token = (next(iter(affected_tokens)))
token = (b' ', b'e', b'r', b'e', b'r', b'e', b'a')
token

(b' ', b'e', b'r', b'e', b'r', b'e', b'a')

In [None]:
i=0
new_token = []
while(i<(len(token))):
    if (token[i] == best_pair[0]) and (token[i+1] == best_pair[1]):
        new_token.append(merged_bytes)
        i = i + 2
    else:
        new_token.append(token[i])
        i = i + 1
new_token = tuple(new_token)
new_token

(b' ', b'e', b're', b're', b'a')

In [68]:
## update pair_to_token
new_pairs = [(new_token[i], new_token[i+1]) for i in range(len(new_token)-1)]
affected_pairs = token_to_pair[token]
for pair in affected_pairs:
    pair_to_token[pair].discard(token)
for pair in new_pairs:
    pair_to_token[pair].add(new_token)
pair_to_token

defaultdict(set,
            {(b'a', b'b'): {(b'a', b'b', b'e', b'r', b'e')},
             (b'b', b'e'): {(b'a', b'b', b'e', b'r', b'e')},
             (b'e', b'r'): {(b'a', b'b', b'e', b'r', b'e')},
             (b'r', b'e'): {(b'a', b'b', b'e', b'r', b'e')},
             (b' ', b'e'): {(b' ', b'e', b're', b're', b'a')},
             (b'e', b'a'): set(),
             (b'e', b're'): {(b' ', b'e', b're', b're', b'a')},
             (b're', b're'): {(b' ', b'e', b're', b're', b'a')},
             (b're', b'a'): {(b' ', b'e', b're', b're', b'a')}})

In [69]:
pair_to_token[pair]

{(b' ', b'e', b're', b're', b'a')}

In [70]:
## update token_to_pair
token_to_pair.pop(token)
token_to_pair[new_token] = new_pairs
token_to_pair

defaultdict(set,
            {(b'a', b'b', b'e', b'r', b'e'): {(b'a', b'b'),
              (b'b', b'e'),
              (b'e', b'r'),
              (b'r', b'e')},
             (b' ', b'e', b're', b're', b'a'): [(b' ', b'e'),
              (b'e', b're'),
              (b're', b're'),
              (b're', b'a')]})

In [71]:
pair_freqs

defaultdict(int,
            {(b'a', b'b'): 1,
             (b'b', b'e'): 1,
             (b'e', b'r'): 3,
             (b'r', b'e'): 3,
             (b' ', b'e'): 1,
             (b'e', b'a'): 1})

In [72]:
affected_pairs = [(token[i], token[i+1]) for i in range(len(token)-1)]
new_pairs = [(new_token[i], new_token[i+1]) for i in range(len(new_token)-1)]

In [75]:
## update pair_freqs
for pair in affected_pairs:
    pair_freqs[pair] -= token_freqs[token]
for pair in new_pairs:
    pair_freqs[pair] += token_freqs[token]
pair_freqs

defaultdict(int,
            {(b'a', b'b'): 1,
             (b'b', b'e'): 1,
             (b'e', b'r'): 1,
             (b'r', b'e'): 1,
             (b' ', b'e'): 1,
             (b'e', b'a'): 0,
             (b'e', b're'): 1,
             (b're', b're'): 1,
             (b're', b'a'): 1})

In [18]:
token_freqs

defaultdict(int, {(b'a', b'b', b'e', b'r', b'e'): 1, (b' ',): 1})

In [19]:
## update token_freqs
origin_freq = token_freqs.pop(token)
token_freqs[new_token] = origin_freq
token_freqs

defaultdict(int, {(b' ',): 1, (b'a', b'b', b'e', b're'): 1})

[[(b'e', b're')], 1]

In [16]:
pair_freqs[best_pair] -= duplications_count*token_freqs[new_token]
for n in new_neighbors:
    pair_freqs[n] += token_freqs[new_token]
pair_freqs

defaultdict(int,
            {(b'a', b'b'): 1,
             (b'b', b'e'): 1,
             (b'e', b'r'): 5,
             (b'r', b'e'): 4,
             (b' ', b'e'): 1,
             (b'e', b'a'): 1,
             (b'e', b're'): 1})

In [None]:
## update pair_to_token and token_to_pair
pair_to_token[best_pair].discard(token)
for n in new_neighbors:
    pair_to_token[n].add(new_token)
print(pair_to_token)

defaultdict(<class 'set'>, {(b'a', b'b'): {(b'a', b'b', b'e', b'r', b'e')}, (b'b', b'e'): {(b'a', b'b', b'e', b'r', b'e')}, (b'e', b'r'): {(b'a', b'b', b'e', b'r', b'e'), (b' ', b'e', b'r', b'e', b'r', b'e', b'r', b'e', b'r', b'e', b'a')}, (b'r', b'e'): {(b'a', b'b', b'e', b'r', b'e')}, (b' ', b'e'): {(b' ', b'e', b'r', b'e', b'r', b'e', b'r', b'e', b'r', b'e', b'a')}, (b'e', b'a'): {(b' ', b'e', b'r', b'e', b'r', b'e', b'r', b'e', b'r', b'e', b'a')}, (b'e', b're'): {(b' ', b'e', b're', b're', b're', b're', b'a')}, (b're', b're'): {(b' ', b'e', b're', b're', b're', b're', b'a')}, (b're', b'a'): {(b' ', b'e', b're', b're', b're', b're', b'a')}})


In [17]:
token_to_pair

defaultdict(set,
            {(b'a', b'b', b'e', b'r', b'e'): {(b'a', b'b'),
              (b'b', b'e'),
              (b'e', b'r'),
              (b'r', b'e')},
             (b' ',
              b'e',
              b'r',
              b'e',
              b'r',
              b'e',
              b'r',
              b'e',
              b'r',
              b'e',
              b'a'): {(b' ', b'e'), (b'e', b'a'), (b'e', b'r'), (b'r', b'e')}})

In [13]:
new_token

[b'a', b'b', b'e', b're']

In [23]:
best_pair

(b'r', b'e')

In [None]:
## update pair_freqs and pair_to_token
affected_tokens = pair_to_token[best_pair]
merged_bytes = best_pair[0]+best_pair[1]
for token in affected_tokens:
    new_neighbors = []
    i=0
    duplications_count = 0
    while(i<len(token)-len(merged_bytes)+1):
        if token[i:(i+len(merged_bytes))] == merged_bytes:
            duplications_count += 1
            if i != 0:
                new_neighbors.append((bytes([token[i-1]]), merged_bytes))
            if i + len(merged_bytes) != len(token):
                new_neighbors.append((merged_bytes, bytes([token[i + len(merged_bytes)]])))
            i += len(merged_bytes)
        else:
            i += 1
    pair_freqs[best_pair] -= duplications_count*word_freqs[token]
    for pair in new_neighbors:
        pair_freqs[pair] += word_freqs[token]
        pair_to_token[pair].add(token)


0
1
2
4
6
8
0
1
2
3


defaultdict(int,
            {(b'a', b'b'): 1,
             (b'b', b'e'): 1,
             (b'e', b'r'): 5,
             (b'r', b'e'): 0,
             (b' ', b'e'): 1,
             (b'e', b're'): 5,
             (b're', b'r'): 3})

In [None]:
affected_token = next(iter(affected_tokens))

duplications_count = 0
while(i<len(affected_token)-len(merged_bytes)+1):
    print(i)
    if affected_token[i:(i+len(merged_bytes))] == merged_bytes:
        duplications_count += 1
        if i != 0:
            new_neighbors.append((bytes([affected_token[i-1]]), merged_bytes))
        if i + len(merged_bytes) != len(affected_token):
            new_neighbors.append((merged_bytes, bytes([affected_token[i + len(merged_bytes)]])))
        i += len(merged_bytes)
    else:
        i += 1
new_neighbors

{b' erererere', b'abere'}

In [41]:
duplications_count

4

In [None]:
pair_freqs[best_pair] -= duplications_count*word_freqs[affected_token]
for pair in new_neighbors:
    pair_freqs[pair] += word_freqs[affected_token]
    pair_to_token[pair].add(affected_token)


In [None]:
pair_to_token

defaultdict(set,
            {(b'a', b'b'): {b'abere'},
             (b'b', b'e'): {b'abere'},
             (b'e', b'r'): {b' erererere', b'abere'},
             (b'r', b'e'): {b' erererere', b'abere'},
             (b' ', b'e'): {b' erererere'},
             (b'e', b're'): {b' erererere'},
             (b're', b'r'): {b' erererere'}})

In [44]:
pair_freqs

defaultdict(int,
            {(b'a', b'b'): 1,
             (b'b', b'e'): 1,
             (b'e', b'r'): 5,
             (b'r', b'e'): 1,
             (b' ', b'e'): 1,
             (b'e', b're'): 4,
             (b're', b'r'): 3})