In [1]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import timm
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, precision_score, recall_score, balanced_accuracy_score, confusion_matrix
import time
import numpy as np
from collections import Counter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Paths
DATA_PATH = "data/"
ORIGINAL_IMAGE_PATH = os.path.join(DATA_PATH, "images")
SYNTHETIC_IMAGE_PATH = os.path.join(DATA_PATH, "synthetic_images")
TRAIN_CSV_PATH = os.path.join(DATA_PATH, "train_split.csv")
VAL_CSV_PATH = os.path.join(DATA_PATH, "validation_split.csv")
MODEL_SAVE_PATH = "models"
BEST_MODEL_SAVE_PATH = os.path.join(MODEL_SAVE_PATH, "edgenext_best_model.pth") # Path para guardar mejor modelo

In [3]:
print(ORIGINAL_IMAGE_PATH, SYNTHETIC_IMAGE_PATH, TRAIN_CSV_PATH, VAL_CSV_PATH)

data/images data/synthetic_images data/train_split.csv data/validation_split.csv


In [4]:
# Propiedades data
IMAGE_ID_COL = "isic_id"
TARGET_COL = "malignant"
IMAGE_EXTENSION = ".jpg"

In [5]:
# Parametros modelo
MODEL_NAME = 'edgenext_base.in21k_ft_in1k'
NUM_CLASSES = 2

In [6]:
# Hiperparametros
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
NUM_EPOCHS = 1
WEIGHT_DECAY = 0.01  

In [7]:
# Configurar GPU
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

In [8]:
# Hilos Dataloader
NUM_WORKERS = 0

## 2. Definición de transformaciones de imagenes

In [9]:
# Obtener data de entrenamiento modelo preentrenado HuggingFace
model_cfg = timm.get_pretrained_cfg(MODEL_NAME)
IMG_SIZE = model_cfg.input_size[1]
NORM_MEAN = model_cfg.mean
NORM_STD = model_cfg.std

In [10]:
# Transformaciones para aumentar el conjunto de entrenamiento
train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
    transforms.ToTensor(),
    transforms.Normalize(mean=NORM_MEAN, std=NORM_STD),
])

val_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=NORM_MEAN, std=NORM_STD),
])

## 3. Crear Datasets y Dataloader

In [11]:
# Clase dataloader custom
class ISICDataset(Dataset):
    def __init__(self, csv_path, original_image_dir, image_id_col, target_col,
                 transforms=None, mode='train',
                 path_to_synthetic_images_to_use=None,
                 synthetic_positive_label=1, # Label para imagenes sinteticas
                 image_extension='.jpg'):
        
        self.mode = mode
        self.original_image_dir = original_image_dir
        self.image_id_col = image_id_col
        self.target_col = target_col
        self.transforms = transforms
        self.path_to_synthetic_images_to_use = path_to_synthetic_images_to_use
        self.synthetic_positive_label = synthetic_positive_label
        self.image_extension = image_extension

        self.samples = []
        self.label_counts = Counter()
    
        # Leer csv con datos originales y cargar imagenes originales
        self.original_df = pd.read_csv(csv_path)
        for idx, row in self.original_df.iterrows():
            image_id = row[self.image_id_col]
            image_path = os.path.join(self.original_image_dir, str(image_id) + self.image_extension)
            label = int(row[self.target_col])
            self.samples.append({'path': image_path, 'label': label, 'source': 'original'})
            self.label_counts[label] += 1 # Para determianr si dataset esta desbalanceado


        # En modo entrenamiento cargar imagenes sinteticas
        if self.mode == 'train' and self.path_to_synthetic_images_to_use:
            synthetic_added_count = 0
            for img_filename in os.listdir(self.path_to_synthetic_images_to_use):
                if img_filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                    image_path = os.path.join(self.path_to_synthetic_images_to_use, img_filename)
                    # Imagenes sinteticas tienen la misma label
                    label = self.synthetic_positive_label
                    self.samples.append({'path': image_path, 'label': label, 'source': 'synthetic'})
                    self.label_counts[label] += 1
                    synthetic_added_count += 1
        
        # Printear distribucion de labels
        print(f"Distribuciones de labels conjunto {self.mode}:")

        for label, count in sorted(self.label_counts.items()):
            percentage = (count / len(self.samples)) * 100 if len(self.samples) > 0 else 0
            print(f"  Label {label}: {count} samples ({percentage:.2f}%)")
        print("-" * 30)


    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample_info = self.samples[idx]
        image_path = sample_info['path']
        label = sample_info['label']
        image = Image.open(image_path).convert('RGB')
        if self.transforms:
            image = self.transforms(image)
        return image, torch.tensor(label, dtype=torch.long)

