In [2]:
import os

import torch
import sentencepiece

from fairseq import checkpoint_utils

In [31]:
codes_as_string = '''Acehnese (Arabic script)	ace_Arab
Acehnese (Latin script)	ace_Latn
Mesopotamian Arabic	acm_Arab
Ta’izzi-Adeni Arabic	acq_Arab
Tunisian Arabic	aeb_Arab
Afrikaans	afr_Latn
South Levantine Arabic	ajp_Arab
Akan	aka_Latn
Amharic	amh_Ethi
North Levantine Arabic	apc_Arab
Modern Standard Arabic	arb_Arab
Modern Standard Arabic (Romanized)	arb_Latn
Najdi Arabic	ars_Arab
Moroccan Arabic	ary_Arab
Egyptian Arabic	arz_Arab
Assamese	asm_Beng
Asturian	ast_Latn
Awadhi	awa_Deva
Central Aymara	ayr_Latn
South Azerbaijani	azb_Arab
North Azerbaijani	azj_Latn
Bashkir	bak_Cyrl
Bambara	bam_Latn
Balinese	ban_Latn
Belarusian	bel_Cyrl
Bemba	bem_Latn
Bengali	ben_Beng
Bhojpuri	bho_Deva
Banjar (Arabic script)	bjn_Arab
Banjar (Latin script)	bjn_Latn
Standard Tibetan	bod_Tibt
Bosnian	bos_Latn
Buginese	bug_Latn
Bulgarian	bul_Cyrl
Catalan	cat_Latn
Cebuano	ceb_Latn
Czech	ces_Latn
Chokwe	cjk_Latn
Central Kurdish	ckb_Arab
Crimean Tatar	crh_Latn
Welsh	cym_Latn
Danish	dan_Latn
German	deu_Latn
Southwestern Dinka	dik_Latn
Dyula	dyu_Latn
Dzongkha	dzo_Tibt
Greek	ell_Grek
English	eng_Latn
Esperanto	epo_Latn
Estonian	est_Latn
Basque	eus_Latn
Ewe	ewe_Latn
Faroese	fao_Latn
Fijian	fij_Latn
Finnish	fin_Latn
Fon	fon_Latn
French	fra_Latn
Friulian	fur_Latn
Nigerian Fulfulde	fuv_Latn
Scottish Gaelic	gla_Latn
Irish	gle_Latn
Galician	glg_Latn
Guarani	grn_Latn
Gujarati	guj_Gujr
Haitian Creole	hat_Latn
Hausa	hau_Latn
Hebrew	heb_Hebr
Hindi	hin_Deva
Chhattisgarhi	hne_Deva
Croatian	hrv_Latn
Hungarian	hun_Latn
Armenian	hye_Armn
Igbo	ibo_Latn
Ilocano	ilo_Latn
Indonesian	ind_Latn
Icelandic	isl_Latn
Italian	ita_Latn
Javanese	jav_Latn
Japanese	jpn_Jpan
Kabyle	kab_Latn
Jingpho	kac_Latn
Kamba	kam_Latn
Kannada	kan_Knda
Kashmiri (Arabic script)	kas_Arab
Kashmiri (Devanagari script)	kas_Deva
Georgian	kat_Geor
Central Kanuri (Arabic script)	knc_Arab
Central Kanuri (Latin script)	knc_Latn
Kazakh	kaz_Cyrl
Kabiyè	kbp_Latn
Kabuverdianu	kea_Latn
Khmer	khm_Khmr
Kikuyu	kik_Latn
Kinyarwanda	kin_Latn
Kyrgyz	kir_Cyrl
Kimbundu	kmb_Latn
Northern Kurdish	kmr_Latn
Kikongo	kon_Latn
Korean	kor_Hang
Lao	lao_Laoo
Ligurian	lij_Latn
Limburgish	lim_Latn
Lingala	lin_Latn
Lithuanian	lit_Latn
Lombard	lmo_Latn
Latgalian	ltg_Latn
Luxembourgish	ltz_Latn
Luba-Kasai	lua_Latn
Ganda	lug_Latn
Luo	luo_Latn
Mizo	lus_Latn
Standard Latvian	lvs_Latn
Magahi	mag_Deva
Maithili	mai_Deva
Malayalam	mal_Mlym
Marathi	mar_Deva
Minangkabau (Arabic script)	min_Arab
Minangkabau (Latin script)	min_Latn
Macedonian	mkd_Cyrl
Plateau Malagasy	plt_Latn
Maltese	mlt_Latn
Meitei (Bengali script)	mni_Beng
Halh Mongolian	khk_Cyrl
Mossi	mos_Latn
Maori	mri_Latn
Burmese	mya_Mymr
Dutch	nld_Latn
Norwegian Nynorsk	nno_Latn
Norwegian Bokmål	nob_Latn
Nepali	npi_Deva
Northern Sotho	nso_Latn
Nuer	nus_Latn
Nyanja	nya_Latn
Occitan	oci_Latn
West Central Oromo	gaz_Latn
Odia	ory_Orya
Pangasinan	pag_Latn
Eastern Panjabi	pan_Guru
Papiamento	pap_Latn
Western Persian	pes_Arab
Polish	pol_Latn
Portuguese	por_Latn
Dari	prs_Arab
Southern Pashto	pbt_Arab
Ayacucho Quechua	quy_Latn
Romanian	ron_Latn
Rundi	run_Latn
Russian	rus_Cyrl
Sango	sag_Latn
Sanskrit	san_Deva
Santali	sat_Olck
Sicilian	scn_Latn
Shan	shn_Mymr
Sinhala	sin_Sinh
Slovak	slk_Latn
Slovenian	slv_Latn
Samoan	smo_Latn
Shona	sna_Latn
Sindhi	snd_Arab
Somali	som_Latn
Southern Sotho	sot_Latn
Spanish	spa_Latn
Tosk Albanian	als_Latn
Sardinian	srd_Latn
Serbian	srp_Cyrl
Swati	ssw_Latn
Sundanese	sun_Latn
Swedish	swe_Latn
Swahili	swh_Latn
Silesian	szl_Latn
Tamil	tam_Taml
Tatar	tat_Cyrl
Telugu	tel_Telu
Tajik	tgk_Cyrl
Tagalog	tgl_Latn
Thai	tha_Thai
Tigrinya	tir_Ethi
Tamasheq (Latin script)	taq_Latn
Tamasheq (Tifinagh script)	taq_Tfng
Tok Pisin	tpi_Latn
Tswana	tsn_Latn
Tsonga	tso_Latn
Turkmen	tuk_Latn
Tumbuka	tum_Latn
Turkish	tur_Latn
Twi	twi_Latn
Central Atlas Tamazight	tzm_Tfng
Uyghur	uig_Arab
Ukrainian	ukr_Cyrl
Umbundu	umb_Latn
Urdu	urd_Arab
Northern Uzbek	uzn_Latn
Venetian	vec_Latn
Vietnamese	vie_Latn
Waray	war_Latn
Wolof	wol_Latn
Xhosa	xho_Latn
Eastern Yiddish	ydd_Hebr
Yoruba	yor_Latn
Yue Chinese	yue_Hant
Chinese (Simplified)	zho_Hans
Chinese (Traditional)	zho_Hant
Standard Malay	zsm_Latn
Zulu	zul_Latn'''

