In [None]:
# 📚 Entrenamiento de Red Neuronal CNN para clasificación multiclase

# 🔧 Cargar librerías
import os
import sys
import yaml
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import load_model

In [None]:
# Ajustar path para importar módulos locales
sys.path.append("../src")
import model
import utils
import train

In [None]:
# 🔧 Cargar configuración
with open("../config.yaml", "r") as f:
    config = yaml.safe_load(f)

img_height = config["image"]["height"]
img_width = config["image"]["width"]
batch_size = config["training"]["batch_size"]
epochs = config["training"]["epochs"]
data_path = config["paths"]["prepared_data"]
model_path = config["paths"]["model"]

In [None]:
# 📦 Cargar datos
train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    os.path.join(data_path, "train"),
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode="categorical"
)

val_generator = val_datagen.flow_from_directory(
    os.path.join(data_path, "val"),
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode="categorical"
)

In [None]:
# ✅ Verificar consistencia de clases
assert train_generator.class_indices == val_generator.class_indices
num_classes = len(train_generator.class_indices)

In [None]:
# 🧠 Construir modelo
cnn_model = model.build_model(config, num_classes)

cnn_model.summary()

In [None]:
# 🏋️ Entrenar el modelo
callbacks = [
    EarlyStopping(patience=5, restore_best_weights=True),
    ModelCheckpoint(model_path, save_best_only=True)
]

history = cnn_model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=epochs,
    callbacks=callbacks
)

In [None]:
# 💾 Guardar modelo
cnn_model.save(model_path)

In [None]:
# 📈 Graficar historia de entrenamiento
utils.plot_training_history(history, model_name="CNN")

In [None]:
# ✅ Guardar nombres de clases para uso posterior
with open("class_indices.yaml", "w") as f:
    yaml.dump(train_generator.class_indices, f)