In [25]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import timm
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, precision_score, recall_score
import numpy as np
from collections import Counter

In [26]:
# Paths
DATA_PATH = "data/"
ORIGINAL_IMAGE_PATH = os.path.join(DATA_PATH, "images")
SYNTHETIC_IMAGE_PATH = os.path.join(DATA_PATH, "synthetic_images")
TRAIN_CSV_PATH = os.path.join(DATA_PATH, "train_split.csv")
VAL_CSV_PATH = os.path.join(DATA_PATH, "val_split.csv")

In [6]:
print(ORIGINAL_IMAGE_PATH, SYNTHETIC_IMAGE_PATH, TRAIN_CSV_PATH, VAL_CSV_PATH)

data/images data/synthetic_images data/train_split.csv data/val_split.csv


In [None]:
# Parametros
NUM_CLASSES = 2
MODEL_NAME = 'edgenext_base.in21k_ft_in1k'

## 2. Definición de transformaciones de imagenes

In [7]:
# Obtener data de entrenamiento modelo preentrenado HuggingFace
MODEL_NAME = 'edgenext_base.in21k_ft_in1k'
model_cfg = timm.get_pretrained_cfg(MODEL_NAME)
IMG_SIZE = model_cfg.input_size[1]
NORM_MEAN = model_cfg.mean
NORM_STD = model_cfg.std

In [8]:
# Transformaciones para aumentar el conjunto de entrenamiento
train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
    transforms.ToTensor(),
    transforms.Normalize(mean=NORM_MEAN, std=NORM_STD),
])

val_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=NORM_MEAN, std=NORM_STD),
])

## 3. Crear DataLoader

In [23]:
# Clase dataloader custom
class ISICDataset(Dataset):
    def __init__(self, csv_path, original_image_dir, image_id_col, target_col,
                 transforms=None, mode='train',
                 path_to_synthetic_images_to_use=None,
                 synthetic_positive_label=1, # Label para imagenes sinteticas
                 image_extension='.png'):
        
        self.mode = mode
        self.original_image_dir = original_image_dir
        self.image_id_col = image_id_col
        self.target_col = target_col
        self.transforms = transforms
        self.path_to_synthetic_images_to_use = path_to_synthetic_images_to_use
        self.synthetic_positive_label = synthetic_positive_label
        self.image_extension = image_extension

        self.samples = []
        self.label_counts = Counter()
    
        # Leer csv con datos originales y cargar imagenes originales
        self.original_df = pd.read_csv(csv_path)
        for idx, row in self.original_df.iterrows():
            image_id = row[self.image_id_col]
            image_path = os.path.join(self.original_image_dir, str(image_id) + self.image_extension)
            label = int(row[self.target_col])
            self.samples.append({'path': image_path, 'label': label, 'source': 'original'})
            self.label_counts[label] += 1 # Para determianr si dataset esta desbalanceado


        # En modo entrenamiento cargar imagenes sinteticas
        if self.mode == 'train' and self.path_to_synthetic_images_to_use:
            synthetic_added_count = 0
            for img_filename in os.listdir(self.path_to_synthetic_images_to_use):
                if img_filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                    image_path = os.path.join(self.path_to_synthetic_images_to_use, img_filename)
                    # Imagenes sinteticas tienen la misma label
                    label = self.synthetic_positive_label
                    self.samples.append({'path': image_path, 'label': label, 'source': 'synthetic'})
                    self.label_counts[label] += 1
                    synthetic_added_count += 1
        
        # Printear distribucion de labels
        print(f"Distribuciones de labels conjunto {self.mode}:")

        for label, count in sorted(self.label_counts.items()):
            percentage = (count / len(self.samples)) * 100 if len(self.samples) > 0 else 0
            print(f"  Label {label}: {count} samples ({percentage:.2f}%)")
        print("-" * 30)


    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample_info = self.samples[idx]
        image_path = sample_info['path']
        label = sample_info['label']
        image = Image.open(image_path).convert('RGB')
        if self.transforms:
            image = self.transforms(image)
        return image, torch.tensor(label, dtype=torch.long)

In [24]:
BATCH_SIZE = 32 # Ensure this is defined
NUM_WORKERS = 0 # Start with 0 for debugging, then increase to 2 or 4

IMAGE_ID_COL = "isic_id"
TARGET_COL = "malignant"
SYNTHETIC_HR_DIR_FOR_TRAINING = "data/synthetic_images/"
IMAGE_EXTENSION = ".png"

print("--- Creating Training Dataset ---")
train_dataset = ISICDataset(
    csv_path=TRAIN_CSV_PATH,
     original_image_dir=ORIGINAL_IMAGE_PATH,
     image_id_col=IMAGE_ID_COL,
     target_col=TARGET_COL,
     transforms=train_transforms,
     mode='train',
     path_to_synthetic_images_to_use=SYNTHETIC_HR_DIR_FOR_TRAINING, # e.g., 'data/synthetic_images/hr'
     synthetic_positive_label=1, # Crucial: ensure this aligns with your positive class label
     image_extension=IMAGE_EXTENSION
)

--- Creating Training Dataset ---
Label distribution for train set:
  Label 0: 272452 samples (95.61%)
  Label 1: 12495 samples (4.39%)
------------------------------
