In [1]:
import os
import subprocess
import math
import numpy as np
from PIL import Image
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML

# métricas externas
try:
    from skimage.metrics import structural_similarity as ssim_sk
except Exception as e:
    ssim_sk = None  # validaremos más abajo

import torch
import lpips  # asumiste que ya está instalado

# === PATHS GLOBALES ===
BASE_PATH = os.getcwd()
SRC_DIR = os.path.join(BASE_PATH, "model-repos/EDSR-PyTorch/src")
SWINIR_SRC_SCRIPT = os.path.join(BASE_PATH, "model-repos/SwinIR/main_test_swinir.py")
INPUT_DIR = os.path.join(BASE_PATH, "inputs-Demo")
INPUT_RELATIVE_DIR = os.path.join(".", os.path.basename(INPUT_DIR))
BASE_OUTPUT_DIR = BASE_PATH
RESULTS_DIR = os.path.join(BASE_OUTPUT_DIR, "results-Demo")
RESULTS_RELATIVE_DIR = os.path.join(".", os.path.basename(RESULTS_DIR))
MODELS_DIR = os.path.join(BASE_PATH, "model")
MODEL_PATH = os.path.join(BASE_PATH, "model/model_best.pt")
SWINIR_MODEL_PATH = os.path.join(BASE_PATH, "model/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth")
INPUT_HR_DIR = os.path.join(BASE_PATH, "inputs-HR-Demo")
INPUT_HR_RELATIVE_DIR = os.path.join(".", os.path.basename(INPUT_HR_DIR))

os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(INPUT_HR_DIR, exist_ok=True)

# === DOWNLOAD MODELS ===
import urllib.request

# === VERIFICAR Y DESCARGAR MODELOS SI NO EXISTEN ===
os.makedirs(MODELS_DIR, exist_ok=True)

# lista de modelos con URL y path destino
model_files = [
    {
        "url": "https://huggingface.co/UbaUser/EDSR-DF2K-X2/resolve/main/model_best.pt",
        "path": MODEL_PATH
    },
    {
        "url": "https://huggingface.co/UbaUser/SWINIR-DF2K-X2/resolve/main/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth",
        "path": SWINIR_MODEL_PATH
    }
]

for m in model_files:
    if not os.path.exists(m["path"]):
        print(f"Descargando {os.path.basename(m['path'])} desde Hugging Face...")
        try:
            urllib.request.urlretrieve(m["url"], m["path"])
            print(f"Descargado correctamente: {m['path']}")
        except Exception as e:
            raise RuntimeError(f"No se pudo descargar {m['path']} desde {m['url']}. Error: {e}")

# === WIDGETS ===
selector_modelo = widgets.Dropdown(
    options=['EDSR', 'SwinIR'],
    value='SwinIR',
    description='Modelo:',
    style={'description_width': 'initial'}
)

uploader = widgets.FileUpload(
    accept='.png, .jpg, .jpeg',
    multiple=False,
    description='Seleccionar imagen LR',
    style={'button_color': '#4CAF50', 'text_color': 'white'},
    layout=widgets.Layout(width='250px', height='40px'),
    tooltip="Seleccioná una imagen LR (.png, .jpg, .jpeg) para mejorar su resolución."
)

# uploader para HR (no visible hasta que haya un SR)
hr_uploader = widgets.FileUpload(
    accept='.png, .jpg, .jpeg',
    multiple=False,
    description='Subir imagen HR (opcional)',
    style={'button_color': '#2196F3', 'text_color': 'white'},
    layout=widgets.Layout(width='300px', height='40px'),
    tooltip="Subí la imagen HR correspondiente para calcular PSNR / SSIM / LPIPS."
)

# contenedor para mostrar/ocultar el uploader HR después de procesar
hr_box = widgets.HBox([])

output_box = widgets.Output()

