In [None]:
import tqdm, os, math
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import re, gc, io, torch, fastwer, cv2, fitz, json, subprocess
from PIL import Image
from io import BytesIO
from paddleocr import PaddleOCR
from doctr.models import ocr_predictor
from together import Together
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import base64, urllib.request
import google.generativeai as genai
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, AutoModelForImageTextToText
from contextlib import redirect_stdout, redirect_stderr
from surya.recognition import RecognitionPredictor
from surya.detection import DetectionPredictor
from surya.layout import LayoutPredictor
from openai import AzureOpenAI
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from skimage import filters, exposure, util

langs = ["es"]
ocr_model = ["surya"] # "paddle", "doctr", "gemini", "surya", "olmocr"
fix_model = [] # "gpt-4o", "gpt-4o-mini", "gemini", "gemini2"
pos_model = [] # "gpt-4o", "gpt-4o-mini", "gemini", "gemini2"
pos_model_surya = ["gpt-4o"] # "gpt-4o", "gpt-4o-mini", "gemini", "gemini2"


### UTIL FUNCTIONS #################################################################

try:
    # Intentamos importar pynvml para GPUs NVIDIA
    import pynvml
    has_gpu = True
except ImportError:
    # Si no está disponible, usamos psutil para memoria del sistema
    import psutil
    has_gpu = False

def check_mem():
    mem = ""
    memoria_disponible = ""
    
    result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, text=True)
    output = result.stdout
    memory_pattern = re.search(r'(\d+)MiB\s+/\s+(\d+)MiB', output)
    if memory_pattern:
        used_mem = memory_pattern.group(1)
        total_mem = memory_pattern.group(2)
        mem = f"{used_mem}MiB / {total_mem}MiB"

    if has_gpu:
        pynvml.nvmlInit()
        handle = pynvml.nvmlDeviceGetHandleByIndex(0)  # Usando primera GPU
        info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        memoria_disponible = info.free
        pynvml.nvmlShutdown()
    else:
        memoria_disponible = psutil.virtual_memory().available
    if memoria_disponible != "":
        mem += f" - {memoria_disponible}"

    return mem

def save_text_file(text, output_file):
    if text != "":
        with open(output_file, 'w', encoding='utf-8') as file:
            file.write(text)

def verify_text(engine, text, Ref):
    def normalize_text(input_text):
        input_text = re.sub(r'===END===', '', input_text)
        input_text = re.sub(r'\s*biblioteca\s+nacional\s+de\s+españa\s*$', '', input_text, flags=re.IGNORECASE)
        lines = input_text.split('\n')
        if lines and len(lines) > 0:
            lines[0] = re.sub(r'\d+', '', lines[0])
            if len(lines) > 1:
                lines[1] = re.sub(r'\d+', '', lines[1])
        text_clean = '\n'.join(lines)
        text_clean = re.sub(r'["„©¶»«—]', '', text_clean)
        text_clean = '\n'.join(line.strip() for line in text_clean.split('\n'))
        text_clean = re.sub(r'\n', ' ', text_clean)
        text_clean = re.sub(r' +', ' ', text_clean)
        text_clean = re.sub(r' +\.', '.', text_clean)
        return text_clean

    normalized_text = normalize_text(text)
    normalized_ref = normalize_text(Ref)
    pattern = r'(\w+)[\-\u00AD\u2010\u2011]+\s+(\w+)'
    normalized_ref = re.sub(pattern, r'\1\2', normalized_ref)
    #print("[text]\n", normalized_text)
    #print("[Ref]\n", normalized_ref)
    
    cer = fastwer.score_sent(normalized_text, normalized_ref, char_level=True)
    wer = fastwer.score_sent(normalized_text, normalized_ref)
    print(f'[{engine}] CER:{cer:.2f} - WER:{wer:.2f}')
    return cer, wer

def postprocess(text):
    pattern = r'(\w+)[\-\u00AD\u2010\u2011]+\s+(\w+)'
    # Reemplaza con las dos partes unidas
    text = re.sub(pattern, r'\1\2', text)
    #text = re.sub(r"(\w+)-\s*\n\s*", r"\1", text)
    text = re.sub(r"\s(vio)\s", r" vió ", text)
    text = re.sub(r"\s(fue)\s", r" fué ", text)
    if debug:
        print(text)
    return text

def postprocess_surya(text):
    def replace_accents(match):
        return replacement_dict[match.group(0)]

    pattern = r'à|è|ì|ò|ù|À|È|Ì|Ò|Ù'
    replacement_dict = {
        'à': 'á', 'è': 'é', 'ì': 'í', 'ò': 'ó', 'ù': 'ú',
        'À': 'Á', 'È': 'É', 'Ì': 'Í', 'Ò': 'Ó', 'Ù': 'Ú'
    }

    #text = re.sub(pattern, replace_accents, text)
    text = re.sub(r'\s*biblioteca\s+nacional\s+de\s+espa[ñn]a\s*$', '', text, flags=re.IGNORECASE)
    pattern = r'(\w+)-\s+(\w+)' 
    text = re.sub(pattern, r'\1\2', text)
    text = re.sub(r"[\u00A9\u24B8]", "", text) # ©
    
    #text = re.sub(r"(\w+)\s*-\s*\n\s*(\w+)", r"\1\2", text)
    text = re.sub(r"\s(vio)\s", r" vió ", text)
    text = re.sub(r"\s(fue)\s", r" fué ", text)
    if len(pos_model_surya)>0:
        text = pos_remote(text, pos_model_surya[0])
    text = re.sub(r'===END===', '', text)
    return postprocess(text)

def postprocess_doctr(text):
    return postprocess(text)

