# 03 - Entrenamiento de Clasificadores de Documentos

Este notebook entrena y compara clasificadores de documentos medicos:
- **Sklearn Ensemble**: Random Forest + SVM + Gradient Boosting
- **HuggingFace Transformer**: RoBERTa biomedical fine-tuned
- **PyTorch Custom**: BERT con training loop personalizado

## Requisitos
- Datos sinteticos generados con `scripts/generate_synthetic_data.py`
- Dependencias: scikit-learn, transformers, torch

In [None]:
import sys
sys.path.insert(0, '..')

import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from app.core.ml.document_classifier import SklearnDocumentClassifier
from app.core.ml.feature_engineering import FeatureEngineer

print('Imports OK')

## 1. Cargar Datos Sinteticos

In [None]:
# Cargar datos generados por generate_synthetic_data.py
data_path = '../data/training/classifier_data.json'

try:
    with open(data_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    texts = [item['text'] for item in data]
    labels = [item['label'] for item in data]
    print(f'Loaded {len(texts)} documents')
    print(f'Label distribution:')
    from collections import Counter
    for label, count in Counter(labels).most_common():
        print(f'  {label}: {count}')
except FileNotFoundError:
    print('Run scripts/generate_synthetic_data.py first!')
    # Create minimal sample data for demonstration
    texts = [
        'Metformina 850mg tabletas cada 12 horas receta Rx',
        'Glucosa 126 mg/dL resultado laboratorio quimica',
        'Nota medica exploracion fisica signos vitales',
        'Referencia segundo nivel hospital motivo envio',
    ] * 10
    labels = ['receta', 'laboratorio', 'nota_medica', 'referencia'] * 10
    print(f'Using {len(texts)} sample documents')

## 2. Entrenar Sklearn Ensemble

In [None]:
clf = SklearnDocumentClassifier()
metrics = clf.train(texts, labels, cv_folds=3)

print('\n=== Resultados de Cross-Validation ===')
for name, m in metrics.items():
    print(f'\n{name}:')
    print(f'  CV F1 Mean: {m.cv_mean:.4f} (+/- {m.cv_std:.4f})')
    print(f'  Train F1:   {m.train_f1:.4f}')

print(f'\nBest model: {clf.best_model_name}')

In [None]:
# Visualizar comparacion de modelos
model_names = list(metrics.keys())
cv_means = [metrics[m].cv_mean for m in model_names]
cv_stds = [metrics[m].cv_std for m in model_names]

fig, ax = plt.subplots(figsize=(10, 5))
bars = ax.bar(model_names, cv_means, yerr=cv_stds, capsize=5,
              color=['#0F766E', '#1E40AF', '#F59E0B'])
ax.set_ylabel('F1 Score (macro)')
ax.set_title('Comparacion de Clasificadores Sklearn')
ax.set_ylim(0, 1.1)
for bar, mean in zip(bars, cv_means):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
            f'{mean:.3f}', ha='center')
plt.tight_layout()
plt.show()

## 3. Confusion Matrix del Mejor Modelo

In [None]:
cm = clf.get_confusion_matrix(texts, labels)
unique_labels = sorted(set(labels))

fig, ax = plt.subplots(figsize=(8, 6))
disp = ConfusionMatrixDisplay(cm, display_labels=unique_labels)
disp.plot(ax=ax, cmap='Blues', values_format='d')
ax.set_title(f'Confusion Matrix - {clf.best_model_name}')
plt.tight_layout()
plt.show()

## 4. Feature Engineering Analysis

In [None]:
engineer = FeatureEngineer(max_tfidf_features=100)
features = engineer.extract_text_features_batch(texts[:50])

print(f'Feature matrix shape: {features.shape}')
print(f'TF-IDF features: {engineer.max_tfidf_features}')
print(f'Manual features: {features.shape[1] - 100}')
print(f'\nFeature statistics:')
print(f'  Mean: {features.mean():.4f}')
print(f'  Std:  {features.std():.4f}')
print(f'  Min:  {features.min():.4f}')
print(f'  Max:  {features.max():.4f}')

## 5. Guardar Mejor Modelo

In [None]:
save_path = '../models/document_classifier'
clf.save(save_path)
print(f'Model saved to {save_path}')

# Test predictions
test_texts = [
    'Metformina 850mg tabletas cada 12 horas por 30 dias',
    'Glucosa 200 mg/dL laboratorio resultado quimica sanguinea',
    'Nota medica exploracion fisica signos vitales presion arterial',
]
for text in test_texts:
    result = clf.predict(text)
    print(f'\n"{text[:50]}..."')
    print(f'  -> {result.document_type} ({result.confidence:.2%})')