codes_as_string = codes_as_string.split('\n')

flores_codes = {}
for code in codes_as_string:
    lang, lang_code = code.split('\t')
    flores_codes[lang] = lang_code

In [35]:
print(*flores_codes.values(), sep='\n')

ace_Arab
ace_Latn
acm_Arab
acq_Arab
aeb_Arab
afr_Latn
ajp_Arab
aka_Latn
amh_Ethi
apc_Arab
arb_Arab
arb_Latn
ars_Arab
ary_Arab
arz_Arab
asm_Beng
ast_Latn
awa_Deva
ayr_Latn
azb_Arab
azj_Latn
bak_Cyrl
bam_Latn
ban_Latn
bel_Cyrl
bem_Latn
ben_Beng
bho_Deva
bjn_Arab
bjn_Latn
bod_Tibt
bos_Latn
bug_Latn
bul_Cyrl
cat_Latn
ceb_Latn
ces_Latn
cjk_Latn
ckb_Arab
crh_Latn
cym_Latn
dan_Latn
deu_Latn
dik_Latn
dyu_Latn
dzo_Tibt
ell_Grek
eng_Latn
epo_Latn
est_Latn
eus_Latn
ewe_Latn
fao_Latn
fij_Latn
fin_Latn
fon_Latn
fra_Latn
fur_Latn
fuv_Latn
gla_Latn
gle_Latn
glg_Latn
grn_Latn
guj_Gujr
hat_Latn
hau_Latn
heb_Hebr
hin_Deva
hne_Deva
hrv_Latn
hun_Latn
hye_Armn
ibo_Latn
ilo_Latn
ind_Latn
isl_Latn
ita_Latn
jav_Latn
jpn_Jpan
kab_Latn
kac_Latn
kam_Latn
kan_Knda
kas_Arab
kas_Deva
kat_Geor
knc_Arab
knc_Latn
kaz_Cyrl
kbp_Latn
kea_Latn
khm_Khmr
kik_Latn
kin_Latn
kir_Cyrl
kmb_Latn
kmr_Latn
kon_Latn
kor_Hang
lao_Laoo
lij_Latn
lim_Latn
lin_Latn
lit_Latn
lmo_Latn
ltg_Latn
ltz_Latn
lua_Latn
lug_Latn
luo_Latn
lus_Latn
l

