In [21]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Dense, Dropout, GlobalAveragePooling2D,
                                     BatchNormalization, LeakyReLU)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
import kagglehub
from google.colab import files



In [8]:
dataset_path = kagglehub.dataset_download("benpalma25/data-ocular-neurosolution-ues")
print("Path to dataset files:", dataset_path)


base_path = os.path.join(dataset_path, 'archive')
csv_path = os.path.join(base_path, 'final_full_df.csv')
image_dir = os.path.join(base_path, 'ODIR-5K', 'ODIR-5K', 'Training Images')


df = pd.read_csv(csv_path)
df['full_path'] = df['ruta'].apply(lambda x: os.path.join(image_dir, x))
df = df[df['full_path'].apply(os.path.exists)]

print("Distribución inicial de clases:")
print(df['diagnostico'].value_counts())


Path to dataset files: /root/.cache/kagglehub/datasets/benpalma25/data-ocular-neurosolution-ues/versions/1
Distribución inicial de clases:
diagnostico
normal fundus                                                              2816
moderate non proliferative retinopathy                                      745
dry age-related macular degeneration                                        475
mild nonproliferative retinopathy                                           460
severe nonproliferative retinopathy                                         342
                                                                           ... 
laser spot，moderate non proliferative retinopathy，white vessel                1
laser spot，white vessel，moderate non proliferative retinopathy                1
myelinated nerve fibers，suspected glaucoma                                    1
macular epiretinal membrane，post laser photocoagulation                       1
cataract，myelinated nerve fibers，moderate non pro

In [9]:
def group_small_classes(series, min_count=10):
    class_counts = series.value_counts()
    small_classes = class_counts[class_counts < min_count].index.tolist()
    grouped_series = series.copy()
    if small_classes:
        grouped_series[grouped_series.isin(small_classes)] = 'Other'
    return grouped_series

df['grouped_diagnostico'] = group_small_classes(df['diagnostico'])
print("\nDistribución de clases tras agrupar:")
print(df['grouped_diagnostico'].value_counts())

label_encoder = LabelEncoder()
df['encoded_diagnostico'] = label_encoder.fit_transform(df['grouped_diagnostico'])

X_train, X_val, y_train, y_val = train_test_split(df['full_path'].values, df['encoded_diagnostico'].values,
                                                    stratify=df['encoded_diagnostico'].values,
                                                    test_size=0.2, random_state=42)

class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
class_weights = dict(enumerate(class_weights))


Distribución de clases tras agrupar:
grouped_diagnostico
normal fundus                                                  2816
moderate non proliferative retinopathy                          745
Other                                                           712
dry age-related macular degeneration                            475
mild nonproliferative retinopathy                               460
                                                               ... 
depigmentation of the retinal pigment epithelium                 10
pigment epithelium proliferation                                 10
proliferative diabetic retinopathy，hypertensive retinopathy      10
moderate non proliferative retinopathy，pathological myopia       10
dry age-related macular degeneration，diabetic retinopathy        10
Name: count, Length: 68, dtype: int64


In [10]:
data_gen = ImageDataGenerator(rotation_range=30, width_shift_range=0.2, height_shift_range=0.2,
                              horizontal_flip=True, zoom_range=0.2, rescale=1./255)

IMG_SIZE = (160, 160)


def load_and_preprocess_image(path, label):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, IMG_SIZE)
    image = preprocess_input(image)
    return image, label

In [11]:
BATCH_SIZE = 32

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.batch(BATCH_SIZE).prefetch(buffer_size=tf.data.AUTOTUNE)

val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))
val_dataset = val_dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
val_dataset = val_dataset.batch(BATCH_SIZE).prefetch(buffer_size=tf.data.AUTOTUNE)

In [12]:
base_model = EfficientNetB3(weights='imagenet', include_top=False, input_shape=(160, 160, 3))
base_model.trainable = False

