In [1]:
import json
import re
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path

import pandas as pd
from datasets import Dataset
from omegaconf import OmegaConf

from src.core.translate import Translator
from src.utils.schemas import (GeneralTranslationResultSchema,
                               RationaleTranslationResultSchema)

In [2]:
regex_check_enhlsih = re.compile(r'\b[a-zA-Z]{2,}\b')

In [8]:
def contains_english(text: str) -> bool:
	"""
	Checks if the given text contains English words.

	Args:
		text (str): Text to check.

	Returns:
		bool:		True if English words are found, False otherwise.
	"""
	return bool(re.search(regex_check_enhlsih, text))

def read_file(path: str):
	return Path(path).open().read()
	
def read_json(path: str):
	return json.load(Path(path).open())

In [11]:
config = OmegaConf.load('configs/conf.yaml')
general_translation_config = config.general_translation
general_translation_correction_config = config.general_translation_correction
rationales_translation_config = config.rationales_translation
rationales_translation_correction_config = config.rationales_translation_correction

In [12]:
general_translator = Translator(
    system_message=read_file(general_translation_config.prompt_path), 
    model_config=read_json(general_translation_config.model_config_path), 
    example_data=read_json(general_translation_config.filepath_examples), 
    batch_size=general_translation_config.batch_size,
    batch_result_dir=general_translation_config.batch_result_dir,
    batch_dir=general_translation_config.batches
)

general_translator_corrector = Translator(
    system_message=read_file(general_translation_correction_config.prompt_path), 
    model_config=read_json(general_translation_correction_config.model_config_path), 
    example_data=read_json(general_translation_correction_config.filepath_examples), 
    batch_size=general_translation_correction_config.batch_size,
    batch_result_dir=general_translation_correction_config.batch_result_dir,
    batch_dir=general_translation_correction_config.batches
)

rational_translator = Translator(
    system_message=read_file(rationales_translation_config.prompt_path), 
    model_config=read_json(rationales_translation_config.model_config_path), 
    example_data=read_json(rationales_translation_config.filepath_examples), 
    batch_size=rationales_translation_config.batch_size,
    batch_result_dir=rationales_translation_config.batch_result_dir,
    batch_dir=rationales_translation_config.batches
)

rational_translator_corrector = Translator(
    system_message=read_file(rationales_translation_correction_config.prompt_path), 
    model_config=read_json(rationales_translation_correction_config.model_config_path), 
    example_data=read_json(rationales_translation_correction_config.filepath_examples), 
    batch_size=rationales_translation_correction_config.batch_size,
    batch_result_dir=rationales_translation_correction_config.batch_result_dir,
    batch_dir=rationales_translation_correction_config.batches
)

general_translation_int_path = f"int_path_general_translator"
rational_translation_int_path = f"int_path_rational_translator"

## Translating

In [13]:
general_data = pd.read_csv(general_translation_config.data_path)
general_dataset = {colname: general_data[colname].tolist() for colname in general_translation_config.cols.split()}
general_input_dataset = Dataset.from_dict(general_dataset)

In [14]:
translation_result = general_translator.translate(general_input_dataset, GeneralTranslationResultSchema)

Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 288.35ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1376.08ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1494.76ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1298.55ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1592.37ba/s]


Processing batch 1/5...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1323.96ba/s]


Processing batch 2/5...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 400.33ba/s]


Processing batch 3/5...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 841.22ba/s]


Processing batch 4/5...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 426.60ba/s]


Processing batch 5/5...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 590.33ba/s]


In [15]:
texts_containing_english = Dataset.from_list(translation_result).filter(lambda text: contains_english(text['text_rus'])) 

Filter: 100%|██████████| 19/19 [00:00<00:00, 2248.26 examples/s]


In [16]:
texts_containing_english["text_rus"]

[]

In [17]:
for t in texts_containing_english:
    translation_result.remove(t)

### Correcting translation

In [18]:
input_dataset_correction = texts_containing_english.remove_columns(["text"]).rename_column("text_rus", "text")

In [19]:
translation_result_corrected = general_translator_corrector.translate(input_dataset_correction, GeneralTranslationResultSchema)

