Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions silnlp/common/postprocess_draft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from ..nmt.config_utils import load_config
from ..nmt.postprocess import postprocess_experiment
from .postprocesser import PostprocessConfig, PostprocessHandler
from .utils import get_mt_exp_dir

LOGGER = logging.getLogger(__package__ + ".postprocess_draft")
LOGGER = logging.getLogger((__package__ or "") + ".postprocess_draft")


def main() -> None:
Expand Down
2 changes: 1 addition & 1 deletion silnlp/common/postprocesser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ..alignment.utils import compute_alignment_scores
from .corpus import load_corpus, write_corpus

LOGGER = logging.getLogger(__package__ + ".translate")
LOGGER = logging.getLogger((__package__ or "") + ".translate")

POSTPROCESS_DEFAULTS = {
"paragraph_behavior": "end", # Possible values: end, place, strip
Expand Down
13 changes: 6 additions & 7 deletions silnlp/common/translate_google.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import argparse
import logging
from typing import Iterable, Optional
from typing import Generator, Iterable, Optional

from google.cloud import translate_v2 as translate
from machine.scripture import VerseRef, book_id_to_number

from ..common.environment import SIL_NLP_ENV
from .paratext import book_file_name_digits
from .translator import TranslationGroup, Translator
from .translator import SentenceTranslation, SentenceTranslationGroup, Translator
from .utils import get_git_revision_hash, get_mt_exp_dir

LOGGER = logging.getLogger(__package__ + ".translate")
LOGGER = logging.getLogger((__package__ or "") + ".translate")


class GoogleTranslator(Translator):
Expand All @@ -24,7 +23,7 @@ def translate(
trg_iso: str,
produce_multiple_translations: bool = False,
vrefs: Optional[Iterable[VerseRef]] = None,
) -> Iterable[TranslationGroup]:
) -> Generator[SentenceTranslationGroup, None, None]:
if produce_multiple_translations:
LOGGER.warning("Google Translator does not support --multiple-translations")

