# **Modelo CNN Simple - ResNet18 + Dropout + Soft Attention Espacial** 

In [1]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

In [34]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Redimensionar a un tamaño fijo (ej. 128x128)
    transforms.ToTensor(),  # Convertir la imagen a un tensor de PyTorch
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalización (media y desviación estándar de imágenes ImageNet)
])

In [2]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        """
        Args:
            csv_file (str): Ruta al archivo CSV con las imágenes y sus etiquetas.
            img_dir (str): Ruta al directorio que contiene las imágenes.
            transform (callable, optional): Transformaciones que se aplican a las imágenes.
        """
        self.img_labels = pd.read_csv(csv_file)  # Leer el archivo CSV con las etiquetas
        self.img_dir = img_dir  # Ruta donde están las imágenes
        self.transform = transform  # Transformaciones a aplicar

        self.class_to_idx = {class_name: idx for idx, class_name in enumerate(self.img_labels['diagnosis'].unique())}
    def __len__(self):
        """Retorna el número total de imágenes en el dataset"""
        return len(self.img_labels)

    def __getitem__(self, idx):
        """Obtiene una imagen y su etiqueta"""
        img_name = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])  # Nombre de la imagen
        image = Image.open(img_name)  # Abrir la imagen
        label = self.class_to_idx[self.img_labels.iloc[idx, 3]] # Etiqueta asociada

        if self.transform:
            image = self.transform(image)  # Aplicar transformaciones si es necesario

        return image, label

**Generación del dataset de entrenamiento y validación**

In [5]:
# Leer el CSV
csvRoute="bcn_20k_train.csv"
df = pd.read_csv(csvRoute)
#Definir las clases que deseas excluir por su nombre
clases_a_excluir = ['SCC', 'DF', 'VASC']  # Sustituye estos nombres por las clases que quieres excluir

# Filtrar el DataFrame para excluir las clases especificadas
df_filtrado = df[~df['diagnosis'].isin(clases_a_excluir)]

df_filtrado.to_csv("bcn_20k_train_filtrado.csv", index=False)

# Dividir el dataset en entrenamiento (80%) y validación (20%)
train_df, val_df = train_test_split(df_filtrado, test_size=0.2, random_state=42)

# Crear el dataset de entrenamiento y validación
train_dataset = CustomDataset(csv_file="bcn_20k_train_filtrado.csv", img_dir='bcn_20k_train', transform=transform)
val_dataset = CustomDataset(csv_file="bcn_20k_train_filtrado.csv", img_dir='bcn_20k_train', transform=transform)

# Actualizar los datasets con los subconjuntos correspondientes
train_dataset.img_labels = train_df
val_dataset.img_labels = val_df

# Crear los DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

**Guardado de los datasets divididos inicialmente**

In [69]:
train_df.to_json('data_train_resnet18_softAtt.json', orient='records', lines=True)
val_df.to_json('data_val_resnet18_softAtt.json', orient='records', lines=True)

**Definición del Modelo**

In [49]:
import torch
import torch.nn as nn
import torchvision.models as models

# Mecanismo de Soft-Attention Espacial
class SpatialAttention(nn.Module):
    def __init__(self, in_channels):
        super(SpatialAttention, self).__init__()
        self.conv_attention = nn.Sequential(
            nn.Conv2d(in_channels, 1, kernel_size=1),  # Mapa de atención 1x1
            nn.Softmax(dim=2)                           # Normalización espacial
        )
        
    def forward(self, x):
        # x: (batch, 512, H, W) [Ej: (batch, 512, 7, 7)]
        
        # Generar mapa de atención (batch, 1, H, W)
        attn_weights = self.conv_attention(x)
        
        # Aplicar atención: características * pesos
        attended_features = x * attn_weights  # Broadcasting automático
        
        return attended_features


