### Training Code

Paste your Python training code here.

In [3]:
import os
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
import cv2
import matplotlib.pyplot as plt

# Configuración
IMG_SIZE = 224  # Tamaño de entrada para MobileNet
BATCH_SIZE = 16  # Aumentado para mejor generalización
EPOCHS = 150  # Más épocas para aprendizaje profundo
LEARNING_RATE = 0.0001

# Rutas
# ¡ADVERTENCIA! Estas rutas deben ser actualizadas para el entorno de Colab.
# DATA_DIR = r"C:\Users\Eduardo\Documents\Maestria\TFM\Proyecto\B. Disease Grading\B. Disease Grading\1. Original Images\a. Training Set"
# LABELS_FILE = r"C:\Users\Eduardo\Documents\Maestria\TFM\Proyecto\B. Disease Grading\B. Disease Grading\2. Groundtruths\a. IDRiD_Disease Grading_Training Labels.csv"

# EJEMPLO: Si subes tus archivos a una carpeta 'data' en Colab
# Puedes montar Google Drive o subir directamente los archivos a /content/
DATA_DIR = "/content/sample_data/a. Training Set" # Ruta de las imágenes actualizada
LABELS_FILE = "/content/sample_data/a. IDRiD_Disease Grading_Training Labels.csv" # Ruta actualizada

MODEL_SAVE_PATH = "models"

