In [None]:
# ============================================================================
# CELDA 1: IMPORTS Y CONFIGURACIÓN INICIAL
# ============================================================================

import os
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import json
from tqdm.auto import tqdm

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as T
from torchvision.transforms import functional as TF
import torchvision.models as models

# Timm para modelos
import timm

# Sklearn
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler, RobustScaler
from sklearn.metrics import (
    accuracy_score, f1_score, precision_score, recall_score,
    confusion_matrix, classification_report, roc_auc_score
)
from sklearn.utils.class_weight import compute_class_weight

# Imbalanced-learn para SMOTE
from imblearn.over_sampling import SMOTE, ADASYN
from imblearn.combine import SMOTETomek

# Albumentations para augmentations avanzadas
import albumentations as A
from albumentations.pytorch import ToTensorV2

# PIL
from PIL import Image


In [None]:
# Configuración de dispositivo
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memoria disponible: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    
# Configuración general
class Config:
    # Rutas
    BASE_DIR = Path('skin-cancer-local')
    IMAGE_DIRS = [
        BASE_DIR / 'images' / 'imgs_part_1',
        BASE_DIR / 'images' / 'imgs_part_2',
        BASE_DIR / 'images' / 'imgs_part_3'
    ]
    METADATA_PATH = BASE_DIR / 'metadata.csv'
    OUTPUT_DIR = Path('outputs')
    MODELS_DIR = OUTPUT_DIR / 'models'
    LOGS_DIR = OUTPUT_DIR / 'logs'
    
    # Parámetros de imagen
    IMG_SIZE = 128  # Resolución alta para EfficientNetV2
    IMG_MEAN = [0.485, 0.456, 0.406]
    IMG_STD = [0.229, 0.224, 0.225]
    
    # Arquitectura
    EFFICIENTNET_MODEL = 'tf_efficientnetv2_m'  # Modelo M para mejor performance
    USE_PRETRAINED = True
    DROP_PATH_RATE = 0.2
    DROP_RATE = 0.3
    
    # TabTransformer
    TAB_EMBED_DIM = 128
    TAB_NUM_HEADS = 8
    TAB_NUM_LAYERS = 6
    TAB_DROPOUT = 0.3
    
    # Fusión
    FUSION_HIDDEN_DIMS = [512, 256, 128]
    FUSION_DROPOUT = 0.4
    
    # Entrenamiento
    BATCH_SIZE = 16  # Ajustar según GPU
    EPOCHS = 100
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5
    WARMUP_EPOCHS = 5
    PATIENCE = 15
    MIN_LR = 1e-7
    
    # Cross-validation
    N_FOLDS = 5
    
    # Augmentation
    MIXUP_ALPHA = 0.4
    CUTMIX_ALPHA = 1.0
    MIXUP_PROB = 0.5
    
    # TTA
    TTA_TRANSFORMS = 5
    
    # Random seed
    SEED = 42
    
    # Otros
    NUM_WORKERS = 1
    PIN_MEMORY = True
    AMP_ENABLED = True  # Automatic Mixed Precision

# Crear directorios
Config.OUTPUT_DIR.mkdir(exist_ok=True)
Config.MODELS_DIR.mkdir(exist_ok=True)
Config.LOGS_DIR.mkdir(exist_ok=True)

# Set seeds para reproducibilidad
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(Config.SEED)

print("✅ Configuración completada")
print(f"   Tamaño de imagen: {Config.IMG_SIZE}x{Config.IMG_SIZE}")
print(f"   Modelo: {Config.EFFICIENTNET_MODEL}")
print(f"   Batch size: {Config.BATCH_SIZE}")
print(f"   Épocas: {Config.EPOCHS}")

In [None]:
# ============================================================================
# CELDA 2: CARGA Y ANÁLISIS EXPLORATORIO DE DATOS
# ============================================================================

# Cargar metadata
df = pd.read_csv(Config.METADATA_PATH)
print(f"📊 Dataset cargado: {len(df)} muestras")
print(f"   Columnas: {df.shape[1]}")

# Análisis inicial
print("\n📋 Primeras filas:")
print(df.head())

print("\n🔍 Información del dataset:")
print(df.info())

print("\n📈 Estadísticas descriptivas:")
print(df.describe())

# Verificar imágenes existentes
def find_image_path(img_id):
    """Busca la imagen en los directorios"""
    for img_dir in Config.IMAGE_DIRS:
        img_path = img_dir / img_id
        if img_path.exists():
            return str(img_path)
    return None

print("\n🖼️ Verificando imágenes...")
df['image_path'] = df['img_id'].apply(find_image_path)
missing_images = df['image_path'].isna().sum()
print(f"   Imágenes encontradas: {len(df) - missing_images}/{len(df)}")
if missing_images > 0:
    print(f"   ⚠️ Imágenes faltantes: {missing_images}")
    df = df.dropna(subset=['image_path'])

# Análisis de la variable objetivo
print("\n🎯 Distribución de diagnósticos:")
diagnostic_counts = df['diagnostic'].value_counts()
print(diagnostic_counts)
print(f"\n   Clases únicas: {df['diagnostic'].nunique()}")
print(f"   Desbalanceo máximo: {diagnostic_counts.max() / diagnostic_counts.min():.2f}x")

# Visualización de distribución
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
diagnostic_counts.plot(kind='bar', color='steelblue')
plt.title('Distribución de Diagnósticos', fontsize=14, fontweight='bold')
plt.xlabel('Diagnóstico')
plt.ylabel('Frecuencia')
plt.xticks(rotation=45)
plt.grid(axis='y', alpha=0.3)

