# Reconocimiento de caras de animales con PyTorch
En este notebook replico el flujo del cuaderno original pero entrenando un clasificador basado en PyTorch. Mantengo el preprocesado clásico (grises a 64x64), genero un dataset listo para reutilizar y entreno una CNN sencilla para distinguir las especies.


## Pasos previstos
1. Reutilizar la rutina de preprocesado para dejar todas las imágenes en 64x64 escala de grises.
2. Guardar/recuperar un descriptor intermedio para no recalcular todo cada vez.
3. Construir `Dataset` y `DataLoader` de PyTorch con splits estratificados.
4. Definir y entrenar una CNN ligera con monitorización de validación.
5. Evaluar en el split de test con métricas y matriz de confusión guardando el modelo entrenado.


In [None]:
from pathlib import Path
import random
import json

import joblib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from PIL import Image, ImageOps, UnidentifiedImageError

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import StratifiedKFold
from IPython.display import display

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

plt.style.use('seaborn-v0_8')
sns.set_theme()


In [None]:
BASE_DIR = Path('.').resolve()
DATA_DIR = BASE_DIR / 'AnimalFace' / 'Image'
PROCESSED_DIR = BASE_DIR / 'AnimalFace' / 'processed_64_gray'
OUTPUT_DIR = BASE_DIR / 'AnimalFace' / 'outputs'
MODEL_DIR = OUTPUT_DIR / 'pytorch'

IMAGE_SIZE = (64, 64)
BATCH_SIZE = 64

for folder in (PROCESSED_DIR, OUTPUT_DIR, MODEL_DIR):
    folder.mkdir(parents=True, exist_ok=True)

if not DATA_DIR.exists():
    raise FileNotFoundError(f'No encontré la carpeta de imágenes en {DATA_DIR}')

CATEGORIES = sorted([p.name for p in DATA_DIR.iterdir() if p.is_dir()])
print(f'Trabajo con {len(CATEGORIES)} clases: {CATEGORIES}')


### Funciones de apoyo
Igual que en el notebook base, centralizo la preparación en helpers reutilizables.


In [None]:
from typing import List, Tuple


def collect_image_paths(class_name: str) -> List[Path]:
    folder = DATA_DIR / class_name
    supported = []
    for pattern in ('*.jpg', '*.jpeg', '*.png', '*.bmp'):
        supported.extend(folder.glob(pattern))
    return sorted(supported)


def preprocess_image(path: Path, size: Tuple[int, int]) -> np.ndarray:
    try:
        with Image.open(path) as image:
            image = ImageOps.exif_transpose(image)
            image = ImageOps.grayscale(image)
            image = image.resize(size, Image.Resampling.LANCZOS)
            array = np.asarray(image, dtype=np.float32) / 255.0
    except (UnidentifiedImageError, OSError) as exc:
        raise ValueError(f'No pude cargar {path}') from exc
    return array


def save_preprocessed_image(array: np.ndarray, original_path: Path) -> Path:
    label = original_path.parent.name
    save_dir = PROCESSED_DIR / label
    save_dir.mkdir(parents=True, exist_ok=True)
    save_path = save_dir / f"{original_path.stem}_64_gray.png"
    image_to_save = Image.fromarray((array * 255).astype(np.uint8))
    image_to_save.save(save_path)
    return save_path


def build_dataset_descriptor() -> pd.DataFrame:
    rows = []
    skipped = 0
    for label in CATEGORIES:
        paths = collect_image_paths(label)
        for path in paths:
            try:
                processed = preprocess_image(path, IMAGE_SIZE)
            except ValueError:
                print(f"Salté {path.name} porque PIL no la reconoce bien.")
                skipped += 1
                continue
            saved_path = save_preprocessed_image(processed, path)
            rows.append({
                'label': label,
                'original_path': path.as_posix(),
                'processed_path': saved_path.as_posix()
            })
    df = pd.DataFrame(rows)
    df = df.sort_values('label').reset_index(drop=True)
    if skipped:
        print(f"Salté {skipped} imágenes que estaban dañadas o con formato raro.")
    return df


