In [1]:
#!pip install sentencepiece
#!pip install transformers

In [2]:
import torch
from transformers import MBart50Tokenizer
from transformers import MBartForConditionalGeneration

In [3]:
import pandas as pd
import csv
from collections import Counter
from tqdm.auto import tqdm, trange

In [4]:
from sentencepiece import sentencepiece_model_pb2 as spmp

In [5]:
model_name = "facebook/mbart-large-50-many-to-many-mmt"

In [6]:
tokenizer = MBart50Tokenizer.from_pretrained(model_name)

In [7]:
tokenizer.save_pretrained('save_models/mbart-large-50-many-to-many-mmt')

('save_models/mbart-large-50-many-to-many-mmt/tokenizer_config.json',
 'save_models/mbart-large-50-many-to-many-mmt/special_tokens_map.json',
 'save_models/mbart-large-50-many-to-many-mmt/sentencepiece.bpe.model',
 'save_models/mbart-large-50-many-to-many-mmt/added_tokens.json')

In [8]:
model = MBartForConditionalGeneration.from_pretrained(model_name)

In [9]:
def msize(m):
    return sum(p.numel() for p in m.parameters())
print(msize(model.model.shared) / msize(model))   # 0.3298
print(msize(model.model.shared) / msize(model))   # 0.3298
print(msize(model.lm_head) / msize(model))  # 0.3298

0.41915844455396084
0.41915844455396084
0.41915844455396084


In [10]:
df_ru = pd.read_csv('/mnt/extendet_data/projects/translate-chinese/texts/src-train.txt', sep="WWWQweW", nrows=5000000, names=['text'], header=None)

  df_ru = pd.read_csv('/mnt/extendet_data/projects/translate-chinese/texts/src-train.txt', sep="WWWQweW", nrows=5000000, names=['text'], header=None)


In [11]:
df_ru

Unnamed: 0,text
0,Кровавое воскресенье (Кровавое в...)
1,"учитывая также, что терроризм создает обстанов..."
2,«Настало время принять жесткие меры по нормиро...
3,Краткая информация о проделанной работе
4,Г-н Харт (Барбадос) (говорит по-английски): Пр...
...,...
4999995,"Глава компании в ответном слове отметил, что в..."
4999996,Категория: Футболисты по клубам Франции
4999997,"Однако специалисты утверждают, что и количеств..."
4999998,Он - бесстрашным разрушителем колдовских чар.


In [12]:
df_ru.dropna(inplace=True)

In [13]:
print(df_ru.text[1])

учитывая также, что терроризм создает обстановку, которая лишает людей свободы жить без страха,


In [14]:
cnt_ru = Counter()
for text in tqdm(df_ru.text):
    cnt_ru.update(tokenizer.encode(text))

  0%|          | 0/4999995 [00:00<?, ?it/s]

In [15]:
tokenizer.vocab_size

250054

In [16]:
#Русский корпус занимает в общем словаре, такую то долю. 25%
print(len(cnt_ru), len(cnt_ru)/tokenizer.vocab_size)  
# 62790 0.2511057611555904

62790 0.2511057611555904


In [21]:
df_zh = pd.read_csv('/mnt/extendet_data/projects/translate-chinese/texts/tgt-train.txt', sep="WWWQweW", nrows=5000000, names=['text'], header=None)

  df_zh = pd.read_csv('/mnt/extendet_data/projects/translate-chinese/texts/tgt-train.txt', sep="WWWQweW", nrows=5000000, names=['text'], header=None)


In [22]:
df_zh.dropna(inplace=True)

In [28]:
df_zh.sample(5)

Unnamed: 0,text
3225249,技经评估小组正努力为各种船舶的现有设备和新设备进一步审查制冷剂备选方案，并预计于2014年4...
4381118,在世界某些地区的国内产品继续享有大幅度补贴的情况下，便没有建立均衡的贸易条件的希望。
450478,多种语文和联合国
2693789,2.54 2012-2013两年期预算已提交大会。
3900110,23. 审计委员会在上一期间建议人口基金保持正确、完整的休假记录，本报告人力资源管理部分将对...