In [20]:
translation_result_corrected

[]

In [21]:
Dataset.from_list(translation_result_corrected).filter(lambda text: contains_english(text['text_rus'])) 

Dataset({
    features: [],
    num_rows: 0
})

In [22]:
for item in translation_result_corrected:
    if not contains_english(item['text_rus']):
        translation_result.append(item)

In [23]:
with open(general_translation_int_path, "w", encoding="utf-8") as f:
	json.dump(translation_result, f, ensure_ascii=False, indent=4)

## Translating rationales

In [24]:
dataset_rationales = json.load(Path(rationales_translation_config.data_path).open())

In [25]:
input_dataset_rationales = Dataset.from_dict(dataset_rationales)

In [26]:
translation_result_rationales = rational_translator.translate(input_dataset_rationales, RationaleTranslationResultSchema)

Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 290.26ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1414.61ba/s]


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 107.31ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 337.05ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1296.94ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1478.95ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 768.05ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1494.76ba/s]


Processing batch 1/8...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 303.85ba/s]


Processing batch 2/8...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 217.82ba/s]


Processing batch 3/8...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 280.59ba/s]


Processing batch 4/8...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 205.08ba/s]


Processing batch 5/8...


ERROR:root:Unexpected error on input #3 (id: 19): 'id'. Attempt 1 of 3. Retrying...
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 943.60ba/s]


Processing batch 6/8...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 579.72ba/s]


Processing batch 7/8...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 648.87ba/s]


Processing batch 8/8...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 313.48ba/s]


In [27]:
incorrectly_translated = [] 
for item in translation_result_rationales:
    rats = item['rationales_rus'].split('|')
    for r in rats:
        if item['text_rus'].find(r) == -1:
            incorrectly_translated.append(item)

translated_rationales = list(filter(lambda x : x not in incorrectly_translated, translation_result_rationales))

In [28]:
incorrectly_translated

[]

### Correcting rationales

In [29]:
rationales_containing_english = Dataset.from_list(translation_result_rationales).filter(lambda text: contains_english(text['rationales_rus'])) 

Filter: 100%|██████████| 30/30 [00:00<00:00, 11455.67 examples/s]


In [30]:
rationales_containing_english

Dataset({
    features: ['id', 'text_rus', 'text_eng', 'rationales_eng', 'rationales_rus'],
    num_rows: 0
})

In [31]:
for t in rationales_containing_english:
    translation_result_rationales.remove(t)

In [32]:
keys = ['id', 'text_rus', 'text_eng', 'rationales_rus']
dataset_correction_rationales = {k: [d[k] for d in rationales_containing_english] for k in keys}
dataset_correction_rationales['rationales_eng'] = dataset_correction_rationales.pop('rationales_rus')

In [33]:
input_dataset_correction = Dataset.from_dict(dataset_correction_rationales)

In [34]:
translation_result_corrected = rational_translator_corrector.translate(input_dataset_correction, RationaleTranslationResultSchema)

In [35]:
translation_result_corrected

[]

In [36]:
Dataset.from_list(translation_result_corrected).filter(lambda text: contains_english(text['text_rus'])) 

Dataset({
    features: [],
    num_rows: 0
})

In [37]:
for item, rationale in zip(translation_result_corrected, rationales_containing_english):
    if not contains_english(item['text_rus']):
        translation_result_rationales.append({'id': item['id'], 'rationales_eng': rationale['rationales_eng'], 'rationales_rus': item['text_rus'], 'text_rus': rationale['text_rus'], 'text_eng': rationale['text_eng']})

In [38]:
d = Dataset.from_list(translation_result_rationales)
with open(rational_translation_int_path, "w", encoding="utf-8") as f:
	json.dump(d.to_list(), f, ensure_ascii=False, indent=4)

In [39]:
incorrectly_translated = [] 
for item in translation_result_rationales:
    rats = item['rationales_rus'].split('|')
    for r in rats:
        if item['text_rus'].find(r) == -1:
            incorrectly_translated.append(item)

translated_rationales = list(filter(lambda x : x not in incorrectly_translated, translation_result_rationales))