In [1]:
%cd "/gscratch/xlab/alisaliu/hack-tokenizers"

/mmfs1/gscratch/xlab/alisaliu/hack-tokenizers


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [1]:
from utils import read_tokenizer_json, ensure_dir
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict, Counter
from tokenizers import Tokenizer
import seaborn as sns
import numpy as np

In [2]:
tokenizer_dir = Path('tokenizer_json/llm_tokenizers')

In [3]:
model_name = 'qwen'
tokenizer_json = read_tokenizer_json(tokenizer_dir / f'{model_name}_tokenizer.json')

In [4]:
# check if merges are unique
rhs = []
for m in tokenizer_json['merges']:
    rhs.append(''.join(m.split(' ')))
print(f'All merges unique: {len(rhs) == len(set(rhs))}')
print(f'Fraction of non-redundant merges: {len(set(rhs))/len(rhs)}')

All merges unique: True
Fraction of non-redundant merges: 1.0


# If all merges are unique (no cleaning needed)

In [5]:
merges = tokenizer_json['merges']
ensure_dir(f'experiments/llm_tokenizers/{model_name}')
with open(f'experiments/llm_tokenizers/{model_name}/merges.txt', 'w') as fin:
    fin.write('#version: 0.2\n')
    fin.writelines('\n'.join(merges))

# Else: clean redundant merges

In [6]:
token_to_rank = {}
for i, m in tqdm(enumerate(tokenizer_json['merges'])):
    merged_token = ''.join(m.split(' '))
    token_to_rank[merged_token] = i

61249it [00:00, 1051021.27it/s]


In [7]:
token_to_merges = defaultdict(list)
for m in tokenizer_json['merges']:
    if model_name == 'gemma':
        if any([all([x == y for x in ''.join(m.split(' '))]) for y in ['\t', '\n', '▁']]):
            continue
    token = ''.join(m.split(' '))
    token_to_merges[token].append(m)

In [8]:
token_to_merges

defaultdict(list,
            {'▁t': ['▁ t'],
             'er': ['e r'],
             'in': ['i n'],
             '▁a': ['▁ a'],
             'en': ['e n'],
             'on': ['o n'],
             '▁th': ['▁t h', '▁ th'],
             'es': ['e s'],
             '▁s': ['▁ s'],
             '▁d': ['▁ d'],
             'at': ['a t'],
             'or': ['o r'],
             'an': ['a n'],
             '▁c': ['▁ c'],
             'is': ['i s'],
             're': ['r e'],
             'it': ['i t'],
             '▁the': ['▁t he', '▁th e', '▁ the'],
             'ar': ['a r'],
             'le': ['l e'],
             '▁w': ['▁ w'],
             '▁p': ['▁ p'],
             'ou': ['o u'],
             'al': ['a l'],
             '▁f': ['▁ f'],
             '▁m': ['▁ m'],
             'ed': ['e d'],
             '▁o': ['▁ o'],
             '▁b': ['▁ b'],
             'om': ['o m'],
             'ion': ['io n', 'i on'],
             'ing': ['in g', 'i ng'],
             'ic': ['i c'],
      

In [9]:
def get_merge_ranks(token, token_to_used_merge):
    """
    Given a token and the dictionary token_to_used_merge which contains the last merge for forming each token,
    return a list containing all merge ranks used in the formation of token.
    """
    if token not in token_to_rank:
        return [-1]
    if token not in token_to_used_merge:  # this token hasn't even been formed by prior merge rules yet!
        return [float('inf')]
    l, r = token_to_used_merge[token].split(' ')
    ranks = get_merge_ranks(l, token_to_used_merge) + get_merge_ranks(r, token_to_used_merge)
    ranks.append(token_to_rank[token])
    return ranks

In [10]:
def clean_up_symmetric_merges(rank):
    """
    Break ties by putting the longer token on the left.
    """
    token_pairs = []
    to_remove = []
    for m in rank:
        l, r = m.split(' ')
        unordered_pair = set([l, r])
        if unordered_pair in token_pairs:
            if len(l) > len(r):
                to_remove.append(f'{r} {l}')
            elif len(l) < len(r):
                to_remove.append(f'{l} {r}')
        token_pairs.append(set([l, r]))
    for m in to_remove:
        rank.pop(m)
    return rank

In [11]:
token_to_used_merge = {}
for token, group in token_to_merges.items():
    if len(group) == 1:
        token_to_used_merge[token] = group[0]
    else:
        rank = {}
        for m in group:
            l, r = m.split(' ')
            rank[m] = get_merge_ranks(l, token_to_used_merge) + get_merge_ranks(r, token_to_used_merge)
        rank = clean_up_symmetric_merges(rank)
        rank = {k: sorted([m for m in v if m != -1]) for k, v in rank.items()}
        token_to_used_merge[token] = min(rank, key=rank.get)

In [12]:
cleaned_merges = []
for token in sorted(token_to_rank, key=token_to_rank.get):
    if token in token_to_used_merge:
        used_merge = token_to_used_merge[token]
        if model_name in ['llama', 'gemma', 'mixtral']:
            used_merge = used_merge.replace('\r', '\\r')
        cleaned_merges.append(used_merge)

In [13]:
cleaned_merges

['▁ t',
 'e r',
 'i n',
 '▁ a',
 'e n',
 'o n',
 '▁t h',
 'e s',
 '▁ s',
 '▁ d',
 'a t',
 'o r',
 'a n',
 '▁ c',
 'i s',
 'r e',
 'i t',
 '▁th e',
 'a r',
 'l e',
 '▁ w',
 '▁ p',
 'o u',
 'a l',
 '▁ f',
 '▁ m',
 'e d',
 '▁ o',
 '▁ b',
 'o m',
 'i on',
 'in g',
 'i c',
 'a s',
 'e l',
 'en t',
 '▁ in',
 '▁ h',
 'n d',
 'e t',
 '▁ l',
 '▁ n',
 's t',
 '▁t o',
 'c h',
 '▁ I',
 'r o',
 'i l',
 '▁o f',
 'd e',
 'c t',
 '▁ (',
 'a m',
 '▁ C',
 '▁d e',
 '▁ S',
 '▁ u',
 '▁ A',
 '▁ \\',
 '▁ e',
 '▁a nd',
 '▁ T',
 'o l',
 '▁ v',
 'i m',
 'o t',
 'a d',
 'u t',
 '▁ g',
 'e m',
 'u r',
 'i d',
 '▁ *',
 'i g',
 'r a',
 '▁ re',
 '▁ is',
 'q u',
 'o w',
 '▁ M',
 'es t',
 '▁ y',
 's e',
 'v e',
 'c e',
 'i e',
 'u n',
 '▁ P',
 '▁ B',
 'a g',
 'u l',
 '▁ =',
 'h e',
 'en d',
 'o de',
 't er',
 'm ent',
 'o s',
 '▁ D',
 'i f',
 'at ion',
 '▁f or',
 '▁ r',
 '▁ L',
 '▁y ou',
 '▁b e',
 'l y',
 'v er',
 'a b',
 't e',
 '▁ it',
 '▁ on',
 'r i',
 'u s',
 '▁ "',
 '▁w h',
 '▁c on',
 '▁ H',
 '▁s t',
 'i r',
 '▁ 

In [16]:
ensure_dir(f'data/llm_tokenizers/{model_name}')
with open(f'data/llm_tokenizers/{model_name}/merges.txt', 'w') as fin:
    fin.write('#version: 0.2\n')
    fin.writelines([m + '\n' for m in cleaned_merges])