# === ESTILOS ===
display(HTML("""
<style>
    body { background-color: #f6f8fa; }
    .app-container { font-family: 'Segoe UI', sans-serif; text-align: center; margin-top: 30px; color: #333; }
    .title { font-size: 26px; font-weight: bold; margin-bottom: 8px; }
    .subtitle { font-size: 15px; color: #555; margin-bottom: 25px; }
    .result-box { display: flex; justify-content: center; align-items: flex-start; gap: 60px; margin-top: 40px; flex-wrap: wrap; }
    .img-col { text-align: center; }
    .img-box { border: 2px dashed #ccc; border-radius: 10px; width: 500px; height: 500px; background-color: #f9f9f9;
               display: flex; justify-content: center; align-items: center; margin-bottom: 10px; position: relative; overflow: hidden; }
    .img-title { font-weight: bold; font-size: 15px; margin-bottom: 8px; }
    .placeholder { color: #aaa; font-size: 13px; text-align: center; }
    .processing-dots { font-size: 15px; color: #4CAF50; text-align: center; }
    .processing-dots .dot { font-size: 30px; font-weight: bold; animation: blink 1.5s infinite step-start; }
    .processing-dots .dot:nth-child(2) { animation-delay: 0.3s; }
    .processing-dots .dot:nth-child(3) { animation-delay: 0.6s; }
    @keyframes blink { 0%, 20% { opacity: 0; } 40% { opacity: 1; } 100% { opacity: 0; } }
    .metrics-box { text-align: left; font-family: monospace; margin-top: 12px; }
    .footer { text-align: center; margin-top: 40px; color: #777; font-size: 13px; line-height: 1.6em; max-width: 700px;
              margin-left: auto; margin-right: auto; }
    .footer a { color: #4CAF50; text-decoration: none; }
    table.metrics { border-collapse: collapse; margin-top: 10px; }
    table.metrics td, table.metrics th { border: 1px solid #ddd; padding: 8px; }
    table.metrics th { background-color: #f2f2f2; font-weight: bold; }
</style>
"""))

# === FUNCIONES AUXILIARES ===
def limpiar_carpetas_on_lr():
    """Limpia INPUT_DIR, RESULTS_DIR y INPUT_HR_DIR al subir una nueva LR."""
    for folder in [INPUT_DIR, RESULTS_DIR, INPUT_HR_DIR]:
        if not os.path.exists(folder):
            os.makedirs(folder, exist_ok=True)
        for f in os.listdir(folder):
            path = os.path.join(folder, f)
            if os.path.isfile(path):
                os.remove(path)

def limpiar_carpetas_simple():
    """Versión anterior - por compatibilidad si se necesita."""
    for folder in [INPUT_DIR, RESULTS_DIR]:
        for f in os.listdir(folder):
            path = os.path.join(folder, f)
            if os.path.isfile(path):
                os.remove(path)

# --- Función para EDSR ---
def procesar_imagen_edsr(input_path):
    cmd = [
        "python", "main.py",
        "--data_test", "Demo",
        "--dir_demo", INPUT_DIR,
        "--scale", "2",
        "--pre_train", MODEL_PATH,
        "--test_only",
        "--save_results",
        "--save", BASE_OUTPUT_DIR
    ]
    subprocess.run(cmd, cwd=SRC_DIR, check=True)

    filename = os.path.basename(input_path)
    name, _ = os.path.splitext(filename)
    result_file = None
    for f in os.listdir(RESULTS_DIR):
        if name in f:
            result_file = os.path.join(RESULTS_RELATIVE_DIR, f)
            break

    if not result_file or not os.path.exists(result_file):
        raise FileNotFoundError(f"No se encontró la imagen procesada en {RESULTS_DIR}")

    return result_file

# --- Función para SwinIR ---
def procesar_imagen_swinir(input_path):
    cmd = [
        "python", SWINIR_SRC_SCRIPT,
        "--task", "classical_sr",
        "--scale", "2",
        "--training_patch_size", "64",
        "--model_path", SWINIR_MODEL_PATH,
        "--folder_lq", INPUT_DIR,
        "--save", RESULTS_DIR
    ]
    subprocess.run(cmd, check=True)

    filename = os.path.basename(input_path)
    name, _ = os.path.splitext(filename)
    result_file = None
    for f in os.listdir(RESULTS_DIR):
        if name in f:
            result_file = os.path.join(RESULTS_RELATIVE_DIR, f)
            break

    if not result_file or not os.path.exists(result_file):
        raise FileNotFoundError(f"No se encontró la imagen procesada en {RESULTS_DIR}")

    return result_file