In [12]:
# Crear datasets
print("--- Creando dataset entrenamiento ---")
train_dataset = ISICDataset(
    csv_path=TRAIN_CSV_PATH,
     original_image_dir=ORIGINAL_IMAGE_PATH,
     image_id_col=IMAGE_ID_COL,
     target_col=TARGET_COL,
     transforms=train_transforms,
     mode='train',
     path_to_synthetic_images_to_use=SYNTHETIC_IMAGE_PATH,
     synthetic_positive_label=1,
     image_extension=IMAGE_EXTENSION
)

print("\n--- Creando dataset de validacion ---")
val_dataset = ISICDataset(
    csv_path=VAL_CSV_PATH,
    original_image_dir=ORIGINAL_IMAGE_PATH,
    image_id_col=IMAGE_ID_COL,
    target_col=TARGET_COL,
    transforms=val_transforms,
    mode='val',
    image_extension=IMAGE_EXTENSION
)

--- Creando dataset entrenamiento ---
Distribuciones de labels conjunto train:
  Label 0: 272452 samples (95.61%)
  Label 1: 12495 samples (4.39%)
------------------------------

--- Creando dataset de validacion ---
Distribuciones de labels conjunto val:
  Label 0: 48081 samples (99.90%)
  Label 1: 47 samples (0.10%)
------------------------------


In [13]:
# Crear Dataloaders a partir de datasets
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True if DEVICE.type == 'cuda' else False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True if DEVICE.type == 'cuda' else False)

## 4. Cargar modelo preentrenado y definir loss

In [14]:
# Cargar modelo preentrenado
model = timm.create_model(
    MODEL_NAME,
    pretrained=True,
    num_classes=NUM_CLASSES
)

model = model.to(DEVICE)

In [15]:
# Balancear funcion de perdida
counts = train_dataset.label_counts
count_class_0 = counts.get(0, 0)
count_class_1 = counts.get(1, 0)


# Peso = N / (2 * N de la clase)
weight_for_0 = (count_class_0 + count_class_1) / (NUM_CLASSES * count_class_0)
weight_for_1 = (count_class_0 + count_class_1) / (NUM_CLASSES * count_class_1)
class_weights = torch.tensor([weight_for_0, weight_for_1], dtype=torch.float32).to(DEVICE)

print(f"Pesos clase 0 , 1: {class_weights.cpu().tolist()}")
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Optimizador ADAM
optimizer = optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY
)

Pesos clase 0 , 1: [0.5229306221008301, 11.402441024780273]


## 5. Funciones de entrenamiento

In [16]:
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch_num, num_epochs, num_classes):
    model.train()
    
    running_loss = 0.0
    all_labels_list = []
    all_preds_proba_list = []

    start_time = time.time()
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        
        # Guardar metricas y labels de epoca
        all_labels_list.extend(labels.detach().cpu().numpy())
        probs = torch.softmax(outputs, dim=1)[:, 1].detach().cpu().numpy()
        all_preds_proba_list.extend(probs)

        if (batch_idx + 1) % 50 == 0:
            print(f"  Epoch [{epoch_num+1}/{num_epochs}] Batch [{batch_idx+1}/{len(train_loader)}] Train Loss: {loss.item():.4f}")

    epoch_loss = running_loss / len(train_loader.dataset)
    
    # Convertir lista a numpy array para sklearn
    all_labels_np = np.array(all_labels_list)
    all_preds_proba_np = np.array(all_preds_proba_list)
    
    # Calculo de metricas para la epoca
    predicted_classes_np = (all_preds_proba_np >= 0.5).astype(int)
    epoch_auc = roc_auc_score(all_labels_np, all_preds_proba_np)
    epoch_f1 = f1_score(all_labels_np, predicted_classes_np, pos_label=1, zero_division=0)
    epoch_recall = recall_score(all_labels_np, predicted_classes_np, pos_label=1, zero_division=0)
    epoch_precision = precision_score(all_labels_np, predicted_classes_np, pos_label=1, zero_division=0)

    epoch_balanced_acc = balanced_accuracy_score(all_labels_np, predicted_classes_np)
    
    end_time = time.time()
    epoch_duration = end_time - start_time
    
    print(f"Epoch [{epoch_num+1}/{num_epochs}] Train Loss: {epoch_loss:.4f}, AUC: {epoch_auc:.4f}, BalAcc: {epoch_balanced_acc:.4f}, F1: {epoch_f1:.4f}, Recall: {epoch_recall:.4f}, Precision: {epoch_precision:.4f}, Time: {epoch_duration:.2f}s")
    
    metrics = {
        'loss': epoch_loss,
        'auc': epoch_auc,
        'balanced_accuracy': epoch_balanced_acc,
        'f1_score': epoch_f1,
        'recall': epoch_recall,
        'precision': epoch_precision
    }
    return metrics


