In [None]:
def count_chars(input_str):
    counts = {}

    for c in input_str:
        counts[c] = counts.get(c, 0) + 1

    return counts

print(count_chars('hello there'))

range(0, 10)
<class 'range'>
5
6
7
8
9
10
11
12
13
14
5
6
7
8
9
10
11
12
13
14


In [None]:
def replace_pairs(lst, target_pair, combined_id):
    replaced_ids = []

    i = 0
    while i < len(lst):
        if tuple(lst[i:i+2]) == target_pair:
            replaced_ids.append(combined_id)
            i += 2
        else:
            replaced_ids.append(lst[i])
            i += 1

    return replaced_ids


print(replace_pairs([1, 2, 3, 1, 2], (1, 2), 4))

In [None]:
from collections import Counter

def count_pairs(nums):
    counts = Counter()

    for pair in zip(nums, nums[1:]):
        counts[pair] = counts.get(pair, 0) + 1

    return counts

prompt = 'rug pug hug pun bun hugs run gun'

string_bytes = list(prompt.encode('utf-8'))
vocab = {idx: bytes([idx]) for idx in range(256)}

print(string_bytes)

pairs = count_pairs(string_bytes)
print(pairs)

merged_string = replace_pairs(string_bytes, pairs.most_common()[0][0], 257)
print(merged_string)

In [None]:
class Cat():
    def __init__(self, color):
        self.color = color

    def purr(self):
        print('puurrrrrr')

    def meow(self):
        print('meow')

    def get_color(self):
        print(self.color)

In [None]:
"""
Minimal (byte-level) Byte Pair Encoding tokenizer.

Algorithmically follows along the GPT tokenizer:
https://github.com/openai/gpt-2/blob/master/src/encoder.py
"""

class Simple_Tokenizer():
    def __init__(self):
        self.merges = {} # (int, int) -> int

        self.special_tokens = {'<|sos|>': 256, '<|eos|>': 257} # str -> int, e.g. {'<|endoftext|>': 100257}
        self.vocab = {idx: bytes([idx]) for idx in range(256)} | {idx: special.encode("utf-8") for special, idx in self.special_tokens.items()}

    def _count_pairs(self, ids, counts=None):
        """
        Given a list of integers, return a dictionary of counts of consecutive pairs
        Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
        Optionally allows to update an existing dictionary of counts
        """
        counts = {} if counts is None else counts

        for pair in zip(ids, ids[1:]): # iterate consecutive elements
            counts[pair] = counts.get(pair, 0) + 1

        return counts

    def _merge_pairs(self, ids, pair, idx):
        """
        In the list of integers (ids), replace all consecutive occurrences
        of pair with the new integer token idx
        Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
        """
        newids = []
        i = 0

        while i < len(ids):
            # if not at the very last position AND the pair matches, replace it
            if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
                newids.append(idx)
                i += 2
            else:
                newids.append(ids[i])
                i += 1

        return newids

    def train(self, text, max_vocab_size):
        vocab_size = len(self.vocab)
        num_merges = max_vocab_size - vocab_size

        # input text preprocessing
        text_bytes = text.encode("utf-8") # raw bytes
        ids = list(text_bytes) # list of integers in range 0..255

        for i in range(num_merges):
            # count up the number of times every consecutive pair appears
            stats = self._count_pairs(ids)

            # find the pair with the highest count
            pair = max(stats, key=stats.get)

            # mint a new token: assign it the next available id
            idx = vocab_size + i

            # replace all occurrences of pair in ids with idx
            ids = self._merge_pairs(ids, pair, idx)

            # save the merge
            self.merges[pair] = idx
            self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]

    def decode(self, ids):
        # given ids (list of integers), return Python string
        text_bytes = b"".join(self.vocab[idx] for idx in ids)
        text = text_bytes.decode("utf-8", errors="replace")
        return text

    def encode(self, text):
        import re
        special_pattern = "(" + "|".join(re.escape(k) for k in self.special_tokens) + ")"
        special_chunks = re.split(special_pattern, text)

        encoded_str = []
        for chunk in special_chunks[1:-1]:
            if chunk in self.special_tokens:
                # this is a special token, encode it separately as a special case
                encoded_str.append(self.special_tokens[chunk])
            else:
                # given a string text, return the token ids
                text_bytes = chunk.encode("utf-8") # raw bytes
                chunk_ids = list(text_bytes) # list of integers in range 0..255

                while len(chunk_ids) >= 2:
                    # find the pair with the lowest merge index
                    counted_pairs = self._count_pairs(chunk_ids)
                    earliest_pair = min(counted_pairs, key=lambda p: self.merges.get(p, float("inf")))

                    # just the first pair in the list, arbitrarily
                    # we can detect this terminating case by a membership check
                    if earliest_pair not in self.merges:
                        break # nothing else can be merged anymore

                    # otherwise let's merge the best pair (lowest merge index)
                    pair_idx = self.merges[earliest_pair]
                    chunk_ids = self._merge_pairs(chunk_ids, earliest_pair, pair_idx)

                encoded_str += chunk_ids

        return encoded_str

    def visualize_tokenization(self, ids):
        """Small helper function useful in debugging: visualize the tokenization of render_conversation"""
        GREEN = '\033[92m'
        RESET = '\033[0m'
        GRAY = '\033[90m'

        tokens = []
        for token_id in ids:
            token_str = self.decode([token_id])
            tokens.append(f"{GREEN}{token_str}{GRAY}({token_id}){RESET}")

            if token_str in ['<|sos|>', '<|eos|>']:
                tokens.append('\n\n\t')

        return ' | '.join(tokens)