Skip to content

Commit

Permalink
refractor, chore: Move try except to each Providers, add baidu API test
Browse files Browse the repository at this point in the history
  • Loading branch information
vTuanpham committed Jan 7, 2024
1 parent 8e712bf commit 1a866c8
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 35 deletions.
8 changes: 5 additions & 3 deletions examples/ELI5/ELI5_10docs_Parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

from configs import BaseConfig
from translator import DataParser

from providers import Provider, GoogleProvider, MultipleProviders

PARSER_NAME = "ELI5_val"


class ELI5Val(DataParser):
def __init__(self, file_path: str, output_path: str, target_lang: str="vi",
max_example_per_thread=400, large_chunks_threshold=20000):
max_example_per_thread=400, large_chunks_threshold=20000,
translator: Provider = GoogleProvider):
super().__init__(file_path, output_path,
parser_name=PARSER_NAME,
target_config=BaseConfig, # The data config to be validated to check if self implement "convert" function is correct or not,
Expand All @@ -22,7 +23,8 @@ def __init__(self, file_path: str, output_path: str, target_lang: str="vi",
do_translate=True,
target_lang=target_lang,
max_example_per_thread=max_example_per_thread,
large_chunks_threshold=large_chunks_threshold)
large_chunks_threshold=large_chunks_threshold,
translator=translator)

self.max_ctxs = 5

Expand Down
15 changes: 13 additions & 2 deletions providers/google_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ class GoogleProvider(Provider):
def __init__(self):
self.translator = Translator()

def _do_translate(self, input_data: Union[str, List[str]], src: str, dest: str, **kwargs) -> Union[str, List[str], Any]:
def _do_translate(self, input_data: Union[str, List[str]],
src: str, dest: str,
fail_translation_code:str = "P1OP1_F", # Pass in this code to replace the input_data if the exception is *unavoidable*, any example that contain this will be remove post translation
**kwargs) -> Union[str, List[str]]:
"""
translate(text, dest='en', src='auto', **kwargs)
Translate text from source language to destination language
Expand All @@ -28,7 +31,15 @@ def _do_translate(self, input_data: Union[str, List[str]], src: str, dest: str,
Return type: list (when a list is passed) else str
"""

return self.translator.translate(input_data, src=src, dest=dest)
data_type = "list" if isinstance(input_data, list) else "str"

try:
return self.translator.translate(input_data, src=src, dest=dest)
# TypeError likely due to gender-specific translation, which has no fix yet. Please refer to
# ssut/py-googletrans#260 for more info
except TypeError:
if data_type == "list": return self.translator.translate([fail_translation_code, fail_translation_code], src=src, dest=dest)
return self.translator.translate(fail_translation_code, src=src, dest=dest)


if __name__ == '__main__':
Expand Down
31 changes: 21 additions & 10 deletions providers/multiple_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
sys.path.insert(0, r'/')
from typing import Union, List
import translators as ts
from translators.server import TranslatorError
from .base_provider import Provider


Expand All @@ -18,7 +19,10 @@ def __init__(self, cache: bool = False):
if cache:
_ = self.translator.preaccelerate_and_speedtest() # Optional. Caching sessions in advance, which can help improve access speed.

def _do_translate(self, input_data: Union[str, List[str]], src: str, dest: str) -> Union[str, List[str]]:
def _do_translate(self, input_data: Union[str, List[str]],
src: str, dest: str,
fail_translation_code:str = "P1OP1_F", # Pass in this code to replace the input_data if the exception is unavoidable, any example that contain this will be remove post translation
**kwargs) -> Union[str, List[str]]:
"""
translate_text(query_text: str, translator: str = 'bing', from_language: str = 'auto', to_language: str = 'en', **kwargs) -> Union[str, dict]
:param query_text: str, must.
Expand Down Expand Up @@ -47,21 +51,28 @@ def _do_translate(self, input_data: Union[str, List[str]], src: str, dest: str)
:param myMemory_mode: str, default "web", choose from ("web", "api").
:return: str or dict
"""
# This provider does not support batch translation
if isinstance(input_data, list):
translated_data = []
for text in input_data:
translated_text = self.translator.translate_text(text, from_language=src, to_language=dest, **self.config)
translated_data.append(translated_text)
else:
translated_data = self.translator.translate_text(input_data, from_language=src, to_language=dest, **self.config)

data_type = "list" if isinstance(input_data, list) else "str"

try:
# This provider does not support batch translation
if data_type == "list":
translated_data = []
for text in input_data:
translated_text = self.translator.translate_text(text, from_language=src, to_language=dest, **self.config)
translated_data.append(translated_text)
else:
translated_data = self.translator.translate_text(input_data, from_language=src, to_language=dest, **self.config)
except TranslatorError:
if data_type == "list": return self.translator.translate_text([fail_translation_code, fail_translation_code], from_language=src, to_language=dest, **self.config)
return self.translator.translate_text(fail_translation_code, from_language=src, to_language=dest, **self.config)