def postprocess_paddle(text):
    return postprocess(text)

def postprocess_gemini(text):
    return postprocess(text)
   
def postprocess_olmocr(text):
    return postprocess(text)
   
# --- Filtros individuales ---

def reescalar(img):
    return cv2.resize(img, None, fx=upscale, fy=upscale, interpolation=cv2.INTER_CUBIC)
    
def limpiar_fondo(img):
    fondo = cv2.medianBlur(img, 21)
    resta = cv2.absdiff(img, fondo)
    norm = cv2.normalize(resta, None, 0, 255, cv2.NORM_MINMAX)
    return norm

def brillo_contraste(img, alpha=1.0, beta=30):
    return cv2.convertScaleAbs(img, alpha=alpha, beta=beta)

def suavizado(img):
    return cv2.medianBlur(img, 3)

def enfoque(img):
    kernel = np.array([[0, -1, 0],
                       [-1, 5, -1],
                       [0, -1, 0]])
    return cv2.filter2D(img, -1, kernel)

def clahe(img):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    return clahe.apply(img)

def binar_adaptativa(img):
    return cv2.adaptiveThreshold(img, 255,
                                 cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                 cv2.THRESH_BINARY_INV,
                                 15, 7)

def otsu_inv(img):
    _, binaria = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    return binaria

def otsu_normal(img):
    _, binaria = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    return binaria

def dilatacion(img):
    kernel = np.ones((1, 1), np.uint8)
    return cv2.dilate(img, kernel, iterations=1)

