In [3]:
# Loading Dataset
import utils

folder_path = 'datasets/cot-zsopt'
dataset = utils.load_dataset(folder_path)

In [2]:
# Translate using unofficial API
from concurrent.futures import wait, FIRST_EXCEPTION
from termcolor import colored
from deep_translator.exceptions import RequestError
from errors import InvalidOutputError, MissingTranslationError, GeneralError, ReachedMaxRetriesError
import csv
import concurrent.futures
import config, utils
from multi_thread_handler import MultiThreadHandler, mth

dataset_name = 'test'
output_folder = f'outputs/{dataset_name}'
start_pointer_file_path = f'{output_folder}/start-pointer.txt'
next_file_index_file_path = f'{output_folder}/next-file-index.txt'
batch_size = 10

start_pointer = utils.read_integer_from_file(start_pointer_file_path)
next_file_index = utils.read_integer_from_file(next_file_index_file_path)
# start_pointer = 9461
# next_file_index = 9

file_name = config.get_file_name(output_folder, next_file_index, dataset_name, 'csv')
content_len = len(dataset)

error_occurred = False


def process_row(args):
    i, row = args
    input_text = row['inputs']
    target_text = row['targets']

    mth.safe_print(f"Processing Row: {i}")
    result = utils.choose_translation_method_and_translate(mth.rate_limited_translate, i, [input_text, target_text])
    if len(result) != 2:
        raise InvalidOutputError

    input_result = result[0]
    target_result = result[1]

    mth.safe_print(f"Queued Translation: {i}")
    return i, input_text, input_result, target_text, target_result


def translate_dataset(block_after: int = None):
    global error_occurred

    with open(file_name, 'w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerow(['Id', 'Original Input', 'Translated Input', 'Original Target', 'Translated Target'])

        utils.update_integer_in_file(next_file_index_file_path, next_file_index + 1)

        end_pointer = start_pointer + block_after if block_after is not None else content_len
        current_batch_start = start_pointer
        connection_retries = 0
        start_time = utils.get_current_time()

        while current_batch_start < end_pointer and connection_retries < 3:
            current_batch_end = min(current_batch_start + batch_size, end_pointer)

            error_occurred = False

            with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
                futures = {executor.submit(process_row, (i, row)): i for i, row in
                           dataset.iloc[current_batch_start:current_batch_end].iterrows()}
                results = {}

                done, not_done = wait(futures.keys(), return_when=FIRST_EXCEPTION)
                non_network_error_occurred = False

                for future in done:
                    try:
                        i, input_text, input_result, target_text, target_result = future.result()
                        results[i] = (input_text, input_result, target_text, target_result)
                    except RequestError as e:
                        mth.safe_print(colored(f"[Network Error - Automatic Retry]: {e}", 'red'))
                        connection_retries += 1
                        if connection_retries >= 3:
                            raise ReachedMaxRetriesError
                        error_occurred = True
                        break
                    except Exception as e:
                        mth.safe_print(colored(f"[Non-Network Error]: {e}", 'red'))
                        error_occurred = True
                        non_network_error_occurred = True
                        raise e

                if error_occurred:
                    if current_batch_start == start_pointer:
                        mth.safe_print(colored("No batches could be processed. Exiting...", 'red'))
                        utils.update_integer_in_file(next_file_index_file_path, next_file_index)

                    for future in not_done:
                        mth.safe_print(colored(f"Cancelling unsubmitted futures: {future}", 'yellow'))
                        future.cancel()

                    if non_network_error_occurred:
                        break
                    else:
                        continue

                # Write the results of this batch
                for i in range(current_batch_start, current_batch_end):
                    print(f"Writing row {i}")
                    if i in results:
                        writer.writerow([i] + list(results[i]))
                    else:
                        raise MissingTranslationError(i)

                utils.update_integer_in_file(start_pointer_file_path, current_batch_end)
                file.flush()

                current_batch_start = current_batch_end
                current_time = utils.get_current_time()
                speed = utils.get_speed(current_batch_end - start_pointer, start_time, current_time)
                estimated_time = utils.get_estimated_time(content_len - start_pointer,
                                                          i - start_pointer, start_time,
                                                          current_time)

                mth.safe_print(
                    colored(
                        f"Moving to next batch. Translated {current_batch_end} of {content_len}, Elapsed (Secs): {current_time - start_time}, Estimated (Hrs): {estimated_time}, Speed: {speed}",
                        "green"))


translate_dataset()

File not found: outputs-test/start-pointer.txt
File not found: outputs-test/next-file-index.txt
Processing Row: 0
Processing Row: 1
Processing Row: 2
Processing Row: 3
Translated by blob for index 1, Time: 1721741184
Queued Translation: 1
Translated by blob for index 3, Time: 1721741184
Processing Row: 4
Queued Translation: 3
Processing Row: 5
Translated by blob for index 0, Time: 1721741184
Queued Translation: 0
Processing Row: 6
Translated by blob for index 2, Time: 1721741184
Queued Translation: 2
Processing Row: 7
Translated by blob for index 4, Time: 1721741184
Queued Translation: 4
Processing Row: 8
Translated by blob for index 7, Time: 1721741185
Queued Translation: 7
Processing Row: 9
Translated by blob for index 6, Time: 1721741185
Queued Translation: 6
Translated by blob for index 8, Time: 1721741185
Queued Translation: 8
Translated by blob for index 5, Time: 1721741185
Queued Translation: 5
Translated by blob for index 9, Time: 1721741185
Queued Translation: 9
Writing row 0


KeyboardInterrupt: 