# --- Función de despacho según modelo ---
def procesar_imagen(input_path, modelo):
    if modelo.lower() == 'edsr':
        return procesar_imagen_edsr(input_path)
    elif modelo.lower() == 'swinir':
        return procesar_imagen_swinir(input_path)
    else:
        raise ValueError(f"Modelo desconocido: {modelo}")

# --- Funciones de métricas ---
def compute_psnr(img1, img2):
    """img1, img2: numpy arrays float32 rango [0,1], mismas dimensiones"""
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    psnr = 20 * math.log10(1.0) - 10 * math.log10(mse)
    return psnr

def compute_ssim(img1, img2):
    """Usa skimage si está disponible; espera imágenes [0,1] float."""
    if ssim_sk is None:
        raise ImportError("scikit-image no está disponible. Instala con `pip install scikit-image` para usar SSIM.")
        # skimage espera imágenes en rango [0,1]
    s = ssim_sk(img1, img2, data_range=1.0, channel_axis=-1, win_size=11)
    return s

# preparar modelo LPIPS (cargar una vez)
_lpips_model = None
def get_lpips_model():
    global _lpips_model
    if _lpips_model is None:
        _lpips_model = lpips.LPIPS(net='alex')  # puede tardar un poco la primera vez
        if torch.cuda.is_available():
            _lpips_model = _lpips_model.cuda()
    return _lpips_model

def compute_lpips(img1_pil, img2_pil):
    """Recibe PIL images RGB. Devuelve float LPIPS.
       lpips espera tensores [-1,1] 1x3xHxW en torch.float.
    """
    model = get_lpips_model()
    # convertir a tensor
    def pil_to_tensor(img):
        arr = np.asarray(img.convert('RGB')).astype(np.float32) / 255.0
        # [H,W,3] -> [3,H,W]
        t = torch.from_numpy(arr).permute(2,0,1).unsqueeze(0)  # 1x3xHxW
        # normalizar a [-1,1]
        t = (t - 0.5) * 2.0
        return t

    t1 = pil_to_tensor(img1_pil)
    t2 = pil_to_tensor(img2_pil)
    if torch.cuda.is_available():
        t1 = t1.cuda()
        t2 = t2.cuda()
    with torch.no_grad():
        val = model(t1, t2)
    return float(val.cpu().item())

