In [None]:
# Loading Dataset
from datasets import load_dataset

dataset = load_dataset("cais/mmlu", "all")

In [None]:
dataset

In [None]:
import pandas as pd

subset = dataset["dev"].to_pandas()

In [None]:
type(subset.iloc[0]["choices"])

In [None]:
choices_lengths = []

for choices in subset['choices']:
    choices_lengths.append(len(choices))

print(list(set(choices_lengths)))

In [None]:
subset

In [None]:
# Translate using unofficial API
from concurrent.futures import wait, FIRST_EXCEPTION
from termcolor import colored
from deep_translator.exceptions import RequestError
from errors import InvalidOutputError, MissingTranslationError, GeneralError, ReachedMaxRetriesError
import csv
import concurrent.futures
import utils
from IPython.display import clear_output
from multi_thread_handler import mth
import importlib
from numpy import ndarray

importlib.reload(utils)

dataset_name = 'mmlu-dev'
output_folder = f'outputs/{dataset_name}'
start_pointer_file_path = f'{output_folder}/start-pointer.txt'
next_file_index_file_path = f'{output_folder}/next-file-index.txt'
batch_size = 20

start_pointer = utils.read_integer_from_file(start_pointer_file_path)
next_file_index = utils.read_integer_from_file(next_file_index_file_path)
# start_pointer = 9461
# next_file_index = 9


def process_row(args):
    i, row = args
    question_text = row['question']
    choices_text: ndarray = row['choices']

    mth.safe_print(f"Processing Row: {i}")
    result = utils.choose_translation_method_and_translate(mth.rate_limited_translate, mth.sdk_translate, i, [question_text, *choices_text])
    if len(result) != 1 + len(choices_text):
        raise InvalidOutputError

    question_result = result[0]
    choices_result = result[1:]

    mth.safe_print(f"Queued Translation: {i}")
    return i, question_text, question_result, choices_text, choices_result


def translate_dataset(block_after: int = None):
    file_name = utils.get_output_csv_path(output_folder, next_file_index, dataset_name, 'csv')
    content_len = len(subset)

    error_occurred = False

    with open(file_name, 'w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerow(['Id', 'Original Question', 'Translated Question', 'Original Choice 1', 'Original Choice 2',
                         'Original Choice 3', 'Original Choice 4', 'Translated Choice 1', 'Translated Choice 2',
                         'Translated Choice 3', 'Translated Choice 4'])

        utils.update_integer_in_file(next_file_index_file_path, next_file_index + 1)

        end_pointer = start_pointer + block_after if block_after is not None else content_len
        current_batch_start = start_pointer
        connection_retries = 0
        start_time = utils.get_current_time()

        while current_batch_start < end_pointer and connection_retries < 5:
            current_batch_end = min(current_batch_start + batch_size, end_pointer)

            error_occurred = False
            non_network_error_occurred = False

            with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
                futures = {executor.submit(process_row, (i, row)): i for i, row in
                           subset.iloc[current_batch_start:current_batch_end].iterrows()}
                results = {}

                done, not_done = wait(futures.keys(), return_when=FIRST_EXCEPTION)

                for future in done:
                    try:
                        i, input_text, input_result, target_text, target_result = future.result()
                        results[i] = (input_text, input_result, target_text, target_result)
                    except RequestError as e:
                        mth.safe_print(colored(f"[Network Error - Automatic Retry]: {e}", 'red'))
                        connection_retries += 1
                        if connection_retries >= 5:
                            raise ReachedMaxRetriesError
                        error_occurred = True
                        break

   
                    except Exception as e:
                        mth.safe_print(colored(f"[Non-Network Error]: {e}", 'red'))
                        error_occurred = True
                        non_network_error_occurred = True
                        raise e

                if error_occurred:
                    for future in not_done:
                        mth.safe_print(colored(f"Cancelling unsubmitted futures: {future}", 'yellow'))
                        future.cancel()

                    if non_network_error_occurred:
                        break
                    else:
                        continue

                # Write the results of this batch
                for i in range(current_batch_start, current_batch_end):
                    mth.safe_print(f"Writing row {i}")

                    saved_result = list(results[i])
                    expanded_result = [i, saved_result[0], saved_result[1], *saved_result[2], *saved_result[3]]

                    if len(expanded_result) != 11:
                        raise InvalidOutputError

                    if i in results:
                        writer.writerow(expanded_result)
                    else:
                        raise MissingTranslationError(i)

                utils.update_integer_in_file(start_pointer_file_path, current_batch_end)
                file.flush()

                current_batch_start = current_batch_end
                current_time = utils.get_current_time()
                speed = utils.get_speed(current_batch_end - start_pointer, start_time, current_time)
                estimated_time = utils.get_estimated_time(content_len - start_pointer,
                                                          i - start_pointer, start_time,
                                                          current_time)

                clear_output(wait=True)
                mth.safe_print(
                    colored(
                        f"Moving to next batch. Translated {current_batch_end} of {content_len}, Elapsed (Secs): {current_time - start_time}, Estimated (Hrs): {estimated_time}, Speed: {speed}",
                        "green"))


translate_dataset()