def show_before_after(df: pd.DataFrame, samples: int = 4) -> None:
    subset = df.groupby('label').head(1)
    subset = subset.sample(min(samples, len(subset)), random_state=42)
    fig, axes = plt.subplots(len(subset), 2, figsize=(6, 3 * len(subset)))
    if len(subset) == 1:
        axes = np.array([[axes[0], axes[1]]])
    for (_, row), ax_pair in zip(subset.iterrows(), axes):
        with Image.open(row['original_path']) as original:
            ax_pair[0].imshow(original)
        with Image.open(row['processed_path']) as processed:
            ax_pair[1].imshow(processed, cmap='gray')
        ax_pair[0].set_title(f"Original ({row['label']})")
        ax_pair[0].axis('off')
        ax_pair[1].set_title('64x64 gris')
        ax_pair[1].axis('off')
    plt.tight_layout()
    plt.show()


### Preparo (o cargo) el descriptor del dataset
Uso el mismo formato que antes pero guardo una copia independiente para PyTorch.


In [None]:
descriptor_path = OUTPUT_DIR / 'dataset_descriptor_pytorch.joblib'

def rebuild_descriptor():
    df = build_dataset_descriptor()
    joblib.dump(df, descriptor_path)
    print(f'Descriptor actualizado y guardado en {descriptor_path}')
    return df

def descriptor_is_outdated(df: pd.DataFrame) -> bool:
    current_originals = {
        path.as_posix()
        for label in CATEGORIES
        for path in collect_image_paths(label)
    }
    stored_originals = set(df['original_path'])
    if current_originals != stored_originals:
        missing = len(current_originals - stored_originals)
        extra = len(stored_originals - current_originals)
        if missing:
            print(f'Hay {missing} imágenes nuevas sin describir.')
        if extra:
            print(f'Hay {extra} imágenes que ya no existen en la carpeta original.')
        return True
    missing_processed = [
        row['processed_path']
        for _, row in df.iterrows()
        if not Path(row['processed_path']).exists()
    ]
    if missing_processed:
        print(f'Faltan {len(missing_processed)} archivos preprocesados; regenero descriptor...')
        return True
    return False

if descriptor_path.exists():
    dataset_df = joblib.load(descriptor_path)
    stored_labels = set(dataset_df['label'].unique())
    expected_labels = set(CATEGORIES)
    needs_rebuild = stored_labels != expected_labels or descriptor_is_outdated(dataset_df)
    if needs_rebuild:
        print('El descriptor está desactualizado, lo vuelvo a generar...')
        dataset_df = rebuild_descriptor()
    else:
        print(f'Descriptor cargado de {descriptor_path}')
else:
    dataset_df = rebuild_descriptor()

print(dataset_df.head())
print()
print(dataset_df['label'].value_counts())

show_before_after(dataset_df)



### Dataset y DataLoaders
Ahora convierto el descriptor en splits estratificados y DataLoaders listos para PyTorch.


In [None]:
label_to_idx = {label: idx for idx, label in enumerate(CATEGORIES)}
idx_to_label = {idx: label for label, idx in label_to_idx.items()}

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

class AnimalFaceDataset(Dataset):
    def __init__(self, df: pd.DataFrame, label_to_idx: dict, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.label_to_idx = label_to_idx

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        path = Path(row['processed_path'])
        with Image.open(path) as image:
            image = ImageOps.exif_transpose(image)
            if self.transform is not None:
                tensor = self.transform(image)
            else:
                tensor = transforms.ToTensor()(image)
        label_idx = self.label_to_idx[row['label']]
        return tensor, label_idx

def build_loader(df: pd.DataFrame, shuffle: bool):
    dataset = AnimalFaceDataset(df, label_to_idx, transform=transform)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=shuffle, num_workers=0)
    return dataset, loader


### Defino la CNN
Modelo compacto con tres bloques conv-bn-relu y un clasificador totalmente conectado.


