# 02 Apply segmentation mask

## Import dependences

In [1]:
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import os
import pandas as pd
from PIL import Image
import shutil
from tqdm import tqdm

## Define images path and masks path

In [2]:
DATA_PATH = "../data/isic-archive/images_selected"
OUTPUT_PATH = "../data/isic-archive/masks_selected"

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

## Define architecture model

In [3]:
def get_model(num_classes):
    # Load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) 

    # Replace the box predictor (FastRCNN)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Replace the mask predictor (MaskRCNN)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)
    return model

## Load trained weights

In [4]:
#load mask rcnn model trained
model = get_model(num_classes=2)
model.load_state_dict(torch.load('../models/mask_models/maskrcnn_ham10000.pth', map_location=torch.device('cpu')))
model.eval()



MaskRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(in

## Apply and save masks

In [5]:
for img_name in tqdm(os.listdir(DATA_PATH), desc="Processing images"):
    if img_name.endswith('.jpg') or img_name.endswith('.png'):
        img_path = os.path.join(DATA_PATH, img_name)
        image = Image.open(img_path).convert("RGB")
        image_tensor = torchvision.transforms.functional.to_tensor(image)

        with torch.no_grad():
            prediction = model([image_tensor])
        
        mask_threshold = 0.5
        
        if len(prediction[0]['masks']) > 0:
            raw_mask = prediction[0]['masks'][0, 0]

            binary_mask = raw_mask > mask_threshold
            mask = binary_mask.mul(255).byte().cpu().numpy()

            mask_image = Image.fromarray(mask)
            mask_image = mask_image.convert('L') 
            
            mask_image.save(os.path.join(OUTPUT_PATH, img_name))

        else:
            empty_mask = Image.new('L', image.size, 0)
            empty_mask.save(os.path.join(OUTPUT_PATH, img_name))

Processing images: 100%|██████████| 74283/74283 [18:14:20<00:00,  1.13it/s]  