class DataLoader:
    """Clase para cargar y preprocesar los datos"""

    def __init__(self, labels_file, data_dir, img_size=224):
        self.labels_file = labels_file
        self.data_dir = data_dir
        self.img_size = img_size
        self.df = None

    def load_labels(self):
        """Carga el archivo CSV con las etiquetas"""
        # Intentar leer como CSV primero, si falla intentar Excel
        try:
            self.df = pd.read_csv(self.labels_file)
        except:
            self.df = pd.read_excel(self.labels_file)

        print(f"Dataset cargado: {len(self.df)} imágenes")
        print(f"\nColumnas: {self.df.columns.tolist()}")

        # Detectar el nombre de la columna de grado de retinopatía
        grade_column = None
        for col in self.df.columns:
            if 'retinopathy' in col.lower() and 'grade' in col.lower():
                grade_column = col
                break

        if grade_column is None:
            # Buscar columnas alternativas
            possible_columns = [col for col in self.df.columns if 'grade' in col.lower() or 'label' in col.lower()]
            if possible_columns:
                grade_column = possible_columns[0]
            else:
                print("ERROR: No se encontró la columna de grado de retinopatía")
                print(f"Columnas disponibles: {self.df.columns.tolist()}")
                return None

        self.grade_column = grade_column
        print(f"\nUsando columna: '{grade_column}'")
        print(f"\nDistribución de clases:")
        print(self.df[grade_column].value_counts().sort_index())
        return self.df

    def load_and_preprocess_image(self, image_name):
        """Carga y preprocesa una imagen"""
        # Si el nombre ya tiene extensión, usarlo directamente
        if any(image_name.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png']):
            image_path = os.path.join(self.data_dir, image_name)
            if os.path.exists(image_path):
                img = cv2.imread(image_path)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = cv2.resize(img, (self.img_size, self.img_size))
                img = tf.keras.applications.mobilenet_v2.preprocess_input(img)
                return img

        # Buscar la imagen con diferentes extensiones
        for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
            image_path = os.path.join(self.data_dir, image_name + ext)
            if os.path.exists(image_path):
                break
        else:
            raise FileNotFoundError(f"No se encontró la imagen: {image_name}")

        # Leer imagen
        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # Redimensionar
        img = cv2.resize(img, (self.img_size, self.img_size))

        # Normalizar a [-1, 1] (preprocesamiento de MobileNet)
        img = tf.keras.applications.mobilenet_v2.preprocess_input(img)

        return img

    def prepare_data(self, test_size=0.2, val_size=0.1):
        """Prepara los datos para entrenamiento"""
        if self.df is None:
            self.load_labels()

        # Detectar el nombre de la columna de imagen
        image_column = None
        for col in self.df.columns:
            if 'image' in col.lower() and 'name' in col.lower():
                image_column = col
                break

        if image_column is None:
            # Buscar la primera columna que parezca contener nombres de archivos
            for col in self.df.columns:
                if self.df[col].dtype == 'object':
                    image_column = col
                    break

        if image_column is None:
            image_column = self.df.columns[0]

        print(f"\nUsando columna de imágenes: '{image_column}'")

        # Cargar todas las imágenes
        images = []
        labels = []

        print("\nCargando imágenes...")
        for idx, row in self.df.iterrows():
            try:
                img = self.load_and_preprocess_image(row[image_column])
                images.append(img)
                labels.append(row[self.grade_column])

                if (idx + 1) % 10 == 0:
                    print(f"Procesadas {idx + 1}/{len(self.df)} imágenes")
            except Exception as e:
                print(f"Error cargando {row[image_column]}: {str(e)}")

        images = np.array(images)
        labels = np.array(labels)

        print(f"\nImágenes cargadas: {len(images)}")
        print(f"Shape de las imágenes: {images.shape}")

        # Split en train, validation y test
        X_train_val, X_test, y_train_val, y_test = train_test_split(
            images, labels, test_size=test_size, random_state=42, stratify=labels
        )

        val_size_adjusted = val_size / (1 - test_size)
        X_train, X_val, y_train, y_val = train_test_split(
            X_train_val, y_train_val, test_size=val_size_adjusted,
            random_state=42, stratify=y_train_val
        )

        print(f"\nDatos divididos:")
        print(f"  Train: {len(X_train)} imágenes")
        print(f"  Validation: {len(X_val)} imágenes")
        print(f"  Test: {len(X_test)} imágenes")

        return X_train, X_val, X_test, y_train, y_val, y_test


class MobileNetClassifier:
    """Modelo MobileNet para clasificación de retinopatía diabética"""

    def __init__(self, num_classes=5, img_size=224):
        self.num_classes = num_classes
        self.img_size = img_size
        self.model = None

    def build_model(self):
        """Construye el modelo usando MobileNetV2 con transfer learning"""
        # Cargar MobileNetV2 pre-entrenado en ImageNet
        base_model = MobileNetV2(
            input_shape=(self.img_size, self.img_size, 3),
            include_top=False,
            weights='imagenet'
        )

        # Congelar las primeras capas del modelo base
        base_model.trainable = False

        # Añadir capas de clasificación personalizadas
        x = base_model.output
        x = GlobalAveragePooling2D()(x)
        x = Dense(512, activation='relu')(x)
        x = Dropout(0.5)(x)
        x = Dense(256, activation='relu')(x)
        x = Dropout(0.3)(x)
        predictions = Dense(self.num_classes, activation='softmax')(x)

        self.model = Model(inputs=base_model.input, outputs=predictions)

        print("\n=== Arquitectura del Modelo ===")
        print(f"Capas del modelo base (MobileNetV2): {len(base_model.layers)}")
        print(f"Total de capas: {len(self.model.layers)}")

        return self.model

    def compile_model(self, learning_rate=0.0001):
        """Compila el modelo"""
        self.model.compile(
            optimizer=Adam(learning_rate=learning_rate),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy', tf.keras.metrics.SparseCategoricalAccuracy()]
        )
        print("\nModelo compilado con éxito")

    def unfreeze_layers(self, num_layers=20):
        """Descongela las últimas capas del modelo base para fine-tuning"""
        # Encontrar la capa MobileNetV2 en el modelo
        base_model = None
        for layer in self.model.layers:
            if isinstance(layer, tf.keras.Model) and 'mobilenet' in layer.name.lower():
                base_model = layer
                break

        if base_model is None:
            print("Advertencia: No se encontró MobileNetV2, descongelando todo el modelo")
            self.model.trainable = True
            return

        base_model.trainable = True

        # Congelar todas excepto las últimas num_layers
        for layer in base_model.layers[:-num_layers]:
            layer.trainable = False

        print(f"\nDescongeladas las últimas {num_layers} capas para fine-tuning")


def create_data_augmentation():
    """Crea un generador de data augmentation"""
    train_datagen = ImageDataGenerator(
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True,
        vertical_flip=True,
        zoom_range=0.2,
        fill_mode='nearest'
    )
    return train_datagen


def plot_training_history(history, save_path='training_history.png'):
    """Visualiza el historial de entrenamiento"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Accuracy
    ax1.plot(history.history['accuracy'], label='Train Accuracy')
    ax1.plot(history.history['val_accuracy'], label='Val Accuracy')
    ax1.set_title('Model Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend()
    ax1.grid(True)

    # Loss
    ax2.plot(history.history['loss'], label='Train Loss')
    ax2.plot(history.history['val_loss'], label='Val Loss')
    ax2.set_title('Model Loss')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nGráfica guardada en: {save_path}")
    plt.close()


def main():
    """Función principal de entrenamiento"""
    print("=" * 60)
    print("ENTRENAMIENTO: Clasificación de Retinopatía Diabética")
    print("=" * 60)

    # Crear directorios necesarios
    os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

    # 1. Cargar datos
    print("\n[1/5] Cargando datos...")
    data_loader = DataLoader(LABELS_FILE, DATA_DIR, IMG_SIZE)
    X_train, X_val, X_test, y_train, y_val, y_test = data_loader.prepare_data()

    # Calcular pesos de clase para manejar desbalance
    class_weights = compute_class_weight(
        'balanced',
        classes=np.unique(y_train),
        y=y_train
    )
    class_weight_dict = dict(enumerate(class_weights))
    print(f"\nPesos de clase (para desbalance): {class_weight_dict}")

    # 2. Crear modelo
    print("\n[2/5] Creando modelo MobileNet...")
    classifier = MobileNetClassifier(num_classes=5, img_size=IMG_SIZE)
    model = classifier.build_model()
    classifier.compile_model(learning_rate=LEARNING_RATE)

    # 3. Configurar callbacks
    print("\n[3/5] Configurando callbacks...")
    callbacks = [
        ModelCheckpoint(
            os.path.join(MODEL_SAVE_PATH, 'best_model.h5'),
            monitor='val_accuracy',
            save_best_only=True,
            verbose=1
        ),
        EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True,
            verbose=1
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-7,
            verbose=1
        )
    ]

    # 4. Entrenar modelo (Fase 1: Feature extraction)
    print("\n[4/5] Entrenando modelo - Fase 1: Feature Extraction...")
    print(f"Epochs: {EPOCHS}, Batch Size: {BATCH_SIZE}")

    history1 = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=EPOCHS // 2,
        batch_size=BATCH_SIZE,
        callbacks=callbacks,
        class_weight=class_weight_dict,
        verbose=1
    )

    # 5. Fine-tuning
    print("\n[5/5] Entrenando modelo - Fase 2: Fine-tuning...")
    classifier.unfreeze_layers(num_layers=30)
    classifier.compile_model(learning_rate=LEARNING_RATE / 10)

    history2 = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=EPOCHS // 2,
        batch_size=BATCH_SIZE,
        callbacks=callbacks,
        class_weight=class_weight_dict,
        verbose=1
    )

    # Combinar historiales
    for key in history1.history.keys():
        history1.history[key].extend(history2.history[key])

    # 6. Evaluar en test set
    print("\n" + "=" * 60)
    print("EVALUACIÓN EN TEST SET")
    print("=" * 60)
    test_loss, test_accuracy, _ = model.evaluate(X_test, y_test, verbose=0)
    print(f"\nTest Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")

    # Predicciones en test
    y_pred = model.predict(X_test)
    y_pred_classes = np.argmax(y_pred, axis=1)

    # Matriz de confusión
    from sklearn.metrics import classification_report, confusion_matrix
    print("\n=== Classification Report ===")
    print(classification_report(y_test, y_pred_classes,
                                target_names=[f'Grade {i}' for i in range(5)]))

    print("\n=== Confusion Matrix ===")
    cm = confusion_matrix(y_test, y_pred_classes)
    print(cm)

    # Guardar resultados
    plot_training_history(history1, 'training_history.png')
    model.save(os.path.join(MODEL_SAVE_PATH, 'final_model.h5'))

    print("\n" + "=" * 60)
    print("ENTRENAMIENTO COMPLETADO")
    print("=" * 60)
    print(f"Modelo guardado en: {MODEL_SAVE_PATH}/final_model.h5")
    print(f"Mejor modelo en: {MODEL_SAVE_PATH}/best_model.h5")


if __name__ == "__main__":
    main()

ENTRENAMIENTO: Clasificación de Retinopatía Diabética

[1/5] Cargando datos...
Dataset cargado: 413 imágenes

Columnas: ['Image name', 'Retinopathy grade', 'Risk of macular edema ', 'Unnamed: 3', 'Unnamed: 4', 'Unnamed: 5', 'Unnamed: 6', 'Unnamed: 7', 'Unnamed: 8', 'Unnamed: 9', 'Unnamed: 10', 'Unnamed: 11']

Usando columna: 'Retinopathy grade'

Distribución de clases:
Retinopathy grade
0    134
1     20
2    136
3     74
4     49
Name: count, dtype: int64

Usando columna de imágenes: 'Image name'

Cargando imágenes...
Procesadas 10/413 imágenes
Procesadas 20/413 imágenes
Procesadas 30/413 imágenes
Procesadas 40/413 imágenes
Procesadas 50/413 imágenes
Procesadas 60/413 imágenes
Procesadas 70/413 imágenes
Procesadas 80/413 imágenes
Procesadas 90/413 imágenes
Procesadas 100/413 imágenes
Procesadas 110/413 imágenes
Procesadas 120/413 imágenes
Procesadas 130/413 imágenes
Procesadas 140/413 imágenes
Procesadas 150/413 imágenes
Procesadas 160/413 imágenes
Procesadas 170/413 imágenes
Procesad



[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 808ms/step - accuracy: 0.1984 - loss: 2.1542 - sparse_categorical_accuracy: 0.1984 - val_accuracy: 0.1667 - val_loss: 1.5407 - val_sparse_categorical_accuracy: 0.1667 - learning_rate: 1.0000e-04
Epoch 2/25
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 493ms/step - accuracy: 0.2247 - loss: 1.6217 - sparse_categorical_accuracy: 0.2247
Epoch 2: val_accuracy improved from 0.16667 to 0.30952, saving model to models/best_model.h5




[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 654ms/step - accuracy: 0.2268 - loss: 1.6250 - sparse_categorical_accuracy: 0.2268 - val_accuracy: 0.3095 - val_loss: 1.4507 - val_sparse_categorical_accuracy: 0.3095 - learning_rate: 1.0000e-04
Epoch 3/25
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 504ms/step - accuracy: 0.2440 - loss: 1.6492 - sparse_categorical_accuracy: 0.2440
Epoch 3: val_accuracy improved from 0.30952 to 0.42857, saving model to models/best_model.h5




[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 688ms/step - accuracy: 0.2443 - loss: 1.6488 - sparse_categorical_accuracy: 0.2443 - val_accuracy: 0.4286 - val_loss: 1.3621 - val_sparse_categorical_accuracy: 0.4286 - learning_rate: 1.0000e-04
Epoch 4/25
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 501ms/step - accuracy: 0.3786 - loss: 1.3854 - sparse_categorical_accuracy: 0.3786
Epoch 4: val_accuracy did not improve from 0.42857
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 601ms/step - accuracy: 0.3790 - loss: 1.3842 - sparse_categorical_accuracy: 0.3790 - val_accuracy: 0.3571 - val_loss: 1.3310 - val_sparse_categorical_accuracy: 0.3571 - learning_rate: 1.0000e-04
Epoch 5/25
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 511ms/step - accuracy: 0.3475 - loss: 1.5664 - sparse_categorical_accuracy: 0.3475




[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 671ms/step - accuracy: 0.4220 - loss: 1.3857 - sparse_categorical_accuracy: 0.4220 - val_accuracy: 0.4762 - val_loss: 1.2184 - val_sparse_categorical_accuracy: 0.4762 - learning_rate: 1.0000e-04
Epoch 7/25
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 567ms/step - accuracy: 0.4313 - loss: 1.2090 - sparse_categorical_accuracy: 0.4313
Epoch 7: val_accuracy did not improve from 0.47619
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 653ms/step - accuracy: 0.4319 - loss: 1.2102 - sparse_categorical_accuracy: 0.4319 - val_accuracy: 0.4286 - val_loss: 1.1999 - val_sparse_categorical_accuracy: 0.4286 - learning_rate: 1.0000e-04
Epoch 8/25
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 584ms/step - accuracy: 0.4951 - loss: 1.1731 - sparse_categorical_accuracy: 0.4951




[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 669ms/step - accuracy: 0.6579 - loss: 0.7967 - sparse_categorical_accuracy: 0.6579 - val_accuracy: 0.5238 - val_loss: 1.1072 - val_sparse_categorical_accuracy: 0.5238 - learning_rate: 5.0000e-05
Epoch 24/25
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 579ms/step - accuracy: 0.5865 - loss: 0.7692 - sparse_categorical_accuracy: 0.5865
Epoch 24: val_accuracy did not improve from 0.52381
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 658ms/step - accuracy: 0.5867 - loss: 0.7719 - sparse_categorical_accuracy: 0.5867 - val_accuracy: 0.4762 - val_loss: 1.1271 - val_sparse_categorical_accuracy: 0.4762 - learning_rate: 5.0000e-05
Epoch 25/25
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 567ms/step - accuracy: 0.5529 - loss: 0.8243 - sparse_categorical_accuracy: 0.55




Gráfica guardada en: training_history.png

ENTRENAMIENTO COMPLETADO
Modelo guardado en: models/final_model.h5
Mejor modelo en: models/best_model.h5




### Execution Code

Paste your Python execution code here.

In [2]:
import tensorflow as tf
import cv2
import numpy as np
import os
import sys

def load_model(model_path='models/best_model.h5'):
    """Carga el modelo entrenado"""
    if not os.path.exists(model_path):
        print(f"Error: No se encontró el modelo en {model_path}")
        return None

    model = tf.keras.models.load_model(model_path)
    print(f"✓ Modelo cargado desde: {model_path}")
    return model

def preprocess_image(image_path, img_size=224):
    """Preprocesa una imagen para predicción"""
    # Leer imagen
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"No se pudo leer la imagen: {image_path}")

    # Convertir a RGB
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Redimensionar
    img = cv2.resize(img, (img_size, img_size))

    # Normalizar (preprocesamiento de MobileNet)
    img = tf.keras.applications.mobilenet_v2.preprocess_input(img)

    # Añadir dimensión de batch
    img = np.expand_dims(img, axis=0)

    return img

def predict_image(model, image_path):
    """Hace predicción en una sola imagen"""
    # Preprocesar
    img = preprocess_image(image_path)

    # Predecir
    prediction = model.predict(img, verbose=0)

    # Obtener clase y confianza
    grade = np.argmax(prediction)
    confidence = np.max(prediction) * 100

    # Todas las probabilidades
    probabilities = prediction[0] * 100

    return grade, confidence, probabilities

def get_grade_description(grade):
    """Devuelve la descripción del grado de retinopatía"""
    descriptions = {
        0: "Sin retinopatía aparente",
        1: "Retinopatía no proliferativa leve",
        2: "Retinopatía no proliferativa moderada",
        3: "Retinopatía no proliferativa severa",
        4: "Retinopatía proliferativa"
    }
    return descriptions.get(grade, "Desconocido")

def predict_single_image(model_path, image_path):
    """Predice en una sola imagen y muestra resultados"""
    print("\n" + "=" * 60)
    print("PREDICCIÓN DE RETINOPATÍA DIABÉTICA")
    print("=" * 60)

    # Cargar modelo
    model = load_model(model_path)
    if model is None:
        return

    # Verificar imagen
    if not os.path.exists(image_path):
        print(f"Error: No se encontró la imagen: {image_path}")
        return

    print(f"\nImagen: {image_path}")

    # Hacer predicción
    try:
        grade, confidence, probabilities = predict_image(model, image_path)

        print("\n--- RESULTADO ---")
        print(f"Grado predicho: {grade}")
        print(f"Diagnóstico: {get_grade_description(grade)}")
        print(f"Confianza: {confidence:.2f}%")

        print("\n--- PROBABILIDADES POR CLASE ---")
        for i, prob in enumerate(probabilities):
            print(f"  Grado {i}: {prob:.2f}% - {get_grade_description(i)}")

        print("\n" + "=" * 60)

    except Exception as e:
        print(f"Error durante la predicción: {str(e)}")

def predict_batch(model_path, images_dir):
    """Predice en múltiples imágenes de un directorio"""
    print("\n" + "=" * 60)
    print("PREDICCIÓN BATCH - RETINOPATÍA DIABÉTICA")
    print("=" * 60)

    # Cargar modelo
    model = load_model(model_path)
    if model is None:
        return

    # Buscar imágenes
    extensions = ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']
    image_files = [f for f in os.listdir(images_dir)
                   if any(f.endswith(ext) for ext in extensions)]

    if not image_files:
        print(f"No se encontraron imágenes en: {images_dir}")
        return

    print(f"\nEncontradas {len(image_files)} imágenes")
    print("\nProcesando...\n")

    results = []

    for image_file in image_files:
        image_path = os.path.join(images_dir, image_file)

        try:
            grade, confidence, probabilities = predict_image(model, image_path)

            results.append({
                'image': image_file,
                'grade': grade,
                'confidence': confidence,
                'diagnosis': get_grade_description(grade)
            })

            print(f"✓ {image_file:30s} -> Grado {grade} ({confidence:.1f}%)")

        except Exception as e:
            print(f"✗ {image_file:30s} -> Error: {str(e)}")

    # Resumen
    print("\n" + "=" * 60)
    print("RESUMEN")
    print("=" * 60)

    if results:
        print(f"\nImágenes procesadas: {len(results)}")
        print("\nDistribución de grados:")

        from collections import Counter
        grade_counts = Counter([r['grade'] for r in results])
        for grade in sorted(grade_counts.keys()):
            count = grade_counts[grade]
            percentage = (count / len(results)) * 100
            print(f"  Grado {grade}: {count} imágenes ({percentage:.1f}%)")

        # Confianza promedio
        avg_confidence = np.mean([r['confidence'] for r in results])
        print(f"\nConfianza promedio: {avg_confidence:.2f}%")

    print("\n" + "=" * 60)

# def main():
#     """Función principal"""
#     import argparse

#     parser = argparse.ArgumentParser(description='Predicción de Retinopatía Diabética')
#     parser.add_argument('--model', type=str, default='models/best_model.h5',
#                         help='Ruta al modelo entrenado')
#     parser.add_argument('--image', type=str, help='Ruta a una imagen individual')
#     parser.add_argument('--batch', type=str, help='Ruta a directorio con imágenes')

#     args = parser.parse_args()

#     if args.image:
#         predict_single_image(args.model, args.image)
#     elif args.batch:
#         predict_batch(args.model, args.batch)
#     else:
#         print("Uso:")
#         print("  Predicción individual: python predict.py --image ruta/imagen.jpg")
#         print("  Predicción batch:      python predict.py --batch ruta/directorio")
#         print("  Especificar modelo:    python predict.py --model models/final_model.h5 --image imagen.jpg")

# if __name__ == "__main__":
#     main()

# Para usar en Colab, llama directamente a las funciones. Por ejemplo:
# predict_single_image(model_path='models/best_model.h5', image_path='/content/data/test_images/image1.jpg')
# predict_batch(model_path='models/best_model.h5', images_dir='/content/data/test_images/')
