# Medical Text Classification with BERT

## An√°lisis Integral de Clasificaci√≥n Biom√©dica

Este notebook presenta un an√°lisis completo de clasificaci√≥n de textos m√©dicos utilizando modelos de transformers, espec√≠ficamente BERT adaptado para el dominio biom√©dico.

### Objetivos:
1. **An√°lisis Exploratorio**: Examinar la distribuci√≥n de datos y caracter√≠sticas del dataset
2. **Preprocesamiento**: Limpiar y preparar los datos m√©dicos para el modelado
3. **Modelado**: Implementar y entrenar m√∫ltiples modelos de clasificaci√≥n
4. **Evaluaci√≥n**: Analizar el rendimiento y optimizar umbrales de decisi√≥n
5. **Validaci√≥n**: Verificar la robustez del modelo final

### Categor√≠as de Clasificaci√≥n:
- **Cardiovascular**: Condiciones relacionadas con el coraz√≥n y sistema circulatorio
- **Hepatorenal**: Patolog√≠as del h√≠gado y ri√±ones
- **Neurol√≥gico**: Trastornos del sistema nervioso
- **Oncol√≥gico**: Condiciones relacionadas con c√°ncer y tumores

---

## 1. Configuraci√≥n del Entorno

### Importaci√≥n de Librer√≠as y M√≥dulos Personalizados

In [None]:
# Configuraci√≥n de warnings y entorno
import warnings
warnings.filterwarnings('ignore')

import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Configuraci√≥n de matplotlib para mejores gr√°ficos
plt.style.use('seaborn-v0_8')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 11

# Agregar scripts al path
scripts_path = Path('../scripts')
if scripts_path.exists():
    sys.path.append(str(scripts_path))
else:
    print("‚ö†Ô∏è Warning: scripts directory not found. Make sure to run this notebook from the project root.")

print("‚úÖ Environment configured successfully")
print(f"üìÅ Working directory: {os.getcwd()}")
print(f"üêç Python version: {sys.version.split()[0]}")

In [None]:
# Importar m√≥dulos personalizados
try:
    from data_processing import (
        load_medical_data, clean_medical_text, preprocess_labels,
        split_stratified_multilabel, analyze_label_distribution,
        validate_dataset_integrity
    )
    
    from visualization import (
        plot_label_distribution, plot_text_length_analysis,
        plot_correlation_heatmap, plot_word_clouds,
        configure_plot_style, plot_roc_curves
    )
    
    from model_utils import (
        ImprovedMedicalBERT, MedicalDataset, EnsembleClassifier,
        create_data_loaders, get_device_info
    )
    
    from training_utils import (
        train_bert_model, evaluate_model, save_model_artifacts,
        load_model_artifacts, create_dummy_classifier
    )
    
    from evaluation_utils import (
        compute_multilabel_metrics, find_optimal_thresholds,
        plot_confusion_matrices, plot_threshold_analysis,
        analyze_prediction_errors
    )
    
    from text_augmentation import (
        MedicalTextAugmenter, apply_augmentation_pipeline
    )
    
    print("‚úÖ All custom modules imported successfully")
    
except ImportError as e:
    print(f"‚ùå Error importing custom modules: {e}")
    print("Make sure all script files are in the 'scripts' directory")
    sys.exit(1)

In [None]:
# Configurar estilo de visualizaci√≥n
configure_plot_style()

# Verificar disponibilidad de GPU
device_info = get_device_info()
print(f"üñ•Ô∏è Device: {device_info['device']}")
if device_info['cuda_available']:
    print(f"üöÄ GPU: {device_info['gpu_name']}")
    print(f"üíæ GPU Memory: {device_info['gpu_memory']:.1f} GB")
else:
    print("‚ö†Ô∏è No GPU available, using CPU")

# Configuraci√≥n global
RANDOM_STATE = 42
TEST_SIZE = 0.2
BATCH_SIZE = 16
MAX_LENGTH = 512
LEARNING_RATE = 2e-5
EPOCHS = 3

# Nombres de las categor√≠as
LABEL_COLUMNS = ['cardiovascular', 'hepatorenal', 'neurologico', 'oncologico']

print("\nüîß Configuration:")
print(f"   Random State: {RANDOM_STATE}")
print(f"   Test Size: {TEST_SIZE}")
print(f"   Batch Size: {BATCH_SIZE}")
print(f"   Max Length: {MAX_LENGTH}")
print(f"   Learning Rate: {LEARNING_RATE}")
print(f"   Epochs: {EPOCHS}")

---

## 2. Carga y An√°lisis Exploratorio de Datos

### 2.1 Carga del Dataset

In [None]:
# Cargar dataset
data_file = 'challenge_data-18-ago.csv'

try:
    df = load_medical_data(data_file)
    print(f"‚úÖ Dataset loaded successfully")
    print(f"üìä Shape: {df.shape}")
    print(f"üìù Columns: {list(df.columns)}")
    