def preprocess(img_orig, filtros_activados):
    try:
        img = img_orig.copy()

        filtros = {
            1: reescalar,
            2: limpiar_fondo,
            3: brillo_contraste,
            4: suavizado,
            5: enfoque,
            6: clahe,
            7: binar_adaptativa,
            8: otsu_inv,
            9: otsu_normal,
            10: dilatacion
        }

        # 1. Asegurar que sea uint8 y escala de grises (igual que opción 1)
        if len(img_orig.shape) == 3 and img_orig.shape[2] == 3:
            img_gray = cv2.cvtColor(img_orig, cv2.COLOR_BGR2GRAY)
        else:
            img_gray = img_orig.copy()

        if img_gray.dtype != np.uint8:
           img_gray = cv2.normalize(img_gray, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

        for id_filtro in filtros_activados:
            if id_filtro not in filtros:
                print(f"Filtro {id_filtro} no está definido.")
                continue

            # Aplicar filtro
            filtro_func = filtros[id_filtro]

            # Filtros con parámetros especiales
            if id_filtro == 3:
                img = filtro_func(img, alpha=1.2, beta=30)
            else:
                img = filtro_func(img)

        return img

    except Exception as e:
        print(f"Error en preprocess: {str(e)}")
        return img_orig

def probar_combinaciones_top10(img_original, Ref):
    combinaciones = [
        [1, 3, 5, 8],
        [1, 3, 6, 8],
        [1, 4, 6, 8],
        [1, 2, 6, 8],
        [1, 2, 4, 8],
        [1, 6, 8],
    ]
    page_num = 1

    for i, combo in enumerate(combinaciones):
        try:
            print(f"\n▶️ Probando combinación #{i+1}: {combo}")
            img_proc = preprocess(img_original, combo)
            cv2.imwrite(f"page_{page_num}_.png", img_proc)

            # Realizar OCR
            layout = process_layout(page_num)
            engine = "olmocr"
            text = process_images(engine, layout["cropped_images"])
            text = globals()[f"postprocess_{engine}"](text)
            verify_text(engine, text, Ref)
            text = globals()[f"ocr_{engine}"](Image.fromarray(img_proc)) # layout["combined_image"]
            text = globals()[f"postprocess_{engine}"](text)
            verify_text(engine, text, Ref)
        except Exception as e:
            print(f"❌ Error en combinación {i+1} {combo}: {e}")
            continue

import pytesseract

def analizar_img(path_imagen):
    imagen = cv2.imread(path_imagen)

    if imagen is None:
        return "ERROR,No se pudo cargar la imagen"

    gris = cv2.cvtColor(imagen, cv2.COLOR_BGR2GRAY)
    alto, ancho = gris.shape

    # 1. Brillo y contraste
    brillo = np.mean(gris)
    contraste = np.std(gris)

    # 2. Detección de columnas (basado en contornos grandes)
    _, binaria = cv2.threshold(gris, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    bin_inv = 255 - binaria
    contornos, _ = cv2.findContours(bin_inv, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    columnas_detectadas = sum(1 for c in contornos if cv2.boundingRect(c)[2] > ancho // 3)

    # 3. Espaciado entre líneas y palabras
    datos_ocr = pytesseract.image_to_data(gris, output_type=pytesseract.Output.DICT)
    y_coords = [datos_ocr['top'][i] for i in range(len(datos_ocr['text'])) if int(datos_ocr['conf'][i]) > 0 and datos_ocr['text'][i].strip() != '']
    if len(y_coords) > 1:
        y_coords.sort()
        espaciado_lineas = np.mean(np.diff(y_coords))
    else:
        espaciado_lineas = 0

    # 4. Inclinación del texto (skew)
    coords = np.column_stack(np.where(gris < 128))
    if len(coords) == 0:
        skew = 0
    else:
        angulo = cv2.minAreaRect(coords)[-1]
        if angulo < -45:
            angulo = -(90 + angulo)
        else:
            angulo = -angulo
        skew = angulo

    # 5. Tipo de fondo (ruido/patrones)
    fondo_var = np.var(gris)
    fondo_ruido = fondo_var > 500  # Umbral empírico

    # 6. Color del texto/fondo
    blancos = np.sum(binaria == 255)
    negros = np.sum(binaria == 0)
    texto_oscuro = negros > blancos  # True si el texto es negro

    # 7. Artefactos de compresión (bordes artificiales)
    bloques = cv2.Laplacian(gris, cv2.CV_64F).var()
    artefactos_compresion = bloques < 50  # Umbral empírico

    # 8. Márgenes (bordes blancos)
    top_margin = np.mean(gris[0:10, :])
    bottom_margin = np.mean(gris[-10:, :])
    left_margin = np.mean(gris[:, 0:10])
    right_margin = np.mean(gris[:, -10:])
    margen_blanco = all(m > 240 for m in [top_margin, bottom_margin, left_margin, right_margin])

    # 9. Tamaño promedio del texto
    alturas = [datos_ocr['height'][i] for i in range(len(datos_ocr['height'])) if int(datos_ocr['conf'][i]) > 0]
    if alturas:
        altura_prom_texto = int(np.mean(alturas))
    else:
        altura_prom_texto = 0
    tam_texto_relativo = altura_prom_texto / alto

    # 10. Densidad de bordes
    bordes = cv2.Canny(gris, 100, 200)
    densidad_bordes = np.sum(bordes > 0) / (gris.shape[0] * gris.shape[1])

    # 11. Ruido estimado
    ruido_estimado = np.var(util.random_noise(gris/255.0, mode='gaussian') - gris/255.0)

    # 12. Proporción de texto
    texto_negro = np.sum(binaria == 0)
    texto_fondo_ratio = texto_negro / binaria.size

    # 13. Componentes conectados
    num_componentes, _ = cv2.connectedComponents(binaria)

    # 14. Presencia de líneas/tablas
    horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
    vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
    detect_horizontal = cv2.morphologyEx(binaria, cv2.MORPH_OPEN, horizontal_kernel)
    detect_vertical = cv2.morphologyEx(binaria, cv2.MORPH_OPEN, vertical_kernel)
    lineas_detectadas = (np.sum(detect_horizontal == 0) + np.sum(detect_vertical == 0)) > 500  # Arbitrario

    # 15. Generar sugerencias
    sugerencias = []

    if brillo < 80:
        sugerencias.append("mejorar_brillo")
    if contraste < 30:
        sugerencias.append("mejorar_contraste")
    if densidad_bordes < 0.01:
        sugerencias.append("aplicar_enfoque")
    if ruido_estimado > 0.01:
        sugerencias.append("suavizar")
    if skew > 2:
        sugerencias.append("corregir_inclinacion")
    if altura_prom_texto < 15:
        sugerencias.append("reescalar")
    if not texto_oscuro:
        sugerencias.append("invertir_colores")
    if fondo_ruido:
        sugerencias.append("limpiar_fondo")
    if margen_blanco:
        sugerencias.append("recortar_margenes")
    if artefactos_compresion:
        sugerencias.append("reducir_artefactos")

    # 16. Formato de salida CSV
    fila_csv = (
        f"{brillo:.2f},{contraste:.2f},{skew:.2f},{altura_prom_texto},"
        f"{tam_texto_relativo:.4f},{espaciado_lineas:.2f},{columnas_detectadas},{densidad_bordes:.4f},"
        f"{ruido_estimado:.4f},{texto_fondo_ratio:.4f},{num_componentes},{int(texto_oscuro)},"
        f"{int(fondo_ruido)},{int(margen_blanco)},{int(artefactos_compresion)},{int(lineas_detectadas)},"
        f"{'|'.join(sugerencias)}"
    )
    return fila_csv

def fix_text_with_image_remote_gemini(text, page_num):
    image_path = f"page_{page_num}_.png"
    img = Image.open(image_path)
    
    # Enviar solicitud al modelo
    respuesta = gmodel.generate_content([text, img])
    
    # Extraer texto de la respuesta
    corrected_text = respuesta.text
    
    return corrected_text

def fix_text_with_image_remote(text, page_num):
    image_path = f"page_{page_num}_.png"
    with open(image_path, "rb") as image_file:
        img = image_file.read()
    base64_image = base64.b64encode(img).decode('utf-8')
    
    # Crear el mensaje con contenido mixto (texto e imagen)
    prompt = set_prompt_with_image(text)
    
    try:
        completion = client.chat.completions.create(
            model=deployment,
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt},
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/png;base64,{base64_image}"
                            }
                        }
                    ]
                }
            ],
            max_tokens=4000,
            temperature=0,
            top_p=0.95,
            frequency_penalty=0,
            presence_penalty=0,
            stop=None,
            stream=False
        )
        data = json.loads(completion.model_dump_json())
        corrected_text = data['choices'][0]['message']['content']
        parts = corrected_text.split('===END===')
        if len(parts) > 1:
            return parts[0].strip()
    except Exception as e:
        print(f"Error querying the remote LLM service: {e}")
        torch.cuda.empty_cache()
    return corrected_text

def fix_text_remote_gemini(prompt):
    corrected_text = ""
    try:
        if "gemini2" in pos_model:
            respuesta = gmodel2.generate_content(prompt)
        else:
            respuesta = gmodel.generate_content(prompt)
        corrected_text = respuesta.text
    except Exception as e:
        print(f"[fix_text_remote_gemini] {e}")
        torch.cuda.empty_cache()
    return corrected_text

