In [96]:
import json
import re
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path

import pandas as pd
from datasets import Dataset
from omegaconf import OmegaConf

from src.core.translate import Translator
from src.utils.schemas import (GeneralTranslationResultSchema,
                               RationaleTranslationResultSchema)

In [97]:
regex_check_enhlsih = re.compile(r'\b[a-zA-Z]{2,}\b')

In [98]:
def contains_english(text: str) -> bool:
	"""
	Checks if the given text contains English words.

	Args:
		text (str): Text to check.

	Returns:
		bool:		True if English words are found, False otherwise.
	"""
	return bool(re.search(regex_check_enhlsih, text))

In [99]:
config = OmegaConf.load('configs/conf.yaml')
general_translation_config = config.general_translation
general_translation_correction_config = config.general_translation_correction
rationales_translation_config = config.rationales_translation
rationales_translation_correction_config = config.rationales_translation_correction

In [100]:
@dataclass
class TranslatorDataConfig:
    system_message: str
    model_config: dict
    example_data: dict
    suffix: str = ""
    intermediate_path: str = field(init=False) 
    
    def __post_init__(self):
        self.intermediate_path = f"int_path_{self.model_config['model']}_{self.suffix}_{datetime.now()}.json"

In [132]:
general_tr_conf = TranslatorDataConfig(
    system_message=Path(general_translation_config['prompt_path']).open().read(),
    model_config=json.load(Path(general_translation_config['model_config_path']).open()),
    example_data=json.load(Path(general_translation_config['filepath_examples']).open()),
    suffix="general-translation"
)

general_tr_corr_conf = TranslatorDataConfig(
    system_message=Path(general_translation_correction_config['prompt_path']).open().read(),
    model_config=json.load(Path(general_translation_correction_config['model_config_path']).open()),
    example_data=json.load(Path(general_translation_correction_config['filepath_examples']).open()),
)

rational_tr_conf = TranslatorDataConfig(
    system_message=Path(rationales_translation_config['prompt_path']).open().read(),
    model_config=json.load(Path(rationales_translation_config['model_config_path']).open()),
    example_data=json.load(Path(rationales_translation_config['filepath_examples']).open()),
    suffix="rational-translation"
)

rational_tr_corr_conf = TranslatorDataConfig(
    system_message=Path(rationales_translation_correction_config['prompt_path']).open().read(),
    model_config=json.load(Path(rationales_translation_correction_config['model_config_path']).open()),
    example_data=json.load(Path(rationales_translation_correction_config['filepath_examples']).open()),
)

In [102]:
general_translator = Translator(
    system_message=general_tr_conf.system_message, 
    model_config=general_tr_conf.model_config, 
    example_data=general_tr_conf.example_data, 
    batch_size=general_translation_config.batch_size,
    batch_result_dir=general_translation_config.batch_result_dir,
    batch_dir=general_translation_config.batches
)

general_translator_corrector = Translator(
    system_message=general_tr_corr_conf.system_message, 
    model_config=general_tr_corr_conf.model_config, 
    example_data=general_tr_corr_conf.example_data, 
    batch_size=general_translation_correction_config.batch_size,
    batch_result_dir=general_translation_correction_config.batch_result_dir,
    batch_dir=general_translation_correction_config.batches
)

rational_translator = Translator(
    system_message=rational_tr_conf.system_message, 
    model_config=rational_tr_conf.model_config, 
    example_data=rational_tr_conf.example_data, 
    batch_size=rationales_translation_config.batch_size,
    batch_result_dir=rationales_translation_config.batch_result_dir,
    batch_dir=rationales_translation_config.batches
)

rational_translator_corrector = Translator(
    system_message=rational_tr_corr_conf.system_message, 
    model_config=rational_tr_corr_conf.model_config, 
    example_data=rational_tr_corr_conf.example_data, 
    batch_size=rationales_translation_correction_config.batch_size,
    batch_result_dir=rationales_translation_correction_config.batch_result_dir,
    batch_dir=rationales_translation_correction_config.batches
)

## Translating

In [103]:
general_data = pd.read_csv(general_translation_config.data_path)
general_dataset = {colname: general_data[colname].tolist() for colname in general_translation_config.cols.split()}
general_input_dataset = Dataset.from_dict(general_dataset)

In [104]:
translation_result = general_translator.translate(general_input_dataset, GeneralTranslationResultSchema)

Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 157.56ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1631.39ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1669.04ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1818.87ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1594.19ba/s]


Processing batch 1/5...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 171.82ba/s]


Processing batch 2/5...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 462.28ba/s]


Processing batch 3/5...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 169.29ba/s]


Processing batch 4/5...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 473.61ba/s]


Processing batch 5/5...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 699.75ba/s]