inputs = Input(shape=(160, 160, 3))
x = preprocess_input(inputs)
x = base_model(x, training=False)
x = GlobalAveragePooling2D()(x)
x = BatchNormalization()(x)
x = Dense(512)(x)
x = LeakyReLU(alpha=0.1)(x)
x = Dropout(0.5)(x)
x = Dense(256)(x)
x = LeakyReLU(alpha=0.1)(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
outputs = Dense(len(label_encoder.classes_), activation='softmax')(x)

model = Model(inputs, outputs)
model.compile(optimizer=Adam(learning_rate=1e-4),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])




In [13]:
checkpoint = ModelCheckpoint('best_model.h5', monitor='val_accuracy', save_best_only=True, mode='max')
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6)

In [14]:
history = model.fit(train_dataset, validation_data=val_dataset, epochs=15,
                    callbacks=[checkpoint, early_stopping, reduce_lr],
                    class_weight=class_weights, verbose=1)


Epoch 1/15
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 411ms/step - accuracy: 0.0108 - loss: 5.5000



[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m178s[0m 571ms/step - accuracy: 0.0109 - loss: 5.4991 - val_accuracy: 0.0372 - val_loss: 4.1178 - learning_rate: 1.0000e-04
Epoch 2/15
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 334ms/step - accuracy: 0.0246 - loss: 4.6884



[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m144s[0m 422ms/step - accuracy: 0.0247 - loss: 4.6878 - val_accuracy: 0.0621 - val_loss: 4.0851 - learning_rate: 1.0000e-04
Epoch 3/15
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 335ms/step - accuracy: 0.0325 - loss: 4.3007



[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m143s[0m 425ms/step - accuracy: 0.0325 - loss: 4.3001 - val_accuracy: 0.0732 - val_loss: 4.0819 - learning_rate: 1.0000e-04
Epoch 4/15
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 331ms/step - accuracy: 0.0425 - loss: 3.8143



[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 422ms/step - accuracy: 0.0425 - loss: 3.8141 - val_accuracy: 0.0849 - val_loss: 4.0592 - learning_rate: 1.0000e-04
Epoch 5/15
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 335ms/step - accuracy: 0.0533 - loss: 3.6702



[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 425ms/step - accuracy: 0.0534 - loss: 3.6698 - val_accuracy: 0.0945 - val_loss: 4.0340 - learning_rate: 1.0000e-04
Epoch 6/15
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 335ms/step - accuracy: 0.0596 - loss: 3.4533



[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 425ms/step - accuracy: 0.0596 - loss: 3.4528 - val_accuracy: 0.1019 - val_loss: 3.9921 - learning_rate: 1.0000e-04
Epoch 7/15
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 334ms/step - accuracy: 0.0682 - loss: 3.1300



[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 425ms/step - accuracy: 0.0683 - loss: 3.1298 - val_accuracy: 0.1099 - val_loss: 3.9918 - learning_rate: 1.0000e-04
Epoch 8/15
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 334ms/step - accuracy: 0.0684 - loss: 3.0584



[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 425ms/step - accuracy: 0.0685 - loss: 3.0581 - val_accuracy: 0.1221 - val_loss: 3.9296 - learning_rate: 1.0000e-04
Epoch 9/15
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 334ms/step - accuracy: 0.0774 - loss: 2.8543



[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 424ms/step - accuracy: 0.0774 - loss: 2.8541 - val_accuracy: 0.1364 - val_loss: 3.8722 - learning_rate: 1.0000e-04
Epoch 10/15
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 333ms/step - accuracy: 0.0912 - loss: 2.7411



[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 425ms/step - accuracy: 0.0913 - loss: 2.7407 - val_accuracy: 0.1433 - val_loss: 3.8382 - learning_rate: 1.0000e-04
Epoch 11/15
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 333ms/step - accuracy: 0.0928 - loss: 2.6428



[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 424ms/step - accuracy: 0.0928 - loss: 2.6423 - val_accuracy: 0.1502 - val_loss: 3.7799 - learning_rate: 1.0000e-04
Epoch 12/15
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m141s[0m 420ms/step - accuracy: 0.1040 - loss: 2.4500 - val_accuracy: 0.1476 - val_loss: 3.7430 - learning_rate: 1.0000e-04
Epoch 13/15
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 335ms/step - accuracy: 0.1076 - loss: 2.4504



[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 425ms/step - accuracy: 0.1077 - loss: 2.4501 - val_accuracy: 0.1603 - val_loss: 3.7152 - learning_rate: 1.0000e-04
Epoch 14/15
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 335ms/step - accuracy: 0.1153 - loss: 2.3264



[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 425ms/step - accuracy: 0.1153 - loss: 2.3260 - val_accuracy: 0.1635 - val_loss: 3.6855 - learning_rate: 1.0000e-04
Epoch 15/15
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 344ms/step - accuracy: 0.1206 - loss: 2.2482



[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m144s[0m 435ms/step - accuracy: 0.1206 - loss: 2.2481 - val_accuracy: 0.1789 - val_loss: 3.6309 - learning_rate: 1.0000e-04


In [15]:
tf.keras.mixed_precision.set_global_policy('mixed_float16')
base_model.trainable = True
model.compile(optimizer=Adam(learning_rate=1e-5), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
history_finetune = model.fit(train_dataset, validation_data=val_dataset, epochs=20,
                             callbacks=[checkpoint, early_stopping, reduce_lr],
                             class_weight=class_weights, verbose=1)


Epoch 1/20
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m310s[0m 780ms/step - accuracy: 0.0164 - loss: 5.0620 - val_accuracy: 0.0446 - val_loss: 4.2014 - learning_rate: 1.0000e-05
Epoch 2/20
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m108s[0m 457ms/step - accuracy: 0.0193 - loss: 4.6777 - val_accuracy: 0.0504 - val_loss: 4.2018 - learning_rate: 1.0000e-05
Epoch 3/20
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m138s[0m 442ms/step - accuracy: 0.0256 - loss: 4.2035 - val_accuracy: 0.0605 - val_loss: 4.1498 - learning_rate: 1.0000e-05
Epoch 4/20
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m143s[0m 446ms/step - accuracy: 0.0348 - loss: 4.0828 - val_accuracy: 0.0711 - val_loss: 4.0987 - learning_rate: 1.0000e-05
Epoch 5/20
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m140s[0m 437ms/step - accuracy: 0.0409 - loss: 3.9033 - val_accuracy: 0.0791 - val_loss: 4.0654 - learning_rate: 1.0000e-05
Epoch 6/20
[1m236/236[0m [3

In [25]:
loss, accuracy = model.evaluate(val_dataset)
print(f"Precisión final: {accuracy:.4f}")
model.save('optimized_model.h5')
files.download('optimized_model.h5')

[1m59/59[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 338ms/step - accuracy: 0.1618 - loss: 3.6835




Precisión final: 0.1614


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [17]:
print("\nEnfermedades entrenadas:")
for i, class_name in enumerate(label_encoder.classes_):
    print(f"{i}: {class_name}")



Enfermedades entrenadas:
0: Other
1: branch retinal vein occlusion
2: cataract
3: cataract，lens dust
4: cataract，moderate non proliferative retinopathy
5: central retinal artery occlusion
6: central retinal vein occlusion
7: chorioretinal atrophy
8: depigmentation of the retinal pigment epithelium
9: diabetic retinopathy
10: diabetic retinopathy，dry age-related macular degeneration
11: drusen
12: drusen，lens dust
13: dry age-related macular degeneration
14: dry age-related macular degeneration，diabetic retinopathy
15: dry age-related macular degeneration，glaucoma
16: epiretinal membrane
17: epiretinal membrane over the macula
18: epiretinal membrane，lens dust
19: glaucoma
20: glaucoma，diabetic retinopathy
21: glaucoma，hypertensive retinopathy
22: glaucoma，macular epiretinal membrane
23: glaucoma，moderate non proliferative retinopathy
24: hypertensive retinopathy
25: laser spot，moderate non proliferative retinopathy
26: lens dust，drusen
27: lens dust，macular epiretinal membrane
28: len