def fix_text_remote(prompt, deploy):
    corrected_text = ""
    try:
        completion = client.chat.completions.create(
            model=deploy ,
            messages= [
            {
              "role": "user",
              "content": prompt
            }],
            max_tokens=4000,
            temperature=0,
            top_p=0.95,
            frequency_penalty=0,
            presence_penalty=0,
            stop=None,
            stream=False
        )
        data = json.loads(completion.model_dump_json())
        corrected_text = data['choices'][0]['message']['content']
        parts = corrected_text.split('===END===')
        if len(parts) > 1:
            return parts[0].strip()
    except:
        print("Error querying the remote LLM service.")
        torch.cuda.empty_cache()
    return corrected_text

def pos_prompt(text):
    prompt = """
    You are an expert in historical Spanish linguistics and orthography. Your task is to carefully review the following text written in Old Spanish (from the 16th to 19th centuries). You must analyze the text sentence by sentence, verifying the following aspects:
    1. Meaning & Coherence:
    - Does each sentence make logical sense in the context of Old Spanish?
    - Identify and highlight any phrases that seem unclear, ambiguous, or grammatically incorrect.
    2. Punctuation & Special Characters:
    - Ensure that question marks (¿...?) and exclamation marks (¡...!) are correctly opened and closed according to Spanish grammar rules.
    - Check for missing or misplaced punctuation marks that may affect readability.
    3. Orthography Consistency (Old Spanish Standards):
    - Verify whether the spelling is consistent with Old Spanish conventions of its time.
    - Identify anachronisms or incorrect modernized spellings.
    - Ensure that words are spelled as they would have been during the era the text belongs to.
    4. Suggestions & Corrections:
    If errors are found, provide a corrected version while preserving the historical style.
    Explain any changes made, especially if modern Spanish rules conflict with historical usage.
    
    Instructions:
    - Maintain the original tone and structure of the text.
    - Do not modernize the language.
    - Provide only the corrected text, without any additional commentary.
    - Do not add any new information or explanations.
    - Focus on fixing spelling and obvious OCR mistakes by comparing with the original image.
    - End your response with '===END===' on a new line.
    
    Input Example:
    "¿Quántos daños no causarán los criados con su olvido ó mala inteligencia de los recados que reciben?"
    
    Expected Output:
    - Sentence makes sense.
    - Proper use of question marks (¿...? is correctly opened and closed).
    - "Quántos" should be corrected to "¿Cuántos?" (based on orthographic conventions of the period).
    - "ó" with an accent is valid in Old Spanish when avoiding vowel collision.
    """
    prompt += f"\nHere is the text:\n\n{text}"
    return prompt

def set_prompt(text):
    prompt = (
        "You are an expert in text correction and OCR error fixing. Your task is to combine and correct several OCR outputs of the same text. "
        f"Here are the texts:\n\n{text}"
        "\n\nInstructions:\n"
        "1. Combine the texts, correcting any OCR errors.\n"
        "2. Provide only the corrected text, without any additional commentary.\n"
        "3. Maintain the original structure and formatting.\n"
        "4. Do not add any new information or explanations.\n"
        "5. Join any words that have been separated by a hyphen at the end of a line. If there're blank spaces after the hyphen, remove them so the two parts of the word get joined correctly.\n"
        "6. The text is written using archaic Spanish spelling.\n"
        "7. Maintain all diacritical marks, old-fashioned spellings, and historical punctuation, such as the use of 'fué' instead of 'fue', 'dió' instead of 'dio', 'ví' instead of 'vi', 'á' instead of 'a' in prepositions. Do not replace older words or grammatical structures with modern equivalents.\n"
        "8. Ensure that all words retain their original diacritics, such as accents (é, á, ó), tildes (ñ), and umlauts (ü), without alteration.\n"
        "9. Focus on fixing spelling and obvious OCR mistakes.\n"
        "10. End your response with '===END===' on a new line.\n\n"
        "Corrected text:"
    )
    return prompt

def set_prompt_with_image(text):
    prompt = (
        "You are an expert in text correction and OCR error fixing. Your task is to combine and correct several OCR outputs of the same text. "
        "I'm providing both the OCR outputs and the original image of the document. "
        f"Here are the OCR texts:\n\n{text}"
        "\n\nInstructions:\n"
        "1. First, look at the image of the original document to understand the correct text.\n"
        "2. Compare the OCR outputs with what you see in the image and create the most accurate version.\n"
        "3. Combine the texts, correcting any OCR errors based on what's visible in the image.\n"
        "4. When the OCR outputs differ, refer to the image to determine the correct text.\n"
        "5. Provide only the corrected text, without any additional commentary.\n"
        "6. Maintain the original structure and formatting.\n"
        "7. Do not add any new information or explanations.\n"
        "8. Join any words that have been separated by a hyphen at the end of a line. If there're blank spaces after the hyphen, remove them so the two parts of the word get joined correctly.\n"
        "9. The text is written using archaic Spanish spelling.\n"
        "10. Maintain all diacritical marks, old-fashioned spellings, and historical punctuation, such as the use of 'fué' instead of 'fue', 'dió' instead of 'dio', 'ví' instead of 'vi', 'á' instead of 'a' in prepositions. Do not replace older words or grammatical structures with modern equivalents.\n"
        "11. Ensure that all words retain their original diacritics, such as accents (é, á, ó), tildes (ñ), and umlauts (ü), without alteration.\n"
        "12. Focus on fixing spelling and obvious OCR mistakes by comparing with the original image.\n"
        "13. End your response with '===END===' on a new line.\n\n"
        "Corrected text:"
    )
    return prompt

def fix_text(text):
    prompt = set_prompt(text)
    corrected_text = ""
    #print(prompt)
    if "gemini" in fix_model:
        corrected_text = fix_text_remote_gemini(prompt)
    elif "gpt-4o" in fix_model:
        corrected_text = fix_text_remote(prompt, deployment)
    else:
        corrected_text = fix_text_remote(prompt, deployment_mini)
    return corrected_text