In [None]:
FEATURE_DIM = 128 * (IMAGE_SIZE[0] // 8) * (IMAGE_SIZE[1] // 8)

class SimpleCNN(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(FEATURE_DIM, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = len(CATEGORIES)
print(f'Entrenaré en: {device}')

def create_model():
    return SimpleCNN(num_classes=num_classes).to(device)

def create_criterion():
    return nn.CrossEntropyLoss()

def create_optimizer(model):
    return torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

def create_scheduler(optimizer):
    return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, factor=0.5, verbose=True)


### Entrenamiento
Ciclo clásico con early stopping cuando la validación deja de mejorar.


In [None]:
def run_epoch(dataloader, model, criterion, optimizer=None):
    is_train = optimizer is not None
    model.train() if is_train else model.eval()
    epoch_loss = 0.0
    correct = 0
    total = 0

    for inputs, targets in dataloader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        if is_train:
            optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        if is_train:
            loss.backward()
            optimizer.step()

        epoch_loss += loss.item() * inputs.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == targets).sum().item()
        total += targets.size(0)

    avg_loss = epoch_loss / total
    accuracy = correct / total
    return avg_loss, accuracy

EPOCHS = 30
PATIENCE = 5

def train_model(train_loader, val_loader, fold_id=None):
    model = create_model()
    criterion = create_criterion()
    optimizer = create_optimizer(model)
    scheduler = create_scheduler(optimizer)

    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }

    best_val_acc = 0.0
    patience_counter = 0
    best_state_dict = None

    for epoch in range(1, EPOCHS + 1):
        train_loss, train_acc = run_epoch(train_loader, model, criterion, optimizer)
        val_loss, val_acc = run_epoch(val_loader, model, criterion)

        scheduler.step(val_acc)

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        prefix = f"Fold {fold_id} " if fold_id is not None else ""
        print(f"{prefix}Epoch {epoch:02d} | train_loss={train_loss:.4f} acc={train_acc:.4f} | val_loss={val_loss:.4f} acc={val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}
            patience_counter = 0
            print('  ↳ nuevo mejor modelo')
        else:
            patience_counter += 1

        if patience_counter >= PATIENCE:
            print('Early stopping activado.')
            break

    if best_state_dict is None:
        best_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}
    model.load_state_dict(best_state_dict)
    return model, history, best_state_dict, best_val_acc


def collect_predictions(model, dataloader):
    model.eval()
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            preds = outputs.argmax(dim=1)
            all_preds.append(preds.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
    y_pred = np.concatenate(all_preds)
    y_true = np.concatenate(all_targets)
    return y_true, y_pred


In [ ]:
N_SPLITS = 5
skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=42)

fold_histories = []
fold_metrics = []
oof_true = []
oof_pred = []

best_fold = {
    'val_acc': float('-inf'),
    'fold': None,
    'state_dict': None,
    'val_df': None,
    'dataset': None,
    'history': None,
    'state_path': None
}

for fold, (train_idx, val_idx) in enumerate(skf.split(dataset_df, dataset_df['label']), start=1):
    print(f"===== Fold {fold}/{N_SPLITS} =====")
    train_df = dataset_df.iloc[train_idx]
    val_df = dataset_df.iloc[val_idx]

    _, train_loader = build_loader(train_df, shuffle=True)
    val_dataset, val_loader = build_loader(val_df, shuffle=False)

    model, history, best_state_dict, best_val_acc = train_model(train_loader, val_loader, fold_id=fold)

    fold_histories.append(history)
    fold_metrics.append({
        'fold': fold,
        'train_size': len(train_df),
        'val_size': len(val_df),
        'best_val_acc': best_val_acc,
        'epochs_ran': len(history['train_loss'])
    })

    torch.save(best_state_dict, MODEL_DIR / f'animal_face_cnn_fold{fold}.pt')

    y_true, y_pred = collect_predictions(model, val_loader)
    oof_true.extend(y_true.tolist())
    oof_pred.extend(y_pred.tolist())

    with open(MODEL_DIR / f'training_history_fold{fold}.json', 'w') as fp:
        json.dump(history, fp, indent=2)

    if best_val_acc >= best_fold['val_acc']:
        best_fold = {
            'val_acc': best_val_acc,
            'fold': fold,
            'state_dict': best_state_dict,
            'val_df': val_df.reset_index(drop=True),
            'dataset': val_dataset,
            'history': {k: v[:] for k, v in history.items()},
            'state_path': MODEL_DIR / 'animal_face_cnn_best.pt'
        }
        torch.save(best_state_dict, MODEL_DIR / 'animal_face_cnn_best.pt')

