Skip to content

Commit

Permalink
refractor, chore: add split_list function, change tqdm colour for eac…
Browse files Browse the repository at this point in the history
…h task
  • Loading branch information
vTuanpham committed Jan 1, 2024
1 parent 610b57a commit 8f21cac
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions translator/data_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ def __init__(self, file_path: str,
def get_translator(self) -> Translator:
return deepcopy(self.translator)()

@staticmethod
def id_generator(size=6, chars=string.ascii_uppercase + string.digits) -> str:
return ''.join(random.choice(chars) for _ in range(size))

@staticmethod
def split_list(input_list: List[str], max_sub_length: int) -> List[list]:
return [input_list[x:x + max_sub_length] for x in range(0, len(input_list), max_sub_length)]

def validate(self, keys: List[str]) -> bool:
dict_fields = self.target_config.get_keys()
for key in dict_fields:
Expand Down Expand Up @@ -138,15 +146,12 @@ def post_translate_validate(self) -> None:
print(f"\nTotal data left after filtering fail translation: {len(post_validated_translate_data)}\n")
self.converted_data_translated = post_validated_translate_data

@staticmethod
def id_generator(size=6, chars=string.ascii_uppercase + string.digits) -> str:
return ''.join(random.choice(chars) for _ in range(size))

def __translate_per_key(self, example: Dict, translator: Translator = None, progress_idx: int = 0) -> Dict:
'''
This function loop through each key of one example and send to __translate_texts if the value of the key is
under a certain threshold. If exceeded, then send to __sublist_multithread_translate
'''

assert self.do_translate, "Please enable translate via self.do_translate"
keys = self.target_config.get_keys()
for key in keys:
Expand Down Expand Up @@ -191,10 +196,10 @@ def __sublist_multithread_translate(self,
This function split a large list into sub-list and translate it in parallel, orders are maintained when merge all
sub-lists, this is useful when order are necessary (e.g Dialogs example)
'''

translated_list_data = []
num_threads = len(list_str) / self.max_list_length_per_thread
sub_str_lists = [list_str[x:x + self.max_list_length_per_thread] for x in
range(0, len(list_str), self.max_list_length_per_thread)]
sub_str_lists = self.split_list(list_str, max_sub_length=self.max_list_length_per_thread)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
finished_task = 0
Expand Down Expand Up @@ -254,6 +259,7 @@ def flatten_list(nested_list):
'''
Turn a list from [[], [], []] -> []
'''

flattened_list = []
for item in nested_list:
if isinstance(item, list):
Expand All @@ -274,6 +280,7 @@ def __translate_texts(self,
'''
Actual place where translation take place
'''

assert self.do_translate, "Please enable translate via self.do_translate"
# This if is for multithread Translator instance
translator_instance = deepcopy(self.translator)() if not translator else translator
Expand All @@ -296,6 +303,7 @@ def extract_texts(obj):
'''
Extract .text attribute from Translator object
'''

if isinstance(obj, list):
return [extract_texts(item) for item in obj]
else:
Expand Down Expand Up @@ -334,8 +342,7 @@ def translate_converted(self,
# Split large data into large chunks, recursive feed to the same function
if len(converted_data) > self.large_chunks_threshold and large_chunk is None:
num_large_chunks = len(converted_data) / self.large_chunks_threshold
large_chunks = [converted_data[x:x + self.large_chunks_threshold] for x in
range(0, len(converted_data), self.large_chunks_threshold)]
large_chunks = self.split_list(converted_data, max_sub_length=self.large_chunks_threshold)
tqdm.write(
f"Data is way too large, spliting data into {num_large_chunks} large chunk for sequential translation")

Expand All @@ -347,8 +354,7 @@ def translate_converted(self,
# Split large chunk into large example, recursive feed to the same function via multithread
if len(converted_data) > self.max_example_per_thread and en_data is None:
num_threads = len(converted_data) / self.max_example_per_thread
chunks = [converted_data[x:x + self.max_example_per_thread] for x in
range(0, len(converted_data), self.max_example_per_thread)]
chunks = self.split_list(converted_data, max_sub_length=self.max_example_per_thread)
tqdm.write(f"Data too large, splitting data into {num_threads} chunk, each chunk is {len(chunks[0])}"
f" Processing with multithread...")

Expand Down Expand Up @@ -417,7 +423,7 @@ def callback_done(future):
return None

progress_bar_desc = "Translating converted data" if not desc else f"Translating converted data {desc}"
for example in tqdm(converted_data, desc=progress_bar_desc):
for example in tqdm(converted_data, desc=progress_bar_desc, colour="#add8e6"):
translated_data_example = self.__translate_per_key(example,
translator,
progress_idx=int(re.findall(r'\d+', desc)[0]) if desc and re.findall(r'\d+', desc) else 0)
Expand All @@ -432,15 +438,13 @@ def callback_done(future):

@abstractmethod
@force_super_call
@timeit
def convert(self) -> Union[List[Dict], None]:
assert self.data_read is not None, "Please implement the read function for DataParser" \
" and assign data to self.data_read"
pass

@abstractmethod
@force_super_call
@timeit
def read(self) -> Union[List, Dict, None]:
assert os.path.isfile(self.file_path), f"Invalid path file for {self.file_path}"
pass
Expand Down

0 comments on commit 8f21cac

Please sign in to comment.