Expand All @@ -35,8 +34,8 @@ def translate(
results = self._translate_client.translate(
sentence, source_language=src_iso, target_language=trg_iso, format_="text"
)
translation = results["translatedText"]
yield [translation]
translation: str = results["translatedText"]
yield [SentenceTranslation(translation, [], [], None)]


def main() -> None:
Expand Down
116 changes: 70 additions & 46 deletions silnlp/common/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@
import re
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import AbstractContextManager
from datetime import date
from itertools import groupby
from math import exp
from pathlib import Path
from typing import DefaultDict, Iterable, List, Optional
from typing import DefaultDict, Generator, Iterable, List, Optional, Tuple, cast

import docx
import nltk
from iso639 import Lang
from machine.corpora import (
FileParatextProjectSettingsParser,
FileParatextProjectTextUpdater,
ParatextProjectSettings,
ScriptureRef,
UpdateUsfmParserHandler,
UpdateUsfmTextBehavior,
UsfmFileText,
Expand All @@ -32,7 +35,7 @@
from .postprocesser import NoDetectedQuoteConventionException, PostprocessHandler, UnknownQuoteConventionException
from .usfm_utils import PARAGRAPH_TYPE_EMBEDS

LOGGER = logging.getLogger(__package__ + ".translate")
LOGGER = logging.getLogger((__package__ or "") + ".translate")
nltk.download("punkt")

CONFIDENCE_SCORES_SUFFIX = ".confidences.tsv"
Expand Down Expand Up @@ -68,7 +71,7 @@ def join_tokens_for_confidence_file(self) -> str:
return "\t".join(self._tokens)

def join_token_scores_for_confidence_file(self) -> str:
return "\t".join([str(exp(ts)) for ts in [self._sequence_score] + self._token_scores])
return "\t".join([str(exp(ts)) for ts in [self._sequence_score] + self._token_scores if ts is not None])


# A group of multiple translations of a single sentence
Expand All @@ -87,7 +90,7 @@ def write_confidence_scores_to_file(
self,
confidences_path: Path,
row1col1_label: str,
vrefs: Optional[List[VerseRef]] = None,
scripture_refs: Optional[List[ScriptureRef]] = None,
) -> None:
with confidences_path.open("w", encoding="utf-8", newline="\n") as confidences_file:
confidences_file.write("\t".join([f"{row1col1_label}"] + [f"Token {i}" for i in range(200)]) + "\n")
Expand All @@ -96,19 +99,24 @@ def write_confidence_scores_to_file(
if not sentence_translation.has_sequence_confidence_score():
continue
sequence_label = str(sentence_num)
if vrefs is not None:
sequence_label = str(vrefs[sentence_num])
if scripture_refs is not None:
sequence_label = str(scripture_refs[sentence_num])
confidences_file.write(
sequence_label + "\t" + sentence_translation.join_tokens_for_confidence_file() + "\n"
)
confidences_file.write(sentence_translation.join_token_scores_for_confidence_file() + "\n")

def write_chapter_confidence_scores_to_file(self, chapter_confidences_path: Path, vrefs: List[VerseRef]):
def write_chapter_confidence_scores_to_file(
self, chapter_confidences_path: Path, scripture_refs: List[ScriptureRef]
):
chapter_confidences: DefaultDict[int, List[float]] = defaultdict(list)
for sentence_num, vref in enumerate(vrefs):
if not vref.is_verse or self._sentence_translations[sentence_num].get_sequence_confidence_score() is None:
for sentence_num, vref in enumerate(scripture_refs):
sequence_confidence_score: Optional[float] = self._sentence_translations[
sentence_num
].get_sequence_confidence_score()
if not vref.is_verse or sequence_confidence_score is None:
continue
vref_confidence = exp(self._sentence_translations[sentence_num].get_sequence_confidence_score())
vref_confidence = exp(sequence_confidence_score)
chapter_confidences[vref.chapter_num].append(vref_confidence)

with chapter_confidences_path.open("w", encoding="utf-8", newline="\n") as chapter_confidences_file:
Expand All @@ -119,9 +127,9 @@ def write_chapter_confidence_scores_to_file(self, chapter_confidences_path: Path

def get_all_sequence_confidence_scores(self) -> List[float]:
return [
exp(st.get_sequence_confidence_score())
for st in self._sentence_translations
if st.get_sequence_confidence_score() is not None
exp(scs)
for scs in [t.get_sequence_confidence_score() for t in self._sentence_translations]
if scs is not None
]

def get_all_translations(self) -> List[str]:
Expand Down Expand Up @@ -153,7 +161,7 @@ def generate_confidence_files(
trg_file_path: Path,
trg_prefix: str = "",
produce_multiple_translations: bool = False,
vrefs: Optional[List[VerseRef]] = None,
scripture_refs: Optional[List[ScriptureRef]] = None,
draft_index: int = 0,
) -> None:
if not translated_draft.has_sequence_confidence_scores():
Expand All @@ -169,8 +177,8 @@ def generate_confidence_files(

ext = trg_file_path.suffix.lower()
if ext in {".usfm", ".sfm"}:
assert vrefs is not None
generate_usfm_confidence_files(translated_draft, trg_file_path, confidences_path, vrefs, draft_index)
assert scripture_refs is not None
generate_usfm_confidence_files(translated_draft, trg_file_path, confidences_path, scripture_refs, draft_index)
elif ext == ".txt":
generate_txt_confidence_files(translated_draft, trg_file_path, confidences_path, trg_prefix)
else:
Expand All @@ -184,24 +192,26 @@ def generate_usfm_confidence_files(
translated_draft: TranslatedDraft,
trg_file_path: Path,
confidences_path: Path,
vrefs: List[VerseRef],
scripture_refs: List[ScriptureRef],
draft_index: int = 0,
) -> None:

translated_draft.write_confidence_scores_to_file(confidences_path, "VRef", vrefs)
translated_draft.write_chapter_confidence_scores_to_file(confidences_path.with_suffix(".chapters.tsv"), vrefs)
_append_book_confidence_score(translated_draft, trg_file_path, vrefs)
translated_draft.write_confidence_scores_to_file(confidences_path, "VRef", scripture_refs)
translated_draft.write_chapter_confidence_scores_to_file(
confidences_path.with_suffix(".chapters.tsv"), scripture_refs
)
_append_book_confidence_score(translated_draft, trg_file_path, scripture_refs)


def _append_book_confidence_score(
translated_draft: TranslatedDraft,
trg_file_path: Path,
vrefs: List[VerseRef],
scripture_refs: List[ScriptureRef],
) -> None:
file_confidences_path = trg_file_path.parent / "confidences.books.tsv"
row1_col1_header = "Book"
if vrefs:
col1_entry = vrefs[0].book
if scripture_refs:
col1_entry = scripture_refs[0].book
else:
col1_entry = trg_file_path.stem

Expand Down Expand Up @@ -250,7 +260,7 @@ def generate_test_confidence_files(
translated_draft.write_confidence_scores_to_file(confidences_path, "Sequence Number")


class Translator(ABC):
class Translator(AbstractContextManager["Translator"], ABC):
@abstractmethod
def translate(
self,
Expand All @@ -259,7 +269,7 @@ def translate(
trg_iso: str,
produce_multiple_translations: bool = False,
vrefs: Optional[Iterable[VerseRef]] = None,
) -> Iterable[SentenceTranslationGroup]:
) -> Generator[SentenceTranslationGroup, None, None]:
pass

def translate_text(
Expand Down Expand Up @@ -348,13 +358,19 @@ def translate_usfm(
) -> None:
# Create UsfmFileText object for source
src_from_project = False
src_settings: Optional[ParatextProjectSettings] = None
stylesheet = UsfmStylesheet("usfm.sty")
if str(src_file_path).startswith(str(get_project_dir(""))):
src_from_project = True
src_settings = FileParatextProjectSettingsParser(src_file_path.parent).parse()
stylesheet = src_settings.stylesheet
book_id = src_settings.get_book_id(src_file_path.name)
assert book_id is not None

src_file_text = UsfmFileText(
src_settings.stylesheet,
src_settings.encoding,
src_settings.get_book_id(src_file_path.name),
book_id,
src_file_path,
src_settings.versification,
include_all_text=True,
Expand All @@ -367,28 +383,28 @@ def translate_usfm(
if not is_book_id_valid(book_id):
raise ValueError(f"Book ID not detected: {book_id}")

src_file_text = UsfmFileText("usfm.sty", "utf-8-sig", book_id, src_file_path, include_all_text=True)
stylesheet = src_settings.stylesheet if src_from_project else UsfmStylesheet("usfm.sty")
src_file_text = UsfmFileText(stylesheet, "utf-8-sig", book_id, src_file_path, include_all_text=True)

sentences = [re.sub(" +", " ", add_tags_to_sentence(tags, s.text.strip())) for s in src_file_text]
vrefs = [s.ref for s in src_file_text]
scripture_refs: List[ScriptureRef] = [s.ref for s in src_file_text]
vrefs: List[VerseRef] = [sr.verse_ref for sr in scripture_refs]
LOGGER.info(f"File {src_file_path} parsed correctly.")

# Filter sentences
for i in reversed(range(len(sentences))):
marker = vrefs[i].path[-1].name if len(vrefs[i].path) > 0 else ""
marker = scripture_refs[i].path[-1].name if len(scripture_refs[i].path) > 0 else ""
if (
(len(chapters) > 0 and vrefs[i].chapter_num not in chapters)
(len(chapters) > 0 and scripture_refs[i].chapter_num not in chapters)
or marker in PARAGRAPH_TYPE_EMBEDS
or stylesheet.get_tag(marker).text_type == UsfmTextType.NOTE_TEXT
):
sentences.pop(i)
vrefs.pop(i)
empty_sents = []
scripture_refs.pop(i)
empty_sents: List[Tuple[int, ScriptureRef]] = []
for i in reversed(range(len(sentences))):
if len(sentences[i].strip()) == 0:
sentences.pop(i)
empty_sents.append((i, vrefs.pop(i)))
empty_sents.append((i, scripture_refs.pop(i)))

sentence_translation_groups: List[SentenceTranslationGroup] = list(
self.translate(sentences, src_iso, trg_iso, produce_multiple_translations, vrefs)
Expand All @@ -399,7 +415,7 @@ def translate_usfm(
# Prevents pre-existing text from showing up in the sections of translated text
for idx, vref in reversed(empty_sents):
sentences.insert(idx, "")
vrefs.insert(idx, vref)
scripture_refs.insert(idx, vref)
sentence_translation_groups.insert(idx, [SentenceTranslation("", [], [], None)] * num_drafts)

text_behavior = (
Expand All @@ -408,13 +424,13 @@ def translate_usfm(

draft_set: DraftGroup = DraftGroup(sentence_translation_groups)
for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1):
postprocess_handler.construct_rows(vrefs, sentences, translated_draft.get_all_translations())
postprocess_handler.construct_rows(scripture_refs, sentences, translated_draft.get_all_translations())

for config in postprocess_handler.configs:

# Compile draft remarks
draft_src_str = f"project {src_file_text.project}" if src_from_project else f"file {src_file_path.name}"
draft_remark = f"This draft of {vrefs[0].book} was machine translated on {date.today()} from {draft_src_str} using model {experiment_ckpt_str}. It should be reviewed and edited carefully."
draft_remark = f"This draft of {scripture_refs[0].book} was machine translated on {date.today()} from {draft_src_str} using model {experiment_ckpt_str}. It should be reviewed and edited carefully."
postprocess_remark = config.get_postprocess_remark()
remarks = [draft_remark] + ([postprocess_remark] if postprocess_remark else [])

Expand Down Expand Up @@ -449,7 +465,7 @@ def translate_usfm(
usfm = f.read()
handler = UpdateUsfmParserHandler(
rows=config.rows,
id_text=vrefs[0].book,
id_text=scripture_refs[0].book,
text_behavior=text_behavior,
paragraph_behavior=config.get_paragraph_behavior(),
embed_behavior=config.get_embed_behavior(),
Expand Down Expand Up @@ -483,9 +499,9 @@ def translate_usfm(
"w",
encoding=(
"utf-8"
if not src_from_project
or src_from_project
and (src_settings.encoding == "utf-8-sig" or src_settings.encoding == "utf_8_sig")
if src_settings is None
or src_settings.encoding == "utf-8-sig"
or src_settings.encoding == "utf_8_sig"
else src_settings.encoding
),
) as f:
Expand All @@ -496,7 +512,7 @@ def translate_usfm(
translated_draft,
trg_file_path,
produce_multiple_translations=produce_multiple_translations,
vrefs=vrefs,
scripture_refs=scripture_refs,
draft_index=draft_index,
)

Expand All @@ -513,7 +529,7 @@ def translate_docx(
try:
src_lang = Lang(src_iso)
tokenizer = nltk.data.load(f"tokenizers/punkt/{src_lang.name.lower()}.pickle")
except:
except Exception:
tokenizer = nltk.data.load("tokenizers/punkt/english.pickle")

with src_file_path.open("rb") as file:
Expand All @@ -522,8 +538,8 @@ def translate_docx(
sentences: List[str] = []
paras: List[int] = []

for i in range(len(doc.paragraphs)):
for sentence in tokenizer.tokenize(doc.paragraphs[i].text, "test"):
for i, paragraph in enumerate(doc.paragraphs):
for sentence in tokenizer.tokenize(paragraph.text):
sentences.append(add_tags_to_sentence(tags, sentence))
paras.append(i)

Expand All @@ -532,7 +548,7 @@ def translate_docx(
)

for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1):
for para, group in groupby(zip(translated_draft, paras), key=lambda t: t[1]):
for para, group in groupby(zip(translated_draft.get_all_translations(), paras), key=lambda t: t[1]):
text = " ".join(s[0] for s in group)
doc.paragraphs[para].text = text

Expand All @@ -543,3 +559,11 @@ def translate_docx(

with trg_draft_file_path.open("wb") as file:
doc.save(file)

def __enter__(self) -> "Translator":
return self

def __exit__(
self, exc_type, exc_val, exc_tb # pyright: ignore[reportMissingParameterType, reportUnknownParameterType]
) -> None:
pass
Loading