In [105]:
texts_containing_english = Dataset.from_list(translation_result).filter(lambda text: contains_english(text['text_rus'])) 

Filter: 100%|██████████| 19/19 [00:00<00:00, 2511.56 examples/s]


In [106]:
texts_containing_english["text_rus"]

[]

In [107]:
for t in texts_containing_english:
    translation_result.remove(t)

### Correcting translation

In [108]:
input_dataset_correction = texts_containing_english.remove_columns(["text"]).rename_column("text_rus", "text")

In [109]:
translation_result_corrected = general_translator_corrector.translate(input_dataset_correction, GeneralTranslationResultSchema)

In [110]:
translation_result_corrected

[]

In [111]:
Dataset.from_list(translation_result_corrected).filter(lambda text: contains_english(text['text_rus'])) 

Dataset({
    features: [],
    num_rows: 0
})

In [112]:
for item in translation_result_corrected:
    if not contains_english(item['text_rus']):
        translation_result.append(item)

In [135]:
with open(general_tr_conf.intermediate_path, "w", encoding="utf-8") as f:
	json.dump(translation_result, f, ensure_ascii=False, indent=4)

## Translating rationales

In [114]:
dataset_rationales = json.load(Path(rationales_translation_config.data_path).open())

In [115]:
input_dataset_rationales = Dataset.from_dict(dataset_rationales)

In [116]:
translation_result_rationales = rational_translator.translate(input_dataset_rationales, RationaleTranslationResultSchema)

Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 644.29ba/s]


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 65.13ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1103.76ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 848.02ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 727.29ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 960.67ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1331.95ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1466.54ba/s]


Processing batch 1/8...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 356.87ba/s]


Processing batch 2/8...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 500.75ba/s]


Processing batch 3/8...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 439.38ba/s]


Processing batch 4/8...


ERROR:root:Unexpected error on input #3 (id: 15): 'id'. Attempt 1 of 3. Retrying...
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 838.19ba/s]


Processing batch 5/8...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 269.25ba/s]


Processing batch 6/8...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 601.94ba/s]


Processing batch 7/8...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 541.20ba/s]


Processing batch 8/8...


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 434.91ba/s]


In [117]:
incorrectly_translated = [] 
for item in translation_result_rationales:
    rats = item['rationales_rus'].split('|')
    for r in rats:
        if item['text_rus'].find(r) == -1:
            incorrectly_translated.append(item)

translated_rationales = list(filter(lambda x : x not in incorrectly_translated, translation_result_rationales))

In [118]:
incorrectly_translated

[]

### Correcting rationales

In [119]:
rationales_containing_english = Dataset.from_list(translation_result_rationales).filter(lambda text: contains_english(text['rationales_rus'])) 

Filter: 100%|██████████| 30/30 [00:00<00:00, 13140.05 examples/s]


In [120]:
rationales_containing_english

Dataset({
    features: ['id', 'text_rus', 'text_eng', 'rationales_eng', 'rationales_rus'],
    num_rows: 0
})

In [121]:
for t in rationales_containing_english:
    translation_result_rationales.remove(t)

In [122]:
keys = ['id', 'text_rus', 'text_eng', 'rationales_rus']
dataset_correction_rationales = {k: [d[k] for d in rationales_containing_english] for k in keys}
dataset_correction_rationales['rationales_eng'] = dataset_correction_rationales.pop('rationales_rus')

In [123]:
input_dataset_correction = Dataset.from_dict(dataset_correction_rationales)

In [124]:
translation_result_corrected = rational_translator_corrector.translate(input_dataset_correction, RationaleTranslationResultSchema)

In [125]:
translation_result_corrected

[]

In [126]:
Dataset.from_list(translation_result_corrected).filter(lambda text: contains_english(text['text_rus'])) 

Dataset({
    features: [],
    num_rows: 0
})

In [127]:
for item, rationale in zip(translation_result_corrected, rationales_containing_english):
    if not contains_english(item['text_rus']):
        translation_result_rationales.append({'id': item['id'], 'rationales_eng': rationale['rationales_eng'], 'rationales_rus': item['text_rus'], 'text_rus': rationale['text_rus'], 'text_eng': rationale['text_eng']})

In [133]:
if rational_tr_conf.intermediate_path:
	d = Dataset.from_list(translation_result_rationales)
	with open(rational_tr_conf.intermediate_path, "w", encoding="utf-8") as f:
		json.dump(d.to_list(), f, ensure_ascii=False, indent=4)

In [129]:
incorrectly_translated = [] 
for item in translation_result_rationales:
    rats = item['rationales_rus'].split('|')
    for r in rats:
        if item['text_rus'].find(r) == -1:
            incorrectly_translated.append(item)

translated_rationales = list(filter(lambda x : x not in incorrectly_translated, translation_result_rationales))