return translated_data


if __name__ == '__main__':
test = MultipleProviders()
print(test.translate("Hello", src="en", dest="vi").text)
print(test.translate("Hello", src="en", dest="vie").text)

"""
Supported languages:
Expand Down
90 changes: 90 additions & 0 deletions tests/eli5_test_baidu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
import unittest
import warnings
import sys
sys.path.insert(0,r'./')

from datasets import load_dataset

from providers import MultipleProviders
from examples.ELI5.ELI5_10docs_Parser import ELI5Val


class TestELI5Val(unittest.TestCase):

def step1(self):
self.file_path = "examples/ELI5/ELI5_val_10_doc.json"
self.output_dir = "examples/ELI5"
self.parser = ELI5Val(self.file_path, self.output_dir, target_lang="vie",
max_example_per_thread=50, large_chunks_threshold=500,
translator=MultipleProviders)

def step2(self):
self.parser.read()
self.assertIsNotNone(self.parser.data_read) # Check that data_read is not None

def step3(self):
self.parser.convert()
self.assertIsNotNone(self.parser.converted_data) # Check that converted_data is not None

def step4(self):
self.parser.save

self.output_path = os.path.join(self.output_dir, "ELI5_val.json")
self.output_path_translated = os.path.join(self.output_dir, "ELI5_val_translated_vie.json")

self.assertTrue(os.path.exists(self.output_path), f"File '{self.output_path}' does not exist")
self.assertTrue(os.path.exists(self.output_path_translated), f"File '{self.output_path_translated}' does not exist")

def step5(self):
try:
self.parsed_dataset = load_dataset("json", data_files=self.output_path, keep_in_memory=False)
self.translated_dataset = load_dataset("json", data_files=self.output_path_translated, keep_in_memory=False)
except Exception as e:
raise SyntaxError("Invalid syntax for save function, the data output must be in the form of"
f"line-delimited json,\n Error message: {e}")

def step6(self):
self.assertEqual(len(self.parsed_dataset['train']), len(self.parser.converted_data),
msg="The parsed dataset does not match the length of the parsed dataset")
self.assertAlmostEqualInt(len(self.translated_dataset['train']), len(self.parser.converted_data),
msg="The parsed translated dataset fail too much and does not meet the length criteria of the parsed dataset",
tolerance=50)

def step7(self):
if os.path.exists(self.output_path):
os.remove(self.output_path)
if os.path.exists(self.output_path_translated):
os.remove(self.output_path_translated)

def _steps(self):
for name in dir(self): # dir() result is implicitly sorted
if name.startswith("step"):
yield name, getattr(self, name)

def test_steps(self):
warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning)
for name, step in self._steps():
try:
step()
except Exception as e:
self.fail(f"{step} failed ({type(e)}: {e})")

def assertAlmostEqualInt(self, int1, int2, tolerance=1, msg=None):
"""
Asserts that two integers are almost equal within a specified tolerance range.
"""
if abs(int1 - int2) > tolerance:
standard_msg = f"{int1} and {int2} are not almost equal within a tolerance of {tolerance}."
if msg:
standard_msg = f"{msg}: {standard_msg}"
raise self.failureException(standard_msg)


if __name__ == '__main__':
unittest.main()





24 changes: 4 additions & 20 deletions translator/data_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,26 +288,10 @@ def __translate_texts(self,
# This if is for multithread Translator instance
translator_instance = deepcopy(self.translator)() if not translator else translator

try:
target_texts = translator_instance.translate(src_texts, src=self.source_lang, dest=self.target_lang)
except (TypeError, TranslatorError):
# except Exception as exc:
# TODO: Move Error except to each individual Providers

# Log the full stack trace of the exception
# traceback_str = ''.join(traceback.format_exception(None, exc, exc.__traceback__))
# tqdm.write(f"An exception occurred:\n{traceback_str}")

# TypeError likely due to gender-specific translation, which has no fix yet. Please refer to
# ssut/py-googletrans#260 for more info
if sub_list_idx is None:
target_texts = translator_instance.translate(self.fail_translation_code,
src=self.source_lang,
dest=self.target_lang)
else:
target_texts = translator_instance.translate([self.fail_translation_code, self.fail_translation_code],
src=self.source_lang,
dest=self.target_lang)
target_texts = translator_instance.translate(src_texts,
src=self.source_lang,
dest=self.target_lang,
fail_translation_code=self.fail_translation_code)

def extract_texts(obj):
'''
Expand Down

0 comments on commit 1a866c8

Please sign in to comment.