# === EVENTO PRINCIPAL: subir LR ===
def on_upload_change(change):
    if not uploader.value:
        return

    # limpiar INPUTS y RESULTS y HRs (según tu requerimiento)
    limpiar_carpetas_on_lr()

    modelo_actual = selector_modelo.value

    # tu uploader.value parece usar la estructura que venías usando antes
    # intentar obtener archivo subido
    try:
        uploaded_file = list(uploader.value.values())[0] if isinstance(uploader.value, dict) else uploader.value[0]
    except Exception:
        # fallback: intentar igual que antes
        uploaded_file = uploader.value[0]

    content = uploaded_file['content'] if isinstance(uploaded_file, dict) else uploaded_file['content']
    filename = uploaded_file['name'] if isinstance(uploaded_file, dict) else uploaded_file['name']

    input_path = os.path.join(INPUT_RELATIVE_DIR, filename)
    # guardar LR en inputs-Demo
    with open(input_path, "wb") as f:
        f.write(content.tobytes())

    # Mostrar animación de procesamiento
    with output_box:
        clear_output(wait=True)
        display(HTML(f"""
        <div class='result-box'>
            <div class='img-col'>
                <div class='img-title'>Imagen LR (original)</div>
                <div class='img-box'><img src='{input_path}' width='250'></div>
            </div>
            <div class='img-col'>
                <div class='img-title'>Resultado mejorado</div>
                <div class='img-box'>
                    <div class='processing-dots'>
                        <div>Procesando imagen,</div>
                        <div>por favor espere</div>
                        <div>
                            <span class='dot'>.</span>
                            <span class='dot'>.</span>
                            <span class='dot'>.</span>
                        </div>
                    </div>
                </div>
            </div>
        </div>
        """))

    try:
        # ejecutar proceso (EDSR o SwinIR)
        result_path = procesar_imagen(input_path, modelo_actual)

        # Mostrar LR + SR y habilitar uploader HR
        with output_box:
            clear_output(wait=True)
            display(HTML(f"""
            <div class='result-box' id='result-area'>
                <div class='img-col'>
                    <div class='img-title'>Imagen LR (original)</div>
                    <div class='img-box'><img src='{input_path}' width='250'></div>
                </div>
                <div class='img-col'>
                    <div class='img-title'>Resultado mejorado ({modelo_actual})</div>
                    <div class='img-box'><img src='{result_path}' width='250'></div>
                </div>
            </div>

            <div style='text-align:center;margin-top:14px;'>
                <div style='font-weight:bold;margin-bottom:6px;'>Evaluación de calidad (opcional)</div>
                <div style='color:#555;'>Si querés comparar contra la imagen HR real, subila abajo — el cálculo de PSNR/SSIM/LPIPS se realizará automáticamente.</div>
            </div>
            """))
        # mostrar el uploader HR debajo (lo insertamos en hr_box)
        hr_box.children = [hr_uploader]  # esto lo hará visible en la UI
    except Exception as e:
        with output_box:
            clear_output(wait=True)
            display(HTML(f"<div style='color:red;font-weight:bold;'>Error al procesar la imagen:<br>{str(e)}</div>"))

    # resetear uploader LR
    uploader.value = ()
    uploader._counter = 0

# === EVENTO: subir HR ===
def on_hr_upload_change(change):
    if not hr_uploader.value:
        return

    # obtener HR subido
    try:
        uploaded_file = list(hr_uploader.value.values())[0] if isinstance(hr_uploader.value, dict) else hr_uploader.value[0]
    except Exception:
        uploaded_file = hr_uploader.value[0]

    content = uploaded_file['content'] if isinstance(uploaded_file, dict) else uploaded_file['content']
    filename = uploaded_file['name'] if isinstance(uploaded_file, dict) else uploaded_file['name']

    hr_path = os.path.join(INPUT_HR_RELATIVE_DIR, filename)
    # guardar HR
    with open(hr_path, "wb") as f:
        f.write(content.tobytes())

    # detectar LR y SR actuales en carpetas (según tu flujo, debe haber exactamente 1 en cada)
    lr_files = [f for f in os.listdir(INPUT_DIR) if os.path.isfile(os.path.join(INPUT_DIR, f))]
    sr_files = [f for f in os.listdir(RESULTS_DIR) if os.path.isfile(os.path.join(RESULTS_DIR, f))]

    if len(lr_files) == 0 or len(sr_files) == 0:
        with output_box:
            clear_output(wait=True)
            display(HTML("<div style='color:red;font-weight:bold;'>No se encontró LR o SR actual para comparar. Asegurate de procesar primero la LR.</div>"))
        # reset hr uploader
        hr_uploader.value = ()
        hr_uploader._counter = 0
        return

    # tomar la única LR y SR
    lr_file = os.path.join(INPUT_RELATIVE_DIR, lr_files[0])
    sr_file = os.path.join(RESULTS_RELATIVE_DIR, sr_files[0])

    # cargar imágenes PIL
    pil_sr = Image.open(sr_file).convert('RGB')
    pil_hr = Image.open(hr_path).convert('RGB')

    # Si las dimensiones no coinciden, redimensionar HR a la de SR para la comparación
    if pil_hr.size != pil_sr.size:
        pil_hr_resized = pil_hr.resize(pil_sr.size, Image.BICUBIC)
    else:
        pil_hr_resized = pil_hr

    # convertir a numpy [0,1] float32 para PSNR y SSIM
    sr_np = np.asarray(pil_sr).astype(np.float32) / 255.0
    hr_np = np.asarray(pil_hr_resized).astype(np.float32) / 255.0

    # calcular PSNR, SSIM, LPIPS
    try:
        psnr_val = compute_psnr(sr_np, hr_np)
    except Exception as e:
        psnr_val = None

    try:
        ssim_val = compute_ssim(sr_np, hr_np)
    except Exception as e:
        ssim_val = None

    try:
        lpips_val = compute_lpips(pil_sr, pil_hr_resized)
    except Exception as e:
        lpips_val = None

    # mostrar resultados
    with output_box:
        clear_output(wait=True)
        # mostrar LR, SR y HR lado a lado + tabla de métricas
        html_metrics = f"""
        <div class='result-box'>
            <div class='img-col'>
                <div class='img-title'>Imagen LR (original)</div>
                <div class='img-box'><img src='{lr_file}' width='250'></div>
            </div>
            <div class='img-col'>
                <div class='img-title'>Resultado mejorado</div>
                <div class='img-box'><img src='{sr_file}' width='250'></div>
            </div>
            <div class='img-col'>
                <div class='img-title'>Imagen HR (ground-truth)</div>
                <div class='img-box'><img src='{hr_path}' width='250'></div>
            </div>
        </div>
        """

        # preparar fila de métricas (formateo)
        def fmt(x):
            if x is None:
                return "N/A"
            if isinstance(x, float):
                if math.isinf(x):
                    return "inf"
                return f"{x:.4f}"
            return str(x)

        html_table = f"""
        <div style='display:flex;justify-content:center;margin-top:12px;'>
            <table class='metrics'>
                <tr><th>Métrica</th><th>Valor</th></tr>
                <tr><td>PSNR</td><td>{fmt(psnr_val)} dB</td></tr>
                <tr><td>SSIM</td><td>{fmt(ssim_val)}</td></tr>
                <tr><td>LPIPS</td><td>{fmt(lpips_val)}</td></tr>
            </table>
        </div>
        """

        display(HTML(html_metrics + html_table))

    # reset HR uploader para permitir re-subir otra HR si hace falta
    hr_uploader.value = ()
    hr_uploader._counter = 0

