Skip to content

Commit

Permalink
refrator: clean up Translator call, rm unnecessary args
Browse files Browse the repository at this point in the history
  • Loading branch information
vTuanpham committed Dec 21, 2023
1 parent ab5bc91 commit d3d61ac
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 51 deletions.
2 changes: 1 addition & 1 deletion configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@dataclass
class Config(ABC):
"""
Abstract config that inherited all method
Abstract config that must inherit all configs class
"""

qas_id: str # Required field in all subclass
Expand Down
91 changes: 41 additions & 50 deletions translator/data_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import os
import random
import sys
from copy import deepcopy

sys.path.insert(0, r'./')
import string
import threading
Expand Down Expand Up @@ -38,11 +40,11 @@ def __init__(self, file_path: str,
target_fields: List[str],
target_config: Union[BaseConfig, QAConfig, DialogsConfig],
do_translate: bool = False,
enable_sub_task_thread: bool = True, # Enable splitting the list into sublist if a list of one example is too large to process
enable_sub_task_thread: bool = True, # Enable splitting a large list into sublist if a list of one example is too large to process
# This argument go with max_list_length_per_thread
no_translated_code: bool = False,
max_example_per_thread: int = 400, # How many examples, each thread can contain
large_chunks_threshold: int = 20000, # Maximum number of examples that will be evenly across threads
large_chunks_threshold: int = 20000, # Maximum number of examples that will be distributed evenly across threads, any examples exceed this threshold will be process in queue
max_list_length_per_thread: int = 3, # Maximum number of strings contain in a list in a single thread.
# if larger, split the list into sub-list and process in parallel
source_lang: str = "en",
Expand Down Expand Up @@ -82,10 +84,14 @@ def __init__(self, file_path: str,

self.converted_data_translated = None

self.translator = Translator()
self.translator = Translator

@property
def get_translator(self) -> Translator:
return deepcopy(self.translator)()

@staticmethod
def validate(keys: List[str], dataclass: Union[BaseConfig, QAConfig] = BaseConfig) -> bool:
def validate(keys: List[str], dataclass: Union[BaseConfig, QAConfig, DialogsConfig] = BaseConfig) -> bool:
dict_fields = dataclass.get_keys()
for key in dict_fields:
assert key in keys, f"\n Invalid parser, the key '{key}' is missing from {dict_fields}\n" \
Expand All @@ -104,7 +110,7 @@ def pre_translate_validate(self) -> None:
contain_code, score, found_elements = have_code(example[key])
if contain_code:
example_filters += 1
if len(self.converted_data) - 1 == idx:
if len(self.converted_data) - 2 == idx:
tqdm.write(f"Number of example with code: {example_filters}")
break
elif key == self.target_fields[-1]:
Expand All @@ -124,7 +130,7 @@ def post_translate_validate(self) -> None:
example_filters = 0
if have_re_code(example[key], code=self.fail_translation_code):
example_filters += 1
if len(self.converted_data_translated) - 1 == idx:
if len(self.converted_data_translated) - 2 == idx:
tqdm.write(f"Number of example with fail code: {example_filters}")
break
elif key == self.target_fields[-1]:
Expand Down Expand Up @@ -162,21 +168,19 @@ def translate_en2vi_advance_qa(self, example: Dict, translator: Translator = Non
if average_length > 1600: average_length_sub_task_criteria = True
if type == "list" and average_length_sub_task_criteria and len(example[key]) >= self.max_list_length_per_thread:
# tqdm.write(f"\nSplitting {key} field which contain {len(example[key])} items on chunk {progress_idx}\n")
del translator
example[key] = self.multithread_list_str_translate(example[key],
type,
translator,
progress_idx,
key)
else:
example[key] = self.translate_en2vi(src_texts=example[key], data_type=type, translator=translator)
example[key] = self.translate_en2vi(src_texts=example[key], translator=translator)
else:
example[key] = self.translate_en2vi(src_texts=example[key], data_type=type, translator=translator)
example[key] = self.translate_en2vi(src_texts=example[key], translator=translator)

return example

def multithread_list_str_translate(self, list_str: List[str],
data_type: str = "list",
translator: Translator = None,
def multithread_list_str_translate(self,
list_str: List[str],
progress_idx: int = 0,
field_name: str=None) -> List[str]:
translated_list_data = []
Expand Down Expand Up @@ -207,8 +211,7 @@ def callback_list_done(future):
# Assign each thread with a new Translator instance
future_chunk = executor.submit(self.translate_en2vi,
src_texts=list_chunk,
data_type=data_type,
translator=Translator(),
translator=self.get_translator,
sub_list_idx=idx)
future_chunk.add_done_callback(callback_list_done)
future_dict = {
Expand All @@ -226,14 +229,11 @@ def callback_list_done(future):
f"Thread {future_dict['idx']} failed, restarting thread with chunk {future_dict['idx']}")
backup_future_chunk = executor.submit(self.translate_en2vi,
src_texts=sub_str_lists[future_dict['idx']],
data_type=data_type,
translator=Translator(),
translator=self.get_translator,
sub_list_idx=future_dict['idx'])
backup_future_chunk.add_done_callback(callback_list_done)
backup_future_dict = {
"future": backup_future_chunk,
"idx": future_dict['idx']
}
backup_future_dict = {"future": backup_future_chunk,
"idx": future_dict['idx']}
futures[future_dict['idx']] = backup_future_dict
continue
elif future_dict['future'].result():
Expand All @@ -257,13 +257,13 @@ def flatten_list(nested_list):

return translated_list_data

def translate_en2vi(self, src_texts: Union[List[str], str],
data_type: str,
def translate_en2vi(self,
src_texts: Union[List[str], str],
translator: Translator = None,
sub_list_idx: int=None) -> Union[List[str], str, Dict[List[str], int]]:
assert self.do_translate, "Please enable translate via self.do_translate"
# This if is for multithread Translator instance
translator_instance = self.translator if not translator else translator
translator_instance = deepcopy(self.translator)() if not translator else translator

try:
target_texts = translator_instance.translate(src_texts, src=self.source_lang, dest=self.target_lang)
Expand Down Expand Up @@ -363,11 +363,10 @@ def callback_done(future):
future_chunk = executor.submit(self.translate_converted,
en_data=chunk,
desc=f"chunk {idx}",
translator=Translator())
translator=self.get_translator)
future_chunk.add_done_callback(callback_done)
future_dict = {
"future": future_chunk,
"idx": idx}
future_dict = {"future": future_chunk,
"idx": idx}
futures.append(future_dict)

# Wait for all threads to complete
Expand All @@ -380,7 +379,7 @@ def callback_done(future):
backup_future_chunk = executor.submit(self.translate_converted,
en_data=chunks[future_dict['idx']],
desc=f"Backup chunk {future_dict['idx']}",
translator=Translator())
translator=self.get_translator)
backup_future_chunk.add_done_callback(callback_done)
backup_future_dict = {"future": backup_future_chunk,
"idx": future_dict['idx']}
Expand All @@ -399,27 +398,19 @@ def callback_done(future):
self.converted_data_translated = translated_data
return None

try:
progress_bar_desc = "Translating converted data" if not desc else f"Translating converted data {desc}"
for example in tqdm(converted_data, desc=progress_bar_desc, colour="blue"):
translated_data_example = self.translate_en2vi_advance_qa(example,
translator,
progress_idx=int(re.findall(r'\d+', desc)[0]) if desc and re.findall(r'\d+', desc) else 0)
translated_data.append(translated_data_example)
if en_data: return translated_data
if large_chunk:
# Assuming that the previous large chunk process already create self.converted_data_translated
# This cover the case where last large chunk only contain a single thread
self.converted_data_translated += translated_data
else:
self.converted_data_translated = translated_data
except ConnectTimeout as e:
if not desc:
raise ConnectTimeout(f" Connection timeout, please provide better connection")
else:
tqdm.write(f"Connection timeout from thread {desc}, please provide better connection")
raise ConnectTimeout(
f" Connection timeout raise from thread {desc}, please provide better connection")
progress_bar_desc = "Translating converted data" if not desc else f"Translating converted data {desc}"
for example in tqdm(converted_data, desc=progress_bar_desc, colour="blue"):
translated_data_example = self.translate_en2vi_advance_qa(example,
translator,
progress_idx=int(re.findall(r'\d+', desc)[0]) if desc and re.findall(r'\d+', desc) else 0)
translated_data.append(translated_data_example)
if en_data: return translated_data
if large_chunk:
# Assuming that the previous large chunk process already create self.converted_data_translated
# This cover the case where last large chunk only contain a single thread
self.converted_data_translated += translated_data
else:
self.converted_data_translated = translated_data

@abstractmethod
@force_super_call
Expand Down

0 comments on commit d3d61ac

Please sign in to comment.