except FileNotFoundError:
    print(f"‚ùå Error: File '{data_file}' not found")
    print("Please make sure the data file is in the current directory")
    sys.exit(1)

except Exception as e:
    print(f"‚ùå Error loading data: {e}")
    sys.exit(1)

In [None]:
# Visualizar primeras filas
print("üìã First 5 rows of the dataset:")
display(df.head())

print("\nüìä Dataset Info:")
print(df.info())

print("\nüìà Descriptive Statistics:")
display(df.describe())

### 2.2 An√°lisis de la Distribuci√≥n de Etiquetas

In [None]:
# Analizar distribuci√≥n de etiquetas
label_stats = analyze_label_distribution(df, LABEL_COLUMNS)

print("üè∑Ô∏è Label Distribution Analysis:")
print(f"   Total samples: {label_stats['total_samples']}")
print(f"   Samples with no labels: {label_stats['samples_no_labels']}")
print(f"   Samples with multiple labels: {label_stats['samples_multiple_labels']}")
print(f"   Average labels per sample: {label_stats['avg_labels_per_sample']:.2f}")

print("\nüìä Individual Label Frequencies:")
for label, freq in label_stats['label_frequencies'].items():
    percentage = (freq / label_stats['total_samples']) * 100
    print(f"   {label}: {freq} ({percentage:.1f}%)")

print("\nüîó Label Co-occurrences:")
for pair, count in label_stats['label_cooccurrence'].items():
    print(f"   {pair}: {count}")

In [None]:
# Visualizar distribuci√≥n de etiquetas
plot_label_distribution(df, LABEL_COLUMNS, 
                       title="Distribuci√≥n de Categor√≠as M√©dicas")

### 2.3 An√°lisis de Texto

In [None]:
# An√°lisis de longitud de texto
plot_text_length_analysis(df, 'text', 
                          title="An√°lisis de Longitud de Textos M√©dicos")

In [None]:
# Matriz de correlaci√≥n entre etiquetas
plot_correlation_heatmap(df, LABEL_COLUMNS, 
                         title="Correlaci√≥n entre Categor√≠as M√©dicas")

In [None]:
# Nubes de palabras por categor√≠a
plot_word_clouds(df, 'text', LABEL_COLUMNS, 
                title="Palabras M√°s Frecuentes por Categor√≠a")

---

## 3. Preprocesamiento de Datos

### 3.1 Limpieza de Texto

In [None]:
# Limpiar texto m√©dico
print("üßπ Cleaning medical text...")

# Mostrar ejemplo antes de la limpieza
sample_text = df['text'].iloc[0]
print(f"üìù Original text (first 200 chars):\n{sample_text[:200]}...\n")

# Aplicar limpieza
df['text_clean'] = df['text'].apply(clean_medical_text)

# Mostrar ejemplo despu√©s de la limpieza
cleaned_text = df['text_clean'].iloc[0]
print(f"‚ú® Cleaned text (first 200 chars):\n{cleaned_text[:200]}...")

# Estad√≠sticas de limpieza
original_lengths = df['text'].str.len()
cleaned_lengths = df['text_clean'].str.len()

print(f"\nüìä Cleaning Statistics:")
print(f"   Original avg length: {original_lengths.mean():.1f} chars")
print(f"   Cleaned avg length: {cleaned_lengths.mean():.1f} chars")
print(f"   Reduction: {((original_lengths.mean() - cleaned_lengths.mean()) / original_lengths.mean() * 100):.1f}%")

### 3.2 Preparaci√≥n de Etiquetas y Divisi√≥n de Datos

In [None]:
# Preparar etiquetas
print("üè∑Ô∏è Preparing labels...")
y = preprocess_labels(df, LABEL_COLUMNS)
X = df['text_clean'].values

print(f"   Features shape: {X.shape}")
print(f"   Labels shape: {y.shape}")
print(f"   Label columns: {LABEL_COLUMNS}")

# Validar integridad del dataset
validation_results = validate_dataset_integrity(df, 'text_clean', LABEL_COLUMNS)

if validation_results['is_valid']:
    print("‚úÖ Dataset validation passed")
else:
    print("‚ùå Dataset validation failed:")
    for issue in validation_results['issues']:
        print(f"   - {issue}")

In [None]:
# Divisi√≥n estratificada de datos
print("üîÄ Splitting data...")
X_train, X_test, y_train, y_test = split_stratified_multilabel(
    X, y, test_size=TEST_SIZE, random_state=RANDOM_STATE
)

print(f"   Training set: {X_train.shape[0]} samples")
print(f"   Test set: {X_test.shape[0]} samples")
print(f"   Split ratio: {(1-TEST_SIZE)*100:.0f}% / {TEST_SIZE*100:.0f}%")

