Skip to content

Commit

Permalink
feat: Add provider abstract class, support for self-implement provider
Browse files Browse the repository at this point in the history
  • Loading branch information
vTuanpham committed Jan 4, 2024
1 parent 8f21cac commit 204b8bb
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 6 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
googletrans==3.1.0a0
translators
datasets
tqdm
13 changes: 7 additions & 6 deletions translator/data_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

from concurrent.futures import ThreadPoolExecutor

from googletrans import Translator
# from googletrans import Translator
from .providers import Provider, MultipleProviders, GoogleProvider

from configs import BaseConfig, QAConfig, DialogsConfig
from .utils import force_super_call, ForceBaseCallMeta, timeit, have_internet
Expand Down Expand Up @@ -84,10 +85,10 @@ def __init__(self, file_path: str,

self.converted_data_translated = None

self.translator = Translator
self.translator = GoogleProvider

@property
def get_translator(self) -> Translator:
def get_translator(self) -> Provider:
return deepcopy(self.translator)()

@staticmethod
Expand Down Expand Up @@ -146,7 +147,7 @@ def post_translate_validate(self) -> None:
print(f"\nTotal data left after filtering fail translation: {len(post_validated_translate_data)}\n")
self.converted_data_translated = post_validated_translate_data

def __translate_per_key(self, example: Dict, translator: Translator = None, progress_idx: int = 0) -> Dict:
def __translate_per_key(self, example: Dict, translator: Provider = None, progress_idx: int = 0) -> Dict:
'''
This function loop through each key of one example and send to __translate_texts if the value of the key is
under a certain threshold. If exceeded, then send to __sublist_multithread_translate
Expand Down Expand Up @@ -274,7 +275,7 @@ def flatten_list(nested_list):

def __translate_texts(self,
src_texts: Union[List[str], str],
translator: Translator = None,
translator: Provider = None,
sub_list_idx: int=None, # sub_list_idx is for pass through of index information and can be merge later by __sublist_multithread_translate
) -> Union[List[str], str, Dict[List[str], int]]:
'''
Expand Down Expand Up @@ -319,7 +320,7 @@ def extract_texts(obj):
def translate_converted(self,
en_data: List[str] = None,
desc: str = None,
translator: Translator = None,
translator: Provider = None,
large_chunk: List[str] = None) -> Union[None, List[str]]:
'''
This function support translation in multithread for large dataset
Expand Down
3 changes: 3 additions & 0 deletions translator/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base_provider import Provider
from .google_provider import GoogleProvider
from .multiple_providers import MultipleProviders
38 changes: 38 additions & 0 deletions translator/providers/base_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Union, List
from abc import ABC, abstractmethod
from types import SimpleNamespace


class Provider(ABC):
@abstractmethod
def __init__(self):
self.translator = None

@abstractmethod
def _do_translate(self, input_data: Union[str, List[str]], src: str, dest: str, **kwargs) -> Union[str, List[str]]:
raise NotImplemented(" The function _do_translate has not been implemented.")

def translate(self, input_data: Union[str, List[str]], src: str, dest: str) -> SimpleNamespace:
"""
Translate text input_data from a language to another language
:param input_data: The input_data(Can be string or list of string)
:param src: The source lang of input_data
:param dest: The target lang you want input_data to be translated
:return:
"""

assert self.translator, "Please assign the translator object instance to self.translator"
translated_instance = self._do_translate(input_data, src=src, dest=dest)
if not hasattr(translated_instance, 'text'):
if isinstance(translated_instance, list) or isinstance(translated_instance, str):
return SimpleNamespace(text=translated_instance)
else:
raise ValueError(f"The return object of _do_translate expected to be 'list' or 'string',"
f" found {type(translated_instance)}")
else:
if isinstance(translated_instance.text, list) or isinstance(translated_instance.text, str):
return translated_instance
else:
raise ValueError(f"The return object of _do_translate with required 'text' attribute expected to be 'list' or 'string' "
f"but found {type(translated_instance.text)}")

18 changes: 18 additions & 0 deletions translator/providers/google_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import sys
from typing import Union, List
sys.path.insert(0, r'./')
from googletrans import Translator
from .base_provider import Provider


class GoogleProvider(Provider):
def __init__(self):
self.translator = Translator()

def _do_translate(self, input_data: Union[str, List[str]], src: str, dest: str, **kwargs) -> Union[str, List[str]]:
return self.translator.translate(input_data, src=src, dest=dest)


if __name__ == '__main__':
test = GoogleProvider()
print(test.translate("Hello", src="en", dest="vi").text)
53 changes: 53 additions & 0 deletions translator/providers/multiple_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import sys
sys.path.insert(0, r'./')
from typing import Union, List
import translators as ts
from .base_provider import Provider


class MultipleProviders(Provider):
def __init__(self, cache: bool=False):
self.translator = ts
self.config = {
"translator": "bing",
"timeout": 5.0,
}
if cache:
_ = self.translator.preaccelerate_and_speedtest() # Optional. Caching sessions in advance, which can help improve access speed.

def _do_translate(self, input_data: Union[str, List[str]], src: str, dest: str) -> Union[str, List[str]]:
"""
translate_text(query_text: str, translator: str = 'bing', from_language: str = 'auto', to_language: str = 'en', **kwargs) -> Union[str, dict]
:param query_text: str, must.
:param translator: str, default 'bing'.
:param from_language: str, default 'auto'.
:param to_language: str, default 'en'.
:param if_use_preacceleration: bool, default False.
:param **kwargs:
:param is_detail_result: bool, default False.
:param professional_field: str, default None. Support alibaba(), baidu(), caiyun(), cloudTranslation(), elia(), sysTran(), youdao(), volcEngine() only.
:param timeout: float, default None.
:param proxies: dict, default None.
:param sleep_seconds: float, default 0.
:param update_session_after_freq: int, default 1000.
:param update_session_after_seconds: float, default 1500.
:param if_use_cn_host: bool, default False. Support google(), bing() only.
:param reset_host_url: str, default None. Support google(), yandex() only.
:param if_check_reset_host_url: bool, default True. Support google(), yandex() only.
:param if_ignore_empty_query: bool, default False.
:param limit_of_length: int, default 20000.
:param if_ignore_limit_of_length: bool, default False.
:param if_show_time_stat: bool, default False.
:param show_time_stat_precision: int, default 2.
:param if_print_warning: bool, default True.
:param lingvanex_mode: str, default 'B2C', choose from ("B2C", "B2B").
:param myMemory_mode: str, default "web", choose from ("web", "api").
:return: str or dict
"""

return self.translator.translate_text(input_data, from_language=src, to_language=dest, **self.config)


if __name__ == '__main__':
test = MultipleProviders()
print(test.translate("Hello", src="en", dest="vi").text)

0 comments on commit 204b8bb

Please sign in to comment.