In [None]:
cnt_zh = Counter()
for text in tqdm(df_zh.text):
    cnt_zh.update(tokenizer.encode(text))

  0%|          | 0/4999974 [00:00<?, ?it/s]

In [29]:
#Китайский корпус занимает в общем словаре, такую то долю. 25%
print(len(cnt_zh), len(cnt_zh)/tokenizer.vocab_size)  
# 64575 0.25824421924864227

64575 0.25824421924864227


In [30]:
#Посмотрим какой процент всех слов мы захватим
for top in 10_000, 20_000, 30_000:
    print(top, sum(v for k, v in cnt_ru.most_common(top)) / sum(cnt_ru.values()))
# 10000 0.9645
# 20000 0.9940
# 30000 0.9982

10000 0.9669254292671258
20000 0.9945023243662758
30000 0.998594687964201


In [31]:
#Посмотрим какой процент всех слов мы захватим
for top in 10_000, 20_000, 30_000:
    print(top, sum(v for k, v in cnt_zh.most_common(top)) / sum(cnt_zh.values()))
# 10000 0.9645
# 20000 0.9940
# 30000 0.9982

10000 0.9727450863454133
20000 0.9930192174410284
30000 0.997806606426593


In [32]:
#Формируем список всех слов new_tokens
new_tokens = set(range(1000))
for i, (k, v) in enumerate(cnt_zh.most_common(30_000)):
    if k not in new_tokens:
        new_tokens.add(k)
for i, (k, v) in enumerate(cnt_ru.most_common(25_000)):
    if k not in new_tokens:
        new_tokens.add(k)

In [33]:
#Добавим 100 последних токенов
for t in range(tokenizer.vocab_size - 100, tokenizer.vocab_size):
    new_tokens.add(t)

In [34]:
print(len(new_tokens))
kept_ids = sorted(new_tokens)

45241


In [35]:
new_size = len(kept_ids)

In [36]:
new_emb = torch.nn.Embedding(new_size, model.model.shared.embedding_dim)
new_emb

Embedding(45241, 1024)

In [37]:
new_embed_tokens = torch.nn.Embedding(new_size, model.model.shared.embedding_dim, padding_idx=1)
new_embed_tokens

Embedding(45241, 1024, padding_idx=1)

In [38]:
new_decoder_tokens = torch.nn.Embedding(new_size, model.model.shared.embedding_dim, padding_idx=1)
new_decoder_tokens

Embedding(45241, 1024, padding_idx=1)

In [39]:
new_head = torch.nn.Linear(in_features=model.lm_head.in_features, out_features=new_size, bias=False)
new_head

Linear(in_features=1024, out_features=45241, bias=False)

In [41]:
for new_id, old_id in enumerate(kept_ids):
    new_emb.weight.data[new_id] = model.model.shared.weight.data[old_id]
    new_head.weight.data[new_id] = model.lm_head.weight.data[old_id]
    new_embed_tokens.weight.data[new_id] = model.model.encoder.embed_tokens.weight.data[old_id]
    new_decoder_tokens.weight.data[new_id] = model.model.decoder.embed_tokens.weight.data[old_id]
    
model.model.shared = new_emb
model.lm_head = new_head
model.model.encoder.embed_tokens = new_embed_tokens
model.model.decoder.embed_tokens = new_decoder_tokens
    
model.model.shared.weight = new_emb.weight
model.lm_head.weight = new_head.weight
model.model.encoder.embed_tokens.weight = new_embed_tokens.weight
model.model.decoder.embed_tokens.weight = new_decoder_tokens.weight

In [42]:
model.config.__dict__['vocab_size'] = new_size
model.config.__dict__['_name_or_path'] = 'joefox/mbart-large-ru-zh-many-to-many-mmt'

In [43]:
sp_model = tokenizer.sp_model
smp = tokenizer.sp_model.serialized_model_proto()
m = spmp.ModelProto()
m.ParseFromString(smp)

5069051

In [45]:
max_ids = max(kept_ids)
max_ids