In [38]:
flores_codes = {
    'en': 'eng_Latn',
    'mt': 'mlt_Latn'
}

In [19]:
model_root = '/mnt/data/siqiouyang/runs/ConST/pretrained/nllb'

In [23]:
model = torch.load(os.path.join(model_root, 'nllb200densedst600mcheckpoint.pt'))

In [43]:
model['cfg']['model'].no_scale_embedding

False

In [49]:
print(*list(model['model'].keys()), sep='\n')

encoder.version
encoder.embed_tokens.weight
encoder.embed_positions._float_tensor
encoder.layers.0.self_attn.k_proj.weight
encoder.layers.0.self_attn.k_proj.bias
encoder.layers.0.self_attn.v_proj.weight
encoder.layers.0.self_attn.v_proj.bias
encoder.layers.0.self_attn.q_proj.weight
encoder.layers.0.self_attn.q_proj.bias
encoder.layers.0.self_attn.out_proj.weight
encoder.layers.0.self_attn.out_proj.bias
encoder.layers.0.self_attn_layer_norm.weight
encoder.layers.0.self_attn_layer_norm.bias
encoder.layers.0.fc1.weight
encoder.layers.0.fc1.bias
encoder.layers.0.fc2.weight
encoder.layers.0.fc2.bias
encoder.layers.0.final_layer_norm.weight
encoder.layers.0.final_layer_norm.bias
encoder.layers.1.self_attn.k_proj.weight
encoder.layers.1.self_attn.k_proj.bias
encoder.layers.1.self_attn.v_proj.weight
encoder.layers.1.self_attn.v_proj.bias
encoder.layers.1.self_attn.q_proj.weight
encoder.layers.1.self_attn.q_proj.bias
encoder.layers.1.self_attn.out_proj.weight
encoder.layers.1.self_attn.out_proj

In [18]:
model['model']['encoder.embed_tokens.weight'].size()

torch.Size([256206, 1024])

In [20]:
spm = sentencepiece.SentencePieceProcessor()
spm.Load(os.path.join(model_root, 'flores200sacrebleuspm.model'))

True

In [50]:
spm.piece_to_id('<pad>')

0

In [3]:
import transformers

In [76]:
tokenizer = transformers.NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", src_lang="mlt_Latn", tgt_lang="eng_Latn")

In [84]:
src_text = "M'għandniex inħalluh haw'."
tgt_text = "We should not leave it here."
inputs = tokenizer(src_text, return_tensors="pt")

In [89]:
translated_tokens = hf_nllb.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["eng_Latn"], max_length=30)

In [91]:
tokenizer.batch_decode(translated_tokens)

['</s>eng_Latn We should not leave him here.</s>']

In [9]:
tokenizer.convert_ids_to_tokens(2)

'</s>'

In [57]:
hf_nllb = transformers.AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")

In [58]:
hf_nllb_tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")

Downloading tokenizer.json:   0%|          | 0.00/17.3M [00:00<?, ?B/s]

In [75]:
hf_nllb_tokenizer.convert_ids_to_tokens(256206)

In [32]:
for item in hf_nllb.named_parameters():
    print(item)

('model.shared.weight', Parameter containing:
tensor([[-0.0321,  0.0348,  0.0181,  ...,  0.0312, -0.0099, -0.0133],
        [-0.0039,  0.0104, -0.0156,  ...,  0.0290, -0.0138, -0.0134],
        [-0.0245, -0.0283, -0.0295,  ...,  0.9712, -0.0255, -0.0273],
        ...,
        [-0.0123, -0.0031, -0.0089,  ...,  0.0645, -0.0182, -0.0740],
        [ 0.0085, -0.0088, -0.0091,  ...,  0.0571, -0.0035, -0.1298],
        [-0.0076, -0.0107, -0.0051,  ...,  1.0264, -0.0338, -0.1175]],
       requires_grad=True))