def fix_text_with_images(text, page_num):
    prompt = set_prompt_with_image(text)
    corrected_text = ""
    if "gemini" in fix_model:
        corrected_text = fix_text_with_image_remote_gemini(prompt, page_num)
    else:
        corrected_text = fix_text_with_image_remote(prompt, page_num)
    return corrected_text

def pos_remote(text, pmodel=""):
    prompt = pos_prompt(text)
    corrected_text = ""
    #print(prompt)
    if pmodel == "gemini":
        corrected_text = fix_text_remote_gemini(prompt)
    elif pmodel == "gpt-4o":
        corrected_text = fix_text_remote(prompt, deployment)
    elif pmodel == "gpt-4o-mini":
        corrected_text = fix_text_remote(prompt, deployment_mini)
    #print('[corrected_text]', corrected_text)
    return corrected_text

### API MODELS #################################################################

if "gemini" in ocr_model or "gemini" in fix_model:
    gemini_api_key = ""
    genai.configure(api_key=gemini_api_key)
    gmodel = genai.GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
    if "gemini2" in pos_model:
        gmodel2 = genai.GenerativeModel('gemini-2.5-pro-exp-03-25')
 
if "gpt-4o" in ocr_model or "gpt-4o" in fix_model or "gpt-4o" in pos_model or "gpt-4o" in pos_model_surya:
    endpoint = os.getenv("ENDPOINT_URL", "https://open-ia-service.openai.azure.com/")
    deployment = os.getenv("DEPLOYMENT_NAME", "gpt-4o")
    deployment_mini = os.getenv("DEPLOYMENT_NAME", "gpt-4o-mini")
    #AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
    AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
    client = AzureOpenAI(
        api_key = AZURE_OPENAI_API_KEY,
        api_version = "2024-05-01-preview",
        azure_endpoint = endpoint 
    )

if "r1" in ocr_model or "r1" in fix_model:
    client = Together(api_key = "xxx")
### LOCAL MODELS #################################################################

def ocr_gemini(img):
    text = ""
    try:
        prompt = """
        Perform OCR (Optical Character Recognition) on this image.
        Extract ALL visible text without modernizing or modifying Old Spanish.
        Correct spelling and punctuation while preserving the original language and format.
        Respond ONLY with the extracted text, without additional comments.
        """
        respuesta = gmodel.generate_content([prompt, img])
        text = respuesta.text
    except Exception as e:
        print(f"Error[ocr_gemini]: {str(e)}")
    return text

paddle_ocr = PaddleOCR(show_log=False, use_angle_cls=True, lang='es', use_gpu=False)
def ocr_paddle(img):
    text = ""
    try:
        if img is not None:
            result = paddle_ocr.ocr(img)
            if result[0] is not None:
                text = " ".join([line[1][0] for line in result[0]])
    except Exception as e:
        print(f"Error[ocr_paddle]: {str(e)}")
    return text

def ocr_doctr(img):
    text = ""
    try:
        if img is not None:
            model = ocr_predictor(pretrained=True)
            result = model([img])
            if len(result.pages[0].blocks) == 0:
                return ""
            text = result.render()
    except Exception as e:
        print(f"Error[ocr_doctr]: {str(e)}")
    return text

os.environ["TQDM_DISABLE"] = "1"
null_stream = io.StringIO()
with redirect_stdout(null_stream), redirect_stderr(null_stream):
    detection_predictor = DetectionPredictor()
    recognition_predictor = RecognitionPredictor()
    layout_predictor = LayoutPredictor()
def ocr_surya(img):
    text = ""
    null_stream = io.StringIO()
    try:
        if img is not None:
            with redirect_stdout(null_stream), redirect_stderr(null_stream):
                predictions = recognition_predictor([img], [langs], detection_predictor)
                if predictions and predictions[0] is not None:
                    # Access text_lines as an attribute
                    text_lines = predictions[0].text_lines
                    text = " ".join([line.text for line in text_lines])
    except Exception as e:
        tmp = f"Error[ocr_surya]: {str(e)}"
        if 'fillPoly' not in str(e):
            print(tmp)
    return text

if "olmocr" in ocr_model:
    #processor = AutoProcessor.from_pretrained("reducto/RolmOCR")
    #model = AutoModelForImageTextToText.from_pretrained("reducto/RolmOCR")
    #model = model.half()

    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
    model = Qwen2VLForConditionalGeneration.from_pretrained("allenai/olmOCR-7B-0225-preview", torch_dtype=torch.bfloat16).eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