# Modelo ResNet18 con Dropout + Soft-Attention
class ResNet18WithAttention(nn.Module):
    def __init__(self, num_classes):
        super(ResNet18WithAttention, self).__init__()
        
        # 1. Cargar ResNet18 preentrenado
        self.resnet18 = models.resnet18(pretrained=True)
        
        # 2. Congelar capas convolucionales
        for param in self.resnet18.parameters():
            param.requires_grad = False
            
        # 3. Añadir mecanismo de atención espacial
        self.attention = SpatialAttention(in_channels=512)  # ResNet18 tiene 512 canales al final
        
        # 4. Modificar capas FC con Dropout
        self.resnet18.fc = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        # Extraer características hasta la última capa convolucional
        x = self.resnet18.conv1(x)
        x = self.resnet18.bn1(x)
        x = self.resnet18.relu(x)
        x = self.resnet18.maxpool(x)
        x = self.resnet18.layer1(x)
        x = self.resnet18.layer2(x)
        x = self.resnet18.layer3(x)
        x = self.resnet18.layer4(x)  # Salida: (batch, 512, 7, 7)
        
        # Aplicar atención espacial
        x = self.attention(x)  # (batch, 512, 7, 7) con pesos aprendidos
        
        # Global Average Pooling y clasificación
        x = self.resnet18.avgpool(x)  # (batch, 512, 1, 1)
        x = torch.flatten(x, 1)       # (batch, 512)
        x = self.resnet18.fc(x)       # (batch, num_classes)
        
        return x

**Entrenamiento y validación del modelo**

In [None]:
# Verificar si CUDA (GPU) está disponible
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Definición de la cantidad de clases, la función de perdida, el optimizador y el learning rate estático
modelAt = ResNet18WithAttention(num_classes=5)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(modelAt.parameters(), lr=0.001)

# Entrenar el modelo
num_epochs = 10

model = modelAt.to(device)

for epoch in range(num_epochs):
    model.train()  # Modo entrenamiento
    running_loss = 0.0
    correct_preds = 0
    total_preds = 0

    # Entrenamiento
    for inputs, labels in train_dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Estadísticas de la pérdida
        running_loss += loss.item()

        # Precisión
        _, predicted = torch.max(outputs, 1)
        correct_preds += (predicted == labels).sum().item()
        total_preds += labels.size(0)

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_dataloader)}, Accuracy: {100 * correct_preds / total_preds}%")
    # Validación
    model.eval()  # Modo evaluación
    correct_preds = 0
    total_preds = 0
    with torch.no_grad():  # No calcular gradientes durante la validación
        for inputs, labels in val_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            correct_preds += (predicted == labels).sum().item()
            total_preds += labels.size(0)

    print(f"Validation Accuracy: {100 * correct_preds / total_preds}%")

**70% de Accuracy de validación en 50 Epocas**

**143 minutos 64.85% en validación 10 Epocas:**

Epoch 1/10, Loss: 1.134643373440723, Accuracy: 56.09236990528892%
Validation Accuracy: 59.91489361702128%
Epoch 2/10, Loss: 1.0131697717572556, Accuracy: 61.57284239650952%
Validation Accuracy: 60.5531914893617%
Epoch 3/10, Loss: 0.9697256232200026, Accuracy: 63.2648717675854%
Validation Accuracy: 62.723404255319146%
Epoch 4/10, Loss: 0.959287223969998, Accuracy: 63.14781313185059%
Validation Accuracy: 63.276595744680854%
Epoch 5/10, Loss: 0.9335303259950106, Accuracy: 64.60572523145684%
Validation Accuracy: 63.702127659574465%
Epoch 6/10, Loss: 0.9142697725166269, Accuracy: 64.97818452697669%
Validation Accuracy: 63.1063829787234%
Epoch 7/10, Loss: 0.8985673685868582, Accuracy: 65.32936043418113%
Validation Accuracy: 63.61702127659574%
Epoch 8/10, Loss: 0.8820558255221568, Accuracy: 66.14877088432479%
Validation Accuracy: 63.829787234042556%
Epoch 9/10, Loss: 0.8775967738660825, Accuracy: 66.44673832074066%
Validation Accuracy: 63.744680851063826%
Epoch 10/10, Loss: 0.8590307535768367, Accuracy: 66.86176439289135%
Validation Accuracy: 64.85106382978724%

**142 minutos 66.9% en validación 10 Epocas (Primero en detectar KA):**

