## Filtering the vocabulary of the mT5 SentencePiece tokenizer

In [None]:
from transformers import AutoTokenizer, T5TokenizerFast

In [None]:
model_id = "google/mt5-small"
tokenizer = AutoTokenizer.from_pretrained(model_id, force_download=True, legacy=False)

Load previously saved dataset and tokenize it

In [None]:
from datasets import load_from_disk

german_ds = load_from_disk("german_ds")

def convert_to_tokens(source):
    return tokenizer(source["text"])

tokenized_ds = german_ds.map(convert_to_tokens, remove_columns=['text'])

In [None]:
input_ids = tokenized_ds["train"]["input_ids"]
input_ids.extend(tokenized_ds["test"]["input_ids"])

Get a flat list of all input ids from the dataset

In [None]:
import itertools

full_vocabulary = tokenizer.get_vocab()
token_list = input_ids
token_list = list(itertools.chain(*token_list))

In [None]:
len(token_list)

Set the target vocabulary size and sort the list of token ids by frequency 

In [None]:
import collections

target_vocab_size = 32000 # T5 vocabulary size
counts = collections.Counter(token_list)
sorted_list = sorted(token_list, key=counts.get, reverse=True)

In [None]:
tokenizer.convert_ids_to_tokens(sorted_list[0])

Get rid of duplicate ids so that the result is a list of all unique token ids sorted by frequency.

In [None]:
seen = set()
seen_add = seen.add
filtered_sorted_list = [x for x in sorted_list if (not (x in seen or seen_add(x))) and x > 258]

Take the first 32.000 token ids from the list.

In [None]:
filtered_vocab_ids = filtered_sorted_list[:32000]

Load the SentencePiece model from the tokenizer, which is needed to get the corresponding sentencepiece for each token id

In [None]:
from sentencepiece import sentencepiece_model_pb2 as sp_model 

def load_spm_protopub():
    m = sp_model.ModelProto()
    m.ParseFromString(open(tokenizer.vocab_file, 'rb').read())
    return m

m = load_spm_protopub()

# There are some reserved places for special tokens up until index 258
for i, piece in enumerate(m.pieces[:320]):
    print(i, piece.piece)

 Find the sentencepieces that should be kept based on the list of filtered token ids

In [None]:
from tqdm import tqdm

kept_pieces, i = [], len(m.pieces) - 1

progress_bar = tqdm(total=len(m.pieces))
while len(m.pieces):
    piece = m.pieces.pop()
    if i < 259 or i in filtered_vocab_ids:
        kept_pieces.append(piece)
    i -= 1
    progress_bar.update(1)
kept_pieces = list(reversed(kept_pieces))

Get the piece and the score for each kept sentencepiece

In [None]:
m = load_spm_protopub()

for i in range(len(m.pieces) - len(kept_pieces)): _ = m.pieces.pop()
print(len(m.pieces))

i = 0
for p in m.pieces:
    p.piece = kept_pieces[i].piece
    p.score = kept_pieces[i].score
    i += 1
    

In [None]:
len(m.pieces)

Save the trimmed sentencepiece model 

In [None]:
with open("spiece.model", "wb") as f:
    f.write(m.SerializeToString())

Save a file containing the ids of the tokens that are staying in the vocabulary. This is needed to later on adjust the embedding layer of a model that is to be trained with the trimmed vocabulary.

In [None]:
import json

kept_ids = sorted(list(set(filtered_vocab_ids).union(set(range(259)))))
print(len(kept_ids))
with open("kept_ids.json", 'w') as f:
    json.dump(kept_ids, f)

Construct a new tokenizer from the trimmed sentencepiece model using T5TokenizerFast and save it

In [None]:
new_tokenizer = T5TokenizerFast(vocab_file="spiece.model", extra_ids=0, legacy=False)
tokenizer_path = "filtered_tokenizer"
new_tokenizer.save_pretrained(tokenizer_path)