Skip to content

Commit

Permalink
fix, chore: Correct base translate Provider, fix test cases, add doc …
Browse files Browse the repository at this point in the history
…string
  • Loading branch information
vTuanpham committed Jan 4, 2024
1 parent 204b8bb commit fd0cdfe
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 73 deletions.
File renamed without changes.
49 changes: 49 additions & 0 deletions providers/base_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Union, List, Any
from abc import ABC, abstractmethod
from types import SimpleNamespace


class Provider(ABC):
"""
Base Provider that must be inherited by all Provider class, implement your own provider by inheriting this class
"""
@abstractmethod
def __init__(self):
self.translator = None

@abstractmethod
def _do_translate(self, input_data: Union[str, List[str]], src: str, dest: str, **kwargs) -> Union[str, List[str], Any]:
raise NotImplemented(" The function _do_translate has not been implemented.")

def translate(self, input_data: Union[str, List[str]], src: str, dest: str) -> Union[SimpleNamespace, List[SimpleNamespace]]:
"""
Translate text input_data from a language to another language
:param input_data: The input_data (Can be string or list of strings)
:param src: The source lang of input_data
:param dest: The target lang you want input_data to be translated
:return: SimpleNamespace object or list of SimpleNamespace objects with 'text' attribute
"""

# Type check for input_data
if not isinstance(input_data, (str, list)):
raise TypeError(f"input_data must be of type str or List[str], not {type(input_data).__name__}")

if isinstance(input_data, list) and not all(isinstance(item, str) for item in input_data):
raise TypeError("All elements of input_data list must be of type str")

# Ensure the translator is set
assert self.translator, "Please assign the translator object instance to self.translator"

# Perform the translation
translated_instance = self._do_translate(input_data, src=src, dest=dest)

# Wrap non-list objects in SimpleNamespace if they don't have a 'text' attribute
if not isinstance(translated_instance, list):
if not hasattr(translated_instance, 'text'):
return SimpleNamespace(text=translated_instance)
else:
# Wrap each item in the list in SimpleNamespace if the item doesn't have a 'text' attribute
return [SimpleNamespace(text=item) if not hasattr(item, 'text') else item for item in translated_instance]

return translated_instance

36 changes: 36 additions & 0 deletions providers/google_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import sys
from typing import Union, List, Any
sys.path.insert(0, r'/')
from googletrans import Translator
from .base_provider import Provider


# https://github.com/ssut/py-googletrans
# This is the best reliable provider, as this has access to API call instead of using the crawling method
class GoogleProvider(Provider):
def __init__(self):
self.translator = Translator()

def _do_translate(self, input_data: Union[str, List[str]], src: str, dest: str, **kwargs) -> Union[str, List[str], Any]:
"""
translate(text, dest='en', src='auto', **kwargs)
Translate text from source language to destination language
Parameters:
text (UTF-8 str; unicode; string sequence (list, tuple, iterator, generator)) – The source text(s) to be translated. Batch translation is supported via sequence input.
dest – The language to translate the source text into. The value should be one of the language codes listed in googletrans.LANGUAGES or one of the language names listed in googletrans.LANGCODES.
dest – str; unicode
src – The language of the source text. The value should be one of the language codes listed in googletrans.LANGUAGES or one of the language names listed in googletrans.LANGCODES. If a language is not specified, the system will attempt to identify the source language automatically.
src – str; unicode
Return type:
Translated
Return type: list (when a list is passed) else str
"""

return self.translator.translate(input_data, src=src, dest=dest)


if __name__ == '__main__':
test = GoogleProvider()
print(test.translate("Hello", src="en", dest="vi").text)
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import sys
sys.path.insert(0, r'./')
sys.path.insert(0, r'/')
from typing import Union, List
import translators as ts
from .base_provider import Provider