def ocr_Rolmocr(img):
    in_text = ""
    out_text = ""
    torch.cuda.empty_cache()  # Liberar memoria no utilizada
     
    try:
        # Verificar tamaño de la imagen
        width, height = img.size
        min_size = 32  # Un poco más grande que el factor requerido (28)
        # Si la imagen es demasiado pequeña, redimensionarla manteniendo la relación de aspecto
        if width < min_size or height < min_size:
            # Calcular la escala necesaria
            scale = max(min_size / width, min_size / height)
            new_width = int(width * scale)
            new_height = int(height * scale)
            img = img.resize((new_width, new_height), Image.LANCZOS)

        area = width * height
        max_area = 1800000
        if area > max_area:
            # Ajustar el factor de escala según el tamaño
            scale = math.sqrt(max_area / area)
            
            new_width = int(width * scale)
            new_height = int(height * scale)
            #print(f"Redimensionando a {new_width}x{new_height} (factor: {scale:.2f})")
            img = img.resize((new_width, new_height), Image.LANCZOS)

        # Build the prompt, using document metadata
        #anchor_text = get_anchor_text("./dev/pdf/9284.pdf", 1, pdf_engine="pdfreport", target_length=4000)
        anchor_text = ""
        prompt = build_finetuning_prompt(anchor_text)
        #pil_img = Image.fromarray(img) if img.shape[-1] == 3 else Image.fromarray(img, mode="L")
        buffer = BytesIO()
        #pil_img.save(buffer, format="PNG")
        img.save(buffer, format="PNG")
        img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
        #img_base64 = base64.b64encode(img).decode('utf-8')
    
        # Build the full prompt
        messages = [{"role": "user",
                        "content": [
                            {"type": "text", "text": prompt},
                            {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_base64}"}},
                        ],}]
        
        # Apply the chat template and processor
        in_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        
        inputs = processor(
            text=[in_text],
            images=[img],
            padding=True,
            return_tensors="pt",
        )
        inputs = {key: value.to(device) for (key, value) in inputs.items()}
        del img_base64, buffer
        torch.cuda.empty_cache()  # Liberar memoria no utilizada
         
        # Generate the output
        output = model.generate(
            **inputs,
            temperature=0.3,       # Aumentar ligeramente para más variedad
            top_p=0.7,             # Ampliar el conjunto de tokens considerados
            max_new_tokens=2048,
            num_return_sequences=1,
            do_sample=True,
            use_cache=True  # Asegurar que el caché está activado
        )
        # Decode the output
        prompt_length = inputs["input_ids"].shape[1]
        new_tokens = output[:, prompt_length:]
        data = processor.tokenizer.batch_decode(
            new_tokens, skip_special_tokens=True
        )

        del inputs, output, new_tokens
        torch.cuda.empty_cache()
        print('[olmo_output]\n', data)
        out_text = data[0]
    except Exception as e:
        print(f"Error[ocr_olmo]: {str(e)}")
    return out_text

def remove_repetitions(text, max_repetitions=5):
    # Dividir en líneas
    lines = text.split('\\n')
    
    # Detectar y eliminar líneas repetidas
    unique_lines = []
    repeat_count = 0
    prev_line = None
    
    for line in lines:
        if line == prev_line:
            repeat_count += 1
            if repeat_count < max_repetitions:
                unique_lines.append(line)
        else:
            repeat_count = 0
            unique_lines.append(line)
            prev_line = line
    
    return '\\n'.join(unique_lines)

def repair_json_attempt(json_str):
    # Buscar patrones de repetición en todo el string
    # Esta es una implementación básica que busca repeticiones obvias
    lines = json_str.split('\\n')
    unique_lines = []
    prev_line = None
    repeat_count = 0
    max_repetitions = 2
    
    for line in lines:
        if line == prev_line:
            repeat_count += 1
            if repeat_count < max_repetitions:
                unique_lines.append(line)
        else:
            repeat_count = 0
            unique_lines.append(line)
            prev_line = line
    
    # Reconstruir el string
    repaired_str = '\\n'.join(unique_lines)
    
    # Asegurarse de que termina correctamente
    if not repaired_str.endswith('"}'):
        repaired_str = repaired_str + '"}'
    
    # Intentar parsear de nuevo
    try:
        json.loads(repaired_str)
        return repaired_str
    except:
        # Si aún falla, devolver una estructura JSON de error
        return None

def ocr_olmocr(img):
    in_text = ""
    out_text = ""
    torch.cuda.empty_cache()  # Liberar memoria no utilizada
     
    try:
        # Verificar tamaño de la imagen
        width, height = img.size
        min_size = 32  # Un poco más grande que el factor requerido (28)
        # Si la imagen es demasiado pequeña, redimensionarla manteniendo la relación de aspecto
        if width < min_size or height < min_size:
            # Calcular la escala necesaria
            scale = max(min_size / width, min_size / height)
            new_width = int(width * scale)
            new_height = int(height * scale)
            img = img.resize((new_width, new_height), Image.LANCZOS)

        area = width * height
        max_area = 1800000
        if area > max_area:
            # Ajustar el factor de escala según el tamaño
            scale = math.sqrt(max_area / area)
            
            new_width = int(width * scale)
            new_height = int(height * scale)
            #print(f"Redimensionando a {new_width}x{new_height} (factor: {scale:.2f})")
            img = img.resize((new_width, new_height), Image.LANCZOS)

        # Build the prompt, using document metadata
        #anchor_text = get_anchor_text("./dev/pdf/9284.pdf", 1, pdf_engine="pdfreport", target_length=4000)
        anchor_text = ""
        prompt = build_finetuning_prompt(anchor_text)
        #pil_img = Image.fromarray(img) if img.shape[-1] == 3 else Image.fromarray(img, mode="L")
        buffer = BytesIO()
        #pil_img.save(buffer, format="PNG")
        img.save(buffer, format="PNG")
        img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
        #img_base64 = base64.b64encode(img).decode('utf-8')
    
        # Build the full prompt
        messages = [{"role": "user",
                        "content": [
                            {"type": "text", "text": prompt},
                            {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_base64}"}},
                        ],}]
        
        # Apply the chat template and processor
        in_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        
        inputs = processor(
            text=[in_text],
            images=[img],
            padding=True,
            return_tensors="pt",
        )
        inputs = {key: value.to(device) for (key, value) in inputs.items()}
        del img_base64, buffer
        torch.cuda.empty_cache()  # Liberar memoria no utilizada
         
        # Generate the output
        output = model.generate(
            **inputs,
            temperature=0.7,       # Aumentar ligeramente para más variedad
            top_p=0.7,             # Ampliar el conjunto de tokens considerados
            max_new_tokens=2048,
            num_return_sequences=1,
            do_sample=True,
            use_cache=True  # Asegurar que el caché está activado
       )
        # Decode the output
        prompt_length = inputs["input_ids"].shape[1]
        new_tokens = output[:, prompt_length:]
        data = processor.tokenizer.batch_decode(
            new_tokens, skip_special_tokens=True
        )

        del inputs, output, new_tokens
        torch.cuda.empty_cache()
        #print('[olmo_output]\n', data)
        json_str = data[0]
        cleaned_text = ""
        try:
            # Intentar parsear el JSON primero
            json_obj = json.loads(json_str)
            
            # Si existe el campo natural_text, limpiarlo de repeticiones
            if "natural_text" in json_obj and json_obj["natural_text"] is not None:
                out_text = remove_repetitions(json_obj["natural_text"], 1)
            
        except json.JSONDecodeError:
            # Si no se puede parsear, intentar reparar el JSON antes
            cleaned_text = repair_json_attempt(json_str)
            if cleaned_text is not None:
                out_text = cleaned_text
    except Exception as e:
        print(f"Error[ocr_olmo]: {str(e)}")
    return out_text

