# Calculate Token Frequency

In [None]:
from datasets import load_from_disk
import sentencepiece as spm
import os
from collections import Counter
import json

sp = spm.SentencePieceProcessor()
sp.Load('Sentencepiece/zhtw-bpe-tok.model')

corpus = load_from_disk('corpus_cleaned')
os.makedirs("token_freq", exist_ok=True)
def count_tokens(examples, indicies):
    token_list = sp.encode_as_pieces(examples['text']) # encode_as_pieces returns List[List[str]]
    token_freq = Counter()
    for tokens in token_list:
        token_freq.update(tokens)

    with open(f"token_freq/{indicies[0]}.json", 'w') as f:
        json.dump(token_freq, f, indent=4, ensure_ascii=False)

corpus.map(count_tokens, batched=True, num_proc=16, batch_size=10_000, with_indices=True)

In [1]:
from collections import Counter
import os
import json

# load all token_freq json files
token_freq = Counter()
for file in os.listdir('token_freq'):
    if not file.endswith('.json'):
        print(f"Skipping {file}")
        continue
    with open(f'token_freq/{file}', 'r') as f:
        token_freq.update(json.load(f))

print(f"Total number of tokens: {sum(token_freq.values())}")

Total number of tokens: 3124575304


In [2]:
import plotly.graph_objects as go
import numpy as np

token_freq_sorted = sorted(token_freq.items(), key=lambda x: x[1], reverse=True)

cumulative_freq = np.cumsum([freq for token, freq in token_freq_sorted])

cumulative_percent = cumulative_freq / cumulative_freq[-1] * 100

fig = go.Figure(data=go.Scatter(x=list(range(len(token_freq_sorted))), y=cumulative_percent))
fig.update_layout(title='Cumulative Distribution of Token Frequencies',
                   xaxis_title='Token Index',
                   yaxis_title='Cumulative Percentage (%)')

fig.show()

# Prune Tokenizer

In [4]:
import sentencepiece as spm
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
from tqdm.auto import tqdm


TARGET_COVERAGE = 95.0
vocab_size = np.argmax(cumulative_percent > TARGET_COVERAGE) + 1

def is_special_token(token):
    return ((token.startswith('<') and token.endswith('>') and len(token) > 2) or
            (token.startswith('[') and token.endswith(']') and len(token) > 2))


chinese_sp_model = spm.SentencePieceProcessor()
chinese_sp_model.Load("./Sentencepiece/zhtw-bpe-tok.model")

chinese_spm = sp_pb2_model.ModelProto()
chinese_spm.ParseFromString(chinese_sp_model.serialized_model_proto())

print(f"Target Vocab size: {vocab_size}, Current Vocab size: {len(chinese_spm.pieces)}")

keep = set()
for idx, piece in enumerate(token_freq.keys()):
    if idx > vocab_size:
        break
    else:
        keep.add(piece)

print(f"Keeping special tokens: ", end="")
for piece in chinese_spm.pieces:
    if is_special_token(piece.piece):
        keep.add(piece.piece)
        print(piece.piece, end=", ")
print()

for piece in tqdm(chinese_spm.pieces[:]):
    if piece.piece in keep:
        continue
    else:
        chinese_spm.pieces.remove(piece)

with open("Sentencepiece/zhtw-bpe-tok-95.model", 'wb') as f:
    f.write(chinese_spm.SerializeToString())

print(f"Saved with new Vocab size: {len(chinese_spm.pieces)}")

Target Vocab size: 31072, Current Vocab size: 64000
Keeping special tokens: <unk>, <s>, </s>, <0x00>, <0x01>, <0x02>, <0x03>, <0x04>, <0x05>, <0x06>, <0x07>, <0x08>, <0x09>, <0x0A>, <0x0B>, <0x0C>, <0x0D>, <0x0E>, <0x0F>, <0x10>, <0x11>, <0x12>, <0x13>, <0x14>, <0x15>, <0x16>, <0x17>, <0x18>, <0x19>, <0x1A>, <0x1B>, <0x1C>, <0x1D>, <0x1E>, <0x1F>, <0x20>, <0x21>, <0x22>, <0x23>, <0x24>, <0x25>, <0x26>, <0x27>, <0x28>, <0x29>, <0x2A>, <0x2B>, <0x2C>, <0x2D>, <0x2E>, <0x2F>, <0x30>, <0x31>, <0x32>, <0x33>, <0x34>, <0x35>, <0x36>, <0x37>, <0x38>, <0x39>, <0x3A>, <0x3B>, <0x3C>, <0x3D>, <0x3E>, <0x3F>, <0x40>, <0x41>, <0x42>, <0x43>, <0x44>, <0x45>, <0x46>, <0x47>, <0x48>, <0x49>, <0x4A>, <0x4B>, <0x4C>, <0x4D>, <0x4E>, <0x4F>, <0x50>, <0x51>, <0x52>, <0x53>, <0x54>, <0x55>, <0x56>, <0x57>, <0x58>, <0x59>, <0x5A>, <0x5B>, <0x5C>, <0x5D>, <0x5E>, <0x5F>, <0x60>, <0x61>, <0x62>, <0x63>, <0x64>, <0x65>, <0x66>, <0x67>, <0x68>, <0x69>, <0x6A>, <0x6B>, <0x6C>, <0x6D>, <0x6E>, <0x6F>, <0x70>, <0

100%|██████████| 64000/64000 [03:19<00:00, 321.06it/s] 