plt.subplot(1, 2, 2)
plt.pie(diagnostic_counts.values, labels=diagnostic_counts.index, autopct='%1.1f%%')
plt.title('Proporción de Clases', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig(Config.OUTPUT_DIR / 'class_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

# Análisis de variables categóricas
categorical_cols = ['smoke', 'drink', 'gender', 'background_father', 'background_mother',
                   'skin_cancer_history', 'cancer_history', 'has_piped_water', 
                   'has_sewage_system', 'region', 'itch', 'grew', 'hurt', 
                   'changed', 'bleed', 'elevation', 'biopsed']

print("\n📊 Valores únicos en variables categóricas:")
for col in categorical_cols:
    if col in df.columns:
        unique_vals = df[col].nunique()
        missing = df[col].isna().sum()
        print(f"   {col}: {unique_vals} valores únicos, {missing} missing ({missing/len(df)*100:.1f}%)")

# Análisis de variables numéricas
numerical_cols = ['age', 'diameter_1', 'diameter_2', 'fitspatrick']

print("\n📊 Estadísticas de variables numéricas:")
for col in numerical_cols:
    if col in df.columns:
        print(f"\n   {col}:")
        print(f"      Media: {df[col].mean():.2f}")
        print(f"      Mediana: {df[col].median():.2f}")
        print(f"      Std: {df[col].std():.2f}")
        print(f"      Missing: {df[col].isna().sum()} ({df[col].isna().sum()/len(df)*100:.1f}%)")

# Correlación con diagnóstico
print("\n🔗 Análisis de correlaciones con diagnóstico:")
df_encoded = df.copy()
le_diag = LabelEncoder()
df_encoded['diagnostic_encoded'] = le_diag.fit_transform(df['diagnostic'])

for col in numerical_cols:
    if col in df.columns and df[col].notna().sum() > 0:
        corr = df_encoded[[col, 'diagnostic_encoded']].corr().iloc[0, 1]
        print(f"   {col}: {corr:.3f}")

# Guardar información procesada
print(f"\n💾 Dataset final: {len(df)} muestras con imágenes válidas")
df.to_csv(Config.OUTPUT_DIR / 'processed_metadata.csv', index=False)
print("✅ Análisis exploratorio completado")

In [None]:
# ============================================================================
# CELDA 3: PREPROCESAMIENTO AVANZADO DE METADATOS
# ============================================================================

class MetadataPreprocessor:
    """Preprocesador avanzado para metadatos tabulares"""
    
    def __init__(self):
        self.label_encoders = {}
        self.scalers = {}
        self.feature_names = []
        self.categorical_features = []
        self.numerical_features = []
        self.categorical_dims = {}  # Dimensiones para embedding
        
    def fit_transform(self, df):
        """Ajusta y transforma los metadatos"""
        df = df.copy()
        processed_features = []
        
        # 1. VARIABLES CATEGÓRICAS
        categorical_cols = [
            'smoke', 'drink', 'gender', 'background_father', 'background_mother',
            'pesticide', 'skin_cancer_history', 'cancer_history', 
            'has_piped_water', 'has_sewage_system', 'region',
            'itch', 'grew', 'hurt', 'changed', 'bleed', 'elevation', 'biopsed'
        ]
        
        for col in categorical_cols:
            if col not in df.columns:
                continue
                
            # Manejar missing values
            df[col] = df[col].fillna('UNKNOWN')
            
            # Convertir booleanos
            if df[col].dtype == 'bool':
                df[col] = df[col].astype(str)
            
            # Label encoding
            le = LabelEncoder()
            df[f'{col}_encoded'] = le.fit_transform(df[col].astype(str))
            
            self.label_encoders[col] = le
            self.categorical_features.append(f'{col}_encoded')
            self.categorical_dims[f'{col}_encoded'] = len(le.classes_)
            processed_features.append(f'{col}_encoded')
        
        # 2. VARIABLES NUMÉRICAS
        numerical_cols = ['age', 'diameter_1', 'diameter_2', 'fitspatrick']
        
        for col in numerical_cols:
            if col not in df.columns:
                continue
            
            # Imputación con mediana
            median_val = df[col].median()
            df[col] = df[col].fillna(median_val)
            
            # Escalado robusto (mejor para outliers)
            scaler = RobustScaler()
            df[f'{col}_scaled'] = scaler.fit_transform(df[[col]])
            
            self.scalers[col] = scaler
            self.numerical_features.append(f'{col}_scaled')
            processed_features.append(f'{col}_scaled')
        
        # 3. FEATURE ENGINEERING
        # Ratio de diámetros (si ambos existen)
        if 'diameter_1' in df.columns and 'diameter_2' in df.columns:
            df['diameter_ratio'] = df['diameter_1'] / (df['diameter_2'] + 1e-6)
            df['diameter_ratio'] = df['diameter_ratio'].fillna(1.0)
            
            scaler_ratio = RobustScaler()
            df['diameter_ratio_scaled'] = scaler_ratio.fit_transform(df[['diameter_ratio']])
            
            self.scalers['diameter_ratio'] = scaler_ratio
            self.numerical_features.append('diameter_ratio_scaled')
            processed_features.append('diameter_ratio_scaled')
        
        # Área aproximada
        if 'diameter_1' in df.columns and 'diameter_2' in df.columns:
            df['lesion_area'] = df['diameter_1'] * df['diameter_2']
            df['lesion_area'] = df['lesion_area'].fillna(df['lesion_area'].median())
            
            scaler_area = RobustScaler()
            df['lesion_area_scaled'] = scaler_area.fit_transform(df[['lesion_area']])
            
            self.scalers['lesion_area'] = scaler_area
            self.numerical_features.append('lesion_area_scaled')
            processed_features.append('lesion_area_scaled')
        
        # Age groups (binning)
        if 'age' in df.columns:
            df['age_group'] = pd.cut(df['age'], bins=[0, 18, 35, 50, 65, 100], 
                                     labels=['0-18', '19-35', '36-50', '51-65', '66+'])
            # Convertir a string antes de fillna para evitar error con categorías
            df['age_group'] = df['age_group'].astype(str).replace('nan', 'UNKNOWN')
            
            le_age = LabelEncoder()
            df['age_group_encoded'] = le_age.fit_transform(df['age_group'])
            
            self.label_encoders['age_group'] = le_age
            self.categorical_features.append('age_group_encoded')
            self.categorical_dims['age_group_encoded'] = len(le_age.classes_)
            processed_features.append('age_group_encoded')
        
        # Conteo de síntomas (feature agregada)
        symptom_cols = ['itch', 'grew', 'hurt', 'changed', 'bleed', 'elevation']
        symptom_encoded = [f'{col}_encoded' for col in symptom_cols if f'{col}_encoded' in df.columns]
        
        if len(symptom_encoded) > 0:
            df['symptom_count'] = df[symptom_encoded].sum(axis=1)
            
            scaler_symp = RobustScaler()
            df['symptom_count_scaled'] = scaler_symp.fit_transform(df[['symptom_count']])
            
            self.scalers['symptom_count'] = scaler_symp
            self.numerical_features.append('symptom_count_scaled')
            processed_features.append('symptom_count_scaled')
        
        # Conteo de factores de riesgo
        risk_cols = ['smoke_encoded', 'drink_encoded', 'skin_cancer_history_encoded', 
                     'cancer_history_encoded']
        risk_encoded = [col for col in risk_cols if col in df.columns]
        
        if len(risk_encoded) > 0:
            df['risk_score'] = df[risk_encoded].sum(axis=1)
            
            scaler_risk = RobustScaler()
            df['risk_score_scaled'] = scaler_risk.fit_transform(df[['risk_score']])
            
            self.scalers['risk_score'] = scaler_risk
            self.numerical_features.append('risk_score_scaled')
            processed_features.append('risk_score_scaled')
        
        self.feature_names = processed_features
        
        # Retornar array con todas las features
        X = df[processed_features].values.astype(np.float32)
        
        print(f"✅ Preprocesamiento completado:")
        print(f"   Features categóricas: {len(self.categorical_features)}")
        print(f"   Features numéricas: {len(self.numerical_features)}")
        print(f"   Total features: {len(self.feature_names)}")
        print(f"   Shape: {X.shape}")
        
        return X, df
    
    def transform(self, df):
        """Transforma nuevos datos usando los ajustes previos"""
        df = df.copy()
        processed_features = []
        
        # Categóricas
        for col, le in self.label_encoders.items():
            if col == 'age_group':
                df['age_group'] = pd.cut(df['age'], bins=[0, 18, 35, 50, 65, 100], 
                                        labels=['0-18', '19-35', '36-50', '51-65', '66+'])
                # Convertir a string antes de fillna
                df['age_group'] = df['age_group'].astype(str).replace('nan', 'UNKNOWN')
                
                # Manejar categorías no vistas
                df['age_group'] = df['age_group'].apply(
                    lambda x: x if x in le.classes_ else 'UNKNOWN'
                )
                df[f'{col}_encoded'] = le.transform(df['age_group'])
            else:
                df[col] = df[col].fillna('UNKNOWN')
                if df[col].dtype == 'bool':
                    df[col] = df[col].astype(str)
                
                # Manejar categorías no vistas
                df[col] = df[col].apply(
                    lambda x: x if x in le.classes_ else 'UNKNOWN'
                )
                df[f'{col}_encoded'] = le.transform(df[col].astype(str))
        
        # Numéricas
        for col, scaler in self.scalers.items():
            if col in ['diameter_ratio', 'lesion_area', 'symptom_count', 'risk_score']:
                # Recalcular features derivadas
                if col == 'diameter_ratio':
                    df['diameter_ratio'] = df['diameter_1'] / (df['diameter_2'] + 1e-6)
                    df['diameter_ratio'] = df['diameter_ratio'].fillna(1.0)
                elif col == 'lesion_area':
                    df['lesion_area'] = df['diameter_1'] * df['diameter_2']
                    df['lesion_area'] = df['lesion_area'].fillna(df['lesion_area'].median())
                elif col == 'symptom_count':
                    symptom_cols = ['itch_encoded', 'grew_encoded', 'hurt_encoded', 
                                   'changed_encoded', 'bleed_encoded', 'elevation_encoded']
                    symptom_encoded = [c for c in symptom_cols if c in df.columns]
                    df['symptom_count'] = df[symptom_encoded].sum(axis=1)
                elif col == 'risk_score':
                    risk_cols = ['smoke_encoded', 'drink_encoded', 
                                'skin_cancer_history_encoded', 'cancer_history_encoded']
                    risk_encoded = [c for c in risk_cols if c in df.columns]
                    df['risk_score'] = df[risk_encoded].sum(axis=1)
                
                df[f'{col}_scaled'] = scaler.transform(df[[col]])
            else:
                orig_col = col
                df[orig_col] = df[orig_col].fillna(df[orig_col].median())
                df[f'{col}_scaled'] = scaler.transform(df[[orig_col]])
        
        X = df[self.feature_names].values.astype(np.float32)
        return X, df

# Aplicar preprocesamiento
preprocessor = MetadataPreprocessor()
X_metadata, df_processed = preprocessor.fit_transform(df)

print(f"\n📊 Metadata shape: {X_metadata.shape}")
print(f"   Rango de valores: [{X_metadata.min():.3f}, {X_metadata.max():.3f}]")

# Guardar preprocessor
import pickle
with open(Config.OUTPUT_DIR / 'metadata_preprocessor.pkl', 'wb') as f:
    pickle.dump(preprocessor, f)

print("✅ Preprocessor guardado")

In [None]:
# ============================================================================
# CELDA 4: AUGMENTATIONS AVANZADAS Y DATASET
# ============================================================================

class AdvancedAugmentation:
    """Augmentations de última generación para dermatología"""
    
    def __init__(self, img_size=512, mode='train'):
        self.img_size = img_size
        self.mode = mode
        
        if mode == 'train':
            self.transform = A.Compose([
                # Redimensionamiento
                A.Resize(img_size, img_size),
                
                # Augmentations geométricas
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.RandomRotate90(p=0.5),
                A.ShiftScaleRotate(
                    shift_limit=0.1,
                    scale_limit=0.2,
                    rotate_limit=45,
                    border_mode=0,
                    p=0.7
                ),
                
                # Distorsiones ópticas (importantes para lesiones)
                A.OneOf([
                    A.ElasticTransform(alpha=1, sigma=50, p=1.0),
                    A.GridDistortion(num_steps=5, distort_limit=0.3, p=1.0),
                    A.OpticalDistortion(distort_limit=0.5, shift_limit=0.5, p=1.0),
                ], p=0.3),
                
                # Cambios de perspectiva
                A.Perspective(scale=(0.05, 0.1), p=0.3),
                
                # Augmentations de color (críticas para dermatología)
                A.OneOf([
                    A.HueSaturationValue(
                        hue_shift_limit=20,
                        sat_shift_limit=30,
                        val_shift_limit=20,
                        p=1.0
                    ),
                    A.RGBShift(
                        r_shift_limit=20,
                        g_shift_limit=20,
                        b_shift_limit=20,
                        p=1.0
                    ),
                    A.ColorJitter(
                        brightness=0.2,
                        contrast=0.2,
                        saturation=0.2,
                        hue=0.1,
                        p=1.0
                    ),
                ], p=0.8),
                
                # Simulación de condiciones de iluminación
                A.OneOf([
                    A.RandomBrightnessContrast(
                        brightness_limit=0.3,
                        contrast_limit=0.3,
                        p=1.0
                    ),
                    A.RandomGamma(gamma_limit=(80, 120), p=1.0),
                    A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=1.0),
                ], p=0.5),
                
                # Efectos de desenfoque (importante para diferentes calidades de imagen)
                A.OneOf([
                    A.GaussianBlur(blur_limit=(3, 7), p=1.0),
                    A.MotionBlur(blur_limit=7, p=1.0),
                    A.MedianBlur(blur_limit=7, p=1.0),
                ], p=0.2),
                
                # Ruido (simula diferentes calidades de cámara)
                A.OneOf([
                    A.GaussNoise(var_limit=(10.0, 50.0), p=1.0),
                    A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=1.0),
                    A.MultiplicativeNoise(multiplier=(0.9, 1.1), p=1.0),
                ], p=0.2),
                
                # Compresión JPEG (realista)
                A.ImageCompression(quality_lower=75, quality_upper=100, p=0.3),
                
                # Cutout avanzado
                A.CoarseDropout(
                    max_holes=8,
                    max_height=int(img_size * 0.15),
                    max_width=int(img_size * 0.15),
                    min_holes=1,
                    fill_value=0,
                    p=0.3
                ),
                
                # Normalización
                A.Normalize(
                    mean=Config.IMG_MEAN,
                    std=Config.IMG_STD,
                    max_pixel_value=255.0
                ),
                
                ToTensorV2()
            ])
        
        else:  # val/test
            self.transform = A.Compose([
                A.Resize(img_size, img_size),
                A.Normalize(
                    mean=Config.IMG_MEAN,
                    std=Config.IMG_STD,
                    max_pixel_value=255.0
                ),
                ToTensorV2()
            ])
    
    def __call__(self, image):
        if isinstance(image, str):
            image = np.array(Image.open(image).convert('RGB'))
        elif isinstance(image, Image.Image):
            image = np.array(image.convert('RGB'))
        
        return self.transform(image=image)['image']


class TTATransform:
    """Test-Time Augmentation transforms"""
    
    def __init__(self, img_size=512):
        self.img_size = img_size
        self.transforms = [
            # Original
            A.Compose([
                A.Resize(img_size, img_size),
                A.Normalize(mean=Config.IMG_MEAN, std=Config.IMG_STD),
                ToTensorV2()
            ]),
            # Flip horizontal
            A.Compose([
                A.Resize(img_size, img_size),
                A.HorizontalFlip(p=1.0),
                A.Normalize(mean=Config.IMG_MEAN, std=Config.IMG_STD),
                ToTensorV2()
            ]),
            # Flip vertical
            A.Compose([
                A.Resize(img_size, img_size),
                A.VerticalFlip(p=1.0),
                A.Normalize(mean=Config.IMG_MEAN, std=Config.IMG_STD),
                ToTensorV2()
            ]),
            # Rotate 90
            A.Compose([
                A.Resize(img_size, img_size),
                A.Rotate(limit=90, p=1.0),
                A.Normalize(mean=Config.IMG_MEAN, std=Config.IMG_STD),
                ToTensorV2()
            ]),
            # Brightness
            A.Compose([
                A.Resize(img_size, img_size),
                A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1.0),
                A.Normalize(mean=Config.IMG_MEAN, std=Config.IMG_STD),
                ToTensorV2()
            ])
        ]
    
    def __call__(self, image):
        if isinstance(image, str):
            image = np.array(Image.open(image).convert('RGB'))
        elif isinstance(image, Image.Image):
            image = np.array(image.convert('RGB'))
        
        return [t(image=image)['image'] for t in self.transforms]