### PROCESS FLOW #################################################################

def process_images(engine, images):
    text_list = []
    full_text = ""
    
    ocr_function_name = f"ocr_{engine}"
    if ocr_function_name not in globals() or not images:
        return ""
    
    # Procesar cada imagen recortada con el motor OCR especificado
    for cropped_img in images:
        cropped_array = np.array(cropped_img)
        try:
            # Seleccionar formato correcto según el motor
            if engine in ["paddle", "doctr"]:
                text = globals()[ocr_function_name](cropped_array)
            else:
                text = globals()[ocr_function_name](cropped_img)
            # Procesar el texto según su tipo
            if isinstance(text, dict) and 'text' in text:
                processed_text = text['text'].strip()
            elif isinstance(text, str):
                processed_text = text.strip()
            else:
                processed_text = str(text).strip()

            if processed_text:
                text_list.append(processed_text)
        except Exception as e:
            print(f"Error en OCR {engine}: {str(e)}")
    
    # Unir los textos
    if text_list:
        full_text = "\n\n".join(text_list)
    return full_text

def show_img(img, title=""):
    plt.imshow(img, cmap='gray')  # Especificamos escala de grises
    plt.title(title)
    plt.axis('off')  # Oculta los ejes
    plt.show()


def process_layout(page_num):
    img_path = f"page_{page_num}_.png"
    
    try:
        # Cargar imagen
        img = Image.open(img_path)
        img_np = np.array(img)
        
        # Predicción de layout
        with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
            layout_predictions = layout_predictor([img])
        
        # Ordenar cajas por posición
        boxes = sorted(layout_predictions[0].bboxes, key=lambda box: box.position)
        
        # Procesar cada caja individualmente
        cropped_images = []
        
        for box in boxes:
            x1, y1, x2, y2 = map(int, box.bbox)
            cropped_array = img_np[y1:y2, x1:x2]
            
            if cropped_array.size == 0:
                continue
                
            cropped_img = Image.fromarray(cropped_array)
            cropped_images.append(cropped_img)
        
        # Detectar si hay columnas (más de una caja en la misma línea horizontal)
        has_columns = False
        if len(boxes) > 1:
            # Ordenamos por coordenada y
            y_sorted_boxes = sorted(boxes, key=lambda box: box.bbox[1])
            
            # Verificamos si hay cajas con superposición vertical significativa
            for i in range(len(y_sorted_boxes) - 1):
                box1 = y_sorted_boxes[i]
                box2 = y_sorted_boxes[i + 1]
                
                # Si hay superposición vertical significativa entre cajas, consideramos que hay columnas
                y1_1, y2_1 = box1.bbox[1], box1.bbox[3]
                y1_2, y2_2 = box2.bbox[1], box2.bbox[3]
                
                overlap = min(y2_1, y2_2) - max(y1_1, y1_2)
                if overlap > 0 and overlap / min(y2_1 - y1_1, y2_2 - y1_2) > 0.3:
                    has_columns = True
                    break
        
        # Si hay imágenes recortadas, combinarlas
        combined_img = None
        if cropped_images:
            # Calcular la altura total y el ancho máximo
            total_height = sum(img.height for img in cropped_images)
            max_width = max(img.width for img in cropped_images)
            
            # Crear una nueva imagen con PIL
            combined_img = Image.new('RGB', (max_width, total_height), (255, 255, 255))
            
            # Pegar cada fragmento en la imagen combinada
            current_y = 0
            for cropped_img in cropped_images:
                combined_img.paste(cropped_img, (0, current_y))
                current_y += cropped_img.height
                if img_show:
                    show_img(cropped_img, "join")
            if img_show:
                show_img(combined_img, "combined")
        
        return {
            "combined_image": combined_img,
            "cropped_images": cropped_images,
            "has_columns": has_columns
        }
        
    except Exception as e:
        print(f"Error procesando layout de página {page_num}: {str(e)}")
        return {
            "combined_image": None,
            "cropped_images": [],
            "has_columns": False
        }

