In [None]:
import torch
from transformers import MT5ForConditionalGeneration, T5Tokenizer
from tqdm.auto import tqdm, trange

In [None]:
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
model = MT5ForConditionalGeneration.from_pretrained('google/mt5-small')

In [None]:
def msize(m):
    return sum(p.numel() for p in m.parameters())

print(msize(model.shared) / msize(model))
print(msize(model.lm_head) / msize(model))

In [None]:
from transformers import T5ForConditionalGeneration
model = T5ForConditionalGeneration.from_pretrained('t5-small')

def msize(m):
    return sum(p.numel() for p in m.parameters())

print(msize(model.shared) / msize(model))
print(msize(model.lm_head) / msize(model))

## Selecting the vocabulary

In [None]:
NAME_CORPUS = 'Spanish_T5'
path_spanish_medical_texts  = 'my_path'

In [None]:
from collections import Counter
import os

cnt_spa = Counter()

path_selected = path_spanish_medical_texts

for file in os.listdir(path_selected):
    with open(os.path.join(path_selected, file), 'r') as ftext:
        text = ftext.read()
        
    cnt_spa.update(tokenizer.encode(text))

print(len(cnt_spa), len(cnt_spa)/tokenizer.vocab_size) 

In [None]:
print("Total vocab", tokenizer.vocab_size)

In [None]:
for top in 10_000, 20_000, 30_000:
    print(top, sum(v for k, v in cnt_spa.most_common(top)) / sum(cnt_spa.values()))

In [None]:
for word_id, freq in cnt_spa.most_common(30):
    print(tokenizer.decode(word_id), freq)

In [None]:
s=""

# using the top 30k tokens
for t in cnt_spa.most_common(30_000):
    s += "{}\t{}\t{}\n".format(t[0], tokenizer.decode(t[0]), t[1])


with open(f'30_000_ESP_'+NAME_CORPUS+'_T5small.txt', 'w', encoding='utf-8') as fp:
    fp.write(s)

In [None]:
new_tokens = set(range(1000)) # 1K of top tokens of the original tokenizer (just in case)

# Top 30K of the Spanish vocabulary
for i, (k, v) in enumerate(cnt_spa.most_common(30_000)):
    
    if len(new_tokens) == 29_900:
        print(i, 'Spanish tokens are included')
        break
    
    if k not in new_tokens:
        new_tokens.add(k)
        
# The 100 special tokens that T5 uses
for t in range(tokenizer.vocab_size - 100, tokenizer.vocab_size):
    new_tokens.add(t)

print(len(new_tokens))
kept_ids = sorted(new_tokens)

## Updating the model

In [None]:
new_size = len(kept_ids)
new_emb = torch.nn.Embedding(new_size, model.shared.embedding_dim)
new_head = torch.nn.Linear(in_features=model.lm_head.in_features, out_features=new_size, bias=False)
for new_id, old_id in enumerate(kept_ids):
    new_emb.weight.data[new_id] = model.shared.weight.data[old_id]
    new_head.weight.data[new_id] = model.lm_head.weight.data[old_id]

model.shared.weight = new_emb.weight
model.lm_head.weight = new_head.weight
model.config.__dict__['vocab_size'] = new_size
model.config.__dict__['_name_or_path'] = NAME_CORPUS + '/es5-small'

## Updating the tokenizer

In [None]:
import sentencepiece_model_pb2 as spmp

smp = tokenizer.sp_model.serialized_model_proto()
m = spmp.ModelProto()
m.ParseFromString(smp)

print('the loaded model has pieces:', len(m.pieces))
new_pieces = [m.pieces[idx] for idx in kept_ids]

print('the new pieces:', len(new_pieces))

# replace the content of the first 30K pieces
for i, p in enumerate(new_pieces):
    m.pieces[i].piece = p.piece
    m.pieces[i].score = p.score
    m.pieces[i].type = p.type

# drop the remaining pieces
n = len(new_pieces)

for i in trange(len(m.pieces) - n):
    m.pieces.pop(len(m.pieces) - 1)

print(len(m.pieces))

with open(NAME_CORPUS + '_es5-small.model', 'wb') as f:
    f.write(m.SerializeToString())

new_tokenizer = T5Tokenizer(NAME_CORPUS + '_es5-small.model', extra_ids=100)

In [None]:
new_tokenizer = T5Tokenizer(NAME_CORPUS+'_es5-small.model', extra_ids=100)
new_tokenizer.save_pretrained(NAME_CORPUS + '_small/espt5-' +NAME_CORPUS+ 'small')
model.save_pretrained('espt5-' +NAME_CORPUS+ 'small')