From 66b63b45521db4014fc15ccad2ca86d03074f183 Mon Sep 17 00:00:00 2001 From: Xing Han Lu <21180505+xhluca@users.noreply.github.com> Date: Tue, 18 Jul 2023 01:23:49 -0400 Subject: [PATCH] Add nllb (#58) * Add NLLB-200 * Improve readme and docs * Fix tests * Bump version * Bump pytho nrequirements * Add new module * Improve tests and fix capitalization errors * Change behvaior of _resolve_lang_codes to resolve one at the time * Add new max token length default to 512, update tests * Add demo --- README.md | 43 +- demos/nllb_demo.ipynb | 1 + dl_translate/_pairs.py | 206 +++++++ dl_translate/_translation_model.py | 59 +- dl_translate/lang/__init__.py | 2 +- dl_translate/lang/nllb200.py | 205 +++++++ dl_translate/utils.py | 52 +- docs/available_languages.md | 540 ++++++++++++------ docs/index.md | 45 +- scripts/generate_langs.py | 15 +- scripts/langs_coverage/nllb200.json | 206 +++++++ scripts/render_available_langs.py | 6 +- .../templates/available_languages.md.jinja2 | 6 +- setup.py | 8 +- tests/quick/test_translation_model.py | 25 +- tests/quick/test_utils.py | 24 +- 16 files changed, 1189 insertions(+), 254 deletions(-) create mode 100644 demos/nllb_demo.ipynb create mode 100644 dl_translate/lang/nllb200.py create mode 100644 scripts/langs_coverage/nllb200.json diff --git a/README.md b/README.md index 3bb7e64..e164fab 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ *A deep learning-based translation library built on Huggingface `transformers`* 💻 [GitHub Repository](https://github.com/xhluca/dl-translate)
-📚 [Documentation](https://xhluca.github.io/dl-translate) / [Readthedocs](https://dl-translate.readthedocs.io)
+📚 [Documentation](https://xhluca.github.io/dl-translate)
🐍 [PyPi project](https://pypi.org/project/dl-translate/)
🧪 [Colab Demo](https://colab.research.google.com/github/xhluca/dl-translate/blob/main/demos/colab_demo.ipynb) / [Kaggle Demo](https://www.kaggle.com/xhlulu/dl-translate-demo/) @@ -58,24 +58,34 @@ By default, the value will be `device="auto"`, which means it will use a GPU if ### Choosing a different model -Two model families are available at the moment: [m2m100](https://huggingface.co/transformers/model_doc/m2m_100.html) and [mBART-50 Large](https://huggingface.co/transformers/master/model_doc/mbart.html), which respective allow translation across over 100 languages and 50 languages. By default, the model will select `m2m100`, but you can also explicitly choose the model by specifying the shorthand (`"m2m100"` or `"mbart50"`) or the full repository name (e.g. `"facebook/m2m100_418M"`). For example: +By default, the `m2m100` model will be used. However, there are a few options: +* [mBART-50 Large](https://huggingface.co/transformers/master/model_doc/mbart.html): Allows translations across 50 languages. +* [m2m100](https://huggingface.co/transformers/model_doc/m2m_100.html): Allows translations across 100 languages. +* [nllb-200](https://huggingface.co/docs/transformers/model_doc/nllb) (New in v0.3): Allows translations across 200 languages, and is faster than m2m100 (On RTX A6000, we can see speed up of 3x). + +Here's an example: ```python -# The following ways are equivalent -mt = dlt.TranslationModel("m2m100") # Default -mt = dlt.TranslationModel("facebook/m2m100_418M") +# The default approval +mt = dlt.TranslationModel("m2m100") # Shorthand +mt = dlt.TranslationModel("facebook/m2m100_418M") # Huggingface repo -# The following ways are equivalent +# If you want to use mBART-50 Large mt = dlt.TranslationModel("mbart50") mt = dlt.TranslationModel("facebook/mbart-large-50-many-to-many-mmt") + +# Or NLLB-200 (faster and has 200 languages) +mt = dlt.TranslationModel("nllb200") +mt = dlt.TranslationModel("facebook/nllb-200-distilled-600M") ``` Note that the language code will change depending on the model family. To find out the correct language codes, please read the doc page on available languages or run `mt.available_codes()`. -By default, `dlt.TranslationModel` will download the model from the huggingface repo for [mbart50](https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt) or [m2m100](https://huggingface.co/facebook/m2m100_418M) and cache it. It's possible to load the model from a path or a model with a similar format, but you will need to specify the `model_family`: +By default, `dlt.TranslationModel` will download the model from the huggingface repo for [mbart50](https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt), [m2m100](https://huggingface.co/facebook/m2m100_418M), or [nllb200](https://huggingface.co/facebook/nllb-200-distilled-600M) and cache it. It's possible to load the model from a path or a model with a similar format, but you will need to specify the `model_family`: ```python mt = dlt.TranslationModel("/path/to/model/directory/", model_family="mbart50") mt = dlt.TranslationModel("facebook/m2m100_1.2B", model_family="m2m100") +mt = dlt.TranslationModel("facebook/nllb-200-distilled-600M", model_family="nllb200") ``` Notes: @@ -114,8 +124,8 @@ An alternative to `mt.available_languages()` is the `dlt.utils` module. You can ```python print(dlt.utils.available_languages('mbart50')) # All languages that you can use -print(dlt.utils.available_codes('mbart50')) # Code corresponding to each language accepted -print(dlt.utils.get_lang_code_map('mbart50')) # Dictionary of lang -> code +print(dlt.utils.available_codes('m2m100')) # Code corresponding to each language accepted +print(dlt.utils.get_lang_code_map('nllb200')) # Dictionary of lang -> code ``` ### Offline usage @@ -159,7 +169,7 @@ If you have knowledge of PyTorch and Huggingface Transformers, you can access ad * **Interacting with underlying model and tokenizer**: When initializing `model`, you can pass in arguments for the underlying BART model and tokenizer with `model_options` and `tokenizer_options` respectively. You can also access the underlying `transformers` with `mt.get_transformers_model()`. * **Keyword arguments for the `generate()` method**: When running `mt.translate`, you can also give `generation_options` that is passed to the `generate()` method of the underlying transformer model. -For more information, please visit the [advanced section of the user guide](https://xhluca.github.io/dl-translate/#advanced) (also available in the [readthedocs version](https://dl-translate.readthedocs.io/en/latest/#advanced)). +For more information, please visit the [advanced section of the user guide](https://xhluca.github.io/dl-translate/#advanced). ## Acknowledgement @@ -186,6 +196,19 @@ For more information, please visit the [advanced section of the user guide](http } ``` +3. The [no language left behind](https://arxiv.org/abs/2207.04672) model, which extends NMT to 200+ languages. You can cite it here: + ``` + @misc{nllbteam2022language, + title={No Language Left Behind: Scaling Human-Centered Machine Translation}, + author={NLLB Team and Marta R. Costa-jussà and James Cross and Onur Çelebi and Maha Elbayad and Kenneth Heafield and Kevin Heffernan and Elahe Kalbassi and Janice Lam and Daniel Licht and Jean Maillard and Anna Sun and Skyler Wang and Guillaume Wenzek and Al Youngblood and Bapi Akula and Loic Barrault and Gabriel Mejia Gonzalez and Prangthip Hansanti and John Hoffman and Semarley Jarrett and Kaushik Ram Sadagopan and Dirk Rowe and Shannon Spruit and Chau Tran and Pierre Andrews and Necip Fazil Ayan and Shruti Bhosale and Sergey Edunov and Angela Fan and Cynthia Gao and Vedanuj Goswami and Francisco Guzmán and Philipp Koehn and Alexandre Mourachko and Christophe Ropers and Safiyyah Saleem and Holger Schwenk and Jeff Wang}, + year={2022}, + eprint={2207.04672}, + archivePrefix={arXiv}, + primaryClass={cs.CL} + } + ``` + + `dlt` is a wrapper with useful `utils` to save you time. For huggingface's `transformers`, the following snippet is shown as an example: ```python from transformers import MBartForConditionalGeneration, MBart50TokenizerFast diff --git a/demos/nllb_demo.ipynb b/demos/nllb_demo.ipynb new file mode 100644 index 0000000..706c6ce --- /dev/null +++ b/demos/nllb_demo.ipynb @@ -0,0 +1 @@ +{"cells":[{"cell_type":"code","execution_count":1,"metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","execution":{"iopub.execute_input":"2023-07-18T05:15:13.999614Z","iopub.status.busy":"2023-07-18T05:15:13.999228Z","iopub.status.idle":"2023-07-18T05:15:31.978108Z","shell.execute_reply":"2023-07-18T05:15:31.976681Z","shell.execute_reply.started":"2023-07-18T05:15:13.999573Z"},"trusted":true},"outputs":[],"source":["!pip install dl-translate==3.* -q"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:15:31.982361Z","iopub.status.busy":"2023-07-18T05:15:31.981992Z","iopub.status.idle":"2023-07-18T05:16:23.731908Z","shell.execute_reply":"2023-07-18T05:16:23.730776Z","shell.execute_reply.started":"2023-07-18T05:15:31.982327Z"},"trusted":true},"outputs":[],"source":["import dl_translate as dlt\n","\n","mt = dlt.TranslationModel(\"nllb200\")"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:16:23.734336Z","iopub.status.busy":"2023-07-18T05:16:23.733295Z","iopub.status.idle":"2023-07-18T05:16:28.025038Z","shell.execute_reply":"2023-07-18T05:16:28.023933Z","shell.execute_reply.started":"2023-07-18T05:16:23.734293Z"},"trusted":true},"outputs":[],"source":["text = \"Meta AI has built a single AI model capable of translating across 200 different languages with state-of-the-art quality.\"\n","\n","# The new translation is much faster than before\n","%time print(mt.translate(text, source=dlt.lang.nllb200.ENGLISH, target=dlt.lang.nllb200.FRENCH))"]},{"cell_type":"code","execution_count":4,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:16:28.028919Z","iopub.status.busy":"2023-07-18T05:16:28.028286Z","iopub.status.idle":"2023-07-18T05:16:28.717521Z","shell.execute_reply":"2023-07-18T05:16:28.716343Z","shell.execute_reply.started":"2023-07-18T05:16:28.028882Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["मेटाएआई एकमेव एआई मॉडलं निर्मितवान्, यत् 200 भिन्नभाषायां अवधीतमतमतमगुणैः अनुवादं कर्तुं समर्थः अस्ति।\n"]}],"source":["# Sanskrit is now available (previously unavailable)\n","print(mt.translate(text, source=dlt.lang.nllb200.ENGLISH, target=dlt.lang.nllb200.SANSKRIT))"]},{"cell_type":"code","execution_count":5,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:16:28.719596Z","iopub.status.busy":"2023-07-18T05:16:28.719227Z","iopub.status.idle":"2023-07-18T05:16:29.443696Z","shell.execute_reply":"2023-07-18T05:16:29.442668Z","shell.execute_reply.started":"2023-07-18T05:16:28.719560Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["Meta AI hà custruitu un solu mudellu d'AI capace di tradurisce in 200 lingue sfarenti cù qualità di u statu di l'arte.\n"]}],"source":["# Sicilian is now available (previously unavailable)\n","print(mt.translate(text, source=dlt.lang.nllb200.ENGLISH, target=dlt.lang.nllb200.SICILIAN))"]},{"cell_type":"code","execution_count":6,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:16:29.447147Z","iopub.status.busy":"2023-07-18T05:16:29.445331Z","iopub.status.idle":"2023-07-18T05:16:30.145637Z","shell.execute_reply":"2023-07-18T05:16:30.144623Z","shell.execute_reply.started":"2023-07-18T05:16:29.447108Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["基於Meta AI 建立咗一個 AI 模型 可以用最先端嘅質量翻譯到 200 個唔同語言\n"]}],"source":["# Yue Chinese is now available (previously unavailable)\n","print(mt.translate(text, source=dlt.lang.nllb200.ENGLISH, target=dlt.lang.nllb200.YUE_CHINESE))"]}],"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.12"}},"nbformat":4,"nbformat_minor":4} diff --git a/dl_translate/_pairs.py b/dl_translate/_pairs.py index 6e14a66..a269952 100644 --- a/dl_translate/_pairs.py +++ b/dl_translate/_pairs.py @@ -166,3 +166,209 @@ ("Galician", "gl_ES"), ("Slovene", "sl_SI"), ) +_PAIRS_NLLB200 = ( + ("Acehnese (Arabic script)", "ace_Arab"), + ("Acehnese (Latin script)", "ace_Latn"), + ("Mesopotamian Arabic", "acm_Arab"), + ("Ta'izzi-Adeni Arabic", "acq_Arab"), + ("Tunisian Arabic", "aeb_Arab"), + ("Afrikaans", "afr_Latn"), + ("South Levantine Arabic", "ajp_Arab"), + ("Akan", "aka_Latn"), + ("Amharic", "amh_Ethi"), + ("North Levantine Arabic", "apc_Arab"), + ("Modern Standard Arabic", "arb_Arab"), + ("Modern Standard Arabic (Romanized)", "arb_Latn"), + ("Najdi Arabic", "ars_Arab"), + ("Moroccan Arabic", "ary_Arab"), + ("Egyptian Arabic", "arz_Arab"), + ("Assamese", "asm_Beng"), + ("Asturian", "ast_Latn"), + ("Awadhi", "awa_Deva"), + ("Central Aymara", "ayr_Latn"), + ("South Azerbaijani", "azb_Arab"), + ("North Azerbaijani", "azj_Latn"), + ("Bashkir", "bak_Cyrl"), + ("Bambara", "bam_Latn"), + ("Balinese", "ban_Latn"), + ("Belarusian", "bel_Cyrl"), + ("Bemba", "bem_Latn"), + ("Bengali", "ben_Beng"), + ("Bhojpuri", "bho_Deva"), + ("Banjar (Arabic script)", "bjn_Arab"), + ("Banjar (Latin script)", "bjn_Latn"), + ("Standard Tibetan", "bod_Tibt"), + ("Bosnian", "bos_Latn"), + ("Buginese", "bug_Latn"), + ("Bulgarian", "bul_Cyrl"), + ("Catalan", "cat_Latn"), + ("Cebuano", "ceb_Latn"), + ("Czech", "ces_Latn"), + ("Chokwe", "cjk_Latn"), + ("Central Kurdish", "ckb_Arab"), + ("Crimean Tatar", "crh_Latn"), + ("Welsh", "cym_Latn"), + ("Danish", "dan_Latn"), + ("German", "deu_Latn"), + ("Southwestern Dinka", "dik_Latn"), + ("Dyula", "dyu_Latn"), + ("Dzongkha", "dzo_Tibt"), + ("Greek", "ell_Grek"), + ("English", "eng_Latn"), + ("Esperanto", "epo_Latn"), + ("Estonian", "est_Latn"), + ("Basque", "eus_Latn"), + ("Ewe", "ewe_Latn"), + ("Faroese", "fao_Latn"), + ("Fijian", "fij_Latn"), + ("Finnish", "fin_Latn"), + ("Fon", "fon_Latn"), + ("French", "fra_Latn"), + ("Friulian", "fur_Latn"), + ("Nigerian Fulfulde", "fuv_Latn"), + ("Scottish Gaelic", "gla_Latn"), + ("Irish", "gle_Latn"), + ("Galician", "glg_Latn"), + ("Guarani", "grn_Latn"), + ("Gujarati", "guj_Gujr"), + ("Haitian Creole", "hat_Latn"), + ("Hausa", "hau_Latn"), + ("Hebrew", "heb_Hebr"), + ("Hindi", "hin_Deva"), + ("Chhattisgarhi", "hne_Deva"), + ("Croatian", "hrv_Latn"), + ("Hungarian", "hun_Latn"), + ("Armenian", "hye_Armn"), + ("Igbo", "ibo_Latn"), + ("Ilocano", "ilo_Latn"), + ("Indonesian", "ind_Latn"), + ("Icelandic", "isl_Latn"), + ("Italian", "ita_Latn"), + ("Javanese", "jav_Latn"), + ("Japanese", "jpn_Jpan"), + ("Kabyle", "kab_Latn"), + ("Jingpho", "kac_Latn"), + ("Kamba", "kam_Latn"), + ("Kannada", "kan_Knda"), + ("Kashmiri (Arabic script)", "kas_Arab"), + ("Kashmiri (Devanagari script)", "kas_Deva"), + ("Georgian", "kat_Geor"), + ("Central Kanuri (Arabic script)", "knc_Arab"), + ("Central Kanuri (Latin script)", "knc_Latn"), + ("Kazakh", "kaz_Cyrl"), + ("Kabiyè", "kbp_Latn"), + ("Kabuverdianu", "kea_Latn"), + ("Khmer", "khm_Khmr"), + ("Kikuyu", "kik_Latn"), + ("Kinyarwanda", "kin_Latn"), + ("Kyrgyz", "kir_Cyrl"), + ("Kimbundu", "kmb_Latn"), + ("Northern Kurdish", "kmr_Latn"), + ("Kikongo", "kon_Latn"), + ("Korean", "kor_Hang"), + ("Lao", "lao_Laoo"), + ("Ligurian", "lij_Latn"), + ("Limburgish", "lim_Latn"), + ("Lingala", "lin_Latn"), + ("Lithuanian", "lit_Latn"), + ("Lombard", "lmo_Latn"), + ("Latgalian", "ltg_Latn"), + ("Luxembourgish", "ltz_Latn"), + ("Luba-Kasai", "lua_Latn"), + ("Ganda", "lug_Latn"), + ("Luo", "luo_Latn"), + ("Mizo", "lus_Latn"), + ("Standard Latvian", "lvs_Latn"), + ("Magahi", "mag_Deva"), + ("Maithili", "mai_Deva"), + ("Malayalam", "mal_Mlym"), + ("Marathi", "mar_Deva"), + ("Minangkabau (Arabic script)", "min_Arab"), + ("Minangkabau (Latin script)", "min_Latn"), + ("Macedonian", "mkd_Cyrl"), + ("Plateau Malagasy", "plt_Latn"), + ("Maltese", "mlt_Latn"), + ("Meitei (Bengali script)", "mni_Beng"), + ("Halh Mongolian", "khk_Cyrl"), + ("Mossi", "mos_Latn"), + ("Maori", "mri_Latn"), + ("Burmese", "mya_Mymr"), + ("Dutch", "nld_Latn"), + ("Norwegian Nynorsk", "nno_Latn"), + ("Norwegian Bokmål", "nob_Latn"), + ("Nepali", "npi_Deva"), + ("Northern Sotho", "nso_Latn"), + ("Nuer", "nus_Latn"), + ("Nyanja", "nya_Latn"), + ("Occitan", "oci_Latn"), + ("West Central Oromo", "gaz_Latn"), + ("Odia", "ory_Orya"), + ("Pangasinan", "pag_Latn"), + ("Eastern Panjabi", "pan_Guru"), + ("Papiamento", "pap_Latn"), + ("Western Persian", "pes_Arab"), + ("Polish", "pol_Latn"), + ("Portuguese", "por_Latn"), + ("Dari", "prs_Arab"), + ("Southern Pashto", "pbt_Arab"), + ("Ayacucho Quechua", "quy_Latn"), + ("Romanian", "ron_Latn"), + ("Rundi", "run_Latn"), + ("Russian", "rus_Cyrl"), + ("Sango", "sag_Latn"), + ("Sanskrit", "san_Deva"), + ("Santali", "sat_Olck"), + ("Sicilian", "scn_Latn"), + ("Shan", "shn_Mymr"), + ("Sinhala", "sin_Sinh"), + ("Slovak", "slk_Latn"), + ("Slovenian", "slv_Latn"), + ("Samoan", "smo_Latn"), + ("Shona", "sna_Latn"), + ("Sindhi", "snd_Arab"), + ("Somali", "som_Latn"), + ("Southern Sotho", "sot_Latn"), + ("Spanish", "spa_Latn"), + ("Tosk Albanian", "als_Latn"), + ("Sardinian", "srd_Latn"), + ("Serbian", "srp_Cyrl"), + ("Swati", "ssw_Latn"), + ("Sundanese", "sun_Latn"), + ("Swedish", "swe_Latn"), + ("Swahili", "swh_Latn"), + ("Silesian", "szl_Latn"), + ("Tamil", "tam_Taml"), + ("Tatar", "tat_Cyrl"), + ("Telugu", "tel_Telu"), + ("Tajik", "tgk_Cyrl"), + ("Tagalog", "tgl_Latn"), + ("Thai", "tha_Thai"), + ("Tigrinya", "tir_Ethi"), + ("Tamasheq (Latin script)", "taq_Latn"), + ("Tamasheq (Tifinagh script)", "taq_Tfng"), + ("Tok Pisin", "tpi_Latn"), + ("Tswana", "tsn_Latn"), + ("Tsonga", "tso_Latn"), + ("Turkmen", "tuk_Latn"), + ("Tumbuka", "tum_Latn"), + ("Turkish", "tur_Latn"), + ("Twi", "twi_Latn"), + ("Central Atlas Tamazight", "tzm_Tfng"), + ("Uyghur", "uig_Arab"), + ("Ukrainian", "ukr_Cyrl"), + ("Umbundu", "umb_Latn"), + ("Urdu", "urd_Arab"), + ("Northern Uzbek", "uzn_Latn"), + ("Venetian", "vec_Latn"), + ("Vietnamese", "vie_Latn"), + ("Waray", "war_Latn"), + ("Wolof", "wol_Latn"), + ("Xhosa", "xho_Latn"), + ("Eastern Yiddish", "ydd_Hebr"), + ("Yoruba", "yor_Latn"), + ("Yue Chinese", "yue_Hant"), + ("Chinese (Simplified)", "zho_Hans"), + ("Chinese (Traditional)", "zho_Hant"), + ("Standard Malay", "zsm_Latn"), + ("Zulu", "zul_Latn"), +) diff --git a/dl_translate/_translation_model.py b/dl_translate/_translation_model.py index 493b9c8..1e65b16 100644 --- a/dl_translate/_translation_model.py +++ b/dl_translate/_translation_model.py @@ -7,6 +7,7 @@ from tqdm.auto import tqdm from . import utils +from .utils import _infer_model_family, _infer_model_or_path def _select_device(device_selection): @@ -23,28 +24,37 @@ def _select_device(device_selection): return device -def _resolve_lang_codes(source: str, target: str, model_family: str): +def _resolve_lang_codes(lang: str, name: str, model_family: str): def error_message(variable, value): - return f'Your {variable}="{value}" is not valid. Please run `print(mt.available_languages())` to see which languages are available.' + return f'Your {variable}="{value}" is not valid. Please run `print(mt.available_languages())` to see which languages are available. Make sure you are using the correct capital letters.' # If can't find in the lang -> code mapping, assumes it's already a code. lang_code_map = utils.get_lang_code_map(model_family) - source = lang_code_map.get(source.capitalize(), source) - target = lang_code_map.get(target.capitalize(), target) + if lang in lang_code_map: + code = lang_code_map[lang] + elif lang.capitalize() in lang_code_map: + code = lang_code_map[lang.capitalize()] + else: + lang_upper = lang.upper() + lang_code_map_upper = {k.upper(): v for k, v in lang_code_map.items()} + + if lang_upper in lang_code_map_upper: + code = lang_code_map_upper[lang_upper] + else: + code = lang # If the code is not valid, raises an error - if source not in utils.available_codes(model_family): - raise ValueError(error_message("source", source)) - if target not in utils.available_codes(model_family): - raise ValueError(error_message("target", target)) + if code not in utils.available_codes(model_family): + raise ValueError(error_message(name, code)) - return source, target + return code def _resolve_tokenizer(model_family): di = { "mbart50": transformers.MBart50TokenizerFast, "m2m100": transformers.M2M100Tokenizer, + "nllb200": transformers.AutoTokenizer, } if model_family in di: return di[model_family] @@ -57,6 +67,7 @@ def _resolve_transformers_model(model_family): di = { "mbart50": transformers.MBartForConditionalGeneration, "m2m100": transformers.M2M100ForConditionalGeneration, + "nllb200": transformers.AutoModelForSeq2SeqLM, } if model_family in di: return di[model_family] @@ -65,31 +76,6 @@ def _resolve_transformers_model(model_family): raise ValueError(error_msg) -def _infer_model_family(model_or_path): - di = { - "facebook/mbart-large-50-many-to-many-mmt": "mbart50", - "facebook/m2m100_418M": "m2m100", - "facebook/m2m100_1.2B": "m2m100", - } - - if model_or_path in di: - return di[model_or_path] - else: - error_msg = f'Unable to infer the model_family from "{model_or_path}". Try explicitly setting the value of model_family to "mbart50" or "m2m100".' - raise ValueError(error_msg) - - -def _infer_model_or_path(model_or_path): - di = { - "mbart50": "facebook/mbart-large-50-many-to-many-mmt", - "m2m100": "facebook/m2m100_418M", - "m2m100-small": "facebook/m2m100_418M", - "m2m100-medium": "facebook/m2m100_1.2B", - } - - return di.get(model_or_path, model_or_path) - - class TranslationModel: def __init__( self, @@ -171,7 +157,9 @@ def translate( if generation_options is None: generation_options = {} - source, target = _resolve_lang_codes(source, target, self.model_family) + source = _resolve_lang_codes(source, "source", self.model_family) + target = _resolve_lang_codes(target, "target", self.model_family) + self._tokenizer.src_lang = source original_text_type = type(text) @@ -184,6 +172,7 @@ def translate( generation_options.setdefault( "forced_bos_token_id", self._tokenizer.lang_code_to_id[target] ) + generation_options.setdefault("max_new_tokens", 512) data_loader = torch.utils.data.DataLoader(text, batch_size=batch_size) output_text = [] diff --git a/dl_translate/lang/__init__.py b/dl_translate/lang/__init__.py index 1254611..315e00b 100644 --- a/dl_translate/lang/__init__.py +++ b/dl_translate/lang/__init__.py @@ -1,2 +1,2 @@ from .m2m100 import * -from . import m2m100, mbart50 +from . import m2m100, mbart50, nllb200 diff --git a/dl_translate/lang/nllb200.py b/dl_translate/lang/nllb200.py new file mode 100644 index 0000000..13f29f4 --- /dev/null +++ b/dl_translate/lang/nllb200.py @@ -0,0 +1,205 @@ +# Auto-generated. Do not modify, use scripts/generate_langs.py instead. +ACEHNESE_ARABIC_SCRIPT = "Acehnese (Arabic script)" +ACEHNESE_LATIN_SCRIPT = "Acehnese (Latin script)" +MESOPOTAMIAN_ARABIC = "Mesopotamian Arabic" +TAIZZI_ADENI_ARABIC = "Ta'izzi-Adeni Arabic" +TUNISIAN_ARABIC = "Tunisian Arabic" +AFRIKAANS = "Afrikaans" +SOUTH_LEVANTINE_ARABIC = "South Levantine Arabic" +AKAN = "Akan" +AMHARIC = "Amharic" +NORTH_LEVANTINE_ARABIC = "North Levantine Arabic" +MODERN_STANDARD_ARABIC = "Modern Standard Arabic" +MODERN_STANDARD_ARABIC_ROMANIZED = "Modern Standard Arabic (Romanized)" +NAJDI_ARABIC = "Najdi Arabic" +MOROCCAN_ARABIC = "Moroccan Arabic" +EGYPTIAN_ARABIC = "Egyptian Arabic" +ASSAMESE = "Assamese" +ASTURIAN = "Asturian" +AWADHI = "Awadhi" +CENTRAL_AYMARA = "Central Aymara" +SOUTH_AZERBAIJANI = "South Azerbaijani" +NORTH_AZERBAIJANI = "North Azerbaijani" +BASHKIR = "Bashkir" +BAMBARA = "Bambara" +BALINESE = "Balinese" +BELARUSIAN = "Belarusian" +BEMBA = "Bemba" +BENGALI = "Bengali" +BHOJPURI = "Bhojpuri" +BANJAR_ARABIC_SCRIPT = "Banjar (Arabic script)" +BANJAR_LATIN_SCRIPT = "Banjar (Latin script)" +STANDARD_TIBETAN = "Standard Tibetan" +BOSNIAN = "Bosnian" +BUGINESE = "Buginese" +BULGARIAN = "Bulgarian" +CATALAN = "Catalan" +CEBUANO = "Cebuano" +CZECH = "Czech" +CHOKWE = "Chokwe" +CENTRAL_KURDISH = "Central Kurdish" +CRIMEAN_TATAR = "Crimean Tatar" +WELSH = "Welsh" +DANISH = "Danish" +GERMAN = "German" +SOUTHWESTERN_DINKA = "Southwestern Dinka" +DYULA = "Dyula" +DZONGKHA = "Dzongkha" +GREEK = "Greek" +ENGLISH = "English" +ESPERANTO = "Esperanto" +ESTONIAN = "Estonian" +BASQUE = "Basque" +EWE = "Ewe" +FAROESE = "Faroese" +FIJIAN = "Fijian" +FINNISH = "Finnish" +FON = "Fon" +FRENCH = "French" +FRIULIAN = "Friulian" +NIGERIAN_FULFULDE = "Nigerian Fulfulde" +SCOTTISH_GAELIC = "Scottish Gaelic" +IRISH = "Irish" +GALICIAN = "Galician" +GUARANI = "Guarani" +GUJARATI = "Gujarati" +HAITIAN_CREOLE = "Haitian Creole" +HAUSA = "Hausa" +HEBREW = "Hebrew" +HINDI = "Hindi" +CHHATTISGARHI = "Chhattisgarhi" +CROATIAN = "Croatian" +HUNGARIAN = "Hungarian" +ARMENIAN = "Armenian" +IGBO = "Igbo" +ILOCANO = "Ilocano" +INDONESIAN = "Indonesian" +ICELANDIC = "Icelandic" +ITALIAN = "Italian" +JAVANESE = "Javanese" +JAPANESE = "Japanese" +KABYLE = "Kabyle" +JINGPHO = "Jingpho" +KAMBA = "Kamba" +KANNADA = "Kannada" +KASHMIRI_ARABIC_SCRIPT = "Kashmiri (Arabic script)" +KASHMIRI_DEVANAGARI_SCRIPT = "Kashmiri (Devanagari script)" +GEORGIAN = "Georgian" +CENTRAL_KANURI_ARABIC_SCRIPT = "Central Kanuri (Arabic script)" +CENTRAL_KANURI_LATIN_SCRIPT = "Central Kanuri (Latin script)" +KAZAKH = "Kazakh" +KABIYÈ = "Kabiyè" +KABUVERDIANU = "Kabuverdianu" +KHMER = "Khmer" +KIKUYU = "Kikuyu" +KINYARWANDA = "Kinyarwanda" +KYRGYZ = "Kyrgyz" +KIMBUNDU = "Kimbundu" +NORTHERN_KURDISH = "Northern Kurdish" +KIKONGO = "Kikongo" +KOREAN = "Korean" +LAO = "Lao" +LIGURIAN = "Ligurian" +LIMBURGISH = "Limburgish" +LINGALA = "Lingala" +LITHUANIAN = "Lithuanian" +LOMBARD = "Lombard" +LATGALIAN = "Latgalian" +LUXEMBOURGISH = "Luxembourgish" +LUBA_KASAI = "Luba-Kasai" +GANDA = "Ganda" +LUO = "Luo" +MIZO = "Mizo" +STANDARD_LATVIAN = "Standard Latvian" +MAGAHI = "Magahi" +MAITHILI = "Maithili" +MALAYALAM = "Malayalam" +MARATHI = "Marathi" +MINANGKABAU_ARABIC_SCRIPT = "Minangkabau (Arabic script)" +MINANGKABAU_LATIN_SCRIPT = "Minangkabau (Latin script)" +MACEDONIAN = "Macedonian" +PLATEAU_MALAGASY = "Plateau Malagasy" +MALTESE = "Maltese" +MEITEI_BENGALI_SCRIPT = "Meitei (Bengali script)" +HALH_MONGOLIAN = "Halh Mongolian" +MOSSI = "Mossi" +MAORI = "Maori" +BURMESE = "Burmese" +DUTCH = "Dutch" +NORWEGIAN_NYNORSK = "Norwegian Nynorsk" +NORWEGIAN_BOKMÅL = "Norwegian Bokmål" +NEPALI = "Nepali" +NORTHERN_SOTHO = "Northern Sotho" +NUER = "Nuer" +NYANJA = "Nyanja" +OCCITAN = "Occitan" +WEST_CENTRAL_OROMO = "West Central Oromo" +ODIA = "Odia" +PANGASINAN = "Pangasinan" +EASTERN_PANJABI = "Eastern Panjabi" +PAPIAMENTO = "Papiamento" +WESTERN_PERSIAN = "Western Persian" +POLISH = "Polish" +PORTUGUESE = "Portuguese" +DARI = "Dari" +SOUTHERN_PASHTO = "Southern Pashto" +AYACUCHO_QUECHUA = "Ayacucho Quechua" +ROMANIAN = "Romanian" +RUNDI = "Rundi" +RUSSIAN = "Russian" +SANGO = "Sango" +SANSKRIT = "Sanskrit" +SANTALI = "Santali" +SICILIAN = "Sicilian" +SHAN = "Shan" +SINHALA = "Sinhala" +SLOVAK = "Slovak" +SLOVENIAN = "Slovenian" +SAMOAN = "Samoan" +SHONA = "Shona" +SINDHI = "Sindhi" +SOMALI = "Somali" +SOUTHERN_SOTHO = "Southern Sotho" +SPANISH = "Spanish" +TOSK_ALBANIAN = "Tosk Albanian" +SARDINIAN = "Sardinian" +SERBIAN = "Serbian" +SWATI = "Swati" +SUNDANESE = "Sundanese" +SWEDISH = "Swedish" +SWAHILI = "Swahili" +SILESIAN = "Silesian" +TAMIL = "Tamil" +TATAR = "Tatar" +TELUGU = "Telugu" +TAJIK = "Tajik" +TAGALOG = "Tagalog" +THAI = "Thai" +TIGRINYA = "Tigrinya" +TAMASHEQ_LATIN_SCRIPT = "Tamasheq (Latin script)" +TAMASHEQ_TIFINAGH_SCRIPT = "Tamasheq (Tifinagh script)" +TOK_PISIN = "Tok Pisin" +TSWANA = "Tswana" +TSONGA = "Tsonga" +TURKMEN = "Turkmen" +TUMBUKA = "Tumbuka" +TURKISH = "Turkish" +TWI = "Twi" +CENTRAL_ATLAS_TAMAZIGHT = "Central Atlas Tamazight" +UYGHUR = "Uyghur" +UKRAINIAN = "Ukrainian" +UMBUNDU = "Umbundu" +URDU = "Urdu" +NORTHERN_UZBEK = "Northern Uzbek" +VENETIAN = "Venetian" +VIETNAMESE = "Vietnamese" +WARAY = "Waray" +WOLOF = "Wolof" +XHOSA = "Xhosa" +EASTERN_YIDDISH = "Eastern Yiddish" +YORUBA = "Yoruba" +YUE_CHINESE = "Yue Chinese" +CHINESE_SIMPLIFIED = "Chinese (Simplified)" +CHINESE_TRADITIONAL = "Chinese (Traditional)" +STANDARD_MALAY = "Standard Malay" +ZULU = "Zulu" diff --git a/dl_translate/utils.py b/dl_translate/utils.py index e4af390..acb26df 100644 --- a/dl_translate/utils.py +++ b/dl_translate/utils.py @@ -1,6 +1,40 @@ from typing import Dict, List -from ._pairs import _PAIRS_MBART50, _PAIRS_M2M100 +from ._pairs import _PAIRS_MBART50, _PAIRS_M2M100, _PAIRS_NLLB200 + + +def _infer_model_family(model_or_path): + di = { + "facebook/mbart-large-50-many-to-many-mmt": "mbart50", + "facebook/m2m100_418M": "m2m100", + "facebook/m2m100_1.2B": "m2m100", + "facebook/nllb-200-distilled-600M": "nllb200", + "facebook/nllb-200-distilled-1.3B": "nllb200", + "facebook/nllb-200-1.3B": "nllb200", + "facebook/nllb-200-3.3B": "nllb200", + } + + if model_or_path in di: + return di[model_or_path] + else: + error_msg = f'Unable to infer the model_family from "{model_or_path}". Try explicitly setting the value of model_family to "mbart50" or "m2m100".' + raise ValueError(error_msg) + + +def _infer_model_or_path(model_or_path): + di = { + "mbart50": "facebook/mbart-large-50-many-to-many-mmt", + "m2m100": "facebook/m2m100_418M", + "m2m100-small": "facebook/m2m100_418M", + "m2m100-medium": "facebook/m2m100_1.2B", + "nllb200": "facebook/nllb-200-distilled-600M", + "nllb200-small": "facebook/nllb-200-distilled-600M", + "nllb200-medium": "facebook/nllb-200-distilled-1.3B", + "nllb200-medium-regular": "facebook/nllb-200-1.3B", + "nllb200-large": "facebook/nllb-200-3.3B", + } + + return di.get(model_or_path, model_or_path) def _weights2pairs(): @@ -13,6 +47,16 @@ def _weights2pairs(): "m2m100_1.2B": _PAIRS_M2M100, "facebook/m2m100_418M": _PAIRS_M2M100, "facebook/m2m100_1.2B": _PAIRS_M2M100, + "nllb200": _PAIRS_NLLB200, + "nllb-200-distilled": _PAIRS_NLLB200, + "nllb-200-distilled-600M": _PAIRS_NLLB200, + "nllb-200-distilled-1.3B": _PAIRS_NLLB200, + "nllb-200-1.3B": _PAIRS_NLLB200, + "nllb-200-3.3B": _PAIRS_NLLB200, + "facebook/nllb-200-distilled-600M": _PAIRS_NLLB200, + "facebook/nllb-200-distilled-1.3B": _PAIRS_NLLB200, + "facebook/nllb-200-1.3B": _PAIRS_NLLB200, + "facebook/nllb-200-3.3B": _PAIRS_NLLB200, } @@ -38,7 +82,7 @@ def _dict_from_weights(weights: str) -> dict: raise ValueError(error_message) -def get_lang_code_map(weights: str = "mbart50") -> Dict[str, str]: +def get_lang_code_map(weights: str = "m2m100") -> Dict[str, str]: """ *Get a dictionary mapping a language -> code for a given model. The code will depend on the model you choose.* @@ -48,7 +92,7 @@ def get_lang_code_map(weights: str = "mbart50") -> Dict[str, str]: return _dict_from_weights(weights)["pairs"] -def available_languages(weights: str = "mbart50") -> List[str]: +def available_languages(weights: str = "m2m100") -> List[str]: """ *Get all the languages available for a given model.* @@ -58,7 +102,7 @@ def available_languages(weights: str = "mbart50") -> List[str]: return _dict_from_weights(weights)["langs"] -def available_codes(weights: str = "mbart50") -> List[str]: +def available_codes(weights: str = "m2m100") -> List[str]: """ *Get all the codes available for a given model. The code format will depend on the model you select.* diff --git a/docs/available_languages.md b/docs/available_languages.md index 3490157..fa4f7c0 100644 --- a/docs/available_languages.md +++ b/docs/available_languages.md @@ -4,170 +4,384 @@ This page gives all the languages available for each model family. ## MBart 50 -- Arabic (ar_AR) -- Czech (cs_CZ) -- German (de_DE) -- English (en_XX) -- Spanish (es_XX) -- Estonian (et_EE) -- Finnish (fi_FI) -- French (fr_XX) -- Gujarati (gu_IN) -- Hindi (hi_IN) -- Italian (it_IT) -- Japanese (ja_XX) -- Kazakh (kk_KZ) -- Korean (ko_KR) -- Lithuanian (lt_LT) -- Latvian (lv_LV) -- Burmese (my_MM) -- Nepali (ne_NP) -- Dutch (nl_XX) -- Romanian (ro_RO) -- Russian (ru_RU) -- Sinhala (si_LK) -- Turkish (tr_TR) -- Vietnamese (vi_VN) -- Chinese (zh_CN) -- Afrikaans (af_ZA) -- Azerbaijani (az_AZ) -- Bengali (bn_IN) -- Persian (fa_IR) -- Hebrew (he_IL) -- Croatian (hr_HR) -- Indonesian (id_ID) -- Georgian (ka_GE) -- Khmer (km_KH) -- Macedonian (mk_MK) -- Malayalam (ml_IN) -- Mongolian (mn_MN) -- Marathi (mr_IN) -- Polish (pl_PL) -- Pashto (ps_AF) -- Portuguese (pt_XX) -- Swedish (sv_SE) -- Swahili (sw_KE) -- Tamil (ta_IN) -- Telugu (te_IN) -- Thai (th_TH) -- Tagalog (tl_XX) -- Ukrainian (uk_UA) -- Urdu (ur_PK) -- Xhosa (xh_ZA) -- Galician (gl_ES) -- Slovene (sl_SI) +| Language Name | Code | +| --- | --- | +| Arabic | ar_AR | +| Czech | cs_CZ | +| German | de_DE | +| English | en_XX | +| Spanish | es_XX | +| Estonian | et_EE | +| Finnish | fi_FI | +| French | fr_XX | +| Gujarati | gu_IN | +| Hindi | hi_IN | +| Italian | it_IT | +| Japanese | ja_XX | +| Kazakh | kk_KZ | +| Korean | ko_KR | +| Lithuanian | lt_LT | +| Latvian | lv_LV | +| Burmese | my_MM | +| Nepali | ne_NP | +| Dutch | nl_XX | +| Romanian | ro_RO | +| Russian | ru_RU | +| Sinhala | si_LK | +| Turkish | tr_TR | +| Vietnamese | vi_VN | +| Chinese | zh_CN | +| Afrikaans | af_ZA | +| Azerbaijani | az_AZ | +| Bengali | bn_IN | +| Persian | fa_IR | +| Hebrew | he_IL | +| Croatian | hr_HR | +| Indonesian | id_ID | +| Georgian | ka_GE | +| Khmer | km_KH | +| Macedonian | mk_MK | +| Malayalam | ml_IN | +| Mongolian | mn_MN | +| Marathi | mr_IN | +| Polish | pl_PL | +| Pashto | ps_AF | +| Portuguese | pt_XX | +| Swedish | sv_SE | +| Swahili | sw_KE | +| Tamil | ta_IN | +| Telugu | te_IN | +| Thai | th_TH | +| Tagalog | tl_XX | +| Ukrainian | uk_UA | +| Urdu | ur_PK | +| Xhosa | xh_ZA | +| Galician | gl_ES | +| Slovene | sl_SI | ## M2M-100 -- Afrikaans (af) -- Amharic (am) -- Arabic (ar) -- Asturian (ast) -- Azerbaijani (az) -- Bashkir (ba) -- Belarusian (be) -- Bulgarian (bg) -- Bengali (bn) -- Breton (br) -- Bosnian (bs) -- Catalan (ca) -- Valencian (ca) -- Cebuano (ceb) -- Czech (cs) -- Welsh (cy) -- Danish (da) -- German (de) -- Greek (el) -- English (en) -- Spanish (es) -- Estonian (et) -- Persian (fa) -- Fulah (ff) -- Finnish (fi) -- French (fr) -- Western Frisian (fy) -- Irish (ga) -- Gaelic (gd) -- Scottish Gaelic (gd) -- Galician (gl) -- Gujarati (gu) -- Hausa (ha) -- Hebrew (he) -- Hindi (hi) -- Croatian (hr) -- Haitian (ht) -- Haitian Creole (ht) -- Hungarian (hu) -- Armenian (hy) -- Indonesian (id) -- Igbo (ig) -- Iloko (ilo) -- Icelandic (is) -- Italian (it) -- Japanese (ja) -- Javanese (jv) -- Georgian (ka) -- Kazakh (kk) -- Khmer (km) -- Central Khmer (km) -- Kannada (kn) -- Korean (ko) -- Luxembourgish (lb) -- Letzeburgesch (lb) -- Ganda (lg) -- Lingala (ln) -- Lao (lo) -- Lithuanian (lt) -- Latvian (lv) -- Malagasy (mg) -- Macedonian (mk) -- Malayalam (ml) -- Mongolian (mn) -- Marathi (mr) -- Malay (ms) -- Burmese (my) -- Nepali (ne) -- Dutch (nl) -- Flemish (nl) -- Norwegian (no) -- Northern Sotho (ns) -- Occitan (oc) -- Oriya (or) -- Panjabi (pa) -- Punjabi (pa) -- Polish (pl) -- Pushto (ps) -- Pashto (ps) -- Portuguese (pt) -- Romanian (ro) -- Moldavian (ro) -- Moldovan (ro) -- Russian (ru) -- Sindhi (sd) -- Sinhala (si) -- Sinhalese (si) -- Slovak (sk) -- Slovenian (sl) -- Somali (so) -- Albanian (sq) -- Serbian (sr) -- Swati (ss) -- Sundanese (su) -- Swedish (sv) -- Swahili (sw) -- Tamil (ta) -- Thai (th) -- Tagalog (tl) -- Tswana (tn) -- Turkish (tr) -- Ukrainian (uk) -- Urdu (ur) -- Uzbek (uz) -- Vietnamese (vi) -- Wolof (wo) -- Xhosa (xh) -- Yiddish (yi) -- Yoruba (yo) -- Chinese (zh) -- Zulu (zu) +| Language Name | Code | +| --- | --- | +| Afrikaans | af | +| Amharic | am | +| Arabic | ar | +| Asturian | ast | +| Azerbaijani | az | +| Bashkir | ba | +| Belarusian | be | +| Bulgarian | bg | +| Bengali | bn | +| Breton | br | +| Bosnian | bs | +| Catalan | ca | +| Valencian | ca | +| Cebuano | ceb | +| Czech | cs | +| Welsh | cy | +| Danish | da | +| German | de | +| Greek | el | +| English | en | +| Spanish | es | +| Estonian | et | +| Persian | fa | +| Fulah | ff | +| Finnish | fi | +| French | fr | +| Western Frisian | fy | +| Irish | ga | +| Gaelic | gd | +| Scottish Gaelic | gd | +| Galician | gl | +| Gujarati | gu | +| Hausa | ha | +| Hebrew | he | +| Hindi | hi | +| Croatian | hr | +| Haitian | ht | +| Haitian Creole | ht | +| Hungarian | hu | +| Armenian | hy | +| Indonesian | id | +| Igbo | ig | +| Iloko | ilo | +| Icelandic | is | +| Italian | it | +| Japanese | ja | +| Javanese | jv | +| Georgian | ka | +| Kazakh | kk | +| Khmer | km | +| Central Khmer | km | +| Kannada | kn | +| Korean | ko | +| Luxembourgish | lb | +| Letzeburgesch | lb | +| Ganda | lg | +| Lingala | ln | +| Lao | lo | +| Lithuanian | lt | +| Latvian | lv | +| Malagasy | mg | +| Macedonian | mk | +| Malayalam | ml | +| Mongolian | mn | +| Marathi | mr | +| Malay | ms | +| Burmese | my | +| Nepali | ne | +| Dutch | nl | +| Flemish | nl | +| Norwegian | no | +| Northern Sotho | ns | +| Occitan | oc | +| Oriya | or | +| Panjabi | pa | +| Punjabi | pa | +| Polish | pl | +| Pushto | ps | +| Pashto | ps | +| Portuguese | pt | +| Romanian | ro | +| Moldavian | ro | +| Moldovan | ro | +| Russian | ru | +| Sindhi | sd | +| Sinhala | si | +| Sinhalese | si | +| Slovak | sk | +| Slovenian | sl | +| Somali | so | +| Albanian | sq | +| Serbian | sr | +| Swati | ss | +| Sundanese | su | +| Swedish | sv | +| Swahili | sw | +| Tamil | ta | +| Thai | th | +| Tagalog | tl | +| Tswana | tn | +| Turkish | tr | +| Ukrainian | uk | +| Urdu | ur | +| Uzbek | uz | +| Vietnamese | vi | +| Wolof | wo | +| Xhosa | xh | +| Yiddish | yi | +| Yoruba | yo | +| Chinese | zh | +| Zulu | zu | + + +## NLLB-200 + +| Language Name | Code | +| --- | --- | +| Acehnese (Arabic script) | ace_Arab | +| Acehnese (Latin script) | ace_Latn | +| Mesopotamian Arabic | acm_Arab | +| Ta'izzi-Adeni Arabic | acq_Arab | +| Tunisian Arabic | aeb_Arab | +| Afrikaans | afr_Latn | +| South Levantine Arabic | ajp_Arab | +| Akan | aka_Latn | +| Amharic | amh_Ethi | +| North Levantine Arabic | apc_Arab | +| Modern Standard Arabic | arb_Arab | +| Modern Standard Arabic (Romanized) | arb_Latn | +| Najdi Arabic | ars_Arab | +| Moroccan Arabic | ary_Arab | +| Egyptian Arabic | arz_Arab | +| Assamese | asm_Beng | +| Asturian | ast_Latn | +| Awadhi | awa_Deva | +| Central Aymara | ayr_Latn | +| South Azerbaijani | azb_Arab | +| North Azerbaijani | azj_Latn | +| Bashkir | bak_Cyrl | +| Bambara | bam_Latn | +| Balinese | ban_Latn | +| Belarusian | bel_Cyrl | +| Bemba | bem_Latn | +| Bengali | ben_Beng | +| Bhojpuri | bho_Deva | +| Banjar (Arabic script) | bjn_Arab | +| Banjar (Latin script) | bjn_Latn | +| Standard Tibetan | bod_Tibt | +| Bosnian | bos_Latn | +| Buginese | bug_Latn | +| Bulgarian | bul_Cyrl | +| Catalan | cat_Latn | +| Cebuano | ceb_Latn | +| Czech | ces_Latn | +| Chokwe | cjk_Latn | +| Central Kurdish | ckb_Arab | +| Crimean Tatar | crh_Latn | +| Welsh | cym_Latn | +| Danish | dan_Latn | +| German | deu_Latn | +| Southwestern Dinka | dik_Latn | +| Dyula | dyu_Latn | +| Dzongkha | dzo_Tibt | +| Greek | ell_Grek | +| English | eng_Latn | +| Esperanto | epo_Latn | +| Estonian | est_Latn | +| Basque | eus_Latn | +| Ewe | ewe_Latn | +| Faroese | fao_Latn | +| Fijian | fij_Latn | +| Finnish | fin_Latn | +| Fon | fon_Latn | +| French | fra_Latn | +| Friulian | fur_Latn | +| Nigerian Fulfulde | fuv_Latn | +| Scottish Gaelic | gla_Latn | +| Irish | gle_Latn | +| Galician | glg_Latn | +| Guarani | grn_Latn | +| Gujarati | guj_Gujr | +| Haitian Creole | hat_Latn | +| Hausa | hau_Latn | +| Hebrew | heb_Hebr | +| Hindi | hin_Deva | +| Chhattisgarhi | hne_Deva | +| Croatian | hrv_Latn | +| Hungarian | hun_Latn | +| Armenian | hye_Armn | +| Igbo | ibo_Latn | +| Ilocano | ilo_Latn | +| Indonesian | ind_Latn | +| Icelandic | isl_Latn | +| Italian | ita_Latn | +| Javanese | jav_Latn | +| Japanese | jpn_Jpan | +| Kabyle | kab_Latn | +| Jingpho | kac_Latn | +| Kamba | kam_Latn | +| Kannada | kan_Knda | +| Kashmiri (Arabic script) | kas_Arab | +| Kashmiri (Devanagari script) | kas_Deva | +| Georgian | kat_Geor | +| Central Kanuri (Arabic script) | knc_Arab | +| Central Kanuri (Latin script) | knc_Latn | +| Kazakh | kaz_Cyrl | +| Kabiyè | kbp_Latn | +| Kabuverdianu | kea_Latn | +| Khmer | khm_Khmr | +| Kikuyu | kik_Latn | +| Kinyarwanda | kin_Latn | +| Kyrgyz | kir_Cyrl | +| Kimbundu | kmb_Latn | +| Northern Kurdish | kmr_Latn | +| Kikongo | kon_Latn | +| Korean | kor_Hang | +| Lao | lao_Laoo | +| Ligurian | lij_Latn | +| Limburgish | lim_Latn | +| Lingala | lin_Latn | +| Lithuanian | lit_Latn | +| Lombard | lmo_Latn | +| Latgalian | ltg_Latn | +| Luxembourgish | ltz_Latn | +| Luba-Kasai | lua_Latn | +| Ganda | lug_Latn | +| Luo | luo_Latn | +| Mizo | lus_Latn | +| Standard Latvian | lvs_Latn | +| Magahi | mag_Deva | +| Maithili | mai_Deva | +| Malayalam | mal_Mlym | +| Marathi | mar_Deva | +| Minangkabau (Arabic script) | min_Arab | +| Minangkabau (Latin script) | min_Latn | +| Macedonian | mkd_Cyrl | +| Plateau Malagasy | plt_Latn | +| Maltese | mlt_Latn | +| Meitei (Bengali script) | mni_Beng | +| Halh Mongolian | khk_Cyrl | +| Mossi | mos_Latn | +| Maori | mri_Latn | +| Burmese | mya_Mymr | +| Dutch | nld_Latn | +| Norwegian Nynorsk | nno_Latn | +| Norwegian Bokmål | nob_Latn | +| Nepali | npi_Deva | +| Northern Sotho | nso_Latn | +| Nuer | nus_Latn | +| Nyanja | nya_Latn | +| Occitan | oci_Latn | +| West Central Oromo | gaz_Latn | +| Odia | ory_Orya | +| Pangasinan | pag_Latn | +| Eastern Panjabi | pan_Guru | +| Papiamento | pap_Latn | +| Western Persian | pes_Arab | +| Polish | pol_Latn | +| Portuguese | por_Latn | +| Dari | prs_Arab | +| Southern Pashto | pbt_Arab | +| Ayacucho Quechua | quy_Latn | +| Romanian | ron_Latn | +| Rundi | run_Latn | +| Russian | rus_Cyrl | +| Sango | sag_Latn | +| Sanskrit | san_Deva | +| Santali | sat_Olck | +| Sicilian | scn_Latn | +| Shan | shn_Mymr | +| Sinhala | sin_Sinh | +| Slovak | slk_Latn | +| Slovenian | slv_Latn | +| Samoan | smo_Latn | +| Shona | sna_Latn | +| Sindhi | snd_Arab | +| Somali | som_Latn | +| Southern Sotho | sot_Latn | +| Spanish | spa_Latn | +| Tosk Albanian | als_Latn | +| Sardinian | srd_Latn | +| Serbian | srp_Cyrl | +| Swati | ssw_Latn | +| Sundanese | sun_Latn | +| Swedish | swe_Latn | +| Swahili | swh_Latn | +| Silesian | szl_Latn | +| Tamil | tam_Taml | +| Tatar | tat_Cyrl | +| Telugu | tel_Telu | +| Tajik | tgk_Cyrl | +| Tagalog | tgl_Latn | +| Thai | tha_Thai | +| Tigrinya | tir_Ethi | +| Tamasheq (Latin script) | taq_Latn | +| Tamasheq (Tifinagh script) | taq_Tfng | +| Tok Pisin | tpi_Latn | +| Tswana | tsn_Latn | +| Tsonga | tso_Latn | +| Turkmen | tuk_Latn | +| Tumbuka | tum_Latn | +| Turkish | tur_Latn | +| Twi | twi_Latn | +| Central Atlas Tamazight | tzm_Tfng | +| Uyghur | uig_Arab | +| Ukrainian | ukr_Cyrl | +| Umbundu | umb_Latn | +| Urdu | urd_Arab | +| Northern Uzbek | uzn_Latn | +| Venetian | vec_Latn | +| Vietnamese | vie_Latn | +| Waray | war_Latn | +| Wolof | wol_Latn | +| Xhosa | xho_Latn | +| Eastern Yiddish | ydd_Hebr | +| Yoruba | yor_Latn | +| Yue Chinese | yue_Hant | +| Chinese (Simplified) | zho_Hans | +| Chinese (Traditional) | zho_Hant | +| Standard Malay | zsm_Latn | +| Zulu | zul_Latn | diff --git a/docs/index.md b/docs/index.md index 70a023d..7a1bb67 100644 --- a/docs/index.md +++ b/docs/index.md @@ -3,7 +3,7 @@ Quick links: 💻 [GitHub Repository](https://github.com/xhlulu/dl-translate)
-📚 [Documentation](https://xhluca.github.io/dl-translate) / [Readthedocs](https://dl-translate.readthedocs.io)
+📚 [Documentation](https://xhluca.github.io/dl-translate)
🐍 [PyPi project](https://pypi.org/project/dl-translate/)
🧪 [Colab Demo](https://colab.research.google.com/github/xhlulu/dl-translate/blob/main/demos/colab_demo.ipynb) / [Kaggle Demo](https://www.kaggle.com/xhlulu/dl-translate-demo/) @@ -53,43 +53,41 @@ mt = dlt.TranslationModel(device="gpu") # Force you to use a GPU mt = dlt.TranslationModel(device="cuda:2") # Use the 3rd GPU available ``` -### Changing the model you are loading +### Choosing a different model -Two model families are available at the moment: [m2m100](https://huggingface.co/transformers/model_doc/m2m_100.html) and [mBART-50 Large](https://huggingface.co/transformers/master/model_doc/mbart.html), which respective allow translation across over 100 languages and 50 languages. By default, the model will select `m2m100`, but you can also explicitly choose the model by specifying the shorthand (`"m2m100"` or `"mbart50"`) or the full repository name (e.g. `"facebook/m2m100_418M"`). For example: +By default, the `m2m100` model will be used. However, there are a few options: +* [mBART-50 Large](https://huggingface.co/transformers/master/model_doc/mbart.html): Allows translations across 50 languages. +* [m2m100](https://huggingface.co/transformers/model_doc/m2m_100.html): Allows translations across 100 languages. +* [nllb-200](https://huggingface.co/docs/transformers/model_doc/nllb) (New in v0.3): Allows translations across 200 languages, and is faster than m2m100 (On RTX A6000, we can see speed up of 3x). + +Here's an example: ```python -# The following ways are equivalent -mt = dlt.TranslationModel("m2m100") # Default -mt = dlt.TranslationModel("facebook/m2m100_418M") +# The default approval +mt = dlt.TranslationModel("m2m100") # Shorthand +mt = dlt.TranslationModel("facebook/m2m100_418M") # Huggingface repo -# The following ways are equivalent +# If you want to use mBART-50 Large mt = dlt.TranslationModel("mbart50") mt = dlt.TranslationModel("facebook/mbart-large-50-many-to-many-mmt") + +# Or NLLB-200 (faster and has 200 languages) +mt = dlt.TranslationModel("nllb200") +mt = dlt.TranslationModel("facebook/nllb-200-distilled-600M") ``` Note that the language code will change depending on the model family. To find out the correct language codes, please read the doc page on available languages or run `mt.available_codes()`. -### Loading from a path - -By default, `dlt.TranslationModel` will download the model from the [huggingface repo](https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt) and cache it. If your model is stored locally, you can also directly load that model, but in that case you will need to specify the model family (e.g. `"mbart50"` and `"m2m100"`). - +By default, `dlt.TranslationModel` will download the model from the huggingface repo for [mbart50](https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt), [m2m100](https://huggingface.co/facebook/m2m100_418M), or [nllb200](https://huggingface.co/facebook/nllb-200-distilled-600M) and cache it. It's possible to load the model from a path or a model with a similar format, but you will need to specify the `model_family`: ```python mt = dlt.TranslationModel("/path/to/model/directory/", model_family="mbart50") -# or -mt = dlt.TranslationModel("/path/to/model/directory/", model_family="m2m100") -``` -Make sure that your tokenizer is also stored in the same directory if you use this approach. - -### Using a different model - -You can also choose another model that has the same format as [mbart50](https://huggingface.co/models?filter=mbart-50) or [m2m100](https://huggingface.co/models?search=facebook/m2m100) e.g. -```python -mt = dlt.TranslationModel("facebook/mbart-large-50-one-to-many-mmt", model_family="mbart50") -# or mt = dlt.TranslationModel("facebook/m2m100_1.2B", model_family="m2m100") +mt = dlt.TranslationModel("facebook/nllb-200-distilled-600M", model_family="nllb200") ``` -Note that the available languages will change if you do this, so you will not be able to leverage `dlt.lang` or `dlt.utils` and the `mt.available_languages()` might also return the incorrect value. +Notes: +* Make sure your tokenizer is also stored in the same directory if you load from a file. +* The available languages will change if you select a different model, so you will not be able to leverage `dlt.lang` or `dlt.utils`. ### Breaking down into sentences @@ -143,6 +141,7 @@ print(dlt.utils.available_languages('m2m100')) # write the name of the model fa At the moment, the following models are accepted: - `"mbart50"` - `"m2m100"` +- `"nllb200"` ### Offline usage diff --git a/scripts/generate_langs.py b/scripts/generate_langs.py index fcc8fed..cd9d941 100644 --- a/scripts/generate_langs.py +++ b/scripts/generate_langs.py @@ -2,6 +2,17 @@ import os +def name_to_var(lang_name): + return ( + lang_name.upper() + .replace(" ", "_") + .replace("(", "") + .replace(")", "") + .replace("-", "_") + .replace("'", "") + ) + + def load_json(name): filepath = os.path.join(os.path.dirname(__file__), "langs_coverage", f"{name}.json") return json.loads(open(filepath).read()) @@ -11,13 +22,13 @@ def load_json(name): name2json = {} -for name in ["m2m100", "mbart50"]: +for name in ["m2m100", "mbart50", "nllb200"]: name2json[name] = lang2code = load_json(name) with open(f"./dl_translate/lang/{name}.py", "w") as f: f.write(auto_gen_comment) for lang, code in lang2code.items(): - f.write(f'{lang.upper().replace(" ", "_")} = "{lang}"\n') + f.write(f'{name_to_var(lang)} = "{lang}"\n') with open("./dl_translate/_pairs.py", "w") as f: diff --git a/scripts/langs_coverage/nllb200.json b/scripts/langs_coverage/nllb200.json new file mode 100644 index 0000000..fe7b95e --- /dev/null +++ b/scripts/langs_coverage/nllb200.json @@ -0,0 +1,206 @@ +{ + "Acehnese (Arabic script)": "ace_Arab", + "Acehnese (Latin script)": "ace_Latn", + "Mesopotamian Arabic": "acm_Arab", + "Ta'izzi-Adeni Arabic": "acq_Arab", + "Tunisian Arabic": "aeb_Arab", + "Afrikaans": "afr_Latn", + "South Levantine Arabic": "ajp_Arab", + "Akan": "aka_Latn", + "Amharic": "amh_Ethi", + "North Levantine Arabic": "apc_Arab", + "Modern Standard Arabic": "arb_Arab", + "Modern Standard Arabic (Romanized)": "arb_Latn", + "Najdi Arabic": "ars_Arab", + "Moroccan Arabic": "ary_Arab", + "Egyptian Arabic": "arz_Arab", + "Assamese": "asm_Beng", + "Asturian": "ast_Latn", + "Awadhi": "awa_Deva", + "Central Aymara": "ayr_Latn", + "South Azerbaijani": "azb_Arab", + "North Azerbaijani": "azj_Latn", + "Bashkir": "bak_Cyrl", + "Bambara": "bam_Latn", + "Balinese": "ban_Latn", + "Belarusian": "bel_Cyrl", + "Bemba": "bem_Latn", + "Bengali": "ben_Beng", + "Bhojpuri": "bho_Deva", + "Banjar (Arabic script)": "bjn_Arab", + "Banjar (Latin script)": "bjn_Latn", + "Standard Tibetan": "bod_Tibt", + "Bosnian": "bos_Latn", + "Buginese": "bug_Latn", + "Bulgarian": "bul_Cyrl", + "Catalan": "cat_Latn", + "Cebuano": "ceb_Latn", + "Czech": "ces_Latn", + "Chokwe": "cjk_Latn", + "Central Kurdish": "ckb_Arab", + "Crimean Tatar": "crh_Latn", + "Welsh": "cym_Latn", + "Danish": "dan_Latn", + "German": "deu_Latn", + "Southwestern Dinka": "dik_Latn", + "Dyula": "dyu_Latn", + "Dzongkha": "dzo_Tibt", + "Greek": "ell_Grek", + "English": "eng_Latn", + "Esperanto": "epo_Latn", + "Estonian": "est_Latn", + "Basque": "eus_Latn", + "Ewe": "ewe_Latn", + "Faroese": "fao_Latn", + "Fijian": "fij_Latn", + "Finnish": "fin_Latn", + "Fon": "fon_Latn", + "French": "fra_Latn", + "Friulian": "fur_Latn", + "Nigerian Fulfulde": "fuv_Latn", + "Scottish Gaelic": "gla_Latn", + "Irish": "gle_Latn", + "Galician": "glg_Latn", + "Guarani": "grn_Latn", + "Gujarati": "guj_Gujr", + "Haitian Creole": "hat_Latn", + "Hausa": "hau_Latn", + "Hebrew": "heb_Hebr", + "Hindi": "hin_Deva", + "Chhattisgarhi": "hne_Deva", + "Croatian": "hrv_Latn", + "Hungarian": "hun_Latn", + "Armenian": "hye_Armn", + "Igbo": "ibo_Latn", + "Ilocano": "ilo_Latn", + "Indonesian": "ind_Latn", + "Icelandic": "isl_Latn", + "Italian": "ita_Latn", + "Javanese": "jav_Latn", + "Japanese": "jpn_Jpan", + "Kabyle": "kab_Latn", + "Jingpho": "kac_Latn", + "Kamba": "kam_Latn", + "Kannada": "kan_Knda", + "Kashmiri (Arabic script)": "kas_Arab", + "Kashmiri (Devanagari script)": "kas_Deva", + "Georgian": "kat_Geor", + "Central Kanuri (Arabic script)": "knc_Arab", + "Central Kanuri (Latin script)": "knc_Latn", + "Kazakh": "kaz_Cyrl", + "Kabiyè": "kbp_Latn", + "Kabuverdianu": "kea_Latn", + "Khmer": "khm_Khmr", + "Kikuyu": "kik_Latn", + "Kinyarwanda": "kin_Latn", + "Kyrgyz": "kir_Cyrl", + "Kimbundu": "kmb_Latn", + "Northern Kurdish": "kmr_Latn", + "Kikongo": "kon_Latn", + "Korean": "kor_Hang", + "Lao": "lao_Laoo", + "Ligurian": "lij_Latn", + "Limburgish": "lim_Latn", + "Lingala": "lin_Latn", + "Lithuanian": "lit_Latn", + "Lombard": "lmo_Latn", + "Latgalian": "ltg_Latn", + "Luxembourgish": "ltz_Latn", + "Luba-Kasai": "lua_Latn", + "Ganda": "lug_Latn", + "Luo": "luo_Latn", + "Mizo": "lus_Latn", + "Standard Latvian": "lvs_Latn", + "Magahi": "mag_Deva", + "Maithili": "mai_Deva", + "Malayalam": "mal_Mlym", + "Marathi": "mar_Deva", + "Minangkabau (Arabic script)": "min_Arab", + "Minangkabau (Latin script)": "min_Latn", + "Macedonian": "mkd_Cyrl", + "Plateau Malagasy": "plt_Latn", + "Maltese": "mlt_Latn", + "Meitei (Bengali script)": "mni_Beng", + "Halh Mongolian": "khk_Cyrl", + "Mossi": "mos_Latn", + "Maori": "mri_Latn", + "Burmese": "mya_Mymr", + "Dutch": "nld_Latn", + "Norwegian Nynorsk": "nno_Latn", + "Norwegian Bokmål": "nob_Latn", + "Nepali": "npi_Deva", + "Northern Sotho": "nso_Latn", + "Nuer": "nus_Latn", + "Nyanja": "nya_Latn", + "Occitan": "oci_Latn", + "West Central Oromo": "gaz_Latn", + "Odia": "ory_Orya", + "Pangasinan": "pag_Latn", + "Eastern Panjabi": "pan_Guru", + "Papiamento": "pap_Latn", + "Western Persian": "pes_Arab", + "Polish": "pol_Latn", + "Portuguese": "por_Latn", + "Dari": "prs_Arab", + "Southern Pashto": "pbt_Arab", + "Ayacucho Quechua": "quy_Latn", + "Romanian": "ron_Latn", + "Rundi": "run_Latn", + "Russian": "rus_Cyrl", + "Sango": "sag_Latn", + "Sanskrit": "san_Deva", + "Santali": "sat_Olck", + "Sicilian": "scn_Latn", + "Shan": "shn_Mymr", + "Sinhala": "sin_Sinh", + "Slovak": "slk_Latn", + "Slovenian": "slv_Latn", + "Samoan": "smo_Latn", + "Shona": "sna_Latn", + "Sindhi": "snd_Arab", + "Somali": "som_Latn", + "Southern Sotho": "sot_Latn", + "Spanish": "spa_Latn", + "Tosk Albanian": "als_Latn", + "Sardinian": "srd_Latn", + "Serbian": "srp_Cyrl", + "Swati": "ssw_Latn", + "Sundanese": "sun_Latn", + "Swedish": "swe_Latn", + "Swahili": "swh_Latn", + "Silesian": "szl_Latn", + "Tamil": "tam_Taml", + "Tatar": "tat_Cyrl", + "Telugu": "tel_Telu", + "Tajik": "tgk_Cyrl", + "Tagalog": "tgl_Latn", + "Thai": "tha_Thai", + "Tigrinya": "tir_Ethi", + "Tamasheq (Latin script)": "taq_Latn", + "Tamasheq (Tifinagh script)": "taq_Tfng", + "Tok Pisin": "tpi_Latn", + "Tswana": "tsn_Latn", + "Tsonga": "tso_Latn", + "Turkmen": "tuk_Latn", + "Tumbuka": "tum_Latn", + "Turkish": "tur_Latn", + "Twi": "twi_Latn", + "Central Atlas Tamazight": "tzm_Tfng", + "Uyghur": "uig_Arab", + "Ukrainian": "ukr_Cyrl", + "Umbundu": "umb_Latn", + "Urdu": "urd_Arab", + "Northern Uzbek": "uzn_Latn", + "Venetian": "vec_Latn", + "Vietnamese": "vie_Latn", + "Waray": "war_Latn", + "Wolof": "wol_Latn", + "Xhosa": "xho_Latn", + "Eastern Yiddish": "ydd_Hebr", + "Yoruba": "yor_Latn", + "Yue Chinese": "yue_Hant", + "Chinese (Simplified)": "zho_Hans", + "Chinese (Traditional)": "zho_Hant", + "Standard Malay": "zsm_Latn", + "Zulu": "zul_Latn" +} \ No newline at end of file diff --git a/scripts/render_available_langs.py b/scripts/render_available_langs.py index 13e3ea6..7ac6101 100644 --- a/scripts/render_available_langs.py +++ b/scripts/render_available_langs.py @@ -10,12 +10,14 @@ def load_json(name): template_values = {} -for name in ["m2m100", "mbart50"]: +for name in ["m2m100", "mbart50", "nllb200"]: content = "" di = load_json(name) + content += "| Language Name | Code |\n" + content += "| --- | --- |\n" for key, val in di.items(): - content += f"- {key} ({val})\n" + content += f"| {key} | {val} |\n" template_values[name] = content diff --git a/scripts/templates/available_languages.md.jinja2 b/scripts/templates/available_languages.md.jinja2 index b4af6b9..0df11e0 100644 --- a/scripts/templates/available_languages.md.jinja2 +++ b/scripts/templates/available_languages.md.jinja2 @@ -8,4 +8,8 @@ This page gives all the languages available for each model family. ## M2M-100 -{{m2m100}} \ No newline at end of file +{{m2m100}} + +## NLLB-200 + +{{nllb200}} \ No newline at end of file diff --git a/setup.py b/setup.py index df3a118..c2dce87 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="dl-translate", - version="0.2.6", + version="0.3.0", author="Xing Han Lu", author_email="github@xinghanlu.com", description="A deep learning-based translation library built on Huggingface transformers", @@ -18,10 +18,10 @@ "Operating System :: OS Independent", ], packages=setuptools.find_packages(), - python_requires=">=3.6", + python_requires=">=3.7", install_requires=[ - "transformers>=4.4.0", - "torch", + "transformers>=4.30.2", + "torch>=2.0.0", "sentencepiece", "protobuf", "tqdm", diff --git a/tests/quick/test_translation_model.py b/tests/quick/test_translation_model.py index 0f44b35..042cff8 100644 --- a/tests/quick/test_translation_model.py +++ b/tests/quick/test_translation_model.py @@ -15,7 +15,8 @@ def test_resolve_lang_codes_mbart50(): targets = [dlt.lang.ENGLISH, "en_XX", "English"] for source, target in zip(sources, targets): - s, t = _resolve_lang_codes(source, target, "mbart50") + s = _resolve_lang_codes(source, "source", "mbart50") + t = _resolve_lang_codes(target, "target", "mbart50") assert s == "fr_XX" assert t == "en_XX" @@ -25,11 +26,31 @@ def test_resolve_lang_codes_m2m100(): targets = [dlt.lang.m2m100.ENGLISH, "en", "English"] for source, target in zip(sources, targets): - s, t = _resolve_lang_codes(source, target, "m2m100") + s = _resolve_lang_codes(source, "source", "m2m100") + t = _resolve_lang_codes(target, "target", "m2m100") assert s == "fr" assert t == "en" +def test_resolve_lang_codes_m2m100(): + sources = [dlt.lang.nllb200.FRENCH, "fra_Latn", "French"] + targets = [dlt.lang.nllb200.ENGLISH, "eng_Latn", "English"] + + for source, target in zip(sources, targets): + s = _resolve_lang_codes(source, "source", "nllb200") + t = _resolve_lang_codes(target, "target", "nllb200") + assert s == "fra_Latn" + assert t == "eng_Latn" + + sources = ["Central Kanuri (Latin script)"] + targets = ["Ta'izzi-Adeni Arabic"] + for source, target in zip(sources, targets): + s = _resolve_lang_codes(source, "source", "nllb200") + t = _resolve_lang_codes(target, "target", "nllb200") + assert s == "knc_Latn" + assert t == "acq_Arab" + + def test_select_device(): assert _select_device("cpu") == torch.device("cpu") assert _select_device("gpu") == torch.device("cuda") diff --git a/tests/quick/test_utils.py b/tests/quick/test_utils.py index 98565bb..dfa9863 100644 --- a/tests/quick/test_utils.py +++ b/tests/quick/test_utils.py @@ -1,7 +1,7 @@ import pytest from dl_translate import utils -from dl_translate._pairs import _PAIRS_MBART50, _PAIRS_M2M100 +from dl_translate._pairs import _PAIRS_MBART50, _PAIRS_M2M100, _PAIRS_NLLB200 def test_dict_from_weights(): @@ -32,28 +32,38 @@ def test_dict_from_weights_exception(): def test_available_languages(): - assert utils.available_languages() == utils.available_languages("mbart50") + assert utils.available_languages() == utils.available_languages() langs = utils.available_languages() + for lang, _ in _PAIRS_M2M100: + assert lang in langs + + langs = utils.available_languages("mbart50") + for lang, _ in _PAIRS_MBART50: assert lang in langs - langs = utils.available_languages("m2m100") + langs = utils.available_languages("nllb200") - for lang, _ in _PAIRS_M2M100: + for lang, _ in _PAIRS_NLLB200: assert lang in langs def test_available_codes(): - assert utils.available_codes() == utils.available_codes("mbart50") + assert utils.available_codes() == utils.available_codes("m2m100") codes = utils.available_codes() + for _, code in _PAIRS_M2M100: + assert code in codes + + codes = utils.available_codes("mbart50") + for _, code in _PAIRS_MBART50: assert code in codes - codes = utils.available_codes("m2m100") + codes = utils.available_codes("nllb200") - for _, code in _PAIRS_M2M100: + for _, code in _PAIRS_NLLB200: assert code in codes