Skip to content

Commit

Permalink
refractor, chore: remove unnecessary api calls, move extract_texts to…
Browse files Browse the repository at this point in the history
… google_provider, add type check for base_provider
  • Loading branch information
vTuanpham committed Jan 16, 2024
1 parent c11aca9 commit cb3a9cb
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 40 deletions.
21 changes: 9 additions & 12 deletions providers/base_provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Union, List, Any
from typing import Union, List
from abc import ABC, abstractmethod
from types import SimpleNamespace


class Provider(ABC):
Expand All @@ -15,18 +14,19 @@ def __init__(self):
def _do_translate(self, input_data: Union[str, List[str]],
src: str, dest: str,
fail_translation_code:str = "P1OP1_F",
**kwargs) -> Union[str, List[str], Any]:
**kwargs) -> Union[str, List[str]]:
raise NotImplemented(" The function _do_translate has not been implemented.")

def translate(self, input_data: Union[str, List[str]],
src: str, dest: str,
fail_translation_code: str="P1OP1_F") -> Union[SimpleNamespace, List[SimpleNamespace]]:
fail_translation_code: str="P1OP1_F") -> Union[str, List[str]]:
"""
Translate text input_data from a language to another language
:param input_data: The input_data (Can be string or list of strings)
:param src: The source lang of input_data
:param dest: The target lang you want input_data to be translated
:return: SimpleNamespace object or list of SimpleNamespace objects with 'text' attribute
:param fail_translation_code: The code that can be use for unavoidable translation error and can be remove post translation
:return: str or list of str
"""

# Type check for input_data
Expand All @@ -44,13 +44,10 @@ def translate(self, input_data: Union[str, List[str]],
src=src, dest=dest,
fail_translation_code=fail_translation_code)

# Wrap non-list objects in SimpleNamespace if they don't have a 'text' attribute
if not isinstance(translated_instance, list):
if not hasattr(translated_instance, 'text'):
return SimpleNamespace(text=translated_instance)
else:
# Wrap each item in the list in SimpleNamespace if the item doesn't have a 'text' attribute
return [SimpleNamespace(text=item) if not hasattr(item, 'text') else item for item in translated_instance]
assert type(input_data) == type(translated_instance),\
f" The function self._do_translate() return mismatch datatype from the input_data," \
f" expected {type(input_data)} from self._do_translate() but got {type(translated_instance)}"

return translated_instance


26 changes: 20 additions & 6 deletions providers/google_provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import Union, List, Any
from typing import Union, List
sys.path.insert(0, r'/')
from googletrans import Translator
from .base_provider import Provider
Expand All @@ -11,6 +11,19 @@ class GoogleProvider(Provider):
def __init__(self):
self.translator = Translator()

def extract_texts(self, obj):
'''
Extract .text attribute from Translator object
'''

if isinstance(obj, list):
return [self.extract_texts(item) for item in obj]
else:
try:
return obj.text
except AttributeError:
return obj

def _do_translate(self, input_data: Union[str, List[str]],
src: str, dest: str,
fail_translation_code:str = "P1OP1_F", # Pass in this code to replace the input_data if the exception is *unavoidable*, any example that contain this will be remove post translation
Expand All @@ -28,20 +41,21 @@ def _do_translate(self, input_data: Union[str, List[str]],
Return type:
Translated
Return type: list (when a list is passed) else str
Return type: list (when a list is passed) else Translated object
"""

data_type = "list" if isinstance(input_data, list) else "str"

try:
return self.translator.translate(input_data, src=src, dest=dest)
return self.extract_texts(self.translator.translate(input_data, src=src, dest=dest))
# TypeError likely due to gender-specific translation, which has no fix yet. Please refer to
# ssut/py-googletrans#260 for more info
except TypeError:
if data_type == "list": return self.translator.translate([fail_translation_code, fail_translation_code], src=src, dest=dest)
return self.translator.translate(fail_translation_code, src=src, dest=dest)
if data_type == "list": return [fail_translation_code, fail_translation_code]
return fail_translation_code


if __name__ == '__main__':
test = GoogleProvider()
print(test.translate("Hello", src="en", dest="vi").text)
print(test.translate(["Hello", "How are you today ?"], src="en", dest="vi"))
print(test.translate("Hello", src="en", dest="vi"))
7 changes: 4 additions & 3 deletions providers/multiple_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,16 @@ def _do_translate(self, input_data: Union[str, List[str]],
else:
translated_data = self.translator.translate_text(input_data, from_language=src, to_language=dest, **self.config)
except TranslatorError:
if data_type == "list": return self.translator.translate_text([fail_translation_code, fail_translation_code], from_language=src, to_language=dest, **self.config)
return self.translator.translate_text(fail_translation_code, from_language=src, to_language=dest, **self.config)
if data_type == "list": return [fail_translation_code, fail_translation_code]
return fail_translation_code

return translated_data


if __name__ == '__main__':
test = MultipleProviders()
print(test.translate("Hello", src="en", dest="vie").text)
print(test.translate(["Hello", "How are you today ?"], src="en", dest="vie"))
print(test.translate("Hello", src="en", dest="vie"))

"""
Supported languages:
Expand Down
21 changes: 2 additions & 19 deletions translator/data_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@
except ImportError:
IN_COLAB = False
from httpcore._exceptions import ConnectTimeout
from translators.server import TranslatorError
from typing import List, Dict, Union
from abc import abstractmethod
from tqdm.auto import tqdm

from concurrent.futures import ThreadPoolExecutor

from providers import Provider, GoogleProvider, MultipleProviders

from configs import BaseConfig, QAConfig, DialogsConfig
from .utils import force_super_call, ForceBaseCallMeta, timeit, have_internet
from .filters import have_code, have_re_code
Expand All @@ -52,7 +50,7 @@ def __init__(self, file_path: str,
translator: Provider = GoogleProvider,
source_lang: str = "en",
target_lang: str = "vi",
fail_translation_code: str="P1OP1_F" # Fail code for unexpected fail translation and can be removed
fail_translation_code: str="P1OP1_F" # Fail code for *expected* fail translation and can be removed
# post-translation
) -> None:

Expand Down Expand Up @@ -293,21 +291,6 @@ def __translate_texts(self,
dest=self.target_lang,
fail_translation_code=self.fail_translation_code)

def extract_texts(obj):
'''
Extract .text attribute from Translator object
'''

if isinstance(obj, list):
return [extract_texts(item) for item in obj]
else:
try:
return obj.text
except AttributeError:
return obj

target_texts = extract_texts(target_texts)

return {'text_list': target_texts, 'key': sub_list_idx} if sub_list_idx is not None else target_texts

def translate_converted(self,
Expand Down Expand Up @@ -354,7 +337,7 @@ def translate_converted(self,

# Progress bar
desc = "Translating total converted large chunk data" if large_chunk else "Translating total converted data"
progress_bar = tqdm(total=math.ceil(num_threads), desc=desc)
progress_bar = tqdm(total=math.ceil(num_threads), desc=desc, position=math.ceil(num_threads)+1)

with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
Expand Down

0 comments on commit cb3a9cb

Please sign in to comment.