250053

In [46]:
add_ids = max_ids - len(m.pieces)
add_ids

53

In [47]:
m_ids = []
m_token = []
print('the loaded model has pieces:', len(m.pieces))
#new_pieces = [m.pieces[idx] for idx in kept_ids]
new_pieces = []
for idx in kept_ids:
    if idx < len(m.pieces) and (idx > 3):
        token = tokenizer._convert_id_to_token(idx)
        spm_id = sp_model.PieceToId(token)
        m_ids.append(spm_id)
        m_token.append(token)
        new_pieces.append(m.pieces[spm_id])
        pass
print('the new pieces:', len(new_pieces))

the loaded model has pieces: 250000
the new pieces: 45183


In [48]:
len(new_pieces), len(new_pieces)+ add_ids

(45183, 45236)

In [49]:
list_ids = []
# replace the content of the first 30K pieces
for i, p in enumerate(new_pieces):
    # if i in fairseq_tokens:
    idx = i + 3
    # else:
    #    idx = i+tokenizer.fairseq_offset
    m.pieces[idx].piece = p.piece
    m.pieces[idx].score = p.score
    m.pieces[idx].type = p.type
    list_ids.append(idx)

In [50]:
len(m.pieces), len(new_pieces)

(250000, 45183)

In [51]:
# drop the remaining pieces
n = len(new_pieces) + 4
for i in trange(len(m.pieces) - n):
    m.pieces.pop(len(m.pieces) - 1)

  0%|          | 0/204813 [00:00<?, ?it/s]

In [52]:
len(m.pieces)

45187

In [53]:
print(len(m.pieces))
with open('new_sp.model', 'wb') as f:
    f.write(m.SerializeToString())


45187


In [54]:
new_tokenizer = MBart50Tokenizer('new_sp.model')
new_tokenizer

MBart50Tokenizer(name_or_path='', vocab_size=45241, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True), 'additional_special_tokens': ['ar_AR', 'cs_CZ', 'de_DE', 'en_XX', 'es_XX', 'et_EE', 'fi_FI', 'fr_XX', 'gu_IN', 'hi_IN', 'it_IT', 'ja_XX', 'kk_KZ', 'ko_KR', 'lt_LT', 'lv_LV', 'my_MM', 'ne_NP', 'nl_XX', 'ro_RO', 'ru_RU', 'si_LK', 'tr_TR', 'vi_VN', 'zh_CN', 'af_ZA', 'az_AZ', 'bn_IN', 'fa_IR', 'he_IL', 'hr_HR', 'id_ID', 'ka_GE', 'km_KH', 'mk_MK', 'ml_IN', 'mn_MN', 'mr_IN', 'pl_PL', 'ps_AF', 'pt_XX', 'sv_SE', 'sw_KE', 'ta_IN', 'te_IN', 'th_TH', 'tl_XX', 'uk_UA', 'ur_PK', 'xh_ZA', 'gl_ES', 'sl_SI']})

In [55]:
tokenizer.vocab_file

'/srv/dev-disk-by-uuid-fa7c0d11-e6eb-484b-858c-52bb87843d42/home/joefox/.cache/huggingface/hub/models--facebook--mbart-large-50-many-to-many-mmt/snapshots/e2cfb9f4d0cfb8879734094041a08c37397b3177/sentencepiece.bpe.model'

In [56]:
new_tokenizer.save_pretrained('save_models/mbart-large-ru-zh-many-to-many-mmt')
model.save_pretrained('save_models/mbart-large-ru-zh-many-to-many-mmt')

In [58]:
model = MBartForConditionalGeneration.from_pretrained('save_models/mbart-large-ru-zh-many-to-many-mmt', ignore_mismatched_sizes=True)

Some weights of MBartForConditionalGeneration were not initialized from the model checkpoint at save_models/mbart-large-ru-zh-many-to-many-mmt and are newly initialized because the shapes did not match:
- final_logits_bias: found shape torch.Size([1, 250054]) in the checkpoint and torch.Size([1, 45241]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