fold_metrics_df = pd.DataFrame(fold_metrics)
display(fold_metrics_df)

print("\nAccuracy promedio (OOF):", (np.array(oof_true) == np.array(oof_pred)).mean())


### Curvas de entrenamiento
Visualizo las métricas registradas para verificar el aprendizaje.


In [None]:
best_history = best_fold['history']
if best_history is None:
    raise RuntimeError('Ejecuta la celda de validación cruzada antes de graficar.')

epochs_ran = len(best_history['train_loss'])
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
epoch_range = range(1, epochs_ran + 1)
axes[0].plot(epoch_range, best_history['train_loss'], label='Train')
axes[0].plot(epoch_range, best_history['val_loss'], label='Val')
axes[0].set_title(f'Fold {best_fold["fold"]} - Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()

axes[1].plot(epoch_range, best_history['train_acc'], label='Train')
axes[1].plot(epoch_range, best_history['val_acc'], label='Val')
axes[1].set_title(f'Fold {best_fold["fold"]} - Accuracy')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].legend()

plt.tight_layout()
plt.show()


### Evaluación con validación cruzada
Calculo métricas y la matriz de confusión usando las predicciones out-of-fold.


In [None]:
if not oof_true:
    raise RuntimeError('Ejecuta antes la celda de validación cruzada para obtener predicciones OOF.')

oof_true_arr = np.array(oof_true, dtype=int)
oof_pred_arr = np.array(oof_pred, dtype=int)
labels_idx = list(range(len(CATEGORIES)))

print(classification_report(oof_true_arr, oof_pred_arr, labels=labels_idx, target_names=CATEGORIES))

cm = confusion_matrix(oof_true_arr, oof_pred_arr, labels=labels_idx)
num_classes = len(CATEGORIES)
fig_size = max(8, num_classes * 0.6)
fig, ax = plt.subplots(figsize=(fig_size, fig_size))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=CATEGORIES)
disp.plot(ax=ax, cmap='Blues', colorbar=False)
rotation = 90 if num_classes > 10 else 45
disp.ax_.set_xticklabels(CATEGORIES, rotation=rotation, ha='right')
disp.ax_.set_yticklabels(CATEGORIES)
plt.tight_layout()
plt.show()


### Ejemplos de predicciones
Muestro algunas imágenes del fold con mejor validación (usar tras la validación cruzada).


In [None]:
def show_predictions(model, dataset: Dataset, num_samples: int = 6):
    model.eval()
    indices = np.random.choice(len(dataset), size=min(num_samples, len(dataset)), replace=False)
    cols = min(3, len(indices))
    rows = int(np.ceil(len(indices) / cols))
    fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))
    axes = np.array(axes).reshape(-1)

    for ax, idx in zip(axes, indices):
        image_tensor, label_idx = dataset[idx]
        image = image_tensor.clone().detach()
        image = image * 0.5 + 0.5
        image_np = image.squeeze().numpy()

        with torch.no_grad():
            output = model(image_tensor.unsqueeze(0).to(device))
            pred_idx = output.argmax(dim=1).item()

        ax.imshow(image_np, cmap='gray')
        ax.set_title(f"Real: {idx_to_label[label_idx]}\nPred: {idx_to_label[pred_idx]}")
        ax.axis('off')

    for ax in axes[len(indices):]:
        ax.axis('off')

    plt.tight_layout()
    plt.show()

if best_fold['state_dict'] is None or best_fold['dataset'] is None:
    raise RuntimeError('Ejecuta antes la celda de validación cruzada para obtener un modelo entrenado.')

best_model_for_display = create_model()
best_model_for_display.load_state_dict(best_fold['state_dict'])
show_predictions(best_model_for_display, best_fold['dataset'])


### Ideas para próximas iteraciones
- Ajustar hiperparámetros (capas, regularización o learning rate schedule).
- Probar data augmentation antes del `Normalize` para robustecer el modelo.
- Implementar métricas adicionales (f1 macro/micro) y tensorboard para monitorizar.
