diff --git a/silnlp/common/postprocess_draft.py b/silnlp/common/postprocess_draft.py index 15af14c0..f228a545 100644 --- a/silnlp/common/postprocess_draft.py +++ b/silnlp/common/postprocess_draft.py @@ -6,9 +6,8 @@ from ..nmt.config_utils import load_config from ..nmt.postprocess import postprocess_experiment from .postprocesser import PostprocessConfig, PostprocessHandler -from .utils import get_mt_exp_dir -LOGGER = logging.getLogger(__package__ + ".postprocess_draft") +LOGGER = logging.getLogger((__package__ or "") + ".postprocess_draft") def main() -> None: diff --git a/silnlp/common/postprocesser.py b/silnlp/common/postprocesser.py index 9bd898bf..1dcf033f 100644 --- a/silnlp/common/postprocesser.py +++ b/silnlp/common/postprocesser.py @@ -34,7 +34,7 @@ from ..alignment.utils import compute_alignment_scores from .corpus import load_corpus, write_corpus -LOGGER = logging.getLogger(__package__ + ".translate") +LOGGER = logging.getLogger((__package__ or "") + ".translate") POSTPROCESS_DEFAULTS = { "paragraph_behavior": "end", # Possible values: end, place, strip diff --git a/silnlp/common/translate_google.py b/silnlp/common/translate_google.py index 055af186..9cf3bd97 100644 --- a/silnlp/common/translate_google.py +++ b/silnlp/common/translate_google.py @@ -1,16 +1,15 @@ import argparse import logging -from typing import Iterable, Optional +from typing import Generator, Iterable, Optional from google.cloud import translate_v2 as translate from machine.scripture import VerseRef, book_id_to_number -from ..common.environment import SIL_NLP_ENV from .paratext import book_file_name_digits -from .translator import TranslationGroup, Translator +from .translator import SentenceTranslation, SentenceTranslationGroup, Translator from .utils import get_git_revision_hash, get_mt_exp_dir -LOGGER = logging.getLogger(__package__ + ".translate") +LOGGER = logging.getLogger((__package__ or "") + ".translate") class GoogleTranslator(Translator): @@ -24,7 +23,7 @@ def translate( trg_iso: str, produce_multiple_translations: bool = False, vrefs: Optional[Iterable[VerseRef]] = None, - ) -> Iterable[TranslationGroup]: + ) -> Generator[SentenceTranslationGroup, None, None]: if produce_multiple_translations: LOGGER.warning("Google Translator does not support --multiple-translations") @@ -35,8 +34,8 @@ def translate( results = self._translate_client.translate( sentence, source_language=src_iso, target_language=trg_iso, format_="text" ) - translation = results["translatedText"] - yield [translation] + translation: str = results["translatedText"] + yield [SentenceTranslation(translation, [], [], None)] def main() -> None: diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index 3464d80f..f1721c0c 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -2,11 +2,12 @@ import re from abc import ABC, abstractmethod from collections import defaultdict +from contextlib import AbstractContextManager from datetime import date from itertools import groupby from math import exp from pathlib import Path -from typing import DefaultDict, Iterable, List, Optional +from typing import DefaultDict, Generator, Iterable, List, Optional, Tuple, cast import docx import nltk @@ -14,6 +15,8 @@ from machine.corpora import ( FileParatextProjectSettingsParser, FileParatextProjectTextUpdater, + ParatextProjectSettings, + ScriptureRef, UpdateUsfmParserHandler, UpdateUsfmTextBehavior, UsfmFileText, @@ -32,7 +35,7 @@ from .postprocesser import NoDetectedQuoteConventionException, PostprocessHandler, UnknownQuoteConventionException from .usfm_utils import PARAGRAPH_TYPE_EMBEDS -LOGGER = logging.getLogger(__package__ + ".translate") +LOGGER = logging.getLogger((__package__ or "") + ".translate") nltk.download("punkt") CONFIDENCE_SCORES_SUFFIX = ".confidences.tsv" @@ -68,7 +71,7 @@ def join_tokens_for_confidence_file(self) -> str: return "\t".join(self._tokens) def join_token_scores_for_confidence_file(self) -> str: - return "\t".join([str(exp(ts)) for ts in [self._sequence_score] + self._token_scores]) + return "\t".join([str(exp(ts)) for ts in [self._sequence_score] + self._token_scores if ts is not None]) # A group of multiple translations of a single sentence @@ -87,7 +90,7 @@ def write_confidence_scores_to_file( self, confidences_path: Path, row1col1_label: str, - vrefs: Optional[List[VerseRef]] = None, + scripture_refs: Optional[List[ScriptureRef]] = None, ) -> None: with confidences_path.open("w", encoding="utf-8", newline="\n") as confidences_file: confidences_file.write("\t".join([f"{row1col1_label}"] + [f"Token {i}" for i in range(200)]) + "\n") @@ -96,19 +99,24 @@ def write_confidence_scores_to_file( if not sentence_translation.has_sequence_confidence_score(): continue sequence_label = str(sentence_num) - if vrefs is not None: - sequence_label = str(vrefs[sentence_num]) + if scripture_refs is not None: + sequence_label = str(scripture_refs[sentence_num]) confidences_file.write( sequence_label + "\t" + sentence_translation.join_tokens_for_confidence_file() + "\n" ) confidences_file.write(sentence_translation.join_token_scores_for_confidence_file() + "\n") - def write_chapter_confidence_scores_to_file(self, chapter_confidences_path: Path, vrefs: List[VerseRef]): + def write_chapter_confidence_scores_to_file( + self, chapter_confidences_path: Path, scripture_refs: List[ScriptureRef] + ): chapter_confidences: DefaultDict[int, List[float]] = defaultdict(list) - for sentence_num, vref in enumerate(vrefs): - if not vref.is_verse or self._sentence_translations[sentence_num].get_sequence_confidence_score() is None: + for sentence_num, vref in enumerate(scripture_refs): + sequence_confidence_score: Optional[float] = self._sentence_translations[ + sentence_num + ].get_sequence_confidence_score() + if not vref.is_verse or sequence_confidence_score is None: continue - vref_confidence = exp(self._sentence_translations[sentence_num].get_sequence_confidence_score()) + vref_confidence = exp(sequence_confidence_score) chapter_confidences[vref.chapter_num].append(vref_confidence) with chapter_confidences_path.open("w", encoding="utf-8", newline="\n") as chapter_confidences_file: @@ -119,9 +127,9 @@ def write_chapter_confidence_scores_to_file(self, chapter_confidences_path: Path def get_all_sequence_confidence_scores(self) -> List[float]: return [ - exp(st.get_sequence_confidence_score()) - for st in self._sentence_translations - if st.get_sequence_confidence_score() is not None + exp(scs) + for scs in [t.get_sequence_confidence_score() for t in self._sentence_translations] + if scs is not None ] def get_all_translations(self) -> List[str]: @@ -153,7 +161,7 @@ def generate_confidence_files( trg_file_path: Path, trg_prefix: str = "", produce_multiple_translations: bool = False, - vrefs: Optional[List[VerseRef]] = None, + scripture_refs: Optional[List[ScriptureRef]] = None, draft_index: int = 0, ) -> None: if not translated_draft.has_sequence_confidence_scores(): @@ -169,8 +177,8 @@ def generate_confidence_files( ext = trg_file_path.suffix.lower() if ext in {".usfm", ".sfm"}: - assert vrefs is not None - generate_usfm_confidence_files(translated_draft, trg_file_path, confidences_path, vrefs, draft_index) + assert scripture_refs is not None + generate_usfm_confidence_files(translated_draft, trg_file_path, confidences_path, scripture_refs, draft_index) elif ext == ".txt": generate_txt_confidence_files(translated_draft, trg_file_path, confidences_path, trg_prefix) else: @@ -184,24 +192,26 @@ def generate_usfm_confidence_files( translated_draft: TranslatedDraft, trg_file_path: Path, confidences_path: Path, - vrefs: List[VerseRef], + scripture_refs: List[ScriptureRef], draft_index: int = 0, ) -> None: - translated_draft.write_confidence_scores_to_file(confidences_path, "VRef", vrefs) - translated_draft.write_chapter_confidence_scores_to_file(confidences_path.with_suffix(".chapters.tsv"), vrefs) - _append_book_confidence_score(translated_draft, trg_file_path, vrefs) + translated_draft.write_confidence_scores_to_file(confidences_path, "VRef", scripture_refs) + translated_draft.write_chapter_confidence_scores_to_file( + confidences_path.with_suffix(".chapters.tsv"), scripture_refs + ) + _append_book_confidence_score(translated_draft, trg_file_path, scripture_refs) def _append_book_confidence_score( translated_draft: TranslatedDraft, trg_file_path: Path, - vrefs: List[VerseRef], + scripture_refs: List[ScriptureRef], ) -> None: file_confidences_path = trg_file_path.parent / "confidences.books.tsv" row1_col1_header = "Book" - if vrefs: - col1_entry = vrefs[0].book + if scripture_refs: + col1_entry = scripture_refs[0].book else: col1_entry = trg_file_path.stem @@ -250,7 +260,7 @@ def generate_test_confidence_files( translated_draft.write_confidence_scores_to_file(confidences_path, "Sequence Number") -class Translator(ABC): +class Translator(AbstractContextManager["Translator"], ABC): @abstractmethod def translate( self, @@ -259,7 +269,7 @@ def translate( trg_iso: str, produce_multiple_translations: bool = False, vrefs: Optional[Iterable[VerseRef]] = None, - ) -> Iterable[SentenceTranslationGroup]: + ) -> Generator[SentenceTranslationGroup, None, None]: pass def translate_text( @@ -348,13 +358,19 @@ def translate_usfm( ) -> None: # Create UsfmFileText object for source src_from_project = False + src_settings: Optional[ParatextProjectSettings] = None + stylesheet = UsfmStylesheet("usfm.sty") if str(src_file_path).startswith(str(get_project_dir(""))): src_from_project = True src_settings = FileParatextProjectSettingsParser(src_file_path.parent).parse() + stylesheet = src_settings.stylesheet + book_id = src_settings.get_book_id(src_file_path.name) + assert book_id is not None + src_file_text = UsfmFileText( src_settings.stylesheet, src_settings.encoding, - src_settings.get_book_id(src_file_path.name), + book_id, src_file_path, src_settings.versification, include_all_text=True, @@ -367,28 +383,28 @@ def translate_usfm( if not is_book_id_valid(book_id): raise ValueError(f"Book ID not detected: {book_id}") - src_file_text = UsfmFileText("usfm.sty", "utf-8-sig", book_id, src_file_path, include_all_text=True) - stylesheet = src_settings.stylesheet if src_from_project else UsfmStylesheet("usfm.sty") + src_file_text = UsfmFileText(stylesheet, "utf-8-sig", book_id, src_file_path, include_all_text=True) sentences = [re.sub(" +", " ", add_tags_to_sentence(tags, s.text.strip())) for s in src_file_text] - vrefs = [s.ref for s in src_file_text] + scripture_refs: List[ScriptureRef] = [s.ref for s in src_file_text] + vrefs: List[VerseRef] = [sr.verse_ref for sr in scripture_refs] LOGGER.info(f"File {src_file_path} parsed correctly.") # Filter sentences for i in reversed(range(len(sentences))): - marker = vrefs[i].path[-1].name if len(vrefs[i].path) > 0 else "" + marker = scripture_refs[i].path[-1].name if len(scripture_refs[i].path) > 0 else "" if ( - (len(chapters) > 0 and vrefs[i].chapter_num not in chapters) + (len(chapters) > 0 and scripture_refs[i].chapter_num not in chapters) or marker in PARAGRAPH_TYPE_EMBEDS or stylesheet.get_tag(marker).text_type == UsfmTextType.NOTE_TEXT ): sentences.pop(i) - vrefs.pop(i) - empty_sents = [] + scripture_refs.pop(i) + empty_sents: List[Tuple[int, ScriptureRef]] = [] for i in reversed(range(len(sentences))): if len(sentences[i].strip()) == 0: sentences.pop(i) - empty_sents.append((i, vrefs.pop(i))) + empty_sents.append((i, scripture_refs.pop(i))) sentence_translation_groups: List[SentenceTranslationGroup] = list( self.translate(sentences, src_iso, trg_iso, produce_multiple_translations, vrefs) @@ -399,7 +415,7 @@ def translate_usfm( # Prevents pre-existing text from showing up in the sections of translated text for idx, vref in reversed(empty_sents): sentences.insert(idx, "") - vrefs.insert(idx, vref) + scripture_refs.insert(idx, vref) sentence_translation_groups.insert(idx, [SentenceTranslation("", [], [], None)] * num_drafts) text_behavior = ( @@ -408,13 +424,13 @@ def translate_usfm( draft_set: DraftGroup = DraftGroup(sentence_translation_groups) for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): - postprocess_handler.construct_rows(vrefs, sentences, translated_draft.get_all_translations()) + postprocess_handler.construct_rows(scripture_refs, sentences, translated_draft.get_all_translations()) for config in postprocess_handler.configs: # Compile draft remarks draft_src_str = f"project {src_file_text.project}" if src_from_project else f"file {src_file_path.name}" - draft_remark = f"This draft of {vrefs[0].book} was machine translated on {date.today()} from {draft_src_str} using model {experiment_ckpt_str}. It should be reviewed and edited carefully." + draft_remark = f"This draft of {scripture_refs[0].book} was machine translated on {date.today()} from {draft_src_str} using model {experiment_ckpt_str}. It should be reviewed and edited carefully." postprocess_remark = config.get_postprocess_remark() remarks = [draft_remark] + ([postprocess_remark] if postprocess_remark else []) @@ -449,7 +465,7 @@ def translate_usfm( usfm = f.read() handler = UpdateUsfmParserHandler( rows=config.rows, - id_text=vrefs[0].book, + id_text=scripture_refs[0].book, text_behavior=text_behavior, paragraph_behavior=config.get_paragraph_behavior(), embed_behavior=config.get_embed_behavior(), @@ -483,9 +499,9 @@ def translate_usfm( "w", encoding=( "utf-8" - if not src_from_project - or src_from_project - and (src_settings.encoding == "utf-8-sig" or src_settings.encoding == "utf_8_sig") + if src_settings is None + or src_settings.encoding == "utf-8-sig" + or src_settings.encoding == "utf_8_sig" else src_settings.encoding ), ) as f: @@ -496,7 +512,7 @@ def translate_usfm( translated_draft, trg_file_path, produce_multiple_translations=produce_multiple_translations, - vrefs=vrefs, + scripture_refs=scripture_refs, draft_index=draft_index, ) @@ -513,7 +529,7 @@ def translate_docx( try: src_lang = Lang(src_iso) tokenizer = nltk.data.load(f"tokenizers/punkt/{src_lang.name.lower()}.pickle") - except: + except Exception: tokenizer = nltk.data.load("tokenizers/punkt/english.pickle") with src_file_path.open("rb") as file: @@ -522,8 +538,8 @@ def translate_docx( sentences: List[str] = [] paras: List[int] = [] - for i in range(len(doc.paragraphs)): - for sentence in tokenizer.tokenize(doc.paragraphs[i].text, "test"): + for i, paragraph in enumerate(doc.paragraphs): + for sentence in tokenizer.tokenize(paragraph.text): sentences.append(add_tags_to_sentence(tags, sentence)) paras.append(i) @@ -532,7 +548,7 @@ def translate_docx( ) for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): - for para, group in groupby(zip(translated_draft, paras), key=lambda t: t[1]): + for para, group in groupby(zip(translated_draft.get_all_translations(), paras), key=lambda t: t[1]): text = " ".join(s[0] for s in group) doc.paragraphs[para].text = text @@ -543,3 +559,11 @@ def translate_docx( with trg_draft_file_path.open("wb") as file: doc.save(file) + + def __enter__(self) -> "Translator": + return self + + def __exit__( + self, exc_type, exc_val, exc_tb # pyright: ignore[reportMissingParameterType, reportUnknownParameterType] + ) -> None: + pass diff --git a/silnlp/common/utils.py b/silnlp/common/utils.py index e3feebd7..55cc4297 100644 --- a/silnlp/common/utils.py +++ b/silnlp/common/utils.py @@ -3,6 +3,7 @@ import random import subprocess from abc import ABC, abstractmethod +from argparse import Namespace from enum import Enum, Flag, auto from inspect import getmembers from pathlib import Path, PurePath @@ -12,7 +13,7 @@ import numpy as np import pandas as pd -from ..common.environment import SIL_NLP_ENV +from ..common.environment import SIL_NLP_ENV, SilNlpEnv LOGGER = logging.getLogger(__name__) @@ -44,7 +45,7 @@ def print_table(rows): print() -def show_attrs(cli_args, envs=SIL_NLP_ENV, actions=[]): +def show_attrs(cli_args: Namespace, envs: SilNlpEnv = SIL_NLP_ENV, actions: List[str] = []) -> None: env_rows = [(k, v) for k, v in attrs(envs).items()] arg_rows = [(k, v) for k, v in cli_args.__dict__.items() if v is not None] @@ -122,7 +123,7 @@ def check_dotnet() -> None: stderr=subprocess.DEVNULL, ) _is_dotnet_installed = True - except: + except Exception: _is_dotnet_installed = False if not _is_dotnet_installed: @@ -131,7 +132,7 @@ def check_dotnet() -> None: class NoiseMethod(ABC): @abstractmethod - def __call__(self, tokens: list) -> list: + def __call__(self, tokens: list[str]) -> list[str]: pass diff --git a/silnlp/nmt/config.py b/silnlp/nmt/config.py index 96e75d5c..d2b24538 100644 --- a/silnlp/nmt/config.py +++ b/silnlp/nmt/config.py @@ -8,7 +8,7 @@ from enum import Enum, auto from pathlib import Path from statistics import mean, median, stdev -from typing import Any, Dict, Iterable, List, Optional, Set, TextIO, Tuple, Union, cast +from typing import Dict, Generator, Iterable, List, Optional, Set, TextIO, Tuple, Union, cast import pandas as pd from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, get_books @@ -48,7 +48,7 @@ ) from .tokenizer import Tokenizer -LOGGER = logging.getLogger(__package__ + ".config") +LOGGER = logging.getLogger((__package__ or "") + ".config") ALIGNMENT_SCORES_FILE = re.compile(r"([a-z]{2,3}-.+)_([a-z]{2,3}-.+)") @@ -84,9 +84,10 @@ def translate( sentences: Iterable[str], src_iso: str, trg_iso: str, + produce_multiple_translations: bool = False, vrefs: Optional[Iterable[VerseRef]] = None, ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, - ) -> Iterable[SentenceTranslationGroup]: ... + ) -> Generator[SentenceTranslationGroup, None, None]: ... @abstractmethod def get_checkpoint_path(self, ckpt: Union[CheckpointType, str, int]) -> Tuple[Path, int]: ... @@ -94,6 +95,9 @@ def get_checkpoint_path(self, ckpt: Union[CheckpointType, str, int]) -> Tuple[Pa @abstractmethod def clear_cache(self) -> None: ... + @abstractmethod + def get_num_drafts(self) -> int: ... + class Config(ABC): def __init__(self, exp_dir: Path, config: dict) -> None: @@ -328,7 +332,8 @@ def _calculate_tokenization_stats(self) -> None: else: existing_stats = pd.DataFrame({(" ", "Translation Side"): ["Source", "Target"]}) - src_tokens_per_verse, src_chars_per_token = [], [] + src_tokens_per_verse: List[int] = [] + src_chars_per_token: List[int] = [] for src_tok_file in self.exp_dir.glob("*.src.txt"): src_tok_file = src_tok_file.name with open(self.exp_dir / src_tok_file, "r+", encoding="utf-8") as f: @@ -336,7 +341,8 @@ def _calculate_tokenization_stats(self) -> None: src_tokens_per_verse.append(len(line.split())) src_chars_per_token.extend([len(token) for token in line.split()]) - trg_tokens_per_verse, trg_chars_per_token = [], [] + trg_tokens_per_verse: List[int] = [] + trg_chars_per_token: List[int] = [] for trg_tok_file in self.exp_dir.glob("*.trg.txt"): trg_tok_file = trg_tok_file.name with open(self.exp_dir / trg_tok_file, "r+", encoding="utf-8") as f: @@ -344,7 +350,9 @@ def _calculate_tokenization_stats(self) -> None: trg_tokens_per_verse.append(len(line.split())) trg_chars_per_token.extend([len(token) for token in line.split()]) - src_chars_per_verse, src_words_per_verse, src_chars_per_word = [], [], [] + src_chars_per_verse: List[int] = [] + src_words_per_verse: List[int] = [] + src_chars_per_word: List[int] = [] for src_detok_file in self.exp_dir.glob("*.src.detok.txt"): src_detok_file = src_detok_file.name with open(self.exp_dir / src_detok_file, "r+", encoding="utf-8") as f: @@ -354,7 +362,9 @@ def _calculate_tokenization_stats(self) -> None: src_words_per_verse.append(len(word_line.split())) src_chars_per_word.extend([len(word) for word in word_line.split()]) - trg_chars_per_verse, trg_words_per_verse, trg_chars_per_word = [], [], [] + trg_chars_per_verse: List[int] = [] + trg_words_per_verse: List[int] = [] + trg_chars_per_word: List[int] = [] for trg_detok_file in self.exp_dir.glob("*.trg.detok.txt"): trg_detok_file = trg_detok_file.name with open(self.exp_dir / trg_detok_file, "r+", encoding="utf-8") as f: @@ -850,7 +860,7 @@ def _collect_terms( categories_set: Optional[Set[str]] = None if categories is None else set(categories) if terms_config["include_glosses"]: - gloss_iso: str = str(terms_config["include_glosses"]).lower() + gloss_iso: Optional[str] = str(terms_config["include_glosses"]).lower() if gloss_iso == "true": src_gloss_iso = list(self.src_isos.intersection(["en", "fr", "id", "es"])) trg_gloss_iso = list(self.trg_isos.intersection(["en", "fr", "id", "es"])) @@ -871,16 +881,16 @@ def _collect_terms( else: gloss_iso = None - all_src_terms: List[Tuple[DataFile, Dict[str, Term], str]] = [] + all_src_terms: List[Tuple[DataFile, Dict[str, Term], List[str]]] = [] for src_terms_file, tags in src_terms_files: all_src_terms.append((src_terms_file, get_terms(src_terms_file.path, iso=gloss_iso), tags)) - all_trg_terms: List[Tuple[DataFile, Dict[str, Term], str]] = [] + all_trg_terms: List[Tuple[DataFile, Dict[str, Term], List[str]]] = [] for trg_terms_file, tags in trg_terms_files: all_trg_terms.append((trg_terms_file, get_terms(trg_terms_file.path, iso=gloss_iso), tags)) for src_terms_file, src_terms, tags in all_src_terms: - for trg_terms_file, trg_terms, trg_tags in all_trg_terms: + for trg_terms_file, trg_terms, _ in all_trg_terms: if src_terms_file.iso == trg_terms_file.iso: continue cur_terms = get_terms_corpus(src_terms, trg_terms, categories_set, filter_books) diff --git a/silnlp/nmt/experiment.py b/silnlp/nmt/experiment.py index 58577bfa..1c91e351 100644 --- a/silnlp/nmt/experiment.py +++ b/silnlp/nmt/experiment.py @@ -1,8 +1,8 @@ import argparse import os -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path -from typing import Optional, Set +from typing import Optional, Set, Union import yaml @@ -11,7 +11,7 @@ from ..common.utils import get_git_revision_hash, show_attrs from .clearml_connection import TAGS_LIST, SILClearML from .config import Config, get_mt_exp_dir -from .test import _SUPPORTED_SCORERS, test +from .test import SUPPORTED_SCORERS, test from .translate import TranslationTask @@ -30,7 +30,7 @@ class SILExperiment: run_translate: bool = False produce_multiple_translations: bool = False save_confidences: bool = False - scorers: Set[str] = field(default_factory=set) + scorers: Optional[Set[str]] = None score_by_book: bool = False commit: Optional[str] = None clearml_tag: Optional[str] = None @@ -48,6 +48,9 @@ def __post_init__(self): self.rev_hash = get_git_revision_hash() self.config.set_seed() + if self.scorers is None: + self.scorers = set() + def run(self): if self.run_prep: self.preprocess() @@ -77,6 +80,7 @@ def train(self): print("Training completed") def test(self): + assert self.scorers is not None test( experiment=self.name, last=self.config.model_dir.exists(), @@ -96,9 +100,10 @@ def translate(self): postprocess_handler = PostprocessHandler([PostprocessConfig(pc) for pc in postprocess_configs]) for translate_config in translate_configs.get("translate", []): + checkpoint: Union[str, int] = translate_config.get("checkpoint", "last") or "last" translator = TranslationTask( name=self.name, - checkpoint=translate_config.get("checkpoint", "last"), + checkpoint=checkpoint, use_default_model_dir=self.save_checkpoints, commit=self.commit, ) @@ -123,6 +128,11 @@ def translate(self): translate_config.get("tags"), ) elif translate_config.get("src_prefix"): + if translate_config.get("trg_prefix") is None: + raise RuntimeError("A target file prefix must be specified.") + if translate_config.get("start_seq") is None or translate_config.get("end_seq") is None: + raise RuntimeError("Start and end sequence numbers must be specified.") + translator.translate_text_files( translate_config.get("src_prefix"), translate_config.get("trg_prefix"), @@ -210,7 +220,7 @@ def main() -> None: "--scorers", nargs="*", metavar="scorer", - choices=_SUPPORTED_SCORERS, + choices=SUPPORTED_SCORERS, default=[ "bleu", "sentencebleu", @@ -223,7 +233,7 @@ def main() -> None: "m-chrf3++", "spbleu", ], - help=f"List of scorers - {_SUPPORTED_SCORERS}", + help=f"List of scorers - {SUPPORTED_SCORERS}", ) args = parser.parse_args() diff --git a/silnlp/nmt/hugging_face_config.py b/silnlp/nmt/hugging_face_config.py index 9edd19c5..59225275 100644 --- a/silnlp/nmt/hugging_face_config.py +++ b/silnlp/nmt/hugging_face_config.py @@ -11,7 +11,7 @@ from itertools import repeat from math import prod from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union, cast +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, TypeVar, Union, cast import datasets.utils.logging as datasets_logging import evaluate @@ -27,7 +27,7 @@ from tokenizers import AddedToken, NormalizedString, Regex from tokenizers.implementations import SentencePieceBPETokenizer, SentencePieceUnigramTokenizer from tokenizers.normalizers import Normalizer -from torch import Tensor, TensorType, nn, optim +from torch import Tensor, nn, optim from torch.utils.data import Sampler from transformers import ( AutoConfig, @@ -52,6 +52,7 @@ Seq2SeqTrainingArguments, T5Tokenizer, T5TokenizerFast, + TensorType, TrainerCallback, TranslationPipeline, set_seed, @@ -90,6 +91,8 @@ if is_peft_available(): from peft import LoraConfig, PeftModel, TaskType, get_peft_model +else: + LoraConfig, PeftModel, TaskType, get_peft_model = None, None, None, None LOGGER = logging.getLogger(__name__) @@ -181,7 +184,16 @@ def prepare_decoder_input_ids_from_labels(self: M2M100ForConditionalGeneration, # "loss" and "eval_loss" are both evaluation loss # The early stopping callback adds "eval_" to all metrics that don't already start with it DEFAULT_METRICS = ["loss", "eval_loss"] -EVAL_METRICS_MODULES = {"bleu": "sacrebleu", "chrf3": "chrf", "chrf3+": "chrf", "chrf3++": "chrf", "m-bleu": "sacrebleu", "m-chrf3": "chrf", "m-chrf3+": "chrf", "m-chrf3++": "chrf"} +EVAL_METRICS_MODULES = { + "bleu": "sacrebleu", + "chrf3": "chrf", + "chrf3+": "chrf", + "chrf3++": "chrf", + "m-bleu": "sacrebleu", + "m-chrf3": "chrf", + "m-chrf3+": "chrf", + "m-chrf3++": "chrf", +} def get_best_checkpoint(model_dir: Path) -> Path: @@ -235,7 +247,11 @@ def delete_tokenizer(checkpoint_path: Path) -> None: def add_lang_code_to_tokenizer(tokenizer: PreTrainedTokenizer, lang_code: str) -> None: - tokenizer.add_special_tokens({"additional_special_tokens": [lang_code]}, replace_additional_special_tokens=False) + # Huggingface does not follow its own type hints with this function and expects Dict[str, List[str]] + tokenizer.add_special_tokens( + {"additional_special_tokens": [lang_code]}, # pyright: ignore[reportArgumentType] + replace_additional_special_tokens=False, + ) lang_id = tokenizer.convert_tokens_to_ids(lang_code) if isinstance(tokenizer, (MBart50Tokenizer, MBartTokenizer)): tokenizer.id_to_lang_code[lang_id] = lang_code @@ -704,7 +720,7 @@ def _write_dictionary( categories_set: Optional[Set[str]] = None if categories is None else set(categories) if terms_config["include_glosses"]: - gloss_iso: str = str(terms_config["include_glosses"]).lower() + gloss_iso: Optional[str] = str(terms_config["include_glosses"]).lower() if gloss_iso == "true": src_gloss_iso = list(self.src_isos.intersection(["en", "fr", "id", "es"])) trg_gloss_iso = list(self.trg_isos.intersection(["en", "fr", "id", "es"])) @@ -1242,7 +1258,7 @@ def translate( produce_multiple_translations: bool = False, vrefs: Optional[Iterable[VerseRef]] = None, ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, - ) -> Iterable[SentenceTranslationGroup]: + ) -> Generator[SentenceTranslationGroup, None, None]: src_lang = self._config.data["lang_codes"].get(src_iso, src_iso) trg_lang = self._config.data["lang_codes"].get(trg_iso, trg_iso) inference_model_params = InferenceModelParams(ckpt, src_lang, trg_lang) @@ -1410,6 +1426,7 @@ def _convert_to_lora_model(self, model: PreTrainedModel) -> PreTrainedModel: if "embed_tokens" in modules_to_save or "embed_tokens" in target_modules: model = self._create_tied_embedding_weights(model) + assert LoraConfig is not None # This was conditionally imported peft_config = LoraConfig( task_type=TaskType.SEQ_2_SEQ_LM, r=lora_config.get("r", 4), diff --git a/silnlp/nmt/preprocess.py b/silnlp/nmt/preprocess.py index 6870d4c8..7b7a35e1 100644 --- a/silnlp/nmt/preprocess.py +++ b/silnlp/nmt/preprocess.py @@ -4,7 +4,7 @@ from ..common.utils import get_git_revision_hash from .config_utils import load_config -LOGGER = logging.getLogger(__package__ + ".preprocess") +LOGGER = logging.getLogger((__package__ or "") + ".preprocess") def main() -> None: diff --git a/silnlp/nmt/test.py b/silnlp/nmt/test.py index de4734fd..b6c1405b 100644 --- a/silnlp/nmt/test.py +++ b/silnlp/nmt/test.py @@ -4,7 +4,7 @@ from contextlib import ExitStack from io import StringIO from pathlib import Path -from typing import IO, Dict, List, Optional, Set, TextIO, Tuple +from typing import Dict, List, Optional, Set, TextIO, Tuple import sacrebleu from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, book_number_to_id, get_chapters @@ -18,11 +18,11 @@ from .config_utils import load_config from .tokenizer import Tokenizer -LOGGER = logging.getLogger(__package__ + ".test") +LOGGER = logging.getLogger((__package__ or "") + ".test") logging.getLogger("sacrebleu").setLevel(logging.ERROR) -_SUPPORTED_SCORERS = [ +SUPPORTED_SCORERS = [ "bleu", "sentencebleu", "chrf3", @@ -61,7 +61,7 @@ def __init__( self.book = book self.draft_index = draft_index - def writeHeader(self, file: IO) -> None: + def writeHeader(self, file: TextIO) -> None: header = ( "book,draft_index,src_iso,trg_iso,num_refs,references,sent_len" + ( @@ -75,7 +75,7 @@ def writeHeader(self, file: IO) -> None: ) file.write(header) - def write(self, file: IO) -> None: + def write(self, file: TextIO) -> None: file.write( f"{self.book},{self.draft_index},{self.src_iso},{self.trg_iso}," f"{self.num_refs},{self.refs},{self.sent_len:d}" @@ -237,7 +237,7 @@ def score_pair( def score_individual_books( - book_dict: Dict[str, Tuple[List[str], List[List[str]]]], + book_dict: Dict[str, Tuple[List[str], List[List[str]], List[float]]], src_iso: str, trg_iso: str, predictions_detok_file_name: str, @@ -280,7 +280,7 @@ def process_individual_books( conf_file_path: Path, select_rand_ref_line: bool, books: Dict[int, List[int]], -) -> Dict[str, Tuple[List[str], List[List[str]]]]: +) -> Dict[str, Tuple[List[str], List[List[str]], List[float]]]: # Output data structure book_dict: Dict[str, Tuple[List[str], List[List[str]], List[float]]] = {} with ExitStack() as stack: @@ -344,10 +344,10 @@ def load_test_data( config: Config, books: Dict[int, List[int]], by_book: bool, -) -> Tuple[List[str], List[List[str]], Dict[str, Tuple[List[str], List[List[str]]]]]: +) -> Tuple[List[str], List[List[str]], Dict[str, Tuple[List[str], List[List[str]], List[float]]]]: sys: List[str] = [] refs: List[List[str]] = [] - book_dict: Dict[str, Tuple[List[str], List[List[str]]]] = {} + book_dict: Dict[str, Tuple[List[str], List[List[str]], List[float]]] = {} pred_file_path = config.exp_dir / pred_file_name conf_file_path = config.exp_dir / conf_file_name with ExitStack() as stack: @@ -364,8 +364,8 @@ def load_test_data( else: # use specified refs only ref_file_paths = [p for p in ref_file_paths if config.is_ref_project(ref_projects, p)] - ref_files: List[IO] = [] - vref_file: Optional[IO] = None + ref_files: List[TextIO] = [] + vref_file: Optional[TextIO] = None vref_file_path = config.exp_dir / vref_file_name if len(books) > 0 and vref_file_path.is_file(): vref_file = stack.enter_context(vref_file_path.open("r", encoding="utf-8")) @@ -644,7 +644,7 @@ def test_checkpoint( ref_projects_suffix = "_".join(sorted(ref_projects)) scores_file_root += f"-{ref_projects_suffix}" with (config.exp_dir / f"{scores_file_root}.csv").open("w", encoding="utf-8") as scores_file: - if scores is not None: + if len(scores) > 0: scores[0].writeHeader(scores_file) for results in scores: results.write(scores_file) @@ -679,7 +679,7 @@ def test( scorers.add("confidence") if len(scorers) == 0: scorers.add("bleu") - scorers.intersection_update(set(_SUPPORTED_SCORERS)) + scorers.intersection_update(set(SUPPORTED_SCORERS)) tokenizer = config.create_tokenizer() model = config.create_model() @@ -720,7 +720,7 @@ def test( save_confidences, ) except ValueError: - LOGGER.warn("No average checkpoint available.") + LOGGER.warning("No average checkpoint available.") best_step = 0 if best and config.has_best_checkpoint: @@ -823,9 +823,9 @@ def main() -> None: "--scorers", nargs="*", metavar="scorer", - choices=_SUPPORTED_SCORERS, + choices=SUPPORTED_SCORERS, default=[], - help=f"List of scorers - {_SUPPORTED_SCORERS}", + help=f"List of scorers - {SUPPORTED_SCORERS}", ) parser.add_argument("--books", nargs="*", metavar="book", default=[], help="Books") parser.add_argument("--by-book", default=False, action="store_true", help="Score individual books") diff --git a/silnlp/nmt/train.py b/silnlp/nmt/train.py index 4e26865d..f05778f2 100644 --- a/silnlp/nmt/train.py +++ b/silnlp/nmt/train.py @@ -5,7 +5,7 @@ from .clearml_connection import TAGS_LIST, SILClearML from .config_utils import load_config -LOGGER = logging.getLogger(__package__ + ".train") +LOGGER = logging.getLogger((__package__ or "") + ".train") # As of TF 2.7, deterministic mode is slower, so we will disable it for now. # os.environ["TF_DETERMINISTIC_OPS"] = "True" diff --git a/silnlp/nmt/translate.py b/silnlp/nmt/translate.py index 29bbd7f7..a9e73d33 100644 --- a/silnlp/nmt/translate.py +++ b/silnlp/nmt/translate.py @@ -5,7 +5,7 @@ from contextlib import AbstractContextManager from dataclasses import dataclass from pathlib import Path -from typing import Iterable, List, Optional, Tuple, Union +from typing import Generator, Iterable, List, Optional, Tuple, Union from machine.scripture import VerseRef, book_number_to_id, get_chapters @@ -17,12 +17,12 @@ from .clearml_connection import TAGS_LIST, SILClearML from .config import CheckpointType, Config, NMTModel -LOGGER = logging.getLogger(__package__ + ".translate") +LOGGER = logging.getLogger((__package__ or "") + ".translate") -class NMTTranslator(Translator, AbstractContextManager): +class NMTTranslator(Translator): def __init__(self, model: NMTModel, checkpoint: Union[CheckpointType, str, int]) -> None: - self._model = model + self._model: NMTModel = model self._checkpoint = checkpoint def translate( @@ -32,12 +32,14 @@ def translate( trg_iso: str, produce_multiple_translations: bool = False, vrefs: Optional[Iterable[VerseRef]] = None, - ) -> Iterable[SentenceTranslationGroup]: - return self._model.translate( + ) -> Generator[SentenceTranslationGroup, None, None]: + yield from self._model.translate( sentences, src_iso, trg_iso, produce_multiple_translations, vrefs, self._checkpoint ) - def __exit__(self, exc_type, exc_value, traceback) -> None: + def __exit__( + self, exc_type, exc_value, traceback # pyright: ignore[reportUnknownParameterType, reportMissingParameterType] + ) -> None: self._model.clear_cache() @@ -50,10 +52,6 @@ class TranslationTask: commit: Optional[str] = None clearml_tag: Optional[str] = None - def __post_init__(self) -> None: - if self.checkpoint is None: - self.checkpoint = "last" - def translate_books( self, books: str, @@ -107,7 +105,7 @@ def translate_books( if not config.model_dir.exists(): experiment_ckpt_str = f"{self.name}:base" - translation_failed = [] + translation_failed: List[str] = [] for book_num, chapters in book_nums.items(): book = book_number_to_id(book_num) try: @@ -127,7 +125,7 @@ def translate_books( config.corpus_pairs, tags, ) - except Exception as e: + except Exception: translation_failed.append(book) LOGGER.exception(f"Was not able to translate {book}.") @@ -148,11 +146,6 @@ def translate_text_files( ) -> None: translator, config, _ = self._init_translation_task(experiment_suffix=f"_{self.checkpoint}_{src_prefix}") with translator: - if trg_prefix is None: - raise RuntimeError("A target file prefix must be specified.") - if start_seq is None or end_seq is None: - raise RuntimeError("Start and end sequence numbers must be specified.") - if src_iso is None: src_iso = config.default_test_src_iso if src_iso == "" and len(config.src_iso) > 0: @@ -261,7 +254,7 @@ def translate_files( trg_iso, produce_multiple_translations, save_confidences, - tags, + tags=tags, ) elif ext == ".docx": translator.translate_docx( @@ -301,7 +294,7 @@ def _init_translation_task(self, experiment_suffix: str) -> Tuple[Translator, Co model = clearml.config.create_model() translator = NMTTranslator(model, self.checkpoint) if clearml.config.model_dir.exists(): - checkpoint_path, step = model.get_checkpoint_path(self.checkpoint) + _, step = model.get_checkpoint_path(self.checkpoint) step_str = "avg" if step == -1 else str(step) else: step_str = "last" @@ -446,9 +439,10 @@ def main() -> None: get_git_revision_hash() + checkpoint: str = args.checkpoint or "last" translator = TranslationTask( name=args.experiment, - checkpoint=args.checkpoint, + checkpoint=checkpoint, clearml_queue=args.clearml_queue, commit=args.commit, clearml_tag=args.clearml_tag, @@ -476,6 +470,10 @@ def main() -> None: actions=[f"Will attempt to translate matching files from {args.src_iso} into {args.trg_iso}."], ) exit() + if args.trg_prefix is None: + raise RuntimeError("A target file prefix must be specified.") + if args.start_seq is None or args.end_seq is None: + raise RuntimeError("Start and end sequence numbers must be specified.") translator.translate_text_files( args.src_prefix, args.trg_prefix,