def read_pdf(doc, Ref, ocr_path2):
    text_list = []
    try:
        for page_num in range(len(doc)):
    #        pix = page.get_pixmap(alpha=False)
    #        pix.save(output_png_path)
           
            page = doc.load_page(page_num)
            mat = fitz.Matrix(zoom, zoom)
            pix = page.get_pixmap(matrix=mat, alpha=False)
            #pix.save(f"page_{page_num}_.png")
            img_data = pix.tobytes("png")
            img_pil = Image.open(BytesIO(img_data)).convert('L')
            img_cv2 = np.array(img_pil)
            img = preprocess(img_cv2, [1, 3, 5, 8])
            cv2.imwrite(f"page_{page_num}_.png", img)
                    #cv2.imwrite(f"page_{page_num}_.png", img_cv2)
            #resultados = analizar_img(f"page_{page_num}_.png")
            #print(resultados)
                    #probar_combinaciones_top10(img_cv2, Ref)
                    #continue

            layout = process_layout(page_num)
            #print("has_columns", layout["has_columns"])
            if "gemini" in ocr_model:
                engine = "gemini"
                if layout["has_columns"]:
                    text = globals()[f"ocr_{engine}"](layout["combined_image"])
                else:
                    text = globals()[f"ocr_{engine}"](Image.fromarray(img)) # layout["combined_image"]
                text = globals()[f"postprocess_{engine}"](text)
                verify_text(engine, text, Ref)
                text_list.append(text)
            if "paddle" in ocr_model:
                engine = "paddle"
                text = process_images(engine, layout["cropped_images"])
                text = globals()[f"postprocess_{engine}"](text)
                verify_text(engine, text, Ref)
                text_list.append(text)
            if "doctr" in ocr_model:
                engine = "doctr"
                text = process_images(engine, layout["cropped_images"])
                text = globals()[f"postprocess_{engine}"](text)
                verify_text(engine, text, Ref)
                text_list.append(text)
            if "surya" in ocr_model:
                engine = "surya"
                text = process_images(engine, layout["cropped_images"])
                text = globals()[f"postprocess_{engine}"](text)
                verify_text(engine, text, Ref)
                text_list.append(text)
                text = globals()[f"ocr_{engine}"](Image.fromarray(img))
                text = globals()[f"postprocess_{engine}"](text)
                verify_text(engine, text, Ref)
                text_list.append(text)
            if "olmocr" in ocr_model:
                engine = "olmocr"
                text = process_images(engine, layout["cropped_images"])
                text = globals()[f"postprocess_{engine}"](text)
                verify_text(engine, text, Ref)
                text_list.append(text)
                text = globals()[f"ocr_{engine}"](Image.fromarray(img)) # layout["combined_image"]
                text = globals()[f"postprocess_{engine}"](text)
                verify_text(engine, text, Ref)
                text_list.append(text)

            text = "\n".join([f"\n<Text>\n{s}\n<\\Text>\n" for s in text_list])
            save_text_file(text, ocr_path2)

            if text_list and fix_model:
                
                if fix_with_images:
                    text_fixed = postprocess(fix_text_with_images(text, page_num))
                    verify_text(f"{fix_model[0]} w/img", text_fixed, Ref)
                else:
                    text_fixed = postprocess(fix_text(text))
                    verify_text(fix_model[0], text_fixed, Ref)

            if len(pos_model)>0:
                text = pos_remote(text, pos_model[0])
                verify_text(f"pos_{pos_model[0]}", text, Ref)
            #ocr_image(img, Ref, page_num)
    except Exception as e:
        print(f"Error[read_pdf]: {str(e)}")

def process_directory(directory, directory_ocr, zoom):
    for file in os.scandir(directory):
        if file.is_file() and file.name.lower().endswith('.pdf'):
            file_path = file.path
            base_name = os.path.splitext(file.name)[0]
            ocr_name = base_name + '.txt'
            ocr_path = os.path.join(directory_ocr, ocr_name)

            ocr_name2 = base_name + '_.txt'
            ocr_path2 = os.path.join(directory_ocr, ocr_name2)

            print(file_path)
            if not os.path.exists(ocr_path):
                print(f"Error: El archivo {ocr_path} no existe.")
                continue
            try:
                with open(ocr_path, 'r', encoding='utf-8') as file:
                    Ref = file.read()
                    doc = fitz.open(file_path)
                    read_pdf(doc, Ref, ocr_path2)
            except Exception as e:
                print(f"Error al leer el archivo: {str(e)}")
    
directory = "./dev/test2/"
directory_ocr = "./dev/test2/"

zoom = 2
upscale = 1.2  # (max_upscale)
img_show = False
debug = False
fix_with_images = False

print("[Memory status]", check_mem())
process_directory(directory, directory_ocr, zoom)

if "olmocr" in ocr_model:
    del(model)
    del(processor)
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    gc.collect()

[Memory status] 5494MiB / 32760MiB - 28064153600
./dev/test2/5.pdf
[surya] CER:5.41 - WER:13.71
[surya] CER:10.42 - WER:26.29
./dev/test2/32.pdf
[surya] CER:26.63 - WER:35.36
[surya] CER:5.06 - WER:15.47
./dev/test2/55.pdf
[surya] CER:21.31 - WER:54.38
[surya] CER:10.44 - WER:36.25
./dev/test2/11.pdf
[surya] CER:25.06 - WER:48.99
[surya] CER:7.83 - WER:26.85
./dev/test2/26.pdf
[surya] CER:53.37 - WER:58.74
[surya] CER:35.79 - WER:62.94
./dev/test2/77.pdf
Error querying the remote LLM service.
[surya] CER:100.00 - WER:100.00


Exception ignored in: <function SyncHttpxClientWrapper.__del__ at 0x75a38132af20>
Traceback (most recent call last):
  File "/home/tmp/anaconda3/lib/python3.12/site-packages/openai/_base_client.py", line 778, in __del__
    def __del__(self) -> None:

KeyboardInterrupt: 


KeyboardInterrupt: 