## Start

In [1]:
import regex as re

from collections import defaultdict

In [2]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

## 1 Assignment Overview

## 2 Byte-Pair Encoding (BPE) Tokenizer

### 2.1 The Unicode Standard

In [2]:
[ord('牛'), chr(29275)]

[29275, '牛']

#### Problem (unicode1)

##### a

In [None]:
chr(0)

'\x00'

##### b

In [None]:
print(chr(0))

 


##### c

In [8]:
"this is a test" + chr(0) + "string"

'this is a test\x00string'

In [9]:
print("this is a test" + chr(0) + "string")

this is a test string


### 2.2 Unicode Encodings

In [10]:
test_string = "hello! こんにちは!"
utf8_encoded = test_string.encode("utf-8")
print(utf8_encoded)

b'hello! \xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf!'


In [11]:
print(type(utf8_encoded))

<class 'bytes'>


In [12]:
list(utf8_encoded)

[104,
 101,
 108,
 108,
 111,
 33,
 32,
 227,
 129,
 147,
 227,
 130,
 147,
 227,
 129,
 171,
 227,
 129,
 161,
 227,
 129,
 175,
 33]

In [13]:
[len(test_string),len(utf8_encoded)]

[13, 23]

In [14]:
print(utf8_encoded.decode("utf-8"))

hello! こんにちは!


#### Problem (unicode2)

##### a

##### b

In [15]:
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])
decode_utf8_bytes_to_str_wrong("hello".encode("utf-8"))

'hello'

In [16]:
decode_utf8_bytes_to_str_wrong("hello, 你好".encode("utf-8"))

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe4 in position 0: unexpected end of data

##### c

In [18]:
sentence = "hello, 你好".encode("utf-8")
[bytes([b]) for b in sentence]

[b'h',
 b'e',
 b'l',
 b'l',
 b'o',
 b',',
 b' ',
 b'\xe4',
 b'\xbd',
 b'\xa0',
 b'\xe5',
 b'\xa5',
 b'\xbd']

In [24]:
(b'\xe4\xbd\xa0').decode("utf-8")

'你'

### 2.3 Subword Tokenization

### 2.4 BPE Tokenizer Training

In [27]:
re.findall(PAT, "some text that i'll pre-tokenize")

['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize']

In [28]:
max([("A", "B"), ("A", "C"), ("B", "ZZ"), ("BA", "A")])

('BA', 'A')

In [31]:
test_string = """low low low low low lower lower widest widest widest newest newest newest newest newest newest"""

In [32]:
re.findall(PAT, test_string)

['low',
 ' low',
 ' low',
 ' low',
 ' low',
 ' lower',
 ' lower',
 ' widest',
 ' widest',
 ' widest',
 ' newest',
 ' newest',
 ' newest',
 ' newest',
 ' newest',
 ' newest']

### 2.5 Experimenting with BPE Tokenizer Training

In [3]:
test_string = "abere erererere<|endoftext|>When and where is not as important as who and what. Hi, I am the the Ivan.<|endoftext|> aaa"
input_path: str = ""
vocab_size: int = 260
special_tokens: list[str] = ['<|endoftext|>']

vocab: dict[int, bytes] = {i:bytes([i]) for i in range(256)}
vocab_inverse: dict[int, bytes] = {v:k for k,v in vocab.items()}
merges: list[tuple[bytes, bytes]] = []

- [x] vocab
- [x] vocab_inverse 
- [x] merges 
- [x] pair_freqs
- [x] pair_to_tokens

In [4]:
new_id = 256

In [5]:
special_pat = "|".join(re.escape(st) for st in special_tokens)
segments = re.split(special_pat, test_string)
text_segment = segments[0]
matches = re.finditer(PAT, text_segment)
word_freqs = defaultdict(int)
pair_freqs = defaultdict(int)
pair_to_tokens = defaultdict(set)
text_segment

'abere erererere'

In [6]:
for match in matches:
    token = match.group()
    token_bytes = token.encode("utf-8")
    word_freqs[token_bytes] += 1
    if len(token_bytes)>1:
        for i in range(1, len(token_bytes)):
            pair = (bytes([token_bytes[i-1]]), bytes([token_bytes[i]]))
            pair_freqs[pair] += 1
            pair_to_tokens[pair].add(token_bytes)