class SkinCancerDataset(Dataset):
    """Dataset avanzado con soporte para mixup y cutmix"""
    
    def __init__(self, image_paths, metadata, labels, transform=None, mode='train'):
        self.image_paths = image_paths
        self.metadata = torch.FloatTensor(metadata)
        self.labels = torch.LongTensor(labels)
        self.transform = transform
        self.mode = mode
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Cargar imagen
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        else:
            image = T.ToTensor()(image)
        
        metadata = self.metadata[idx]
        label = self.labels[idx]
        
        return {
            'image': image,
            'metadata': metadata,
            'label': label
        }


def mixup_data(x_img, x_meta, y, alpha=0.4):
    """Mixup augmentation para imagen y metadata"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    
    batch_size = x_img.size(0)
    index = torch.randperm(batch_size).to(x_img.device)
    
    mixed_img = lam * x_img + (1 - lam) * x_img[index]
    mixed_meta = lam * x_meta + (1 - lam) * x_meta[index]
    y_a, y_b = y, y[index]
    
    return mixed_img, mixed_meta, y_a, y_b, lam


def cutmix_data(x_img, x_meta, y, alpha=1.0):
    """CutMix augmentation"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    
    batch_size = x_img.size(0)
    index = torch.randperm(batch_size).to(x_img.device)
    
    _, _, h, w = x_img.size()
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(w * cut_rat)
    cut_h = int(h * cut_rat)
    
    cx = np.random.randint(w)
    cy = np.random.randint(h)
    
    bbx1 = np.clip(cx - cut_w // 2, 0, w)
    bby1 = np.clip(cy - cut_h // 2, 0, h)
    bbx2 = np.clip(cx + cut_w // 2, 0, w)
    bby2 = np.clip(cy + cut_h // 2, 0, h)
    
    mixed_img = x_img.clone()
    mixed_img[:, :, bby1:bby2, bbx1:bbx2] = x_img[index, :, bby1:bby2, bbx1:bbx2]
    
    # Metadata no se mezcla en CutMix
    mixed_meta = x_meta
    
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (w * h))
    y_a, y_b = y, y[index]
    
    return mixed_img, mixed_meta, y_a, y_b, lam


print("✅ Augmentations y Dataset configurados")
print(f"   Políticas de augmentation: Train (avanzado) / Val (básico)")
print(f"   TTA transforms: {Config.TTA_TRANSFORMS}")
print(f"   Mixup alpha: {Config.MIXUP_ALPHA}")
print(f"   CutMix alpha: {Config.CUTMIX_ALPHA}")

In [None]:
# ============================================================================
# CELDA 5: ARQUITECTURA DEL MODELO - FUSIÓN MULTIMODAL
# ============================================================================

class TabTransformer(nn.Module):
    """Transformer para datos tabulares con embeddings categóricos"""
    
    def __init__(self, 
                 categorical_dims,
                 numerical_features,
                 embed_dim=128,
                 num_heads=8,
                 num_layers=6,
                 dropout=0.3):
        super().__init__()
        
        self.categorical_dims = categorical_dims
        self.num_numerical = numerical_features
        self.embed_dim = embed_dim
        
        # Embeddings para features categóricas
        self.categorical_embeddings = nn.ModuleDict({
            name: nn.Embedding(dim + 1, embed_dim)
            for name, dim in categorical_dims.items()
        })
        
        # Proyección para features numéricas
        if self.num_numerical > 0:
            self.numerical_projection = nn.Sequential(
                nn.Linear(self.num_numerical, embed_dim),
                nn.LayerNorm(embed_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            )
        
        # Positional encoding
        total_features = len(categorical_dims) + (1 if self.num_numerical > 0 else 0)
        self.pos_embedding = nn.Parameter(torch.randn(1, total_features, embed_dim))
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Layer normalization final
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x, cat_indices):
        """
        x: tensor [batch, total_features]
        cat_indices: dict con índices de cada feature categórica
        """
        embeddings = []
        
        # Embeddings categóricos
        for name, dim in self.categorical_dims.items():
            idx = cat_indices[name]
            cat_vals = x[:, idx].long()
            
            # ⚠️ CLIP CRÍTICO: Asegurar que estén en rango [0, dim-1]
            cat_vals = torch.clamp(cat_vals, 0, dim - 1)
            
            emb = self.categorical_embeddings[name](cat_vals)
            embeddings.append(emb)
        
        # Proyección de features numéricas
        if self.num_numerical > 0:
            num_start = len(self.categorical_dims)
            num_features = x[:, num_start:]
            num_emb = self.numerical_projection(num_features)
            embeddings.append(num_emb)
        
        # Stack embeddings
        x = torch.stack(embeddings, dim=1)  # [batch, n_features, embed_dim]
        
        # Add positional encoding
        x = x + self.pos_embedding
        
        # Transformer
        x = self.transformer(x)
        
        # Global average pooling
        x = x.mean(dim=1)  # [batch, embed_dim]
        
        x = self.norm(x)
        
        return x


class MultimodalSkinCancerModel(nn.Module):
    """Modelo multimodal: EfficientNetV2 + TabTransformer + Late Fusion"""
    
    def __init__(self, 
                 num_classes,
                 categorical_dims,
                 num_numerical_features,
                 img_model_name='tf_efficientnetv2_m',
                 pretrained=True):
        super().__init__()
        
        # ===== BRANCH 1: IMAGE MODEL (EfficientNetV2) =====
        self.image_model = timm.create_model(
            img_model_name,
            pretrained=pretrained,
            num_classes=0,  # Remove classifier
            drop_rate=Config.DROP_RATE,
            drop_path_rate=Config.DROP_PATH_RATE
        )
        
        # Obtener dimensión de salida
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, Config.IMG_SIZE, Config.IMG_SIZE)
            img_features = self.image_model(dummy_input)
            self.img_feature_dim = img_features.shape[1]
        
        print(f"✅ EfficientNetV2 cargado: {self.img_feature_dim} features")
        
        # ===== BRANCH 2: METADATA MODEL (TabTransformer) =====
        self.metadata_model = TabTransformer(
            categorical_dims=categorical_dims,
            numerical_features=num_numerical_features,
            embed_dim=Config.TAB_EMBED_DIM,
            num_heads=Config.TAB_NUM_HEADS,
            num_layers=Config.TAB_NUM_LAYERS,
            dropout=Config.TAB_DROPOUT
        )
        
        self.meta_feature_dim = Config.TAB_EMBED_DIM
        
        print(f"✅ TabTransformer construido: {self.meta_feature_dim} features")
        
        # ===== FUSION HEAD =====
        fusion_input_dim = self.img_feature_dim + self.meta_feature_dim
        
        fusion_layers = []
        prev_dim = fusion_input_dim
        
        for hidden_dim in Config.FUSION_HIDDEN_DIMS:
            fusion_layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.GELU(),
                nn.Dropout(Config.FUSION_DROPOUT)
            ])
            prev_dim = hidden_dim
        
        self.fusion_head = nn.Sequential(*fusion_layers)
        
        # Classifier
        self.classifier = nn.Linear(prev_dim, num_classes)
        
        # Attention weights para fusión interpretable (opcional)
        self.attention = nn.Sequential(
            nn.Linear(fusion_input_dim, 2),
            nn.Softmax(dim=1)
        )
        
        print(f"✅ Fusion head: {fusion_input_dim} -> {prev_dim} -> {num_classes}")
        
        # Inicialización de pesos
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Inicialización Xavier/Kaiming"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, images, metadata, cat_indices):
        """
        images: [batch, 3, H, W]
        metadata: [batch, n_features]
        cat_indices: dict con índices de features categóricas
        """
        # Image features
        img_features = self.image_model(images)  # [batch, img_dim]
        
        # Metadata features
        meta_features = self.metadata_model(metadata, cat_indices)  # [batch, meta_dim]
        
        # Concatenate
        fused = torch.cat([img_features, meta_features], dim=1)
        
        # Optional: Attention-weighted fusion
        # att_weights = self.attention(fused)  # [batch, 2]
        # img_weight = att_weights[:, 0:1]
        # meta_weight = att_weights[:, 1:2]
        # fused = torch.cat([
        #     img_features * img_weight, 
        #     meta_features * meta_weight
        # ], dim=1)
        
        # Fusion head
        fused = self.fusion_head(fused)
        
        # Classification
        logits = self.classifier(fused)
        
        return logits
    
    def freeze_image_backbone(self):
        """Congela el backbone de imagen para fine-tuning"""
        for param in self.image_model.parameters():
            param.requires_grad = False
    
    def unfreeze_image_backbone(self):
        """Descongela el backbone de imagen"""
        for param in self.image_model.parameters():
            param.requires_grad = True


# Función para crear el modelo
def create_model(num_classes, categorical_dims, num_numerical_features):
    """Factory function para crear el modelo"""
    model = MultimodalSkinCancerModel(
        num_classes=num_classes,
        categorical_dims=categorical_dims,
        num_numerical_features=num_numerical_features,
        img_model_name=Config.EFFICIENTNET_MODEL,
        pretrained=Config.USE_PRETRAINED
    )
    
    return model


# Contar parámetros
def count_parameters(model):
    """Cuenta parámetros entrenables y totales"""
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total


print("✅ Arquitectura del modelo definida")
print(f"   Modelo de imagen: {Config.EFFICIENTNET_MODEL}")
print(f"   TabTransformer: {Config.TAB_NUM_LAYERS} layers, {Config.TAB_NUM_HEADS} heads")
print(f"   Fusion dims: {Config.FUSION_HIDDEN_DIMS}")

In [None]:
# ============================================================================
# CELDA 6: LOSS FUNCTIONS Y MÉTRICAS AVANZADAS
# ============================================================================

class FocalLoss(nn.Module):
    """Focal Loss para desbalanceo de clases"""
    
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss


class LabelSmoothingCrossEntropy(nn.Module):
    """Cross Entropy con Label Smoothing"""
    
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
    
    def forward(self, pred, target):
        n_classes = pred.size(1)
        log_pred = F.log_softmax(pred, dim=1)
        
        loss = -log_pred.sum(dim=1).mean()
        nll = F.nll_loss(log_pred, target, reduction='mean')
        
        return self.smoothing * (loss / n_classes) + (1 - self.smoothing) * nll