Epoch 1/10, Loss: 0.8453486519200462, Accuracy: 67.73438331382356%
Validation Accuracy: 65.23404255319149%
Epoch 2/10, Loss: 0.8239534395892604, Accuracy: 68.34095988081303%
Validation Accuracy: 65.02127659574468%
Epoch 3/10, Loss: 0.8275354070728328, Accuracy: 68.11748430350112%
Validation Accuracy: 66.25531914893617%
Epoch 4/10, Loss: 0.8035658865558858, Accuracy: 68.57507715228265%
Validation Accuracy: 65.40425531914893%
Epoch 5/10, Loss: 0.7843043409845456, Accuracy: 70.22453974672769%
Validation Accuracy: 65.48936170212765%
Epoch 6/10, Loss: 0.7854251228019494, Accuracy: 69.91593061615409%
Validation Accuracy: 66.59574468085107%
Epoch 7/10, Loss: 0.7668884254637218, Accuracy: 70.21389805256997%
Validation Accuracy: 65.82978723404256%
Epoch 8/10, Loss: 0.7613350249269382, Accuracy: 70.67149090135149%
Validation Accuracy: 66.34042553191489%
Epoch 9/10, Loss: 0.7379224217262398, Accuracy: 72.19325316590401%
Validation Accuracy: 67.19148936170212%
Epoch 10/10, Loss: 0.7312207400393324, Accuracy: 71.50154304565287%
Validation Accuracy: 66.93617021276596%

**150 minutos 68.12% en validación 10 Epocas**

Epoch 1/10, Loss: 0.7263397300729946, Accuracy: 72.32095349579653%
Validation Accuracy: 66.68085106382979%
Epoch 2/10, Loss: 0.7067413063479119, Accuracy: 72.42737043737363%
Validation Accuracy: 67.23404255319149%
Epoch 3/10, Loss: 0.701630963456063, Accuracy: 73.42768968819836%
Validation Accuracy: 67.06382978723404%
Epoch 4/10, Loss: 0.6938930058560404, Accuracy: 73.05523039267851%
Validation Accuracy: 66.12765957446808%
Epoch 5/10, Loss: 0.6726329312438056, Accuracy: 74.4386506331808%
Validation Accuracy: 67.70212765957447%
Epoch 6/10, Loss: 0.6607335396364432, Accuracy: 74.71533468128126%
Validation Accuracy: 67.48936170212765%
Epoch 7/10, Loss: 0.6529108140947056, Accuracy: 75.17292753006278%
Validation Accuracy: 67.44680851063829%
Epoch 8/10, Loss: 0.6382793583432023, Accuracy: 75.81142917952538%
Validation Accuracy: 68.12765957446808%
Epoch 9/10, Loss: 0.6398449084993933, Accuracy: 76.21581355751836%
Validation Accuracy: 68.68085106382979%
Epoch 10/10, Loss: 0.6278198312739937, Accuracy: 76.01362136852187%
Validation Accuracy: 68.12765957446808%

**130 minutos 68.7% en validación 10 Epocas**

Epoch 1/10, Loss: 0.6140470188491198, Accuracy: 76.65212301798447%
Validation Accuracy: 68.2127659574468%
Epoch 2/10, Loss: 0.6087234322311116, Accuracy: 76.29030541662233%
Validation Accuracy: 68.68085106382979%
Epoch 3/10, Loss: 0.5855527020433322, Accuracy: 77.73757582207088%
Validation Accuracy: 69.14893617021276%
Epoch 4/10, Loss: 0.5734712509881883, Accuracy: 78.36543577737577%
Validation Accuracy: 68.76595744680851%
Epoch 5/10, Loss: 0.5858643731089677, Accuracy: 78.23773544748325%
Validation Accuracy: 68.85106382978724%
Epoch 6/10, Loss: 0.5681056010277092, Accuracy: 78.1951686708524%
Validation Accuracy: 69.02127659574468%
Epoch 7/10, Loss: 0.563230012549835, Accuracy: 78.833670320315%
Validation Accuracy: 69.31914893617021%
Epoch 8/10, Loss: 0.5470320342146621, Accuracy: 79.9084814302437%
Validation Accuracy: 68.59574468085107%
Epoch 9/10, Loss: 0.5348383441161947, Accuracy: 79.63179738214323%
Validation Accuracy: 69.36170212765957%
Epoch 10/10, Loss: 0.5381639800002785, Accuracy: 79.35511333404278%
Validation Accuracy: 68.76595744680851%