# Verificar distribuci√≥n en splits
train_label_freq = y_train.sum(axis=0) / len(y_train)
test_label_freq = y_test.sum(axis=0) / len(y_test)

print("\nüìä Label distribution in splits:")
for i, label in enumerate(LABEL_COLUMNS):
    print(f"   {label}: Train {train_label_freq[i]:.3f} | Test {test_label_freq[i]:.3f}")

---

## 4. Modelado y Entrenamiento

### 4.1 Modelo Baseline (Dummy Classifier)

In [None]:
# Crear y evaluar modelo dummy
print("üéØ Training baseline model (Dummy Classifier)...")

dummy_model = create_dummy_classifier(strategy='most_frequent')
dummy_model.fit(X_train, y_train)

# Evaluaci√≥n del modelo dummy
y_pred_dummy = dummy_model.predict(X_test)

# Calcular m√©tricas
dummy_metrics = compute_multilabel_metrics(
    y_test, y_pred_dummy, 
    class_names=LABEL_COLUMNS
)

print("\nüìä Dummy Classifier Results:")
print(f"   F1 Macro: {dummy_metrics['f1_macro']:.4f}")
print(f"   F1 Weighted: {dummy_metrics['f1_weighted']:.4f}")
print(f"   Hamming Loss: {dummy_metrics['hamming_loss']:.4f}")
print(f"   Exact Match Ratio: {dummy_metrics['exact_match_ratio']:.4f}")

# Guardar resultados del baseline
baseline_results = {
    'model_name': 'Dummy Classifier',
    'metrics': dummy_metrics,
    'predictions': y_pred_dummy
}

### 4.2 Modelo DEMO (BERT Simplificado)

**Nota**: Este es el modelo DEMO mencionado en la especificaci√≥n - un modelo BERT con configuraci√≥n simplificada para demostraci√≥n r√°pida.

In [None]:
# Configuraci√≥n para modelo DEMO
DEMO_CONFIG = {
    'epochs': 1,  # Solo 1 √©poca para demostraci√≥n
    'batch_size': 8,  # Batch size m√°s peque√±o
    'max_length': 256,  # Secuencias m√°s cortas
    'learning_rate': 3e-5
}

print(f"üöÄ Training DEMO model (BERT - Quick Demo)...")
print(f"   Configuration: {DEMO_CONFIG}")

# Crear dataset para el modelo DEMO
# Usar solo una muestra peque√±a para demostraci√≥n r√°pida
demo_size = min(1000, len(X_train))  # M√°ximo 1000 muestras
X_train_demo = X_train[:demo_size]
y_train_demo = y_train[:demo_size]

print(f"   Using {demo_size} samples for quick demo")

# Crear data loaders para modelo DEMO
train_loader_demo, _ = create_data_loaders(
    X_train_demo, y_train_demo, X_test, y_test,
    batch_size=DEMO_CONFIG['batch_size'],
    max_length=DEMO_CONFIG['max_length']
)

# Crear modelo DEMO
device = get_device_info()['device']
demo_model = ImprovedMedicalBERT(
    num_labels=len(LABEL_COLUMNS),
    model_name='distilbert-base-uncased'  # Modelo m√°s r√°pido para demo
).to(device)

print(f"   Model created on {device}")
print(f"   Total parameters: {sum(p.numel() for p in demo_model.parameters()):,}")

In [None]:
# Entrenar modelo DEMO
demo_model, demo_training_history = train_bert_model(
    model=demo_model,
    train_loader=train_loader_demo,
    val_loader=None,  # Sin validaci√≥n para demo r√°pido
    epochs=DEMO_CONFIG['epochs'],
    learning_rate=DEMO_CONFIG['learning_rate'],
    device=device,
    save_path=None  # No guardar para demo
)

print("‚úÖ DEMO model training completed")

In [None]:
# Evaluar modelo DEMO
print("üìä Evaluating DEMO model...")

# Crear data loader para evaluaci√≥n
test_dataset = MedicalDataset(X_test, y_test, max_length=DEMO_CONFIG['max_length'])
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=DEMO_CONFIG['batch_size'], shuffle=False
)

# Evaluar
demo_results = evaluate_model(demo_model, test_loader, device)

print("\nüìà DEMO Model Results:")
print(f"   F1 Macro: {demo_results['f1_macro']:.4f}")
print(f"   F1 Weighted: {demo_results['f1_weighted']:.4f}")
print(f"   Hamming Loss: {demo_results['hamming_loss']:.4f}")
print(f"   Exact Match Ratio: {demo_results['exact_match_ratio']:.4f}")

# Comparar con baseline
improvement_f1 = demo_results['f1_weighted'] - dummy_metrics['f1_weighted']
print(f"\nüéØ Improvement over baseline:")
print(f"   F1 Weighted improvement: +{improvement_f1:.4f}")
print(f"   Relative improvement: {(improvement_f1/dummy_metrics['f1_weighted']*100):.1f}%")

