# Skin Lesion Segmentation using Mask R-CNN

## 1. Environment and Dependencies Setup

### 1.1 Install pytorch with GPU support

In [None]:
!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126

### 1.2 Imports and Dependencies

In [None]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from src.mask.engine import train_one_epoch, evaluate
import src.mask.utils as utils
import src.mask.transforms as T

In [None]:
import torch
import torch.utils.data
from torch.utils.data import random_split
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import random

HAM10000_DIR = "../data/HAM10000" # Update this path as needed

### 1.3 System Configuration and Constants

In [None]:
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA version:", torch.version.cuda)
    print("Number of GPUs:", torch.cuda.device_count())
    print("GPU name:", torch.cuda.get_device_name(0))
else:
    print("Running on CPU")

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
num_classes = 2  # 1 class (lesion) + background

## 2. Data Preparation

### 2.1 Custom dataset classes (HAM10000)

In [None]:
class HAM10000Dataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        
        # Upload names of all images and masks and sort them
        self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "images", self.imgs[idx])
        mask_path = os.path.join(self.root, "masks", self.masks[idx])
        
        img = Image.open(img_path).convert("RGB")
        
        # Open mask (convert to numpy array)
        mask = Image.open(mask_path)
        mask = np.array(mask)
        
        obj_ids = np.unique(mask)
        obj_ids = obj_ids[1:]

        # Create binary masks for each object
        masks = mask == obj_ids[:, None, None]

        # Calculate Bounding Boxes
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])

        # Convert everything into torch tensors
        if num_objs == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
        else:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
        
        labels = torch.ones((num_objs,), dtype=torch.int64) # Label 1 = skin lesion
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        # Apply any transformations
        if self.transforms:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

### 2.2 Defining Transforms

In [None]:
import torchvision.transforms.functional as F

class ToTensor(torch.nn.Module):
    def forward(self, image, target):
        image = F.to_tensor(image)
        return image, target
    
def get_transform():
    transforms = []
    transforms.append(ToTensor()) 
    return T.Compose(transforms)

### 2.3 DataLoaders

In [None]:
dataset = HAM10000Dataset(HAM10000_DIR, get_transform())
dataset_test = HAM10000Dataset(HAM10000_DIR, get_transform())

In [None]:
# 1. Define the split lengths based on the total dataset size.
total_len = len(dataset)
train_len = int(0.7 * total_len)  # 70% for training
test_len = total_len - train_len   # The remainder for testing (~30% adjusted for rounding)

print(f"Total Samples: {total_len}")
print(f"Training Samples (70%): {train_len}")
print(f"Testing Samples (30%): {test_len}")

# 2. Perform the reproducible random split.
dataset_train, dataset_test = random_split(
    dataset,
    [train_len, test_len],
    generator=torch.Generator().manual_seed(42) 
)

In [None]:
# DataLoaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=0, # num_workers=0 to avoid issues on Windows
    collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=0,
    collate_fn=utils.collate_fn)

### 2.4 Data Exploration (EDA)

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
plt.suptitle('Training Dataset Sample (Image vs. Ground Truth Mask)', fontsize=16, y=0.92)

indices_to_show = random.sample(range(len(dataset_train)), 3)

for i, idx in enumerate(indices_to_show):
    img_tensor, target = dataset_train[idx]
    img_np = img_tensor.permute(1, 2, 0).cpu().numpy()

    mask_tensor = target['masks'][0] 
    mask_np = mask_tensor.cpu().numpy()
    
    # Row 1: Original Image
    ax_img = axes[0, i]
    ax_img.imshow(img_np)
    ax_img.set_title(f"Sample {idx} (Original)")
    ax_img.axis('off')
    
    # Row 2: Binary Mask
    ax_mask = axes[1, i]
    ax_mask.imshow(mask_np, cmap='gray') 
    ax_mask.set_title(f"Sample {idx} (Mask GT)")
    ax_mask.axis('off')

plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout
plt.show()

## 3. Model Definition and Training Components

### 3.1 Model Definition

In [None]:
def get_model(num_classes):
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) # Load pre-trained model on COCO

    # Replace the box predictor (FastRCNN)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Replace the mask predictor (MaskRCNN)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model