pair_freqs

defaultdict(int,
            {(b'a', b'b'): 1,
             (b'b', b'e'): 1,
             (b'e', b'r'): 5,
             (b'r', b'e'): 5,
             (b' ', b'e'): 1})

In [7]:
### find most frequent pair
best_pair = max(pair_freqs.keys(), key=lambda k: (pair_freqs[k], k))
vocab[new_id] = best_pair
vocab_inverse[best_pair] = new_id
merges.append(best_pair)
best_pair

(b'r', b'e')

In [9]:
affected_tokens = pair_to_tokens[best_pair]
merged_bytes = best_pair[0] + best_pair[1]

token = (next(iter(affected_tokens)))
token

b'abere'

In [12]:
merged_bytes

b're'

In [20]:
new_neighbors = []
i=0
duplications_count = 0
new_token = []

In [None]:
while(i<(len(token)-1)):
    print(token[i])
    print(token[i+1])
    if (bytes([token[i]]) == best_pair[0]) and (bytes([token[i+1]]) == best_pair[1]):
        print('yes')
        new_token.append(merged_bytes)
        i = i + 2
        duplications_count += 1
    else:
        new_token.append(bytes([token[i]]))
        i = i + 1

97
98
98
101
101
114
114
101


[97, 98, 101, 114]

In [24]:
token

b'abere'

In [23]:
best_pair

(b'r', b'e')

In [None]:
## update pair_freqs and pair_to_tokens
affected_tokens = pair_to_tokens[best_pair]
merged_bytes = best_pair[0]+best_pair[1]
for token in affected_tokens:
    new_neighbors = []
    i=0
    duplications_count = 0
    while(i<len(token)-len(merged_bytes)+1):
        if token[i:(i+len(merged_bytes))] == merged_bytes:
            duplications_count += 1
            if i != 0:
                new_neighbors.append((bytes([token[i-1]]), merged_bytes))
            if i + len(merged_bytes) != len(token):
                new_neighbors.append((merged_bytes, bytes([token[i + len(merged_bytes)]])))
            i += len(merged_bytes)
        else:
            i += 1
    pair_freqs[best_pair] -= duplications_count*word_freqs[token]
    for pair in new_neighbors:
        pair_freqs[pair] += word_freqs[token]
        pair_to_tokens[pair].add(token)


0
1
2
4
6
8
0
1
2
3


defaultdict(int,
            {(b'a', b'b'): 1,
             (b'b', b'e'): 1,
             (b'e', b'r'): 5,
             (b'r', b'e'): 0,
             (b' ', b'e'): 1,
             (b'e', b're'): 5,
             (b're', b'r'): 3})

In [None]:
affected_token = next(iter(affected_tokens))

duplications_count = 0
while(i<len(affected_token)-len(merged_bytes)+1):
    print(i)
    if affected_token[i:(i+len(merged_bytes))] == merged_bytes:
        duplications_count += 1
        if i != 0:
            new_neighbors.append((bytes([affected_token[i-1]]), merged_bytes))
        if i + len(merged_bytes) != len(affected_token):
            new_neighbors.append((merged_bytes, bytes([affected_token[i + len(merged_bytes)]])))
        i += len(merged_bytes)
    else:
        i += 1
new_neighbors

{b' erererere', b'abere'}

In [41]:
duplications_count

4

In [None]:
pair_freqs[best_pair] -= duplications_count*word_freqs[affected_token]
for pair in new_neighbors:
    pair_freqs[pair] += word_freqs[affected_token]
    pair_to_tokens[pair].add(affected_token)


In [None]:
pair_to_tokens

defaultdict(set,
            {(b'a', b'b'): {b'abere'},
             (b'b', b'e'): {b'abere'},
             (b'e', b'r'): {b' erererere', b'abere'},
             (b'r', b'e'): {b' erererere', b'abere'},
             (b' ', b'e'): {b' erererere'},
             (b'e', b're'): {b' erererere'},
             (b're', b'r'): {b' erererere'}})

In [44]:
pair_freqs

defaultdict(int,
            {(b'a', b'b'): 1,
             (b'b', b'e'): 1,
             (b'e', b'r'): 5,
             (b'r', b'e'): 1,
             (b' ', b'e'): 1,
             (b'e', b're'): 4,
             (b're', b'r'): 3})