**113 minutos 70.29% en validación 10 Epocas (Ya Detecta KA y CCB)**

Epoch 1/10, Loss: 0.5233964961605008, Accuracy: 80.05746514845163%
Validation Accuracy: 69.23404255319149%
Epoch 2/10, Loss: 0.5092849620953709, Accuracy: 80.0255400659785%
Validation Accuracy: 69.48936170212765%
Epoch 3/10, Loss: 0.5059386601253432, Accuracy: 80.62147493881025%
Validation Accuracy: 68.17021276595744%
Epoch 4/10, Loss: 0.5176538025238075, Accuracy: 79.86591465361286%
Validation Accuracy: 68.8936170212766%
Epoch 5/10, Loss: 0.4845995868043024, Accuracy: 81.61115249547728%
Validation Accuracy: 69.61702127659575%
Epoch 6/10, Loss: 0.497197156657978, Accuracy: 81.17484303501118%
Validation Accuracy: 70.2127659574468%
Epoch 7/10, Loss: 0.4828392481621431, Accuracy: 81.64307757795041%
Validation Accuracy: 69.65957446808511%
Epoch 8/10, Loss: 0.47273193527849355, Accuracy: 81.76013621368521%
Validation Accuracy: 69.91489361702128%
Epoch 9/10, Loss: 0.4651789545607405, Accuracy: 82.17516228583591%
Validation Accuracy: 69.87234042553192%
Epoch 10/10, Loss: 0.46274231151253187, Accuracy: 82.09002873257423%
Validation Accuracy: 70.29787234042553%


**Guardado del modelo completo**

In [107]:
torch.save(model, 'modelo_entrenado_resnet18_softAtt_ka_completo_70_5_clases.pth')

**Guardado de los pesos del modelo**

In [108]:
torch.save(model.state_dict(), 'modelo_entrenado_resnet18_softAtt_ka_pesos_70_5_clases.pth')

**Evaluación en datos reales**

In [13]:
# Mapeo de clases
print(train_dataset.class_to_idx)

{'MEL': 0, 'NV': 1, 'BCC': 2, 'BKL': 3, 'AK': 4}


In [109]:
model.eval()  # Ponemos el modelo en modo de evaluación

# Paso 3: Cargar la imagen y aplicar las transformaciones
image_path = 'ka.jpg'  # Pon aquí la ruta de tu imagen
image = Image.open(image_path)  # Abrir la imagen
image_tensor = transform(image)  # Aplicar las transformaciones

image_tensor = image_tensor.unsqueeze(0)  # Convertirlo a un batch de tamaño 1

# Paso 5: Mover la imagen al dispositivo (GPU o CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_tensor = image_tensor.to(device)
model = model.to(device)

# Paso 6: Realizar la predicción
with torch.no_grad():  # No necesitamos gradientes para la inferencia
    output = model(image_tensor)

# Paso 7: Convertir las predicciones en probabilidades con softmax
probabilities = torch.nn.functional.softmax(output, dim=1)  # Usamos dim=1 porque tenemos un batch

# Paso 8: Obtener la clase con la mayor probabilidad
_, predicted_class = torch.max(probabilities, dim=1)

# Paso 9: Interpretar la clase predicha
# Usamos el mapeo que ya tienes de clases (el 'class_to_idx' que ya definiste en tu dataset)
predicted_idx = predicted_class.item()  # Obtenemos el índice de la clase predicha
print(predicted_idx)
# Aquí usamos el mapeo de clases que creamos antes para convertir el índice a una clase legible
predicted_class_name = [key for key, value in train_dataset.class_to_idx.items() if value == predicted_idx][0]

# Mostrar la clase predicha
print(f"Predicción: {predicted_class_name}")

4
Predicción: AK