class CombinedLoss(nn.Module):
    """Combinación de múltiples loss functions"""
    
    def __init__(self, class_weights=None, focal_gamma=2.0, label_smoothing=0.1):
        super().__init__()
        
        # Focal loss para desbalanceo
        self.focal_loss = FocalLoss(alpha=class_weights, gamma=focal_gamma)
        
        # Label smoothing para regularización
        self.ls_loss = LabelSmoothingCrossEntropy(smoothing=label_smoothing)
        
        # Pesos
        self.focal_weight = 0.7
        self.ls_weight = 0.3
    
    def forward(self, pred, target):
        focal = self.focal_loss(pred, target)
        ls = self.ls_loss(pred, target)
        
        return self.focal_weight * focal + self.ls_weight * ls


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Loss para mixup/cutmix"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


class MetricsTracker:
    """Seguimiento de métricas durante entrenamiento"""
    
    def __init__(self, num_classes, class_names):
        self.num_classes = num_classes
        self.class_names = class_names
        self.reset()
    
    def reset(self):
        self.predictions = []
        self.targets = []
        self.losses = []
    
    def update(self, preds, targets, loss):
        """
        preds: logits [batch, num_classes]
        targets: labels [batch]
        loss: scalar
        """
        pred_classes = torch.argmax(preds, dim=1)
        
        self.predictions.extend(pred_classes.cpu().numpy())
        self.targets.extend(targets.cpu().numpy())
        self.losses.append(loss)
    
    def compute(self):
        """Calcula todas las métricas"""
        preds = np.array(self.predictions)
        targets = np.array(self.targets)
        
        # Accuracy
        acc = accuracy_score(targets, preds)
        
        # Métricas por clase
        precision = precision_score(targets, preds, average='weighted', zero_division=0)
        recall = recall_score(targets, preds, average='weighted', zero_division=0)
        f1 = f1_score(targets, preds, average='weighted', zero_division=0)
        
        # Métricas macro (importante para desbalanceo)
        precision_macro = precision_score(targets, preds, average='macro', zero_division=0)
        recall_macro = recall_score(targets, preds, average='macro', zero_division=0)
        f1_macro = f1_score(targets, preds, average='macro', zero_division=0)
        
        # Loss promedio
        avg_loss = np.mean(self.losses)
        
        # Matriz de confusión
        cm = confusion_matrix(targets, preds)
        
        # Per-class metrics
        per_class_acc = cm.diagonal() / cm.sum(axis=1)
        
        metrics = {
            'loss': avg_loss,
            'accuracy': acc,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'f1_macro': f1_macro,
            'confusion_matrix': cm,
            'per_class_accuracy': dict(zip(self.class_names, per_class_acc))
        }
        
        return metrics
    
    def print_metrics(self, metrics, prefix=''):
        """Imprime métricas de forma legible"""
        print(f"\n{prefix} Métricas:")
        print(f"  Loss: {metrics['loss']:.4f}")
        print(f"  Accuracy: {metrics['accuracy']:.4f}")
        print(f"  Precision (weighted): {metrics['precision']:.4f}")
        print(f"  Recall (weighted): {metrics['recall']:.4f}")
        print(f"  F1-Score (weighted): {metrics['f1']:.4f}")
        print(f"  Precision (macro): {metrics['precision_macro']:.4f}")
        print(f"  Recall (macro): {metrics['recall_macro']:.4f}")
        print(f"  F1-Score (macro): {metrics['f1_macro']:.4f}")
        
        print(f"\n  Accuracy por clase:")
        for class_name, acc in metrics['per_class_accuracy'].items():
            print(f"    {class_name}: {acc:.4f}")


