# CLIP Tokenizier

> Fill in a module description here

In [None]:
#| default_exp clip.tokenizier

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import re
import os
import gzip
from functools import lru_cache
from collections import namedtuple

import transformers
from fastcore.test import test_eq

  from .autonotebook import tqdm as notebook_tqdm


##### How to build a CLIP Tokenizier??
- Step 1: Define the vocabulary of symbols that the BPE tokenizer will use to represent the text
- Step 2: Takes a string of text as input and splits it into a list of symbols
- Step 3: Calculates the frequency of each symbol pair in the text
- Step 4: Sort the symbol pairs by frequency, with the most frequent pairs appearing first
- Step 5: Implement a loop that repeatedly merges the most frequent symbol pair until a stopping condition is reached. This could be a fixed number of merges, or it could be based on the frequency of the symbol pairs (e.g., stop when the frequency of the most frequent pair drops below a certain threshold)
- Step 6: As each symbol pair is merged, update the list of symbols and the symbol pair frequencies to reflect the changes.
- Step 7: When the loop is finished, the resulting list of symbols will be the BPE vocabulary.
- Step 8: Use the BPE vocabulary to encode text by replacing each symbol pair in the text with a single symbol from the vocabulary.



In [None]:
#| export
def split_text(text):
  # Compile a regular expression pattern to match any sequence of non-whitespace characters
  pattern = re.compile(r'\S+')
  
  # Use the pattern to split the text into a list of symbols
  symbols = pattern.findall(text)
  
  return symbols

In [None]:
text = "The quick brown fox jumps over the lazy dog."
symbols = split_text(text)
symbols

['The', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog.']

In [None]:
#| hide
test_eq(symbols, ['The', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog.'])

In [None]:
#| export
def calculate_pair_frequencies():
    pass

In [None]:
def calculate_pair_frequencies(symbols):
  # Create an empty dictionary to store the symbol pair frequencies
  pair_frequencies = {}
  
  # Iterate over the symbols and count the number of times each symbol pair appears in the text
  for i in range(len(symbols) - 1):
    pair = (symbols[i], symbols[i+1])
    if pair in pair_frequencies:
      pair_frequencies[pair] += 1
    else:
      pair_frequencies[pair] = 1
  
  return pair_frequencies


In [None]:
symbols = ['The', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog.']
pair_frequencies = calculate_pair_frequencies(symbols)
print(pair_frequencies)


{('The', 'quick'): 1, ('quick', 'brown'): 1, ('brown', 'fox'): 1, ('fox', 'jumps'): 1, ('jumps', 'over'): 1, ('over', 'the'): 1, ('the', 'lazy'): 1, ('lazy', 'dog.'): 1}


### CLIP Tokenizer

In [None]:
#| export
TOKEN_LENGTH = 77

Temporary, i will steal the clip tokenizier from tinygrad. Will implement it from scratch later

In [None]:
#| export
@lru_cache()
def default_bpe():
  return os.path.join(os.path.dirname(os.path.abspath(__file__)), "../weights/bpe_simple_vocab_16e6.txt.gz")
  # return os.path.join(os.path.dirname(os.path.abspath(".")), "../weights/bpe_simple_vocab_16e6.txt.gz")

In [None]:
#| export
def get_pairs(word):
  """Return set of symbol pairs in a word.
  Word is represented as tuple of symbols (symbols being variable-length strings).
  """
  pairs = set()
  prev_char = word[0]
  for char in word[1:]:
    pairs.add((prev_char, char))
    prev_char = char
  return pairs

def whitespace_clean(text):
  text = re.sub(r'\s+', ' ', text)
  text = text.strip()
  return text

In [None]:
#| export
def bytes_to_unicode():
  """
  Returns list of utf-8 byte and a corresponding list of unicode strings.
  The reversible bpe codes work on unicode strings.
  This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
  When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
  This is a signficant percentage of your normal, say, 32K bpe vocab.
  To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
  And avoids mapping to whitespace/control characters the bpe code barfs on.
  """
  bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
  cs = bs[:]
  n = 0
  for b in range(2**8):
    if b not in bs:
      bs.append(b)
      cs.append(2**8+n)
      n += 1
  cs = [chr(n) for n in cs]
  return dict(zip(bs, cs))

In [None]:
# #| export
# class ClipTokenizer:
#   def __init__(self, bpe_path: str = default_bpe()):
#     self.byte_encoder = bytes_to_unicode()
#     merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
#     merges = merges[1:49152-256-2+1]
#     merges = [tuple(merge.split()) for merge in merges]
#     vocab = list(bytes_to_unicode().values())
#     vocab = vocab + [v+'</w>' for v in vocab]
#     for merge in merges:
#       vocab.append(''.join(merge))
#     vocab.extend(['<|startoftext|>', '<|endoftext|>'])
#     self.encoder = dict(zip(vocab, range(len(vocab))))
#     self.bpe_ranks = dict(zip(merges, range(len(merges))))
#     self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
#     self.pat = self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)

#   def bpe(self, token):
#     if token in self.cache:
#       return self.cache[token]
#     word = tuple(token[:-1]) + ( token[-1] + '</w>',)
#     pairs = get_pairs(word)

#     if not pairs:
#       return token+'</w>'

#     while True:
#       bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
#       if bigram not in self.bpe_ranks:
#         break
#       first, second = bigram
#       new_word = []
#       i = 0
#       while i < len(word):
#         try:
#           j = word.index(first, i)
#           new_word.extend(word[i:j])
#           i = j
#         except Exception:
#           new_word.extend(word[i:])
#           break

#         if word[i] == first and i < len(word)-1 and word[i+1] == second:
#           new_word.append(first+second)
#           i += 2
#         else:
#           new_word.append(word[i])
#           i += 1
#       new_word = tuple(new_word)
#       word = new_word
#       if len(word) == 1:
#         break
#       else:
#         pairs = get_pairs(word)
#     word = ' '.join(word)
#     self.cache[token] = word
#     return word

#   def encode(self, text):
#     bpe_tokens = []
#     text = whitespace_clean(text.strip()).lower()
#     for token in re.findall(self.pat, text):
#       token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
#       bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
#     # Truncation, keeping two slots for start and end tokens.
#     if len(bpe_tokens) > 75:
#       bpe_tokens = bpe_tokens[:75]
#     return [49406] + bpe_tokens + [49407] * (77 - len(bpe_tokens) - 1)

In [None]:
# tokenizier = ClipTokenizer()

In [None]:
# prompt = "persistence is all you need"
# tokenizier.encode(prompt)

In [None]:
#| export
class CLIPTokenizer:
    def __init__(self):
        self.tokenizer = transformers.CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
    
    @property
    def model_max_length(self):
        return self.tokenizer.model_max_length
    
    def encode(self, prompt: str):
        # return self.tokenizer(prompt)
        return self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
    
    def __call__(self, *args, **kwargs):
        return self.tokenizer(*args, **kwargs)

In [None]:
tokenizier = CLIPTokenizer()

In [None]:
tokenizier.model_max_length

77

In [None]:
output = tokenizier.encode("hello world")

In [None]:
output['input_ids'].shape

torch.Size([1, 77])