In [None]:
def create_matrix(rows, cols):
    matrix = []

    for r in range(rows):
        matrix.append([])
        for i in range(cols):
            matrix[r].append(0)

    return matrix

print(create_matrix(2, 2))

In [None]:
def create_matrix(rows, cols):
    return [[0] * cols for r in range(rows)]

print(create_matrix(2, 2))

In [None]:
from collections import Counter

def _count_pairs(tokens):
    pair_counts = Counter()

    for pair in zip(tokens, tokens[1:]):
        pair_counts[pair] += 1

    return pair_counts

def _merge_pair(tokens, target_pair, combined_id):
    merged_tokens = []

    i = 0
    while i < len(tokens):
        if i < len(tokens) - 1 and (tokens[i], tokens[i+1]) == target_pair:
            merged_tokens.append(combined_id)
            i += 2
        else:
            merged_tokens.append(tokens[i])
            i += 1

    return merged_tokens

prompt = 'rug pug hug pun bun hugs run gun bug'

string_bytes = list(prompt.encode('utf-8'))
print('String ASCII Values (bytes):', string_bytes)

vocab = {idx: bytes([idx]) for idx in range(256)}
print('Initial Vocab:', vocab)

pairs = _count_pairs(string_bytes)
print('Pairs and Counts:', pairs)

merged_string = _merge_pair(string_bytes, (117, 103), 257)
print('After First Merge:', merged_string)

merged_string = _merge_pair(merged_string, (117, 110), 258)
print('After Second Merge:', merged_string)

In [None]:
class Cat():
    def __init__(self, color):
        self.color = color

    def purr(self):
        print('puurrrrrr')

    def meow(self):
        print('meow')

    def get_color(self):
        print(self.color)

In [None]:
"""
Minimal (byte-level) Byte Pair Encoding tokenizer.

Algorithmically follows along the GPT tokenizer:
https://github.com/openai/gpt-2/blob/master/src/encoder.py
"""

class Simple_Tokenizer():
    def __init__(self):
        self.merges = {} # (int, int) -> int

        self.special_tokens = {'<|sos|>': 256, '<|eos|>': 257} # str -> int, e.g. {'<|endoftext|>': 100257}
        self.vocab = {idx: bytes([idx]) for idx in range(256)} | {idx: special.encode("utf-8") for special, idx in self.special_tokens.items()}


    def _count_pairs(self, tokens, counts=None):
        pair_counts = Counter()

        for pair in zip(tokens, tokens[1:]):
            pair_counts[pair] += 1

        return pair_counts

    def _merge_pair(self, tokens, target_pair, combined_id):
        merged_tokens = []

        i = 0
        while i < len(tokens):
            if i < len(tokens) - 1 and (tokens[i], tokens[i+1]) == target_pair:
                merged_tokens.append(combined_id)
                i += 2
            else:
                merged_tokens.append(tokens[i])
                i += 1

        return merged_tokens

    def train(self, text, max_vocab_size):
        vocab_size = len(self.vocab)
        num_merges = max_vocab_size - vocab_size

        text_bytes = text.encode('utf-8')
        tokens = list(text_bytes)

        for i in range(num_merges):
            pairs = self._count_pairs(ids)

            pair = pairs.most_common()[0][0]

            merged_token_id = vocab_size + i

            ids = self._merge_pair(tokens, pair, merged_token_id)

            self.merges[pair] = merged_token_id
            self.vocab[merged_token_id] = self.vocab[pair[0]] + self.vocab[pair[1]]

    def decode(self, ids):
        # given ids (list of integers), return Python string
        text_bytes = b"".join(self.vocab[idx] for idx in ids)
        text = text_bytes.decode("utf-8", errors="replace")
        return text

    def encode(self, text):
        import re
        special_pattern = "(" + "|".join(re.escape(k) for k in self.special_tokens) + ")"
        special_chunks = re.split(special_pattern, text)

        encoded_str = []
        for chunk in special_chunks[1:-1]:
            if chunk in self.special_tokens:
                # this is a special token, encode it separately as a special case
                encoded_str.append(self.special_tokens[chunk])
            else:
                # given a string text, return the token ids
                text_bytes = chunk.encode("utf-8") # raw bytes
                chunk_ids = list(text_bytes) # list of integers in range 0..255

                while len(chunk_ids) >= 2:
                    # find the pair with the lowest merge index
                    counted_pairs = self._count_pairs(chunk_ids)
                    earliest_pair = min(counted_pairs, key=lambda p: self.merges.get(p, float("inf")))

                    # just the first pair in the list, arbitrarily
                    # we can detect this terminating case by a membership check
                    if earliest_pair not in self.merges:
                        break # nothing else can be merged anymore

                    # otherwise let's merge the best pair (lowest merge index)
                    pair_idx = self.merges[earliest_pair]
                    chunk_ids = self._merge_pairs(chunk_ids, earliest_pair, pair_idx)

                encoded_str += chunk_ids

        return encoded_str