In [1]:
import re
import torch
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Configuración del dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Cargar modelo y tokenizer
model_name = "facebook/m2m100_418M"
tokenizer = M2M100Tokenizer.from_pretrained(model_name, src_lang="ru", tgt_lang="en")
model = M2M100ForConditionalGeneration.from_pretrained(model_name).to(device)

# Cargar modelo y tokenizer para corrección tipográfica (usando BERT como ejemplo)
correction_model_name = "Phind/Phind-CodeLlama-34B-v2"
correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_name)
correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_name).to(device)


# Patrones
image_pattern = r"!\[\]\((.*?)\)"
latex_inline_pattern = r"\$[^$]+\$"
latex_block_pattern = r"\$\$[\s\S]*?\$\$"
code_block_pattern = r"```[\s\S]*?```"
header_pattern = r"^(#{1,6})\s+(.*)$"

# Diccionarios para preservar bloques LaTeX
latex_blocks = {}
block_counter = 0


def translate_text(text):
    tokenizer.src_lang = "ru"
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device)
    translated = model.generate(**inputs, forced_bos_token_id=tokenizer.get_lang_id("en"), max_length=1024)
    return tokenizer.decode(translated[0], skip_special_tokens=True)


def preserve_latex_blocks(text):
    global block_counter
    def repl(match):
        global block_counter
        key = f"<LATEX_BLOCK_{block_counter}>"
        latex_blocks[key] = match.group(0)
        block_counter += 1
        return key
    return re.sub(latex_block_pattern, repl, text)


def restore_latex_blocks(text):
    for key, value in latex_blocks.items():
        text = text.replace(key, value)
    return text


def split_paragraph(paragraph):
    patterns = [image_pattern, latex_block_pattern, latex_inline_pattern, code_block_pattern]
    parts = [paragraph]
    for pattern in patterns:
        new_parts = []
        for part in parts:
            matches = list(re.finditer(pattern, part, re.DOTALL))
            last_pos = 0
            for match in matches:
                start, end = match.span()
                if last_pos < start:
                    new_parts.append(part[last_pos:start])
                new_parts.append(match.group(0))
                last_pos = end
            if last_pos < len(part):
                new_parts.append(part[last_pos:])
        parts = new_parts
    return parts


def is_russian(text):
    return re.search(r"[а-яА-Я]", text, re.DOTALL)

def correct_latex_in_text(text):
    latex_pattern = re.compile(r"(\${1,2})([^\$]+?)\1", re.DOTALL)
    matches = list(latex_pattern.finditer(text))

    corrected_latex = {}
    for match in tqdm(matches, desc="Corrigiendo LaTeX"):
        delimiter, formula = match.groups()
        prompt = f"Corrige cualquier error tipográfico o sintáctico en esta fórmula LaTeX sin cambiar su contenido matemático:\n\n{delimiter}{formula}{delimiter}"
        
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
        outputs = model.generate(**inputs, max_length=512)
        corrected = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        
        # Quitar el prompt si el modelo responde con texto de más
        corrected_formula = re.search(r"(\${1,2})([^\$]+?)\1", corrected)
        if corrected_formula:
            corrected = f"{delimiter}{corrected_formula.group(2)}{delimiter}"
        else:
            corrected = f"{delimiter}{formula.strip()}{delimiter}"

        corrected_latex[match.group(0)] = corrected

    # Reemplazar en el contenido original
    for original, fixed in corrected_latex.items():
        text = text.replace(original, fixed)

    return text

def process_paragraph(paragraph):
    header_match = re.match(header_pattern, paragraph)
    if header_match:
        level, text = header_match.groups()
        if is_russian(text):
            text = translate_text(text)
        return f"{level} {text}"

    paragraph = preserve_latex_blocks(paragraph)
    parts = split_paragraph(paragraph)

    translated_parts = []
    for part in parts:
        if is_russian(part):
            translated_parts.append(translate_text(part))
        elif re.match(latex_inline_pattern, part) or re.match(latex_block_pattern, part):
            # Si es texto LaTeX, corregirlo
            translated_parts.append(correct_latex_in_text(part))
        else:
            translated_parts.append(part)

    final_paragraph = "".join(translated_parts)
    return restore_latex_blocks(final_paragraph).strip()


def main():
    input_file = "full.md"
    output_file = "output.md"

    with open(input_file, "r", encoding="utf-8") as f:
        content = f.read()

    paragraphs = content.split("\n\n")
    translated_paragraphs = []

    for paragraph in tqdm(paragraphs, desc="Translating markdown paragraphs"):
        translated_paragraph = process_paragraph(paragraph)
        translated_paragraphs.append(translated_paragraph)

    translated_content = "\n\n".join(translated_paragraphs)

    with open(output_file, "w", encoding="utf-8") as f:
        f.write(translated_content)

    print("Translation complete! Check 'output.md' for the result.")


if __name__ == "__main__":
    main()


Using device: cuda




tokenizer_config.json:   0%|          | 0.00/824 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/434 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/638 [00:00<?, ?B/s]

ValueError: Unrecognized configuration class <class 'transformers.models.llama.configuration_llama.LlamaConfig'> for this kind of AutoModel: AutoModelForSeq2SeqLM.
Model type should be one of BartConfig, BigBirdPegasusConfig, BlenderbotConfig, BlenderbotSmallConfig, EncoderDecoderConfig, FSMTConfig, GPTSanJapaneseConfig, LEDConfig, LongT5Config, M2M100Config, MarianConfig, MBartConfig, MT5Config, MvpConfig, NllbMoeConfig, PegasusConfig, PegasusXConfig, PLBartConfig, ProphetNetConfig, SwitchTransformersConfig, T5Config, XLMProphetNetConfig.

In [None]:
ls