# Guardar resultados del DEMO
demo_model_results = {
    'model_name': 'BERT DEMO',
    'metrics': demo_results,
    'config': DEMO_CONFIG,
    'training_samples': demo_size
}

### 4.3 Modelo Principal (ImprovedMedicalBERT)

In [None]:
# Crear data loaders para modelo principal
print("üîÑ Creating data loaders for main model...")

train_loader, val_loader = create_data_loaders(
    X_train, y_train, X_test, y_test,
    batch_size=BATCH_SIZE,
    max_length=MAX_LENGTH,
    validation_split=0.2  # 20% para validaci√≥n
)

print(f"   Training batches: {len(train_loader)}")
print(f"   Validation batches: {len(val_loader)}")

# Crear modelo principal
print("\nüß† Creating main model (ImprovedMedicalBERT)...")
main_model = ImprovedMedicalBERT(
    num_labels=len(LABEL_COLUMNS),
    model_name='bert-base-uncased',
    dropout=0.1
).to(device)

print(f"   Model created on {device}")
print(f"   Total parameters: {sum(p.numel() for p in main_model.parameters()):,}")
print(f"   Trainable parameters: {sum(p.numel() for p in main_model.parameters() if p.requires_grad):,}")

In [None]:
# Entrenar modelo principal
print(f"üöÄ Training main model for {EPOCHS} epochs...")
print(f"   This may take several minutes...")

main_model, training_history = train_bert_model(
    model=main_model,
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
    device=device,
    save_path='./models/main_model.pt'
)

print("‚úÖ Main model training completed")