Saved with new Vocab size: 31261





# Merge Tokenizer with Llama

In [3]:
import argparse
import os

import sentencepiece as spm
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
from transformers import GemmaTokenizer, LlamaTokenizer

os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

llama_tokenizer_dir = "microsoft/Phi-3-mini-4k-instruct"
output_dir = "Phi-3-mini-4k-instruct-zhtw-tokenizer"

if "gemma" in llama_tokenizer_dir:
    TOK_CLASS = GemmaTokenizer
else:
    TOK_CLASS = LlamaTokenizer

llama_tokenizer_dir = llama_tokenizer_dir
chinese_sp_model_file = "Sentencepiece/zhtw-bpe-tok-95.model"

# load
llama_tokenizer = TOK_CLASS.from_pretrained(
    llama_tokenizer_dir, legacy=True)
chinese_sp_model = spm.SentencePieceProcessor()
chinese_sp_model.Load(chinese_sp_model_file)

llama_spm = sp_pb2_model.ModelProto()
llama_spm.ParseFromString(llama_tokenizer.sp_model.serialized_model_proto())
chinese_spm = sp_pb2_model.ModelProto()
chinese_spm.ParseFromString(chinese_sp_model.serialized_model_proto())

# print number of tokens
print(len(llama_tokenizer), len(chinese_sp_model))
print(llama_tokenizer.all_special_tokens)
print(llama_tokenizer.all_special_ids)
print(llama_tokenizer.special_tokens_map)

# Add Chinese tokens to LLaMA tokenizer
llama_spm_tokens_set = set(p.piece for p in llama_spm.pieces)
print(len(llama_spm_tokens_set))
print(f"Before:{len(llama_spm_tokens_set)}")
for p in chinese_spm.pieces:
    piece = p.piece
    if piece not in llama_spm_tokens_set:
        new_p = sp_pb2_model.ModelProto().SentencePiece()
        new_p.piece = piece
        new_p.score = 0
        llama_spm.pieces.append(new_p)
print(f"New model pieces: {len(llama_spm.pieces)}")

# Save
with open('chinese_llama_tmp.model', 'wb') as f:
    f.write(llama_spm.SerializeToString())
tokenizer = TOK_CLASS(vocab_file="chinese_llama_tmp.model")
os.remove('chinese_llama_tmp.model')

tokenizer.save_pretrained(output_dir)
print(f"Chinese-LLaMA tokenizer has been saved to {output_dir}")


# Test
llama_tokenizer = TOK_CLASS.from_pretrained(
    llama_tokenizer_dir, legacy=True)
chinese_llama_tokenizer = TOK_CLASS.from_pretrained(output_dir)
print(tokenizer.all_special_tokens)
print(tokenizer.all_special_ids)
print(tokenizer.special_tokens_map)
text = '''蔡英文，中華民國政治人物、法學家與律師，民主進步黨籍，現任中華民國總統。她曾擔任民主進步黨主席、行政院副院長、立法委員、大陸委員會主任委員、國家安全會議諮詢委員等職。
Tsai Ing-wen is a Taiwanese politician who has been serving as the 7th president of the Republic of China (Taiwan) since 2016.'''
print("##########\nTesting:\n", text)
print(f"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}")
print(
    f"Tokenized by Chinese-LLaMA tokenizer:{chinese_llama_tokenizer.tokenize(text)}")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


32011 31261
['<s>', '<|endoftext|>', '<unk>']
[1, 32000, 0]
{'bos_token': '<s>', 'eos_token': '<|endoftext|>', 'unk_token': '<unk>', 'pad_token': '<|endoftext|>'}
32000
Before:32000
New model pieces: 61758
Chinese-LLaMA tokenizer has been saved to Phi-3-mini-4k-instruct-zhtw-tokenizer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


['<s>', '</s>', '<unk>']
[1, 2, 0]
{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}
##########
Testing:
 蔡英文，中華民國政治人物、法學家與律師，民主進步黨籍，現任中華民國總統。她曾擔任民主進步黨主席、行政院副院長、立法委員、大陸委員會主任委員、國家安全會議諮詢委員等職。
Tsai Ing-wen is a Taiwanese politician who has been serving as the 7th president of the Republic of China (Taiwan) since 2016.
Tokenized by LLaMA tokenizer:['▁', '<0xE8>', '<0x94>', '<0xA1>', '英', '文', '，', '中', '華', '民', '國', '政', '治', '人', '物', '、', '法', '學', '家', '<0xE8>', '<0x88>', '<0x87>', '<0xE5>', '<0xBE>', '<0x8B>', '師', '，', '民', '主', '進', '<0xE6>', '<0xAD>', '<0xA5>', '<0xE9>', '<0xBB>', '<0xA8>', '<0xE7>', '<0xB1>', '<0x8D>', '，', '現', '任', '中', '華', '民', '國', '<0xE7>', '<0xB8>', '<0xBD>', '<0xE7>', '<0xB5>', '<0xB1>', '。', '<0xE5>', '<0xA5>', '<0xB9>', '<0xE6>', '<0x9B>', '<0xBE>', '<0xE6>', '<0x93>', '<0x94>', '任', '民', '主', '進', '<0xE6>', '<0xAD>', '<0xA5>', '<0xE9>', '<0xBB>', '<0xA8>', '主', '<0xE5>', '<0xB8>', '<0xAD>', '、', '行', '政', '院', '<0xE5>', '<0x89>', '<0xAF>',