# https://github.com/UlionTse/translators
# This library is not as reliable provider as googletrans, use this if you want to try out other translation services
class MultipleProviders(Provider):
def __init__(self, cache: bool=False):
def __init__(self, cache: bool = False):
self.translator = ts
self.config = {
"translator": "bing",
"timeout": 5.0,
"translator": "baidu",
"timeout": 10.0,
"if_ignore_empty_query": True
}
if cache:
_ = self.translator.preaccelerate_and_speedtest() # Optional. Caching sessions in advance, which can help improve access speed.
Expand Down Expand Up @@ -44,8 +47,13 @@ def _do_translate(self, input_data: Union[str, List[str]], src: str, dest: str)
:param myMemory_mode: str, default "web", choose from ("web", "api").
:return: str or dict
"""

return self.translator.translate_text(input_data, from_language=src, to_language=dest, **self.config)
# This provider does not support batch translation
translated_data = []
if isinstance(input_data, list):
for text in input_data:
translated_text = self.translator.translate_text(text, from_language=src, to_language=dest, **self.config)
translated_data.append(translated_text)
return translated_data


if __name__ == '__main__':
Expand Down
20 changes: 17 additions & 3 deletions tests/eli5_qaconfig_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,18 @@ def step4(self):

def step5(self):
try:
self.translated_dataset = load_dataset("json", data_files=self.output_path, keep_in_memory=False)
self.parsed_dataset = load_dataset("json", data_files=self.output_path, keep_in_memory=False)
self.translated_dataset = load_dataset("json", data_files=self.output_path_translated, keep_in_memory=False)
except Exception as e:
raise SyntaxError("Invalid syntax for save function, the data output must be in the form of"
f"line-delimited json,\n Error message: {e}")

def step6(self):
self.assertEqual(len(self.translated_dataset['train']), len(self.parser.converted_data),
"The parsed translated dataset does not match the length of the parsed dataset")
self.assertEqual(len(self.parsed_dataset['train']), len(self.parser.converted_data),
msg="The parsed dataset does not match the length of the parsed dataset")
self.assertAlmostEqualInt(len(self.translated_dataset['train']), len(self.parser.converted_data),
msg="The parsed translated dataset fail too much and does not meet the length criteria of the parsed dataset",
tolerance=50)

def step7(self):
if os.path.exists(self.output_path):
Expand All @@ -63,6 +67,16 @@ def test_steps(self):
except Exception as e:
self.fail(f"{step} failed ({type(e)}: {e})")

def assertAlmostEqualInt(self, int1, int2, tolerance=1, msg=None):
"""
Asserts that two integers are almost equal within a specified tolerance range.
"""
if abs(int1 - int2) > tolerance:
standard_msg = f"{int1} and {int2} are not almost equal within a tolerance of {tolerance}."
if msg:
standard_msg = f"{msg}: {standard_msg}"
raise self.failureException(standard_msg)


if __name__ == '__main__':
unittest.main()
Expand Down
20 changes: 17 additions & 3 deletions tests/eli5_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,18 @@ def step4(self):

def step5(self):
try:
self.translated_dataset = load_dataset("json", data_files=self.output_path, keep_in_memory=False)
self.parsed_dataset = load_dataset("json", data_files=self.output_path, keep_in_memory=False)
self.translated_dataset = load_dataset("json", data_files=self.output_path_translated, keep_in_memory=False)
except Exception as e:
raise SyntaxError("Invalid syntax for save function, the data output must be in the form of"
f"line-delimited json,\n Error message: {e}")

def step6(self):
self.assertEqual(len(self.translated_dataset['train']), len(self.parser.converted_data),
"The parsed translated dataset does not match the length of the parsed dataset")
self.assertEqual(len(self.parsed_dataset['train']), len(self.parser.converted_data),
msg="The parsed dataset does not match the length of the parsed dataset")
self.assertAlmostEqualInt(len(self.translated_dataset['train']), len(self.parser.converted_data),
msg="The parsed translated dataset fail too much and does not meet the length criteria of the parsed dataset",
tolerance=50)

def step7(self):
if os.path.exists(self.output_path):
Expand All @@ -64,6 +68,16 @@ def test_steps(self):
except Exception as e:
self.fail(f"{step} failed ({type(e)}: {e})")

def assertAlmostEqualInt(self, int1, int2, tolerance=1, msg=None):
"""
Asserts that two integers are almost equal within a specified tolerance range.
"""
if abs(int1 - int2) > tolerance:
standard_msg = f"{int1} and {int2} are not almost equal within a tolerance of {tolerance}."
if msg:
standard_msg = f"{msg}: {standard_msg}"
raise self.failureException(standard_msg)


if __name__ == '__main__':
unittest.main()
Expand Down
19 changes: 14 additions & 5 deletions translator/data_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,28 @@
import json
import os
import random
import string
import sys
sys.path.insert(0, r'./')
from copy import deepcopy

import string
import threading
import warnings
import traceback
try:
from google.colab import files
IN_COLAB = True
except ImportError:
IN_COLAB = False
from httpcore._exceptions import ConnectTimeout
from translators.server import TranslatorError
from typing import List, Dict, Union
from abc import abstractmethod
from tqdm.auto import tqdm

from concurrent.futures import ThreadPoolExecutor

# from googletrans import Translator
from .providers import Provider, MultipleProviders, GoogleProvider
from providers import Provider, GoogleProvider, MultipleProviders

from configs import BaseConfig, QAConfig, DialogsConfig
from .utils import force_super_call, ForceBaseCallMeta, timeit, have_internet
Expand All @@ -48,6 +49,7 @@ def __init__(self, file_path: str,
large_chunks_threshold: int = 20000, # Maximum number of examples that will be distributed evenly across threads, any examples exceed this threshold will be process in queue
max_list_length_per_thread: int = 3, # Maximum number of strings contain in a list in a single thread.
# if larger, split the list into sub-list and process in parallel
translator: Provider = GoogleProvider,
source_lang: str = "en",
target_lang: str = "vi",
fail_translation_code: str="P1OP1_F" # Fail code for unexpected fail translation and can be removed
Expand Down Expand Up @@ -85,7 +87,7 @@ def __init__(self, file_path: str,

self.converted_data_translated = None

self.translator = GoogleProvider
self.translator = translator

@property
def get_translator(self) -> Provider:
Expand Down Expand Up @@ -288,7 +290,14 @@ def __translate_texts(self,

try:
target_texts = translator_instance.translate(src_texts, src=self.source_lang, dest=self.target_lang)
except TypeError:
except (TypeError, TranslatorError):
# except Exception as exc:
# TODO: Move Error except to each individual Providers

# Log the full stack trace of the exception
# traceback_str = ''.join(traceback.format_exception(None, exc, exc.__traceback__))
# tqdm.write(f"An exception occurred:\n{traceback_str}")

# TypeError likely due to gender-specific translation, which has no fix yet. Please refer to
# ssut/py-googletrans#260 for more info
if sub_list_idx is None:
Expand Down
38 changes: 0 additions & 38 deletions translator/providers/base_provider.py

This file was deleted.

18 changes: 0 additions & 18 deletions translator/providers/google_provider.py

This file was deleted.

0 comments on commit fd0cdfe

Please sign in to comment.