('model.encoder.layers.0.self_attn.k_proj.weight', Parameter containing:
tensor([[ 5.9570e-01,  6.5479e-01,  5.9863e-01,  ...,  1.2024e-01,
         -1.4905e-01, -1.1023e-01],
        [ 3.1909e-01, -1.8079e-01, -2.6294e-01,  ..., -2.0264e-01,
          1.9861e-01,  4.2529e-01],
        [ 5.2148e-01,  8.4229e-01,  9.9951e-01,  ...,  2.0679e-01,
          2.0203e-01, -5.2261e-03],
        ...,
        [ 1.3794e-01, -1.1102e-01, -5.5176e-01,  ..., -3.5400e-03,
          4.2267e-02, -3.9399e

In [34]:
hf_nllb.model.shared.weight.size()

torch.Size([256206, 1024])

In [43]:
model['model']['encoder.embed_tokens.weight']

tensor([[-0.0321,  0.0348,  0.0181,  ...,  0.0312, -0.0099, -0.0133],
        [-0.0039,  0.0104, -0.0156,  ...,  0.0290, -0.0138, -0.0134],
        [-0.0245, -0.0283, -0.0295,  ...,  0.9712, -0.0255, -0.0273],
        ...,
        [-0.0123, -0.0031, -0.0089,  ...,  0.0645, -0.0182, -0.0740],
        [ 0.0085, -0.0088, -0.0091,  ...,  0.0571, -0.0035, -0.1298],
        [-0.0076, -0.0107, -0.0051,  ...,  1.0264, -0.0338, -0.1175]],
       dtype=torch.float16)

In [48]:
model['cfg']['model']

Namespace(_name='transformer', activation_dropout=0.0, activation_fn='relu', adam_betas='(0.9, 0.98)', adam_eps=1e-06, adaptive_input=False, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, add_data_source_prefix_tags=True, add_ssl_task_tokens=False, all_gather_list_size=16384, alpha_ce=0.5, amp=False, amp_batch_retries=2, amp_init_scale=128, amp_scale_window=None, arch='transformer', attention_dropout=0.1, azureml_logging=False, batch_size=None, batch_size_valid=None, best_checkpoint_metric='nll_loss', bf16=False, block_wise=False, bpe=None, broadcast_buffers=False, bucket_cap_mb=25, checkpoint_activations=False, checkpoint_shard_count=1, checkpoint_suffix='', clip_norm=0.0, combine_valid_subsets=None, continue_once=None, cpu=False, cpu_offload=False, criterion='soft_label_smoothed_cross_entropy', cross_self_attention=False, curriculum=0, data='/data/nllb/nllb/flores200.en_xx_en.v4.4.256k/data_bin/shard000:/data/nllb/nllb/flores200.en_xx_en.v4.4.256k/data_bin/shard001:/data/n

In [50]:
xlsr = torch.load('/mnt/data/siqiouyang/runs/mST/pretrained/xlsr2_300m.pt')

In [55]:
xlsr.keys()

dict_keys(['args', 'cfg', 'model', 'criterion', 'optimizer_history', 'task_state', 'extra_state', 'last_optimizer_state'])

In [56]:
xlsr['model']

{'mask_emb': tensor([ 0.3542, -0.0494,  0.1383,  ...,  0.1020, -0.0372, -0.3062]),
 'feature_extractor.conv_layers.0.0.weight': tensor([[[ 2.8641e-02, -2.3376e-02,  8.2932e-03,  ...,  5.4230e-02,
            1.2428e-02, -1.9028e-02]],
 
         [[-1.2030e-01,  1.2115e-01,  3.9490e-02,  ...,  1.0455e-01,
           -2.0935e-01,  9.1431e-02]],
 
         [[-3.5248e-02,  1.2421e-01, -1.8872e-01,  ...,  1.8066e-01,
           -1.0065e-01,  3.0670e-02]],
 
         ...,
 
         [[-4.5624e-02,  1.2128e-01, -1.2976e-01,  ..., -7.0038e-03,
           -8.4381e-03,  8.2626e-03]],
 
         [[ 2.8181e-04, -1.7670e-02,  5.6915e-02,  ..., -1.0803e-01,
            1.8001e-04,  2.4429e-02]],
 
         [[ 3.7445e-02, -8.4229e-02,  5.0201e-02,  ..., -5.6915e-02,
           -3.2013e-02,  3.4363e-02]]]),
 'feature_extractor.conv_layers.0.0.bias': tensor([-0.0159, -0.0166, -0.0268, -0.0550, -0.0154, -0.0116, -0.0149, -0.1331,
         -0.0151, -0.0453, -0.0158, -0.0152, -0.0401, -0.0160, -0.0302, -0

In [96]:
a = [1, 2, 3]
a[-1:-1:1]

[]

NameError: name 'torch' is not defined

In [4]:
model = torch.nn.Linear(5, 4, bias=False)

In [6]:
model.weight.requires_grad = False