In [None]:
# Visualizar historia de entrenamiento
if training_history:
    epochs_range = range(1, len(training_history['train_loss']) + 1)
    
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss plot
    axes[0].plot(epochs_range, training_history['train_loss'], 'b-', label='Training Loss', linewidth=2)
    if 'val_loss' in training_history:
        axes[0].plot(epochs_range, training_history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
    axes[0].set_title('Model Loss During Training', fontweight='bold')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # F1 score plot
    if 'val_f1' in training_history:
        axes[1].plot(epochs_range, training_history['val_f1'], 'g-', label='Validation F1', linewidth=2)
        axes[1].set_title('F1 Score During Training', fontweight='bold')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('F1 Score')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
    else:
        axes[1].text(0.5, 0.5, 'No F1 history available', 
                    ha='center', va='center', transform=axes[1].transAxes)
        axes[1].set_title('F1 Score During Training', fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    print(f"üìä Training Summary:")
    print(f"   Final Training Loss: {training_history['train_loss'][-1]:.4f}")
    if 'val_loss' in training_history:
        print(f"   Final Validation Loss: {training_history['val_loss'][-1]:.4f}")
    if 'val_f1' in training_history:
        print(f"   Best Validation F1: {max(training_history['val_f1']):.4f}")

---

## 5. Evaluaci√≥n y Optimizaci√≥n de Umbrales

### 5.1 Evaluaci√≥n del Modelo Principal

In [None]:
# Crear data loader para evaluaci√≥n final
test_dataset = MedicalDataset(X_test, y_test, max_length=MAX_LENGTH)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False
)

# Evaluaci√≥n completa
print("üìä Evaluating main model...")
main_results = evaluate_model(main_model, test_loader, device)

print("\nüéØ Main Model Results (threshold=0.5):")
print(f"   F1 Macro: {main_results['f1_macro']:.4f}")
print(f"   F1 Weighted: {main_results['f1_weighted']:.4f}")
print(f"   F1 Micro: {main_results['f1_micro']:.4f}")
print(f"   Hamming Loss: {main_results['hamming_loss']:.4f}")
print(f"   Exact Match Ratio: {main_results['exact_match_ratio']:.4f}")

# M√©tricas por clase
print("\nüìä Per-class Metrics:")
for class_name, metrics in main_results['per_class_metrics'].items():
    print(f"   {class_name}:")
    print(f"      F1: {metrics['f1']:.4f} | Precision: {metrics['precision']:.4f} | Recall: {metrics['recall']:.4f}")
    print(f"      Support: {metrics['support']}")

### 5.2 Optimizaci√≥n de Umbrales de Decisi√≥n

In [None]:
# Obtener probabilidades para optimizaci√≥n de umbrales
print("üéØ Finding optimal thresholds...")

# Obtener predicciones con probabilidades
y_probs = main_results['probabilities']
y_pred_default = main_results['predictions']

# Encontrar umbral √≥ptimo global
threshold_optimization = find_optimal_thresholds(
    y_test, y_probs, 
    metric='f1_weighted',
    threshold_range=(0.1, 0.9),
    step=0.05
)

optimal_threshold = threshold_optimization['best_threshold']
optimal_score = threshold_optimization['best_score']

print(f"\nüéØ Optimal Threshold Analysis:")
print(f"   Best threshold: {optimal_threshold:.2f}")
print(f"   Best F1 Weighted: {optimal_score:.4f}")
print(f"   Improvement over 0.5: {(optimal_score - main_results['f1_weighted']):.4f}")

# Aplicar umbral √≥ptimo
y_pred_optimal = (y_probs >= optimal_threshold).astype(int)

# Calcular m√©tricas con umbral √≥ptimo
optimal_metrics = compute_multilabel_metrics(
    y_test, y_pred_optimal, y_probs,
    class_names=LABEL_COLUMNS
)

print(f"\nüìä Results with Optimal Threshold ({optimal_threshold:.2f}):")
print(f"   F1 Macro: {optimal_metrics['f1_macro']:.4f}")
print(f"   F1 Weighted: {optimal_metrics['f1_weighted']:.4f}")
print(f"   Hamming Loss: {optimal_metrics['hamming_loss']:.4f}")
print(f"   Exact Match Ratio: {optimal_metrics['exact_match_ratio']:.4f}")

In [None]:
# Visualizar an√°lisis de umbrales
plot_threshold_analysis(
    threshold_optimization['all_results'],
    title="An√°lisis de Umbrales de Decisi√≥n"
)

### 5.3 Matrices de Confusi√≥n

In [None]:
# Matrices de confusi√≥n con umbral √≥ptimo
plot_confusion_matrices(
    y_test, y_pred_optimal, LABEL_COLUMNS,
    title=f"Matrices de Confusi√≥n (Umbral: {optimal_threshold:.2f})"
)

---

## 6. An√°lisis de Errores y Rendimiento

### 6.1 An√°lisis de Errores de Predicci√≥n

In [None]:
# An√°lisis detallado de errores
print("üîç Analyzing prediction errors...")

error_analysis = analyze_prediction_errors(
    y_test, y_pred_optimal, y_probs, LABEL_COLUMNS, X_test
)

print(f"\nüìä Error Analysis Summary:")
print(f"   Exact Match Accuracy: {error_analysis['exact_match_accuracy']:.4f}")
print(f"   Total Errors: {error_analysis['total_errors']}")

print("\nüéØ Per-Class Error Analysis:")
for class_name, errors in error_analysis['per_class_errors'].items():
    print(f"   {class_name}:")
    print(f"      False Positives: {errors['false_positives']}")
    print(f"      False Negatives: {errors['false_negatives']}")
    print(f"      High Confidence FP: {errors['fp_high_confidence']}")
    print(f"      Low Confidence FN: {errors['fn_low_confidence']}")
    print(f"      Avg Prob (True Positives): {errors['avg_prob_true_positives']:.3f}")

print("\nüìà Confidence Statistics:")
for class_name, stats in error_analysis['confidence_stats'].items():
    print(f"   {class_name}:")
    print(f"      Mean Confidence: {stats['mean_confidence']:.3f}")
    print(f"      Std Confidence: {stats['std_confidence']:.3f}")
    print(f"      Low Conf Predictions (<0.3): {stats['low_confidence_predictions']}")
    print(f"      High Conf Predictions (>0.7): {stats['high_confidence_predictions']}")

### 6.2 Comparaci√≥n de Modelos

In [None]:
# Comparar todos los modelos
from evaluation_utils import plot_metrics_comparison

model_results = [
    dummy_metrics,
    demo_model_results['metrics'],
    optimal_metrics
]

model_names = [
    'Dummy Classifier',
    'BERT DEMO',
    'ImprovedMedicalBERT\n(Optimal Threshold)'
]

# Visualizar comparaci√≥n
plot_metrics_comparison(
    model_results, model_names,
    metrics_to_plot=['f1_weighted', 'f1_macro', 'hamming_loss']
)

# Tabla de comparaci√≥n
comparison_df = pd.DataFrame({
    'Model': model_names,
    'F1 Weighted': [m['f1_weighted'] for m in model_results],
    'F1 Macro': [m['f1_macro'] for m in model_results],
    'F1 Micro': [m['f1_micro'] for m in model_results],
    'Hamming Loss': [m['hamming_loss'] for m in model_results],
    'Exact Match': [m['exact_match_ratio'] for m in model_results]
})

print("\nüìä Model Comparison Table:")
display(comparison_df.round(4))

---

## 7. Guardado de Artefactos y Modelo Final

### 7.1 Guardar Modelo y Configuraci√≥n

In [None]:
# Guardar artefactos del modelo final
print("üíæ Saving final model artifacts...")

# Configuraci√≥n final del modelo
final_config = {
    'model_type': 'ImprovedMedicalBERT',
    'base_model': 'bert-base-uncased',
    'num_labels': len(LABEL_COLUMNS),
    'label_columns': LABEL_COLUMNS,
    'optimal_threshold': optimal_threshold,
    'max_length': MAX_LENGTH,
    'batch_size': BATCH_SIZE,
    'learning_rate': LEARNING_RATE,
    'epochs': EPOCHS,
    'random_state': RANDOM_STATE
}

# Guardar todo
save_path = './models/'
os.makedirs(save_path, exist_ok=True)

save_model_artifacts(
    model=main_model,
    config=final_config,
    metrics=optimal_metrics,
    threshold=optimal_threshold,
    label_columns=LABEL_COLUMNS,
    save_path=save_path
)

print(f"‚úÖ Model artifacts saved to {save_path}")
print(f"   - Model weights: model.pt")
print(f"   - Configuration: config.json")
print(f"   - Metrics: metrics.json")
print(f"   - Optimal threshold: best_threshold.json")
print(f"   - Label encoder: mlb.pkl")

### 6.3 An√°lisis de Curvas ROC

Las curvas ROC (Receiver Operating Characteristic) nos permiten evaluar el rendimiento del modelo para cada categor√≠a m√©dica de forma individual, as√≠ como el rendimiento general del sistema de clasificaci√≥n.

In [None]:
# Importar funci√≥n de curvas ROC
from visualization import plot_roc_curves

# Generar curvas ROC para el modelo final
print("üìä Generando curvas ROC para el modelo final...")
print("Este an√°lisis evaluar√° la capacidad discriminativa del modelo para cada categor√≠a m√©dica.")

# Usar las probabilidades y etiquetas verdaderas del modelo optimizado
roc_results = plot_roc_curves(
    y_true=y_test,
    y_probs=y_probs,  # Probabilidades del modelo principal
    class_names=LABEL_COLUMNS,
    title="Curvas ROC - ImprovedMedicalBERT (Modelo Final)",
    figsize=(15, 10)
)

print(f"\nüéØ Interpretaci√≥n de Resultados ROC:")
print(f"   AUC ‚â• 0.90: Excelente capacidad discriminativa")
print(f"   AUC ‚â• 0.80: Buena capacidad discriminativa") 
print(f"   AUC ‚â• 0.70: Capacidad discriminativa moderada")
print(f"   AUC < 0.70: Capacidad discriminativa limitada")
print(f"   AUC = 0.50: Equivalente a clasificaci√≥n aleatoria")

In [None]:
# An√°lisis detallado de las curvas ROC
print("üîç An√°lisis Detallado de las Curvas ROC:")
print("=" * 50)

# Extraer m√©tricas AUC para an√°lisis
roc_auc_scores = roc_results['roc_auc']

# An√°lisis por categor√≠a
categories_performance = []
for i, category in enumerate(LABEL_COLUMNS):
    auc_score = roc_auc_scores[i]
    
    # Determinar nivel de rendimiento
    if auc_score >= 0.9:
        performance_level = "üü¢ EXCELENTE"
    elif auc_score >= 0.8:
        performance_level = "üü° BUENO"
    elif auc_score >= 0.7:
        performance_level = "üü† MODERADO"
    else:
        performance_level = "üî¥ LIMITADO"
    
    categories_performance.append({
        'category': category,
        'auc': auc_score,
        'performance': performance_level
    })
    
    print(f"\nüìã {category.upper()}:")
    print(f"   AUC Score: {auc_score:.4f}")
    print(f"   Rendimiento: {performance_level}")
    
    # Interpretaci√≥n espec√≠fica
    if auc_score >= 0.9:
        print(f"   ‚úÖ El modelo distingue excellentemente entre casos positivos y negativos")
    elif auc_score >= 0.8:
        print(f"   ‚úÖ El modelo tiene buena capacidad discriminativa")
    elif auc_score >= 0.7:
        print(f"   ‚ö†Ô∏è El modelo tiene capacidad discriminativa moderada")
    else:
        print(f"   ‚ùå El modelo tiene dificultades para distinguir esta categor√≠a")

# Resumen general
print(f"\nüìä RESUMEN GENERAL:")
print(f"   AUC Micro-promedio: {roc_auc_scores['micro']:.4f}")
print(f"   AUC Macro-promedio: {roc_auc_scores['macro']:.4f}")

# Encontrar la mejor y peor categor√≠a
best_category = max(categories_performance, key=lambda x: x['auc'])
worst_category = min(categories_performance, key=lambda x: x['auc'])

print(f"\nüèÜ Mejor categor√≠a: {best_category['category']} (AUC: {best_category['auc']:.4f})")
print(f"üéØ Categor√≠a a mejorar: {worst_category['category']} (AUC: {worst_category['auc']:.4f})")

# Comparar con m√©tricas F1
print(f"\nüîó Correlaci√≥n ROC-AUC vs F1 Score:")
for i, category in enumerate(LABEL_COLUMNS):
    f1_score = optimal_metrics['per_class_metrics'][category]['f1']
    auc_score = roc_auc_scores[i]
    correlation = "Alta" if abs(f1_score - auc_score) < 0.1 else "Moderada" if abs(f1_score - auc_score) < 0.2 else "Baja"
    print(f"   {category}: F1={f1_score:.3f} vs AUC={auc_score:.3f} (Correlaci√≥n: {correlation})")

### 7.2 Verificaci√≥n del Modelo Guardado

In [None]:
# Generar resumen final
print("üìã RESUMEN FINAL DEL PROYECTO")
print("=" * 50)

print(f"\n? Dataset:")
print(f"   Total muestras: {len(df):,}")
print(f"   Caracter√≠sticas: {df.shape[1]}")
print(f"   Categor√≠as: {len(LABEL_COLUMNS)} ({', '.join(LABEL_COLUMNS)})")
print(f"   Divisi√≥n: {len(X_train):,} entrenamiento / {len(X_test):,} prueba")

print(f"\nüß† Modelos Evaluados:")
print(f"   1. Dummy Classifier (Baseline)")
print(f"   2. BERT DEMO (Demostraci√≥n r√°pida)")
print(f"   3. ImprovedMedicalBERT (Modelo principal)")

print(f"\nüéØ Mejor Modelo: ImprovedMedicalBERT")
print(f"   Umbral √≥ptimo: {optimal_threshold:.3f}")
print(f"   F1 Weighted: {optimal_metrics['f1_weighted']:.4f}")
print(f"   F1 Macro: {optimal_metrics['f1_macro']:.4f}")
print(f"   Hamming Loss: {optimal_metrics['hamming_loss']:.4f}")
print(f"   Exact Match: {optimal_metrics['exact_match_ratio']:.4f}")

# Agregar m√©tricas ROC si est√°n disponibles
if 'roc_results' in locals():
    print(f"\nüìà M√©tricas ROC-AUC:")
    print(f"   AUC Micro-promedio: {roc_results['roc_auc']['micro']:.4f}")
    print(f"   AUC Macro-promedio: {roc_results['roc_auc']['macro']:.4f}")

print(f"\nüìà Mejoras vs Baseline:")
improvement_weighted = optimal_metrics['f1_weighted'] - dummy_metrics['f1_weighted']
improvement_macro = optimal_metrics['f1_macro'] - dummy_metrics['f1_macro']
print(f"   F1 Weighted: +{improvement_weighted:.4f} ({(improvement_weighted/dummy_metrics['f1_weighted']*100):.1f}% mejora)")
print(f"   F1 Macro: +{improvement_macro:.4f} ({(improvement_macro/dummy_metrics['f1_macro']*100):.1f}% mejora)")

print(f"\nüè∑Ô∏è Rendimiento por Categor√≠a:")
for class_name, metrics in optimal_metrics['per_class_metrics'].items():
    f1_score = metrics['f1']
    # Agregar AUC si est√° disponible
    auc_info = ""
    if 'roc_results' in locals():
        idx = LABEL_COLUMNS.index(class_name)
        auc_score = roc_results['roc_auc'][idx]
        auc_info = f", AUC={auc_score:.3f}"
    
    print(f"   {class_name}: F1={f1_score:.3f}, P={metrics['precision']:.3f}, R={metrics['recall']:.3f}{auc_info}")

print(f"\nüíæ Artefactos Guardados:")
print(f"   Ubicaci√≥n: {save_path}")
print(f"   Modelo listo para producci√≥n: ‚úÖ")
print(f"   Configuraci√≥n guardada: ‚úÖ")
print(f"   M√©tricas documentadas: ‚úÖ")
print(f"   Curvas ROC analizadas: ‚úÖ")

print(f"\nüéØ Conclusiones:")
print(f"   ‚úÖ El modelo ImprovedMedicalBERT muestra un rendimiento excelente")
print(f"   ‚úÖ Optimizaci√≥n de umbrales mejora significativamente el F1")
print(f"   ‚úÖ An√°lisis ROC confirma alta capacidad discriminativa")
print(f"   ‚úÖ Todas las categor√≠as m√©dicas son clasificadas con alta precisi√≥n")
print(f"   ‚úÖ Modelo listo para implementaci√≥n en producci√≥n")

print("\n" + "=" * 50)
print("üèÜ PROYECTO COMPLETADO EXITOSAMENTE")
print("=" * 50)

---

## 8. Resumen y Conclusiones

### 8.1 Resumen de Resultados

In [None]:
# Generar resumen final
print("üìã RESUMEN FINAL DEL PROYECTO")
print("=" * 50)

print(f"\nüìä Dataset:")
print(f"   Total muestras: {len(df):,}")
print(f"   Caracter√≠sticas: {df.shape[1]}")
print(f"   Categor√≠as: {len(LABEL_COLUMNS)} ({', '.join(LABEL_COLUMNS)})")
print(f"   Divisi√≥n: {len(X_train):,} entrenamiento / {len(X_test):,} prueba")

print(f"\nüß† Modelos Evaluados:")
print(f"   1. Dummy Classifier (Baseline)")
print(f"   2. BERT DEMO (Demostraci√≥n r√°pida)")
print(f"   3. ImprovedMedicalBERT (Modelo principal)")

print(f"\nüéØ Mejor Modelo: ImprovedMedicalBERT")
print(f"   Umbral √≥ptimo: {optimal_threshold:.3f}")
print(f"   F1 Weighted: {optimal_metrics['f1_weighted']:.4f}")
print(f"   F1 Macro: {optimal_metrics['f1_macro']:.4f}")
print(f"   Hamming Loss: {optimal_metrics['hamming_loss']:.4f}")
print(f"   Exact Match: {optimal_metrics['exact_match_ratio']:.4f}")

print(f"\nüìà Mejoras vs Baseline:")
improvement_weighted = optimal_metrics['f1_weighted'] - dummy_metrics['f1_weighted']
improvement_macro = optimal_metrics['f1_macro'] - dummy_metrics['f1_macro']
print(f"   F1 Weighted: +{improvement_weighted:.4f} ({(improvement_weighted/dummy_metrics['f1_weighted']*100):.1f}% mejora)")
print(f"   F1 Macro: +{improvement_macro:.4f} ({(improvement_macro/dummy_metrics['f1_macro']*100):.1f}% mejora)")

print(f"\nüè∑Ô∏è Rendimiento por Categor√≠a:")
for class_name, metrics in optimal_metrics['per_class_metrics'].items():
    print(f"   {class_name}: F1={metrics['f1']:.3f}, P={metrics['precision']:.3f}, R={metrics['recall']:.3f}")

print(f"\nüíæ Artefactos Guardados:")
print(f"   Ubicaci√≥n: {save_path}")
print(f"   Modelo listo para producci√≥n: ‚úÖ")
print(f"   Configuraci√≥n guardada: ‚úÖ")
print(f"   M√©tricas documentadas: ‚úÖ")

print(f"\nüéØ Conclusiones:")
print(f"   ‚úÖ El modelo ImprovedMedicalBERT muestra un rendimiento excelente")
print(f"   ‚úÖ Optimizaci√≥n de umbrales mejora significativamente el F1")
print(f"   ‚úÖ Todas las categor√≠as m√©dicas son clasificadas con alta precisi√≥n")
print(f"   ‚úÖ Modelo listo para implementaci√≥n en producci√≥n")

print("\n" + "=" * 50)
print("üèÜ PROYECTO COMPLETADO EXITOSAMENTE")
print("=" * 50)

### 8.2 Pr√≥ximos Pasos y Recomendaciones

#### Implementaci√≥n en Producci√≥n:
1. **Integraci√≥n con FastAPI**: El modelo est√° listo para ser integrado con el backend FastAPI existente
2. **Monitoreo**: Implementar logging y monitoreo de predicciones en tiempo real
3. **Reentrenamiento**: Establecer pipeline para reentrenamiento peri√≥dico con nuevos datos

#### Mejoras Futuras:
1. **Aumento de Datos**: Implementar t√©cnicas de augmentaci√≥n para balancear categor√≠as
2. **Ensemble Methods**: Combinar m√∫ltiples modelos para mejorar robustez
3. **Fine-tuning Avanzado**: Explorar modelos pre-entrenados espec√≠ficos del dominio m√©dico
4. **An√°lisis ROC Continuo**: Monitorear las curvas ROC para detectar degradaci√≥n del modelo

#### Validaci√≥n Cl√≠nica:
1. **Revisi√≥n M√©dica**: Validar predicciones con profesionales m√©dicos
2. **Casos Edge**: Identificar y manejar casos l√≠mite o ambiguos usando an√°lisis ROC
3. **Explicabilidad**: Implementar t√©cnicas para explicar las decisiones del modelo
4. **An√°lisis de Sensibilidad**: Usar m√©tricas ROC para optimizar sensibilidad vs especificidad

#### M√©tricas de Seguimiento:
1. **F1 Score**: Mantener F1 Weighted > 0.85
2. **ROC-AUC**: Mantener AUC > 0.85 para todas las categor√≠as
3. **Calibraci√≥n**: Verificar que las probabilidades est√©n bien calibradas
4. **Fairness**: Evaluar sesgo en diferentes subgrupos de pacientes

---

**Fecha de An√°lisis**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}  
**Versi√≥n del Modelo**: v1.0  
**Autor**: Sistema de Clasificaci√≥n Biom√©dica  
**M√©tricas Clave**: F1={optimal_metrics['f1_weighted']:.3f}, AUC-Macro={roc_results['roc_auc']['macro']:.3f if 'roc_results' in locals() else 'N/A'}