From 8f21cac05728ded1d1349255772b16f8b9bd3583 Mon Sep 17 00:00:00 2001 From: vTuanpham Date: Mon, 1 Jan 2024 22:35:02 +0700 Subject: [PATCH] refractor, chore: add split_list function, change tqdm colour for each task --- translator/data_parser.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/translator/data_parser.py b/translator/data_parser.py index f6b2328..1dd014a 100644 --- a/translator/data_parser.py +++ b/translator/data_parser.py @@ -90,6 +90,14 @@ def __init__(self, file_path: str, def get_translator(self) -> Translator: return deepcopy(self.translator)() + @staticmethod + def id_generator(size=6, chars=string.ascii_uppercase + string.digits) -> str: + return ''.join(random.choice(chars) for _ in range(size)) + + @staticmethod + def split_list(input_list: List[str], max_sub_length: int) -> List[list]: + return [input_list[x:x + max_sub_length] for x in range(0, len(input_list), max_sub_length)] + def validate(self, keys: List[str]) -> bool: dict_fields = self.target_config.get_keys() for key in dict_fields: @@ -138,15 +146,12 @@ def post_translate_validate(self) -> None: print(f"\nTotal data left after filtering fail translation: {len(post_validated_translate_data)}\n") self.converted_data_translated = post_validated_translate_data - @staticmethod - def id_generator(size=6, chars=string.ascii_uppercase + string.digits) -> str: - return ''.join(random.choice(chars) for _ in range(size)) - def __translate_per_key(self, example: Dict, translator: Translator = None, progress_idx: int = 0) -> Dict: ''' This function loop through each key of one example and send to __translate_texts if the value of the key is under a certain threshold. If exceeded, then send to __sublist_multithread_translate ''' + assert self.do_translate, "Please enable translate via self.do_translate" keys = self.target_config.get_keys() for key in keys: @@ -191,10 +196,10 @@ def __sublist_multithread_translate(self, This function split a large list into sub-list and translate it in parallel, orders are maintained when merge all sub-lists, this is useful when order are necessary (e.g Dialogs example) ''' + translated_list_data = [] num_threads = len(list_str) / self.max_list_length_per_thread - sub_str_lists = [list_str[x:x + self.max_list_length_per_thread] for x in - range(0, len(list_str), self.max_list_length_per_thread)] + sub_str_lists = self.split_list(list_str, max_sub_length=self.max_list_length_per_thread) with ThreadPoolExecutor(max_workers=num_threads) as executor: futures = [] finished_task = 0 @@ -254,6 +259,7 @@ def flatten_list(nested_list): ''' Turn a list from [[], [], []] -> [] ''' + flattened_list = [] for item in nested_list: if isinstance(item, list): @@ -274,6 +280,7 @@ def __translate_texts(self, ''' Actual place where translation take place ''' + assert self.do_translate, "Please enable translate via self.do_translate" # This if is for multithread Translator instance translator_instance = deepcopy(self.translator)() if not translator else translator @@ -296,6 +303,7 @@ def extract_texts(obj): ''' Extract .text attribute from Translator object ''' + if isinstance(obj, list): return [extract_texts(item) for item in obj] else: @@ -334,8 +342,7 @@ def translate_converted(self, # Split large data into large chunks, recursive feed to the same function if len(converted_data) > self.large_chunks_threshold and large_chunk is None: num_large_chunks = len(converted_data) / self.large_chunks_threshold - large_chunks = [converted_data[x:x + self.large_chunks_threshold] for x in - range(0, len(converted_data), self.large_chunks_threshold)] + large_chunks = self.split_list(converted_data, max_sub_length=self.large_chunks_threshold) tqdm.write( f"Data is way too large, spliting data into {num_large_chunks} large chunk for sequential translation") @@ -347,8 +354,7 @@ def translate_converted(self, # Split large chunk into large example, recursive feed to the same function via multithread if len(converted_data) > self.max_example_per_thread and en_data is None: num_threads = len(converted_data) / self.max_example_per_thread - chunks = [converted_data[x:x + self.max_example_per_thread] for x in - range(0, len(converted_data), self.max_example_per_thread)] + chunks = self.split_list(converted_data, max_sub_length=self.max_example_per_thread) tqdm.write(f"Data too large, splitting data into {num_threads} chunk, each chunk is {len(chunks[0])}" f" Processing with multithread...") @@ -417,7 +423,7 @@ def callback_done(future): return None progress_bar_desc = "Translating converted data" if not desc else f"Translating converted data {desc}" - for example in tqdm(converted_data, desc=progress_bar_desc): + for example in tqdm(converted_data, desc=progress_bar_desc, colour="#add8e6"): translated_data_example = self.__translate_per_key(example, translator, progress_idx=int(re.findall(r'\d+', desc)[0]) if desc and re.findall(r'\d+', desc) else 0) @@ -432,7 +438,6 @@ def callback_done(future): @abstractmethod @force_super_call - @timeit def convert(self) -> Union[List[Dict], None]: assert self.data_read is not None, "Please implement the read function for DataParser" \ " and assign data to self.data_read" @@ -440,7 +445,6 @@ def convert(self) -> Union[List[Dict], None]: @abstractmethod @force_super_call - @timeit def read(self) -> Union[List, Dict, None]: assert os.path.isfile(self.file_path), f"Invalid path file for {self.file_path}" pass