# 02 Apply segmentation mask

## Import dependences

In [None]:
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import os

## Define images path and masks path

In [None]:
DATA_PATH = "../data/ISIC/images"
OUTPUT_PATH = "../data/ISIC/masks"

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

## Define architecture model

In [None]:
def get_model(num_classes):
    # Load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) 

    # Replace the box predictor (FastRCNN)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Replace the mask predictor (MaskRCNN)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)
    return model

## Load trained weights

In [None]:
#load mask rcnn model trained
model = get_model(num_classes=2)
model.load_state_dict(torch.load('../models/maskrcnn_ham10000.pth', map_location=torch.device('cpu')))
model.eval()

## Apply and save masks

In [None]:
for img_name in os.listdir(DATA_PATH):
    if img_name.endswith('.jpg') or img_name.endswith('.png'):
        img_path = os.path.join(DATA_PATH, img_name)
        image = Image.open(img_path).convert("RGB")
        image_tensor = torchvision.transforms.functional.to_tensor(image)

        with torch.no_grad():
            prediction = model([image_tensor])
        
        mask_threshold = 0.5  # Tu umbral de decisión
        
        if len(prediction[0]['masks']) > 0:
            # 1. Obtener el mapa de probabilidades crudo (valores entre 0.0 y 1.0)
            raw_mask = prediction[0]['masks'][0, 0]
            
            # 2. APLICAR EL UMBRAL (Aquí ocurre la magia de la binarización)
            # Esto convierte el tensor a True/False (Boolean)
            binary_mask = raw_mask > mask_threshold
            
            # 3. Convertir True/False a 255/0 y pasar a numpy
            mask = binary_mask.mul(255).byte().cpu().numpy()
            
            # 4. Crear imagen
            mask_image = Image.fromarray(mask)
            
            # Opcional: Asegurar modo 'L' (8-bit pixels, black and white) o '1' (1-bit pixels)
            # 'L' es más compatible si luego vas a procesarla con otras librerías
            mask_image = mask_image.convert('L') 
            
            mask_image.save(os.path.join(OUTPUT_PATH, img_name))
        else:
            # If no object detected, save an empty mask (Negra total)
            empty_mask = Image.new('L', image.size, 0)s
            empty_mask.save(os.path.join(OUTPUT_PATH, img_name))