# Homework 2 - Solution

Implement the train, decode, and encode functions of the tokenizer below.
The tokenizer should use the GPT2_SPLIT_PATTERN and treat the special token <|endoftext|> appropriately.

In [1]:
import regex as re


class Tokenizer:
    """Base class for Tokenizers"""

    def __init__(self, vocab_size):
        assert vocab_size >= 256
        # default: vocab size of 256 (all bytes), no merges, no patterns
        self.merges = {} # (int, int) -> int
        self.pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" # GPT2_SPLIT_PATTERN
        self.special_tokens = {'<|endoftext|>': vocab_size}
        self.vocab_size = vocab_size
        self.vocab = self._build_vocab() # int -> bytes

    def train(self, text, verbose=False):
        # Tokenizer can train a vocabulary of size vocab_size from text
        
        num_merges = self.vocab_size - 256
        
        split_text = re.findall(self.pattern, text) # GPT2_split_pattern
        tokens = [list(ch.encode("utf-8")) for ch in split_text]        
        
        for i in range(num_merges):
            # count the number of times every consecutive pair appears
            counts = {}
            for splits in tokens:
                # passing in counts will update it in place, adding up the counts
                self._get_frequencies(splits, counts)
            # find the pair with the highest count
            pair = max(counts, key=counts.get)
            # assign a new token to the next available id
            idx = 256 + i
            # replace all occurrences of pair in tokens with idx
            tokens = [self._merge(splits, pair, idx) for splits in tokens]
            # save the merge
            self.merges[pair] = idx
            self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]
            # prints
            if verbose:
                print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({self.vocab[idx]}) had {counts[pair]} occurrences")

    def encode(self, text):
        """
        Unlike encode_ordinary, this function handles special tokens.
        """
        # encode the user desire w.r.t. handling of special tokens
        special = self.special_tokens

        # we handle special tokens by splitting the text
        # based on the occurrence of any exact match with any of the special tokens
        # we can use re.split for this. note that surrounding the pattern with ()
        # makes it into a capturing group, so the special tokens will be included
        special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
        special_chunks = re.split(special_pattern, text)
        # now all the special characters are separated from the rest of the text
        # all chunks of text are encoded separately, then results are joined
        ids = []
        for part in special_chunks:
            if part in special:
                # this is a special token, encode it separately as a special case
                ids.append(special[part])
            else:
                # this is an ordinary sequence, encode it normally
                ids.extend(self._encode_ordinary(part))
        return ids


    def decode(self, ids):
        # given ids (list of integers), return Python string
        part_bytes = []
        for idx in ids:
            if idx in self.vocab:
                part_bytes.append(self.vocab[idx])
            else:
                raise ValueError(f"invalid token id: {idx}")
        text_bytes = b"".join(part_bytes)
        text = text_bytes.decode("utf-8", errors="replace")
        return text
    
    def _build_vocab(self):
        # vocab is simply and deterministically derived from merges
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        for special, idx in self.special_tokens.items():
            vocab[idx] = special.encode("utf-8")
        return vocab
    
    def _get_frequencies(self,seq, counts=None):
        counts = {} if counts is None else counts
        for pair in zip(seq,seq[1:]): # iterate over consecutive elements
            counts[pair] = counts.get(pair,0) + 1
        return counts
    
    def _merge(self,seq,pair,index):
        new_seq = []
        i = 0
        while i < len(seq):
            if seq[i:i+2] == list(pair) and i < len(seq) - 1:
                new_seq.append(index)
                i += 2
            else:
                new_seq.append(seq[i])
                i += 1
        return new_seq
    
    def _encode_chunk(self, text_bytes):
        # return the token ids
        # let's begin. first, convert all bytes to integers in range 0..255
        ids = list(text_bytes)
        while len(ids) >= 2:
            # find the pair with the lowest merge index
            stats = self._get_frequencies(ids)
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
            # subtle: if there are no more merges available, the key will
            # result in an inf for every single pair, and the min will be
            # just the first pair in the list, arbitrarily
            # we can detect this terminating case by a membership check
            if pair not in self.merges:
                break # nothing else can be merged anymore
            # otherwise let's merge the best pair (lowest merge index)
            idx = self.merges[pair]
            ids = self._merge(ids, pair, idx)
        return ids

    def _encode_ordinary(self, text):
        """Encoding that ignores any special tokens."""
        # split text into chunks of text by categories defined 
        text_chunks = re.findall(self.pattern, text)
        # all chunks of text are encoded separately, then results are joined
        ids = []
        for chunk in text_chunks:
            chunk_bytes = chunk.encode("utf-8") # raw bytes
            chunk_ids = self._encode_chunk(chunk_bytes)
            ids.extend(chunk_ids)
        return ids    

In [2]:
training_text = "Byte pair encoding[1][2] (also known as digram coding)[3] is an algorithm, first described in 1994 by Philip Gage for encoding strings of text into tabular form for use in downstream modeling."

In [3]:
tokenizer = Tokenizer(270)

tokenizer.train(training_text, verbose=True)

merge 1/14: (105, 110) -> 256 (b'in') had 8 occurrences
merge 2/14: (256, 103) -> 257 (b'ing') had 5 occurrences
merge 3/14: (111, 100) -> 258 (b'od') had 4 occurrences
merge 4/14: (111, 114) -> 259 (b'or') had 4 occurrences
merge 5/14: (32, 102) -> 260 (b' f') had 4 occurrences
merge 6/14: (99, 258) -> 261 (b'cod') had 3 occurrences
merge 7/14: (261, 257) -> 262 (b'coding') had 3 occurrences
merge 8/14: (32, 97) -> 263 (b' a') had 3 occurrences
merge 9/14: (32, 100) -> 264 (b' d') had 3 occurrences
merge 10/14: (115, 116) -> 265 (b'st') had 3 occurrences
merge 11/14: (32, 256) -> 266 (b' in') had 3 occurrences
merge 12/14: (260, 259) -> 267 (b' for') had 3 occurrences
merge 13/14: (116, 101) -> 268 (b'te') had 2 occurrences
merge 14/14: (105, 114) -> 269 (b'ir') had 2 occurrences


In [4]:
dict(list(tokenizer.vocab.items())[256:]) #the new learnt vocab

{270: b'<|endoftext|>',
 256: b'in',
 257: b'ing',
 258: b'od',
 259: b'or',
 260: b' f',
 261: b'cod',
 262: b'coding',
 263: b' a',
 264: b' d',
 265: b'st',
 266: b' in',
 267: b' for',
 268: b'te',
 269: b'ir'}

In [5]:
tokenizer.decode(tokenizer.encode(training_text))

'Byte pair encoding[1][2] (also known as digram coding)[3] is an algorithm, first described in 1994 by Philip Gage for encoding strings of text into tabular form for use in downstream modeling.'