def plot_confusion_matrix(cm, class_names, save_path=None):
    """Visualiza matriz de confusión"""
    plt.figure(figsize=(12, 10))
    
    # Normalizar
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    sns.heatmap(
        cm_norm,
        annot=True,
        fmt='.2f',
        cmap='Blues',
        xticklabels=class_names,
        yticklabels=class_names,
        cbar_kws={'label': 'Proporción'}
    )
    
    plt.title('Matriz de Confusión Normalizada', fontsize=16, fontweight='bold')
    plt.ylabel('Verdadero', fontsize=12)
    plt.xlabel('Predicho', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    plt.tight_layout()
    plt.show()


def plot_training_history(history, save_path=None):
    """Visualiza historial de entrenamiento"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss
    axes[0, 0].plot(history['train_loss'], label='Train', linewidth=2)
    axes[0, 0].plot(history['val_loss'], label='Val', linewidth=2)
    axes[0, 0].set_title('Loss', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(alpha=0.3)
    
    # Accuracy
    axes[0, 1].plot(history['train_acc'], label='Train', linewidth=2)
    axes[0, 1].plot(history['val_acc'], label='Val', linewidth=2)
    axes[0, 1].set_title('Accuracy', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(alpha=0.3)
    
    # F1 Score
    axes[1, 0].plot(history['train_f1'], label='Train', linewidth=2)
    axes[1, 0].plot(history['val_f1'], label='Val', linewidth=2)
    axes[1, 0].set_title('F1-Score (Weighted)', fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('F1-Score')
    axes[1, 0].legend()
    axes[1, 0].grid(alpha=0.3)
    
    # Learning Rate
    axes[1, 1].plot(history['lr'], linewidth=2, color='green')
    axes[1, 1].set_title('Learning Rate', fontsize=14, fontweight='bold')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('LR')
    axes[1, 1].set_yscale('log')
    axes[1, 1].grid(alpha=0.3)
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    plt.tight_layout()
    plt.show()


print("✅ Loss functions y métricas configuradas")
print("   - Focal Loss (para desbalanceo)")
print("   - Label Smoothing (regularización)")
print("   - Combined Loss (0.7*Focal + 0.3*LS)")
print("   - Métricas: Accuracy, Precision, Recall, F1 (weighted y macro)")

In [None]:
# ============================================================================
# CELDA 7: PREPARACIÓN Y BALANCEO DE DATOS
# ============================================================================

# Encode target labels
le_target = LabelEncoder()
df_processed['diagnostic_encoded'] = le_target.fit_transform(df_processed['diagnostic'])

class_names = le_target.classes_
num_classes = len(class_names)

print(f"🎯 Clases detectadas: {num_classes}")
for i, name in enumerate(class_names):
    count = (df_processed['diagnostic_encoded'] == i).sum()
    print(f"   {i}: {name} - {count} muestras")

# Preparar arrays
image_paths = df_processed['image_path'].values
labels = df_processed['diagnostic_encoded'].values

# Crear índices de features categóricas para TabTransformer
cat_feature_indices = {}
idx = 0
for feat in preprocessor.categorical_features:
    cat_feature_indices[feat] = idx
    idx += 1

print(f"\n📊 Features para TabTransformer:")
print(f"   Categóricas: {len(preprocessor.categorical_features)}")
print(f"   Numéricas: {len(preprocessor.numerical_features)}")
print(f"   Total: {X_metadata.shape[1]}")

# Split estratificado train/val/test
X_temp, X_test, y_temp, y_test = train_test_split(
    np.arange(len(image_paths)),
    labels,
    test_size=0.15,
    stratify=labels,
    random_state=Config.SEED
)

X_train, X_val, y_train, y_val = train_test_split(
    X_temp,
    y_temp,
    test_size=0.15 / 0.85,  # ~15% of total
    stratify=y_temp,
    random_state=Config.SEED
)

print(f"\n📦 Splits:")
print(f"   Train: {len(X_train)} ({len(X_train)/len(labels)*100:.1f}%)")
print(f"   Val: {len(X_val)} ({len(X_val)/len(labels)*100:.1f}%)")
print(f"   Test: {len(X_test)} ({len(X_test)/len(labels)*100:.1f}%)")

# Distribución por clase en cada split
print(f"\n📊 Distribución por split:")
for split_name, split_indices in [('Train', X_train), ('Val', X_val), ('Test', X_test)]:
    print(f"\n{split_name}:")
    split_labels = labels[split_indices]
    for i, name in enumerate(class_names):
        count = (split_labels == i).sum()
        pct = count / len(split_labels) * 100
        print(f"   {name}: {count} ({pct:.1f}%)")

# Calcular class weights para loss
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train),
    y=y_train
)
class_weights = torch.FloatTensor(class_weights).to(device)

print(f"\n⚖️ Class weights calculados:")
for i, (name, weight) in enumerate(zip(class_names, class_weights)):
    print(f"   {name}: {weight:.3f}")

# OPCIONAL: SMOTE para balancear train set (solo si desbalanceo es extremo)
# Si el desbalanceo es > 10x, aplicar SMOTE
max_samples = np.max(np.bincount(y_train))
min_samples = np.min(np.bincount(y_train))
imbalance_ratio = max_samples / min_samples

print(f"\n📊 Ratio de desbalanceo en train: {imbalance_ratio:.2f}x")

if imbalance_ratio > 10:
    print("⚠️ Desbalanceo alto detectado. Aplicando SMOTE...")
    
    # SMOTE solo se puede aplicar a metadata, no a imágenes
    # Estrategia: oversample los índices
    
    # Crear estrategia de oversampling
    sampling_strategy = {}
    for i in range(num_classes):
        count = (y_train == i).sum()
        if count < max_samples * 0.5:  # Clases muy minoritarias
            sampling_strategy[i] = int(max_samples * 0.7)
    
    if len(sampling_strategy) > 0:
        smote = SMOTE(
            sampling_strategy=sampling_strategy,
            random_state=Config.SEED,
            k_neighbors=min(5, min_samples - 1)
        )
        
        # Aplicar SMOTE
        X_meta_train = X_metadata[X_train]
        X_meta_resampled, y_train_resampled = smote.fit_resample(X_meta_train, y_train)
        
        # Ahora necesitamos duplicar las imágenes correspondientes
        # Para esto, mapeamos los índices originales
        original_indices = X_train.copy()
        
        # Encontrar qué muestras fueron duplicadas
        n_original = len(X_train)
        n_resampled = len(y_train_resampled)
        n_synthetic = n_resampled - n_original
        
        print(f"   Muestras originales: {n_original}")
        print(f"   Muestras después de SMOTE: {n_resampled}")
        print(f"   Muestras sintéticas: {n_synthetic}")
        
        # Para las sintéticas, asignamos la imagen más cercana de la misma clase
        # (simplificación: usar índices aleatorios de la misma clase)
        synthetic_indices = []
        for i in range(n_synthetic):
            # Obtener clase de la muestra sintética
            synthetic_label = y_train_resampled[n_original + i]
            # Seleccionar índice aleatorio de la misma clase
            class_indices = original_indices[y_train == synthetic_label]
            synthetic_idx = np.random.choice(class_indices)
            synthetic_indices.append(synthetic_idx)
        
        # Combinar índices
        X_train_final = np.concatenate([original_indices, synthetic_indices])
        y_train_final = y_train_resampled
        X_meta_train_final = X_meta_resampled
        
        print(f"\n✅ SMOTE aplicado. Nuevo tamaño de train: {len(X_train_final)}")
        print("   Distribución después de SMOTE:")
        for i, name in enumerate(class_names):
            count = (y_train_final == i).sum()
            print(f"   {name}: {count}")
    else:
        print("   No se requiere SMOTE (desbalanceo moderado)")
        X_train_final = X_train
        y_train_final = y_train
        X_meta_train_final = X_metadata[X_train]
else:
    print("✅ Desbalanceo aceptable. No se aplica SMOTE.")
    X_train_final = X_train
    y_train_final = y_train
    X_meta_train_final = X_metadata[X_train]

# Crear datasets
train_dataset = SkinCancerDataset(
    image_paths=image_paths[X_train_final],
    metadata=X_meta_train_final,
    labels=y_train_final,
    transform=AdvancedAugmentation(img_size=Config.IMG_SIZE, mode='train'),
    mode='train'
)

val_dataset = SkinCancerDataset(
    image_paths=image_paths[X_val],
    metadata=X_metadata[X_val],
    labels=y_val,
    transform=AdvancedAugmentation(img_size=Config.IMG_SIZE, mode='val'),
    mode='val'
)

test_dataset = SkinCancerDataset(
    image_paths=image_paths[X_test],
    metadata=X_metadata[X_test],
    labels=y_test,
    transform=AdvancedAugmentation(img_size=Config.IMG_SIZE, mode='val'),
    mode='test'
)

print(f"\n✅ Datasets creados:")
print(f"   Train: {len(train_dataset)}")
print(f"   Val: {len(val_dataset)}")
print(f"   Test: {len(test_dataset)}")

# Crear dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=Config.BATCH_SIZE,
    shuffle=True,
    num_workers=Config.NUM_WORKERS,
    pin_memory=Config.PIN_MEMORY,
    drop_last=True  # Para mixup/cutmix
)

val_loader = DataLoader(
    val_dataset,
    batch_size=Config.BATCH_SIZE,
    shuffle=False,
    num_workers=Config.NUM_WORKERS,
    pin_memory=Config.PIN_MEMORY
)

test_loader = DataLoader(
    test_dataset,
    batch_size=Config.BATCH_SIZE,
    shuffle=False,
    num_workers=Config.NUM_WORKERS,
    pin_memory=Config.PIN_MEMORY
)

print(f"\n✅ DataLoaders creados:")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")
print(f"   Test batches: {len(test_loader)}")

# Guardar información para posterior uso
dataset_info = {
    'class_names': class_names.tolist(),
    'num_classes': num_classes,
    'class_weights': class_weights.cpu().numpy().tolist(),
    'train_size': len(train_dataset),
    'val_size': len(val_dataset),
    'test_size': len(test_dataset),
    'cat_feature_indices': cat_feature_indices,
    'categorical_dims': preprocessor.categorical_dims
}

with open(Config.OUTPUT_DIR / 'dataset_info.json', 'w') as f:
    json.dump(dataset_info, f, indent=2)

print("\n💾 Información del dataset guardada")

In [None]:
# ============================================================================
# CELDA 8: TRAINING LOOP COMPLETO CON TÉCNICAS AVANZADAS
# ============================================================================

class Trainer:
    """Trainer avanzado con todas las optimizaciones"""
    
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, 
                 scheduler, cat_indices, device, config):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.cat_indices = cat_indices
        self.device = device
        self.config = config
        
        # Gradient scaler para AMP
        self.scaler = torch.cuda.amp.GradScaler() if config.AMP_ENABLED else None
        
        # History
        self.history = {
            'train_loss': [], 'val_loss': [],
            'train_acc': [], 'val_acc': [],
            'train_f1': [], 'val_f1': [],
            'lr': []
        }
        
        # Best model tracking
        self.best_val_acc = 0.0
        self.best_val_f1 = 0.0
        self.best_epoch = 0
        self.patience_counter = 0
        
    def train_epoch(self, epoch):
        """Entrena una época"""
        self.model.train()
        metrics_tracker = MetricsTracker(len(class_names), class_names)
        
        pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{self.config.EPOCHS} [Train]')
        
        for batch_idx, batch in enumerate(pbar):
            images = batch['image'].to(self.device)
            metadata = batch['metadata'].to(self.device)
            labels = batch['label'].to(self.device)
            
            # Aplicar mixup/cutmix con probabilidad
            use_mixup = np.random.rand() < self.config.MIXUP_PROB
            
            if use_mixup:
                if np.random.rand() < 0.5:  # 50% mixup, 50% cutmix
                    images, metadata, labels_a, labels_b, lam = mixup_data(
                        images, metadata, labels, alpha=self.config.MIXUP_ALPHA
                    )
                else:
                    images, metadata, labels_a, labels_b, lam = cutmix_data(
                        images, metadata, labels, alpha=self.config.CUTMIX_ALPHA
                    )
            
            # Forward pass con AMP
            with torch.cuda.amp.autocast(enabled=self.config.AMP_ENABLED):
                logits = self.model(images, metadata, self.cat_indices)
                
                if use_mixup:
                    loss = mixup_criterion(self.criterion, logits, labels_a, labels_b, lam)
                else:
                    loss = self.criterion(logits, labels)
            
            # Backward pass
            self.optimizer.zero_grad()
            
            if self.scaler is not None:
                self.scaler.scale(loss).backward()
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.optimizer.step()
            
            # Update metrics
            if use_mixup:
                # Para mixup, usamos las etiquetas originales para métricas
                metrics_tracker.update(logits.detach(), labels_a, loss.item())
            else:
                metrics_tracker.update(logits.detach(), labels, loss.item())
            
            # Update progress bar
            pbar.set_postfix({'loss': loss.item()})
        
        # Compute epoch metrics
        train_metrics = metrics_tracker.compute()
        
        return train_metrics
    
    def validate(self, epoch):
        """Valida el modelo"""
        self.model.eval()
        metrics_tracker = MetricsTracker(len(class_names), class_names)
        
        pbar = tqdm(self.val_loader, desc=f'Epoch {epoch+1}/{self.config.EPOCHS} [Val]')
        
        with torch.no_grad():
            for batch in pbar:
                images = batch['image'].to(self.device)
                metadata = batch['metadata'].to(self.device)
                labels = batch['label'].to(self.device)
                
                # Forward pass
                with torch.cuda.amp.autocast(enabled=self.config.AMP_ENABLED):
                    logits = self.model(images, metadata, self.cat_indices)
                    loss = self.criterion(logits, labels)
                
                # Update metrics
                metrics_tracker.update(logits, labels, loss.item())
                
                # Update progress bar
                pbar.set_postfix({'loss': loss.item()})
        
        # Compute epoch metrics
        val_metrics = metrics_tracker.compute()
        
        return val_metrics
    
    def train(self):
        """Loop de entrenamiento completo"""
        print("\n" + "="*80)
        print("🚀 INICIANDO ENTRENAMIENTO")
        print("="*80)
        
        for epoch in range(self.config.EPOCHS):
            print(f"\n{'='*80}")
            print(f"Epoch {epoch+1}/{self.config.EPOCHS}")
            print(f"{'='*80}")
            
            # Train
            train_metrics = self.train_epoch(epoch)
            
            # Validate
            val_metrics = self.validate(epoch)
            
            # Scheduler step
            if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                self.scheduler.step(val_metrics['loss'])
            else:
                self.scheduler.step()
            
            current_lr = self.optimizer.param_groups[0]['lr']
            
            # Update history
            self.history['train_loss'].append(train_metrics['loss'])
            self.history['val_loss'].append(val_metrics['loss'])
            self.history['train_acc'].append(train_metrics['accuracy'])
            self.history['val_acc'].append(val_metrics['accuracy'])
            self.history['train_f1'].append(train_metrics['f1'])
            self.history['val_f1'].append(val_metrics['f1'])
            self.history['lr'].append(current_lr)
            
            # Print metrics
            print(f"\n📊 TRAIN - Loss: {train_metrics['loss']:.4f} | "
                  f"Acc: {train_metrics['accuracy']:.4f} | F1: {train_metrics['f1']:.4f}")
            print(f"📊 VAL   - Loss: {val_metrics['loss']:.4f} | "
                  f"Acc: {val_metrics['accuracy']:.4f} | F1: {val_metrics['f1']:.4f}")
            print(f"📈 LR: {current_lr:.2e}")
            
            # Save best model
            is_best = val_metrics['accuracy'] > self.best_val_acc
            
            if is_best:
                self.best_val_acc = val_metrics['accuracy']
                self.best_val_f1 = val_metrics['f1']
                self.best_epoch = epoch
                self.patience_counter = 0
                
                # Save checkpoint
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    'best_val_acc': self.best_val_acc,
                    'best_val_f1': self.best_val_f1,
                    'history': self.history,
                }
                
                torch.save(checkpoint, self.config.MODELS_DIR / 'best_model.pth')
                print(f"💾 Mejor modelo guardado! Val Acc: {self.best_val_acc:.4f}")
            else:
                self.patience_counter += 1
                print(f"⏳ Patience: {self.patience_counter}/{self.config.PATIENCE}")
            
            # Early stopping
            if self.patience_counter >= self.config.PATIENCE:
                print(f"\n⚠️ Early stopping activado en epoch {epoch+1}")
                print(f"   Mejor val acc: {self.best_val_acc:.4f} (epoch {self.best_epoch+1})")
                break
            
            # Save checkpoint every 10 epochs
            if (epoch + 1) % 10 == 0:
                checkpoint_path = self.config.MODELS_DIR / f'checkpoint_epoch_{epoch+1}.pth'
                torch.save(checkpoint, checkpoint_path)
                print(f"💾 Checkpoint guardado: {checkpoint_path}")
        
        print("\n" + "="*80)
        print("✅ ENTRENAMIENTO COMPLETADO")
        print("="*80)
        print(f"Mejor modelo: Epoch {self.best_epoch+1}")
        print(f"  Val Accuracy: {self.best_val_acc:.4f}")
        print(f"  Val F1-Score: {self.best_val_f1:.4f}")
        
        return self.history


def create_optimizer_and_scheduler(model, train_loader, config):
    """Crea optimizer y scheduler optimizados"""
    
    # Separar parámetros para diferentes learning rates
    # Backbone: LR más bajo
    # Heads: LR más alto
    
    backbone_params = []
    head_params = []
    
    for name, param in model.named_parameters():
        if 'image_model' in name:
            backbone_params.append(param)
        else:
            head_params.append(param)
    
    # Optimizer: AdamW con weight decay
    optimizer = torch.optim.AdamW([
        {'params': backbone_params, 'lr': config.LEARNING_RATE * 0.1},  # 10x menor para backbone
        {'params': head_params, 'lr': config.LEARNING_RATE}
    ], weight_decay=config.WEIGHT_DECAY)
    
    # Scheduler: Cosine Annealing con Warmup
    num_training_steps = len(train_loader) * config.EPOCHS
    num_warmup_steps = len(train_loader) * config.WARMUP_EPOCHS
    
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            # Linear warmup
            return float(current_step) / float(max(1, num_warmup_steps))
        # Cosine annealing
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + np.cos(np.pi * progress)))
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    # Alternativa: ReduceLROnPlateau
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #     optimizer, mode='min', factor=0.5, patience=5, verbose=True, min_lr=config.MIN_LR
    # )
    
    return optimizer, scheduler


print("✅ Training loop y optimización configurados")
print("   - AdamW optimizer con differential learning rates")
print("   - Cosine annealing con warmup")
print("   - Gradient clipping")
print("   - Automatic Mixed Precision (AMP)")
print("   - Early stopping")

In [None]:
# ============================================================================
# DIAGNÓSTICO COMPLETO DE DATOS
# ============================================================================

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # Para ver errores exactos

print("🔍 DIAGNÓSTICO COMPLETO DE DATOS")
print("="*80)

# 1. Verificar estructura del dataset
print("\n📦 Verificando primer batch del train_loader...")
try:
    test_batch = next(iter(train_loader))
    print(f"✅ Batch cargado exitosamente")
    print(f"\nShapes:")
    print(f"   Images: {test_batch['image'].shape}")
    print(f"   Metadata: {test_batch['metadata'].shape}")
    print(f"   Labels: {test_batch['label'].shape}")
    
    # Verificar rangos
    print(f"\nRangos de valores:")
    print(f"   Images: [{test_batch['image'].min():.3f}, {test_batch['image'].max():.3f}]")
    print(f"   Metadata: [{test_batch['metadata'].min():.3f}, {test_batch['metadata'].max():.3f}]")
    print(f"   Labels: [{test_batch['label'].min()}, {test_batch['label'].max()}]")
    
    # CRÍTICO: Verificar que labels están en rango correcto
    max_label = test_batch['label'].max().item()
    print(f"\n🎯 Verificación de labels:")
    print(f"   Max label en batch: {max_label}")
    print(f"   Num classes esperado: {num_classes}")
    
    if max_label >= num_classes:
        print(f"   ❌ ERROR CRÍTICO: Label {max_label} >= num_classes {num_classes}")
        print(f"   Las labels deben estar en rango [0, {num_classes-1}]")
    else:
        print(f"   ✅ Labels en rango correcto [0, {num_classes-1}]")
    
    # 2. Verificar dimensiones categóricas
    print(f"\n📊 Verificando categorical dimensions:")
    print(f"   Total categorical features: {len(preprocessor.categorical_features)}")
    print(f"   Categorical dims: {preprocessor.categorical_dims}")
    
    # 3. Verificar que cat_feature_indices coincide con categorical_dims
    print(f"\n🔑 Verificando cat_feature_indices:")
    print(f"   cat_feature_indices keys: {list(cat_feature_indices.keys())}")
    print(f"   categorical_dims keys: {list(preprocessor.categorical_dims.keys())}")
    
    if set(cat_feature_indices.keys()) != set(preprocessor.categorical_dims.keys()):
        print(f"   ❌ ERROR: Las keys no coinciden!")
    else:
        print(f"   ✅ Keys coinciden correctamente")
    
    # 4. Verificar valores categóricos en el batch
    metadata_batch = test_batch['metadata']
    print(f"\n📊 Verificando valores categóricos en metadata:")
    
    cat_start_idx = 0
    for feat_name in preprocessor.categorical_features:
        if feat_name in preprocessor.categorical_dims:
            max_dim = preprocessor.categorical_dims[feat_name]
            col_values = metadata_batch[:, cat_start_idx]
            min_val = col_values.min().item()
            max_val = col_values.max().item()
            
            print(f"   {feat_name}: range=[{min_val:.0f}, {max_val:.0f}], dim={max_dim}")
            
            if max_val >= max_dim:
                print(f"      ❌ ERROR: Valor {max_val:.0f} >= dim {max_dim}")
            
            cat_start_idx += 1
    
    # 5. PRUEBA CRÍTICA: Intentar forward pass
    print(f"\n🧪 PRUEBA DE FORWARD PASS:")
    try:
        images = test_batch['image'].to(device)
        metadata = test_batch['metadata'].to(device)
        labels = test_batch['label'].to(device)
        
        print(f"   Datos movidos a GPU...")
        
        with torch.no_grad():
            print(f"   Intentando forward pass...")
            logits = model(images, metadata, cat_feature_indices)
            print(f"   ✅ Forward pass exitoso!")
            print(f"   Output shape: {logits.shape}")
            
            # Verificar que la salida tiene el número correcto de clases
            if logits.shape[1] != num_classes:
                print(f"   ❌ ERROR: Output tiene {logits.shape[1]} clases, esperado {num_classes}")
            else:
                print(f"   ✅ Output correcto: {num_classes} clases")
        
    except Exception as e:
        print(f"   ❌ ERROR en forward pass:")
        print(f"   {str(e)}")
        import traceback
        traceback.print_exc()
    
    del test_batch
    
except Exception as e:
    print(f"❌ Error cargando batch:")
    print(f"   {str(e)}")
    import traceback
    traceback.print_exc()

print("\n" + "="*80)
print("✅ Diagnóstico completado")
print("="*80)

In [None]:
# ============================================================================
# CELDA 9: EJECUTAR ENTRENAMIENTO
# ============================================================================

# Crear modelo
print("🏗️ Creando modelo...")
model = create_model(
    num_classes=num_classes,
    categorical_dims=preprocessor.categorical_dims,
    num_numerical_features=len(preprocessor.numerical_features)
)

model = model.to(device)

# Contar parámetros
trainable_params, total_params = count_parameters(model)
print(f"\n📊 Parámetros del modelo:")
print(f"   Total: {total_params:,}")
print(f"   Entrenables: {trainable_params:,}")
print(f"   No entrenables: {total_params - trainable_params:,}")

# Crear loss function
criterion = CombinedLoss(
    class_weights=class_weights,
    focal_gamma=2.0,
    label_smoothing=0.1
)

print(f"\n🎯 Loss function: Combined (Focal + Label Smoothing)")

# Crear optimizer y scheduler
optimizer, scheduler = create_optimizer_and_scheduler(model, train_loader, Config)

print(f"\n⚙️ Optimizer: AdamW")
print(f"   Learning rate (backbone): {Config.LEARNING_RATE * 0.1:.2e}")
print(f"   Learning rate (heads): {Config.LEARNING_RATE:.2e}")
print(f"   Weight decay: {Config.WEIGHT_DECAY:.2e}")
print(f"\n📅 Scheduler: Cosine Annealing con Warmup")
print(f"   Warmup epochs: {Config.WARMUP_EPOCHS}")
print(f"   Total epochs: {Config.EPOCHS}")

# Crear trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    cat_indices=cat_feature_indices,
    device=device,
    config=Config
)

# ENTRENAR
history = trainer.train()

# Guardar history
with open(Config.OUTPUT_DIR / 'training_history.json', 'w') as f:
    # Convertir history a formato serializable
    history_serializable = {
        k: [float(v) if isinstance(v, (np.floating, float)) else v for v in vals]
        for k, vals in history.items()
    }
    json.dump(history_serializable, f, indent=2)

print("\n💾 Training history guardado")

# Visualizar training history
plot_training_history(history, save_path=Config.OUTPUT_DIR / 'training_curves.png')

print("\n" + "="*80)
print("✅ ENTRENAMIENTO FINALIZADO EXITOSAMENTE")
print("="*80)

In [None]:
# ============================================================================
# CELDA 10: EVALUACIÓN CON TEST-TIME AUGMENTATION (TTA)
# ============================================================================

def evaluate_with_tta(model, dataloader, cat_indices, device, tta_transforms=5):
    """Evalúa el modelo con TTA"""
    model.eval()
    
    all_preds = []
    all_probs = []
    all_targets = []
    
    tta_augmenter = TTATransform(img_size=Config.IMG_SIZE)
    
    print(f"🔍 Evaluando con TTA ({tta_transforms} augmentations)...")
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='TTA Evaluation'):
            images_orig = batch['image']
            metadata = batch['metadata'].to(device)
            labels = batch['label']
            
            batch_size = images_orig.size(0)
            tta_probs = []
            
            # Para cada imagen en el batch
            for i in range(batch_size):
                img_path = dataloader.dataset.image_paths[i]
                img_pil = Image.open(img_path).convert('RGB')
                
                # Aplicar TTA transforms
                tta_images = tta_augmenter(img_pil)
                
                # Promediar predicciones
                img_probs = []
                for tta_img in tta_images:
                    tta_img_batch = tta_img.unsqueeze(0).to(device)
                    meta_batch = metadata[i:i+1]
                    
                    with torch.cuda.amp.autocast(enabled=Config.AMP_ENABLED):
                        logits = model(tta_img_batch, meta_batch, cat_indices)
                        probs = F.softmax(logits, dim=1)
                    
                    img_probs.append(probs.cpu())
                
                # Promediar probabilidades de todas las augmentations
                avg_probs = torch.stack(img_probs).mean(dim=0)
                tta_probs.append(avg_probs)
            
            # Stack batch probabilities
            batch_probs = torch.cat(tta_probs, dim=0)
            batch_preds = torch.argmax(batch_probs, dim=1)
            
            all_probs.append(batch_probs)
            all_preds.append(batch_preds)
            all_targets.append(labels)
    
    # Concatenate all batches
    all_probs = torch.cat(all_probs, dim=0).numpy()
    all_preds = torch.cat(all_preds, dim=0).numpy()
    all_targets = torch.cat(all_targets, dim=0).numpy()
    
    return all_preds, all_probs, all_targets


def evaluate_standard(model, dataloader, cat_indices, device):
    """Evaluación estándar sin TTA"""
    model.eval()
    
    all_preds = []
    all_probs = []
    all_targets = []
    
    print(f"🔍 Evaluando (estándar)...")
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Standard Evaluation'):
            images = batch['image'].to(device)
            metadata = batch['metadata'].to(device)
            labels = batch['label']
            
            with torch.cuda.amp.autocast(enabled=Config.AMP_ENABLED):
                logits = model(images, metadata, cat_indices)
                probs = F.softmax(logits, dim=1)
            
            preds = torch.argmax(probs, dim=1)
            
            all_probs.append(probs.cpu())
            all_preds.append(preds.cpu())
            all_targets.append(labels)
    
    all_probs = torch.cat(all_probs, dim=0).numpy()
    all_preds = torch.cat(all_preds, dim=0).numpy()
    all_targets = torch.cat(all_targets, dim=0).numpy()
    
    return all_preds, all_probs, all_targets


# Cargar mejor modelo
print("📥 Cargando mejor modelo...")
checkpoint = torch.load(Config.MODELS_DIR / 'best_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"✅ Modelo cargado (Epoch {checkpoint['epoch']+1})")

# ===== EVALUACIÓN EN VALIDATION SET =====
print("\n" + "="*80)
print("📊 EVALUACIÓN EN VALIDATION SET")
print("="*80)

# Sin TTA
val_preds, val_probs, val_targets = evaluate_standard(model, val_loader, cat_feature_indices, device)

val_acc = accuracy_score(val_targets, val_preds)
val_f1_weighted = f1_score(val_targets, val_preds, average='weighted')
val_f1_macro = f1_score(val_targets, val_preds, average='macro')

print(f"\n📈 Resultados (Sin TTA):")
print(f"   Accuracy: {val_acc:.4f}")
print(f"   F1-Score (weighted): {val_f1_weighted:.4f}")
print(f"   F1-Score (macro): {val_f1_macro:.4f}")

# Con TTA
val_preds_tta, val_probs_tta, val_targets_tta = evaluate_with_tta(
    model, val_loader, cat_feature_indices, device, tta_transforms=Config.TTA_TRANSFORMS
)

val_acc_tta = accuracy_score(val_targets_tta, val_preds_tta)
val_f1_weighted_tta = f1_score(val_targets_tta, val_preds_tta, average='weighted')
val_f1_macro_tta = f1_score(val_targets_tta, val_preds_tta, average='macro')

print(f"\n📈 Resultados (Con TTA):")
print(f"   Accuracy: {val_acc_tta:.4f} ({'+' if val_acc_tta > val_acc else ''}{val_acc_tta - val_acc:+.4f})")
print(f"   F1-Score (weighted): {val_f1_weighted_tta:.4f} ({'+' if val_f1_weighted_tta > val_f1_weighted else ''}{val_f1_weighted_tta - val_f1_weighted:+.4f})")
print(f"   F1-Score (macro): {val_f1_macro_tta:.4f} ({'+' if val_f1_macro_tta > val_f1_macro else ''}{val_f1_macro_tta - val_f1_macro:+.4f})")

# Matriz de confusión
cm_val = confusion_matrix(val_targets_tta, val_preds_tta)
plot_confusion_matrix(cm_val, class_names, save_path=Config.OUTPUT_DIR / 'confusion_matrix_val.png')

# ===== EVALUACIÓN EN TEST SET =====
print("\n" + "="*80)
print("📊 EVALUACIÓN FINAL EN TEST SET")
print("="*80)

# Sin TTA
test_preds, test_probs, test_targets = evaluate_standard(model, test_loader, cat_feature_indices, device)

test_acc = accuracy_score(test_targets, test_preds)
test_f1_weighted = f1_score(test_targets, test_preds, average='weighted')
test_f1_macro = f1_score(test_targets, test_preds, average='macro')

print(f"\n📈 Resultados (Sin TTA):")
print(f"   Accuracy: {test_acc:.4f}")
print(f"   F1-Score (weighted): {test_f1_weighted:.4f}")
print(f"   F1-Score (macro): {test_f1_macro:.4f}")

# Con TTA
test_preds_tta, test_probs_tta, test_targets_tta = evaluate_with_tta(
    model, test_loader, cat_feature_indices, device, tta_transforms=Config.TTA_TRANSFORMS
)

test_acc_tta = accuracy_score(test_targets_tta, test_preds_tta)
test_f1_weighted_tta = f1_score(test_targets_tta, test_preds_tta, average='weighted')
test_f1_macro_tta = f1_score(test_targets_tta, test_preds_tta, average='macro')
test_precision_tta = precision_score(test_targets_tta, test_preds_tta, average='weighted')
test_recall_tta = recall_score(test_targets_tta, test_preds_tta, average='weighted')

print(f"\n📈 Resultados (Con TTA):")
print(f"   Accuracy: {test_acc_tta:.4f} ({'+' if test_acc_tta > test_acc else ''}{test_acc_tta - test_acc:+.4f})")
print(f"   Precision (weighted): {test_precision_tta:.4f}")
print(f"   Recall (weighted): {test_recall_tta:.4f}")
print(f"   F1-Score (weighted): {test_f1_weighted_tta:.4f} ({'+' if test_f1_weighted_tta > test_f1_weighted else ''}{test_f1_weighted_tta - test_f1_weighted:+.4f})")
print(f"   F1-Score (macro): {test_f1_macro_tta:.4f} ({'+' if test_f1_macro_tta > test_f1_macro else ''}{test_f1_macro_tta - test_f1_macro:+.4f})")

# Classification report detallado
print(f"\n📋 Classification Report:")
print(classification_report(test_targets_tta, test_preds_tta, target_names=class_names))

# Matriz de confusión
cm_test = confusion_matrix(test_targets_tta, test_preds_tta)
plot_confusion_matrix(cm_test, class_names, save_path=Config.OUTPUT_DIR / 'confusion_matrix_test.png')

# Guardar resultados
test_results = {
    'test_accuracy': float(test_acc_tta),
    'test_accuracy_no_tta': float(test_acc),
    'test_precision': float(test_precision_tta),
    'test_recall': float(test_recall_tta),
    'test_f1_weighted': float(test_f1_weighted_tta),
    'test_f1_macro': float(test_f1_macro_tta),
    'per_class_metrics': {},
    'confusion_matrix': cm_test.tolist()
}

# Per-class metrics
for i, class_name in enumerate(class_names):
    class_mask = test_targets_tta == i
    class_preds = test_preds_tta[class_mask]
    class_targets = test_targets_tta[class_mask]
    
    if len(class_targets) > 0:
        class_acc = accuracy_score(class_targets, class_preds)
        test_results['per_class_metrics'][class_name] = {
            'accuracy': float(class_acc),
            'samples': int(len(class_targets))
        }

with open(Config.OUTPUT_DIR / 'test_results.json', 'w') as f:
    json.dump(test_results, f, indent=2)

print("\n💾 Resultados de test guardados")

print("\n" + "="*80)
print("✅ EVALUACIÓN COMPLETADA")
print("="*80)
print(f"\n🎯 RESULTADO FINAL:")
print(f"   Test Accuracy (con TTA): {test_acc_tta:.4f} ({test_acc_tta*100:.2f}%)")
print(f"   Test F1-Score (weighted): {test_f1_weighted_tta:.4f}")
print(f"   Mejora con TTA: +{(test_acc_tta - test_acc)*100:.2f}%")

In [None]:
# ============================================================================
# CELDA 11: INFERENCIA Y PREDICCIÓN EN NUEVAS IMÁGENES
# ============================================================================

class SkinCancerPredictor:
    """Predictor para nuevas imágenes"""
    
    def __init__(self, model, preprocessor, cat_indices, class_names, device, use_tta=True):
        self.model = model
        self.preprocessor = preprocessor
        self.cat_indices = cat_indices
        self.class_names = class_names
        self.device = device
        self.use_tta = use_tta
        
        self.transform = AdvancedAugmentation(img_size=Config.IMG_SIZE, mode='val')
        self.tta_transform = TTATransform(img_size=Config.IMG_SIZE) if use_tta else None
        
        self.model.eval()
    
    def predict_single(self, image_path, metadata_dict):
        """
        Predice una sola imagen
        
        Args:
            image_path: ruta a la imagen
            metadata_dict: diccionario con metadatos (mismo formato que CSV)
        
        Returns:
            dict con predicción, probabilidades y confianza
        """
        # Preprocesar metadata
        metadata_df = pd.DataFrame([metadata_dict])
        metadata_processed, _ = self.preprocessor.transform(metadata_df)
        metadata_tensor = torch.FloatTensor(metadata_processed).to(self.device)
        
        # Cargar imagen
        image = Image.open(image_path).convert('RGB')
        
        with torch.no_grad():
            if self.use_tta and self.tta_transform is not None:
                # TTA
                tta_images = self.tta_transform(image)
                tta_probs = []
                
                for tta_img in tta_images:
                    tta_img_batch = tta_img.unsqueeze(0).to(self.device)
                    
                    with torch.cuda.amp.autocast(enabled=Config.AMP_ENABLED):
                        logits = self.model(tta_img_batch, metadata_tensor, self.cat_indices)
                        probs = F.softmax(logits, dim=1)
                    
                    tta_probs.append(probs.cpu())
                
                # Promediar probabilidades
                avg_probs = torch.stack(tta_probs).mean(dim=0)
            else:
                # Predicción estándar
                img_tensor = self.transform(image).unsqueeze(0).to(self.device)
                
                with torch.cuda.amp.autocast(enabled=Config.AMP_ENABLED):
                    logits = self.model(img_tensor, metadata_tensor, self.cat_indices)
                    avg_probs = F.softmax(logits, dim=1).cpu()
        
        # Obtener predicción
        pred_class_idx = torch.argmax(avg_probs, dim=1).item()
        pred_class_name = self.class_names[pred_class_idx]
        confidence = avg_probs[0, pred_class_idx].item()
        
        # Todas las probabilidades
        class_probs = {
            name: float(avg_probs[0, i].item())
            for i, name in enumerate(self.class_names)
        }
        
        result = {
            'predicted_class': pred_class_name,
            'predicted_class_idx': pred_class_idx,
            'confidence': confidence,
            'class_probabilities': class_probs,
            'top_3_predictions': sorted(
                class_probs.items(),
                key=lambda x: x[1],
                reverse=True
            )[:3]
        }
        
        return result
    
    def predict_batch(self, image_paths, metadata_dicts):
        """Predice múltiples imágenes"""
        results = []
        
        for img_path, meta_dict in tqdm(zip(image_paths, metadata_dicts), 
                                        total=len(image_paths),
                                        desc='Prediciendo'):
            result = self.predict_single(img_path, meta_dict)
            results.append(result)
        
        return results
    
    def visualize_prediction(self, image_path, metadata_dict, save_path=None):
        """Visualiza predicción con probabilidades"""
        result = self.predict_single(image_path, metadata_dict)
        
        # Crear figura
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))
        
        # Mostrar imagen
        img = Image.open(image_path)
        axes[0].imshow(img)
        axes[0].axis('off')
        axes[0].set_title(f'Predicción: {result["predicted_class"]}\n'
                         f'Confianza: {result["confidence"]:.2%}',
                         fontsize=14, fontweight='bold')
        
        # Mostrar probabilidades
        probs = result['class_probabilities']
        classes = list(probs.keys())
        values = list(probs.values())
        
        colors = ['green' if c == result['predicted_class'] else 'steelblue' 
                 for c in classes]
        
        axes[1].barh(classes, values, color=colors)
        axes[1].set_xlabel('Probabilidad', fontsize=12)
        axes[1].set_title('Probabilidades por Clase', fontsize=14, fontweight='bold')
        axes[1].set_xlim(0, 1)
        
        for i, (c, v) in enumerate(zip(classes, values)):
            axes[1].text(v + 0.01, i, f'{v:.3f}', va='center', fontsize=10)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
        
        plt.show()
        
        # Imprimir resultado detallado
        print(f"\n{'='*60}")
        print(f"PREDICCIÓN DETALLADA")
        print(f"{'='*60}")
        print(f"Clase predicha: {result['predicted_class']}")
        print(f"Confianza: {result['confidence']:.2%}")
        print(f"\nTop 3 predicciones:")
        for i, (cls, prob) in enumerate(result['top_3_predictions'], 1):
            print(f"  {i}. {cls}: {prob:.2%}")
        print(f"{'='*60}\n")


# Crear predictor
print("🔮 Creando predictor...")
predictor = SkinCancerPredictor(
    model=model,
    preprocessor=preprocessor,
    cat_indices=cat_feature_indices,
    class_names=class_names,
    device=device,
    use_tta=True
)

print("✅ Predictor listo")

# ===== EJEMPLO DE USO =====
print("\n" + "="*80)
print("EJEMPLO DE PREDICCIÓN")
print("="*80)

# Tomar una muestra del test set
sample_idx = 0
sample_img_path = test_dataset.image_paths[sample_idx]
sample_metadata = df_processed.iloc[X_test[sample_idx]].to_dict()
sample_true_label = class_names[test_dataset.labels[sample_idx].item()]

print(f"\nImagen de ejemplo: {sample_img_path}")
print(f"Diagnóstico real: {sample_true_label}")

# Predecir
predictor.visualize_prediction(
    sample_img_path,
    sample_metadata,
    save_path=Config.OUTPUT_DIR / 'example_prediction.png'
)

# ===== FUNCIÓN PARA PREDICCIÓN DE NUEVAS IMÁGENES =====
def predict_new_image(image_path, metadata_dict=None):
    """
    Función helper para predecir nuevas imágenes
    
    Args:
        image_path: ruta a la imagen
        metadata_dict: diccionario con metadatos (opcional, se usarán valores por defecto)
    
    Ejemplo:
        metadata = {
            'age': 45,
            'gender': 'FEMALE',
            'region': 'ARM',
            'diameter_1': 5.0,
            'diameter_2': 4.0,
            'smoke': False,
            'drink': False,
            # ... más campos
        }
        result = predict_new_image('path/to/image.png', metadata)
    """
    if metadata_dict is None:
        # Valores por defecto
        metadata_dict = {
            'age': 50,
            'gender': 'UNKNOWN',
            'region': 'UNKNOWN',
            'smoke': 'UNKNOWN',
            'drink': 'UNKNOWN',
            'background_father': 'UNKNOWN',
            'background_mother': 'UNKNOWN',
            'pesticide': 'UNKNOWN',
            'skin_cancer_history': 'UNKNOWN',
            'cancer_history': 'UNKNOWN',
            'has_piped_water': 'UNKNOWN',
            'has_sewage_system': 'UNKNOWN',
            'fitspatrick': 3.0,
            'diameter_1': 5.0,
            'diameter_2': 5.0,
            'itch': False,
            'grew': False,
            'hurt': False,
            'changed': False,
            'bleed': False,
            'elevation': False,
            'biopsed': False
        }
    
    result = predictor.predict_single(image_path, metadata_dict)
    
    # Visualizar
    predictor.visualize_prediction(
        image_path,
        metadata_dict,
        save_path=Config.OUTPUT_DIR / f'prediction_{Path(image_path).stem}.png'
    )
    
    return result


print("\n" + "="*80)
print("✅ SISTEMA DE PREDICCIÓN LISTO")
print("="*80)
print("\nPara predecir nuevas imágenes, usa:")
print("  result = predict_new_image('path/to/image.png', metadata_dict)")
print("\nDonde metadata_dict es un diccionario con los campos del CSV")

In [None]:
# ============================================================================
# CELDA 12: RESUMEN Y ANÁLISIS FINAL
# ============================================================================

print("="*80)
print(" "*20 + "🏆 RESUMEN DEL MODELO 🏆")
print("="*80)

print(f"""
📊 ARQUITECTURA:
   • Modelo de imagen: {Config.EFFICIENTNET_MODEL}
   • Tamaño de entrada: {Config.IMG_SIZE}x{Config.IMG_SIZE}
   • TabTransformer: {Config.TAB_NUM_LAYERS} layers, {Config.TAB_NUM_HEADS} heads
   • Fusión: Late fusion con MLP ({' -> '.join(map(str, Config.FUSION_HIDDEN_DIMS))})
   • Parámetros totales: {total_params:,}
   • Parámetros entrenables: {trainable_params:,}

🎯 TÉCNICAS APLICADAS:
   ✓ EfficientNetV2-M pre-entrenado
   ✓ TabTransformer para metadatos tabulares
   ✓ Late Fusion multimodal
   ✓ Augmentations avanzadas (geométricas, colorimétricas, ópticas)
   ✓ Mixup & CutMix durante entrenamiento
   ✓ Combined Loss (Focal + Label Smoothing)
   ✓ Class weighting para desbalanceo
   ✓ Differential learning rates (backbone vs heads)
   ✓ Cosine Annealing con Warmup
   ✓ Gradient clipping
   ✓ Automatic Mixed Precision (AMP)
   ✓ Early stopping
   ✓ Test-Time Augmentation (TTA)
   
📈 RESULTADOS:
   • Val Accuracy: {checkpoint['best_val_acc']:.4f} ({checkpoint['best_val_acc']*100:.2f}%)
   • Test Accuracy (sin TTA): {test_acc:.4f} ({test_acc*100:.2f}%)
   • Test Accuracy (con TTA): {test_acc_tta:.4f} ({test_acc_tta*100:.2f}%)
   • Test F1-Score: {test_f1_weighted_tta:.4f}
   • Mejora con TTA: +{(test_acc_tta - test_acc)*100:.2f}%
   
💾 ARCHIVOS GENERADOS:
   • {Config.MODELS_DIR / 'best_model.pth'}
   • {Config.OUTPUT_DIR / 'metadata_preprocessor.pkl'}
   • {Config.OUTPUT_DIR / 'training_history.json'}
   • {Config.OUTPUT_DIR / 'test_results.json'}
   • {Config.OUTPUT_DIR / 'training_curves.png'}
   • {Config.OUTPUT_DIR / 'confusion_matrix_test.png'}
""")

print("="*80)
print(" "*15 + "🚀 TÉCNICAS PARA MEJORAR RENDIMIENTO 🚀")
print("="*80)

print("""
Si quieres mejorar aún más el modelo (objetivo: 98%+), considera:

1. 📸 DATOS Y AUGMENTATION:
   • Aumentar resolución a 640px (requiere más GPU memory)
   • Agregar más augmentations específicas para dermatología:
     - Simulación de diferentes iluminaciones (dermoscopia)
     - Simulación de vello (hair occlusion)
     - Ajustes de contraste específicos para lesiones
   • Recolectar más datos (especialmente de clases minoritarias)
   • Usar pseudo-labeling con datos no etiquetados

2. 🏗️ ARQUITECTURA:
   • Ensembling de múltiples modelos:
     - EfficientNetV2-L + ConvNeXt + Swin Transformer
   • Attention mechanisms más sofisticados:
     - CBAM (Convolutional Block Attention Module)
     - Squeeze-and-Excitation
   • Cross-attention entre imagen y metadatos

3. 🎓 ENTRENAMIENTO:
   • Progressive resizing (empezar con 224px, luego 384px, finalmente 512px)
   • Two-stage training:
     - Stage 1: Solo clasificación básica
     - Stage 2: Fine-tuning con metadatos
   • Knowledge distillation de modelos grandes
   • Self-supervised pre-training en datos de dermatología

4. 🔧 OPTIMIZACIÓN:
   • Usar SAM optimizer (Sharpness-Aware Minimization)
   • Stochastic Weight Averaging (SWA)
   • Gradient accumulation para batches más grandes
   • FP16/BF16 training completo

5. 📊 POST-PROCESSING:
   • Calibración de probabilidades (Temperature Scaling, Platt Scaling)
   • Threshold optimization por clase
   • Ensembling con diferentes seeds
   • TTA más agresivo (más transformaciones)

6. 🔍 ANÁLISIS DE ERRORES:
   • Identificar clases confusas y aplicar técnicas específicas
   • Análisis de muestras mal clasificadas
   • Feature importance analysis
   • GradCAM para interpretabilidad

7. 🎯 TÉCNICAS ESPECÍFICAS PARA DESBALANCEO:
   • Focal Loss con diferentes gammas por clase
   • Class-balanced Loss
   • Oversampling más agresivo con ADASYN
   • Crear muestras sintéticas con GANs

8. 💡 METADATA ENGINEERING:
   • Extraer features automáticas de las imágenes:
     - Descriptores de color (histogramas HSV)
     - Textura (Haralick features, LBP)
     - Forma (momentos de Hu, contornos)
   • Combinar con metadatos clínicos

""")

print("="*80)
print(" "*25 + "📝 CÓDIGO DE EJEMPLO")
print("="*80)

print("""
# EJEMPLO 1: Entrenar con resolución más alta
Config.IMG_SIZE = 640
Config.BATCH_SIZE = 8  # Reducir batch size

# EJEMPLO 2: Ensembling de modelos
models = [
    create_model(..., img_model_name='tf_efficientnetv2_l'),
    create_model(..., img_model_name='convnext_base'),
    create_model(..., img_model_name='swin_base_patch4_window7_224')
]

def ensemble_predict(models, image, metadata):
    probs = []
    for model in models:
        logits = model(image, metadata, cat_indices)
        prob = F.softmax(logits, dim=1)
        probs.append(prob)
    return torch.stack(probs).mean(dim=0)

# EJEMPLO 3: Progressive resizing
# Stage 1: 224px, 30 epochs
# Stage 2: 384px, 20 epochs  
# Stage 3: 512px, 20 epochs

# EJEMPLO 4: Pseudo-labeling
# 1. Entrenar modelo inicial
# 2. Predecir en datos no etiquetados con alta confianza
# 3. Agregar predicciones confiables al training set
# 4. Re-entrenar

""")

print("="*80)
print(" "*20 + "✅ SISTEMA COMPLETO Y LISTO")
print("="*80)

# Guardar resumen final
summary = {
    'model_config': {
        'image_model': Config.EFFICIENTNET_MODEL,
        'img_size': Config.IMG_SIZE,
        'batch_size': Config.BATCH_SIZE,
        'epochs_trained': checkpoint['epoch'] + 1,
    },
    'results': {
        'best_val_acc': float(checkpoint['best_val_acc']),
        'best_val_f1': float(checkpoint['best_val_f1']),
        'test_acc_no_tta': float(test_acc),
        'test_acc_tta': float(test_acc_tta),
        'test_f1_weighted': float(test_f1_weighted_tta),
        'test_f1_macro': float(test_f1_macro_tta),
        'tta_improvement': float((test_acc_tta - test_acc) * 100)
    },
    'techniques': [
        'EfficientNetV2-M pre-trained',
        'TabTransformer for metadata',
        'Late Fusion',
        'Advanced Augmentations',
        'Mixup & CutMix',
        'Combined Loss (Focal + Label Smoothing)',
        'Class Weighting',
        'Differential Learning Rates',
        'Cosine Annealing with Warmup',
        'Gradient Clipping',
        'Automatic Mixed Precision',
        'Early Stopping',
        'Test-Time Augmentation'
    ]
}

with open(Config.OUTPUT_DIR / 'final_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print(f"\n💾 Resumen final guardado en: {Config.OUTPUT_DIR / 'final_summary.json'}")
print(f"\n🎉 ¡Entrenamiento completado con éxito!")
print(f"   Accuracy final: {test_acc_tta*100:.2f}%")