def validate_one_epoch(model, val_loader, criterion, device, epoch_num, num_epochs, num_classes):
    model.eval()
    
    running_loss = 0.0
    all_labels_list = []
    all_preds_proba_list = []

    start_time = time.time()

    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(val_loader):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            
            all_labels_list.extend(labels.detach().cpu().numpy())
            probs = torch.softmax(outputs, dim=1)[:, 1].detach().cpu().numpy()
            all_preds_proba_list.extend(probs)

    epoch_loss = running_loss / len(val_loader.dataset)
    
    all_labels_np = np.array(all_labels_list)
    all_preds_proba_np = np.array(all_preds_proba_list)

    if len(np.unique(all_labels_np)) < 2 :
        print(f"Warning: Conjunto de validacion para epoca {epoch_num+1} le falta una clase!.")
        epoch_auc = 0.5
    else:
        epoch_auc = roc_auc_score(all_labels_np, all_preds_proba_np)
    
    predicted_classes_np = (all_preds_proba_np >= 0.5).astype(int)
    epoch_f1 = f1_score(all_labels_np, predicted_classes_np, pos_label=1, zero_division=0)
    epoch_recall = recall_score(all_labels_np, predicted_classes_np, pos_label=1, zero_division=0)
    epoch_precision = precision_score(all_labels_np, predicted_classes_np, pos_label=1, zero_division=0)
    conf_matrix = confusion_matrix(all_labels_np, predicted_classes_np, labels=[0,1])

    epoch_balanced_acc = balanced_accuracy_score(all_labels_np, predicted_classes_np)

    end_time = time.time()
    epoch_duration = end_time - start_time

    print(f"Epoch [{epoch_num+1}/{num_epochs}] Val Loss: {epoch_loss:.4f}, AUC: {epoch_auc:.4f}, BalAcc: {epoch_balanced_acc:.4f}, F1(pos): {epoch_f1:.4f}, Recall(pos): {epoch_recall:.4f}, Precision(pos): {epoch_precision:.4f}, Time: {epoch_duration:.2f}s")
    if conf_matrix is not None:
        print(f"Matriz de confusion para epoca {epoch_num+1}:\n{conf_matrix}")
    
    metrics = {
        'loss': epoch_loss,
        'auc': epoch_auc,
        'balanced_accuracy': epoch_balanced_acc,
        'f1_score_positive': epoch_f1,
        'recall_positive': epoch_recall,
        'precision_positive': epoch_precision,
        'conf_matrix': conf_matrix
    }
    return metrics

## 6. Entrenamiento

In [None]:
print("\n--- Entrenamiento ---")

# Listas para trackeo
train_history = []
val_history = []

best_val_auc = 0.0

start_training_time = time.time()

for epoch in range(NUM_EPOCHS):
    print(f"\n===== Epoch {epoch+1}/{NUM_EPOCHS} =====")
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Tasa de aprendizaje: {current_lr:.6e}")


    # --- Entrenamiento ---
    if train_loader:
        train_metrics = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE, epoch, NUM_EPOCHS, NUM_CLASSES)
        train_history.append(train_metrics)
    else:
        train_history.append({'loss': float('nan'), 'auc': float('nan'), 'balanced_accuracy': float('nan'), 'f1_score': float('nan'), 'recall': float('nan'), 'precision': float('nan')})

    # --- Validacion ---
    if val_loader:
        val_metrics = validate_one_epoch(model, val_loader, criterion, DEVICE, epoch, NUM_EPOCHS, NUM_CLASSES)
        val_history.append(val_metrics)

        # --- Checkpoint ---
        current_val_auc = val_metrics['auc']
        if current_val_auc > best_val_auc:
            best_val_auc = current_val_auc
            torch.save(model.state_dict(), BEST_MODEL_SAVE_PATH)
            print(f"Epoch {epoch+1}: Mejor modelo con AUC: {best_val_auc:.4f} guardado en {BEST_MODEL_SAVE_PATH}")
        
    else:
        val_history.append({'loss': float('nan'), 'auc': float('nan'), 'balanced_accuracy': float('nan'), 'f1_score_positive': float('nan'), 'recall_positive': float('nan'), 'precision_positive': float('nan'), 'conf_matrix': None })


end_training_time = time.time()
total_training_duration = end_training_time - start_training_time
print(f"\n--- Entrenamiento terminado ---")
print(f"Tiempo: {total_training_duration / 60:.2f} minutos ({total_training_duration:.2f} segundos)")
if val_loader:
    print(f"Mejor AUC validacion: {best_val_auc:.4f}")
    print(f"Nombre mejor modelo: {BEST_MODEL_SAVE_PATH}")


--- Entrenamiento ---

===== Epoch 1/1 =====
Tasa de aprendizaje: 1.000000e-04
  Epoch [1/1] Batch [50/8905] Train Loss: 0.0051
  Epoch [1/1] Batch [100/8905] Train Loss: 1.5077
  Epoch [1/1] Batch [150/8905] Train Loss: 0.0006
  Epoch [1/1] Batch [200/8905] Train Loss: 0.0019
  Epoch [1/1] Batch [250/8905] Train Loss: 0.0009