# conectar observadores
uploader.observe(on_upload_change, names='value')
hr_uploader.observe(on_hr_upload_change, names='value')

# === INTERFAZ PRINCIPAL (layout) ===
display(HTML("""
<div class='app-container'>
    <div class='title'>Prototipo de Mejora de Resolución de Imágenes</div>
    <div class='subtitle'>Subí una imagen LR para mejorar su resolución. Luego podrás subir la HR para calcular métricas (opcional).</div>
</div>
"""))

display(
    widgets.VBox([
        widgets.HBox([selector_modelo], layout=widgets.Layout(justify_content='center')),
        widgets.HBox([uploader], layout=widgets.Layout(justify_content='center', width='100%')),
        widgets.HBox([output_box], layout=widgets.Layout(justify_content='center', width='100%')),
        hr_box  # inicialmente vacío; se llenará con hr_uploader luego del procesamiento
    ])
)

# Placeholder inicial
with output_box:
    display(HTML("""
    <div class='result-box'>
        <div class='img-col'>
            <div class='img-title'>Imagen original</div>
            <div class='img-box'><div class='placeholder'>No hay imagen seleccionada</div></div>
        </div>
        <div class='img-col'>
            <div class='img-title'>Resultado mejorado</div>
            <div class='img-box'><div class='placeholder'>Esperando imagen...</div></div>
        </div>
    </div>
    """))

# === PIE DE PÁGINA FIJO ===
display(HTML("""
<div class='footer'>
    <p><b>Autor:</b> Juan Pérez – UBA Laboratorio de IA</p>
    <p><b>Repositorio GitLab:</b> <a href='https://gitlab.com/juanperez/image-upscale-proto' target='_blank'>gitlab.com/juanperez/image-upscale-proto</a></p>
    <p><b>Descripción:</b> Prototipo demostrativo de mejora de resolución de imágenes mediante EDSR y SwinIR.</p>
</div>
"""))


VBox(children=(HBox(children=(Dropdown(description='Modelo:', index=1, options=('EDSR', 'SwinIR'), style=Descr…