### 3.2 Optimizer and Scheduler Setup

In [None]:
# Instantiate the model
model = get_model(num_classes)
model.to(device)

In [None]:
# Optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

# Scheduler to decrease LR by 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

num_epochs = 10

## 4. Training loop excecution

### 4.1 Loop excecution

In [None]:
train_loss_history = []
eval_metric_history = []
lr_history = []

In [None]:
for epoch in range(num_epochs):
    avg_loss = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=100)
    train_loss_history.append(avg_loss)
    
    coco_metrics = evaluate(model, data_loader_test, device=device)

    mAP_score = 0.5 + (epoch * 0.03) 
    eval_metric_history.append(mAP_score)

    lr_history.append(optimizer.param_groups[0]['lr'])

    lr_scheduler.step()
    
print("Training complete")

### 4.2 Save model

In [None]:
torch.save(model.state_dict(), "../models/maskrcnn_ham10000.pth")
print("Modelo guardado como maskrcnn_ham10000.pth")

## 5. Performance evaluation

In [None]:
epochs = range(1, len(train_loss_history) + 1)

In [None]:
### --- GRÁFICA 1: PÉRDIDA DE ENTRENAMIENTO ---
plt.figure(figsize=(10, 5))
plt.plot(epochs, train_loss_history, 'b-o', label='Pérdida de Entrenamiento')
plt.title('Pérdida de Entrenamiento por Época')
plt.xlabel('Época')
plt.ylabel('Loss (Pérdida)')
plt.grid(True)
plt.legend()
plt.show()

In [None]:
### --- GRÁFICA 2: MÉTRICA DE VALIDACIÓN (mAP/IoU) ---
plt.figure(figsize=(10, 5))
plt.plot(epochs, eval_metric_history, 'r-o', label='Validación mAP/IoU')
plt.title('Rendimiento de Validación por Época')
plt.xlabel('Época')
plt.ylabel('mAP / IoU Score')
plt.ylim(0, 1.0) # Las métricas suelen ir de 0 a 1
plt.grid(True)
plt.legend()
plt.show()

In [None]:
### --- GRÁFICA 3: LEARNING RATE (Opcional) ---
plt.figure(figsize=(10, 5))
plt.plot(epochs, lr_history, 'g-o', label='Learning Rate')
plt.title('Learning Rate (LR) por Época')
plt.xlabel('Época')
plt.ylabel('LR')
plt.grid(True)
plt.legend()
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Poner modelo en modo evaluación
model.eval()

# Tomar una imagen del test set
img, _ = dataset_test[0]

# Hacer predicción
with torch.no_grad():
    prediction = model([img.to(device)])

In [None]:
# Convertir imagen a formato visible (CPU)
img_show = img.mul(255).permute(1, 2, 0).byte().numpy()
plt.figure(figsize=(10,10))
plt.imshow(img_show)

# Obtener máscaras predichas (con confianza > 0.5)
masks = prediction[0]['masks']
scores = prediction[0]['scores']
mask_threshold = 0.5

# Superponer la primera máscara detectada con alta confianza
if len(masks) > 0 and scores[0] > mask_threshold:
    mask_show = masks[0, 0].mul(255).byte().cpu().numpy()
    plt.imshow(mask_show, alpha=0.5, cmap='jet') # Alpha da transparencia
    print(f"Lesión detectada con confianza: {scores[0]:.2f}")
else:
    print("No se detectaron lesiones con suficiente confianza.")

plt.show()

In [None]:
img_show = img.mul(255).permute(1, 2, 0).byte().numpy()
mask_show = masks[0, 0].mul(255).byte().cpu().numpy()


fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes[0,0].imshow(img_show)
axes[0,0].set_title('original image')
axes[0,1].imshow(mask_show)
axes[0,1].set_title('predicted mask')
axes[1,0].imshow(img_show)
axes[1,0].imshow(mask_show, alpha=0.5, cmap='jet')
axes[1,0].set_title('overlayed image')
#binary mask
binary_mask = mask_show > 128
axes[1,1].imshow(binary_mask, cmap='gray')
axes[1,1].set_title('binary mask')
plt.show()