# Setup

In [None]:
# Install libraries

!pip install -q git+https://github.com/huggingface/transformers.git
!pip install -q datasets
!pip install -q monai

In [None]:
# Import libraries

import os

import numpy as np

import pandas as pd

from collections import Counter

from PIL import Image, ImageOps

import matplotlib.pyplot as plt

from datasets import Dataset, load_dataset, load_from_disk

from tqdm.notebook import tqdm

from google.colab import drive

import monai

import torch

from torch.optim import Adam

from torch.nn.functional import threshold, normalize

from transformers import SamProcessor, SamModel

from torch.utils.data import DataLoader, Dataset

from tqdm import tqdm

from statistics import mean

In [None]:
# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"

Mounted at /content/drive


In [None]:
# Functions

def show_mask(mask, ax, random_color=False):

    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)

    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])

    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def load_image(path):

    return Image.open(path).convert("RGB")

def load_mask(path):

    return Image.open(path).convert("L")

def transform(example):

    # Images
    image = load_image(example['image_path'])
    mask = load_mask(example['mask_path'])

    # Keep images as PIL images
    example['image'] = image
    example['mask'] = mask

    return example

def load_salient_object_dataset(image_dir, mask_dir):

    # List all images and masks
    images = sorted([os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.jpg')])
    masks = sorted([os.path.join(mask_dir, file) for file in os.listdir(mask_dir) if file.endswith('.png')])

    # Create a DataFrame
    df = pd.DataFrame({'image_path': images, 'mask_path': masks})

    # Create a Dataset
    dataset = Dataset.from_pandas(df)

    # Apply transformations
    dataset = dataset.map(transform)

    return dataset

def pad_image_to_square(image, desired_size = 1000):

    # Determine the number of channels (3 for RGB, 1 for L)
    if image.mode == 'RGB':
        fill_color = (0, 0, 0)  # Black for RGB

    #
    elif image.mode == 'L':
        fill_color = 0  # Black for grayscale

    # Calculate padding size
    old_size = image.size
    delta_w = desired_size - old_size[0]
    delta_h = desired_size - old_size[1]
    padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))

    # Pad and return
    new_im = ImageOps.expand(image, padding, fill = fill_color)

    return new_im

def resize_image(image, new_size=(256, 256)):

    # Resize the image
    resized_im = image.resize(new_size, Image.ANTIALIAS)

    return resized_im

def convert_grayscale(image):

    # Convert the image to grayscale
    grayscale_image = image.convert('L')

    return grayscale_image

def convert_to_binary_mask_and_back(mask_pil, threshold=127):

    # Convert PIL Image to NumPy array
    mask_array = np.array(mask_pil)

    # Apply threshold to convert to binary mask
    binary_mask = (mask_array > threshold).astype(np.uint8)

    # Convert binary mask back to PIL Image
    binary_mask_pil = Image.fromarray(binary_mask)  # Multiply by 255 to get back to 0-255 range

    return binary_mask_pil

def resize_transform(example):

    # Resize the image and the mask
    example['image'] = resize_image(example['image'])
    example['mask'] = resize_image(example['mask'])

    return example

def pad_transform(example):

    # Pad the image and the mask
    example['image'] = pad_image_to_square(example['image'])
    example['mask'] = pad_image_to_square(example['mask'])

    return example

def grayscale_transform(example):

    example['image'] = convert_grayscale(example['image'])
    example['mask'] = example['mask']

    return example

def binary_mask_transform(example):

    example['image'] = example['image']
    example['mask'] = convert_to_binary_mask_and_back(example['mask'])

    return example

class SAMDataset(Dataset):

    def __init__(self, dataset, processor):

        self.dataset = dataset
        self.processor = processor

    def __len__(self):

        return len(self.dataset)

    def __getitem__(self, idx):

        item = self.dataset[idx]
        image = item["image"]
        ground_truth_mask = np.array(item["mask"])

        # get bounding box prompt
        prompt = get_bounding_box(ground_truth_mask)

        # prepare image and prompt for the model
        inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")

        # remove batch dimension which the processor adds by default
        inputs = {k:v.squeeze(0) for k,v in inputs.items()}

        # add ground truth segmentation
        inputs["ground_truth_mask"] = ground_truth_mask

        return inputs

def get_bounding_box(ground_truth_map):

    # get bounding box from mask
    y_indices, x_indices = np.where(ground_truth_map > 0)
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)

    # add perturbation to bounding box coordinates
    H, W = ground_truth_map.shape
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))
    bbox = [x_min, y_min, x_max, y_max]

    return bbox


# Format data

In [None]:
# Load processor

processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

In [None]:
# Load train dataset

train_image_path = 'CarDD/CarDD_SOD/CarDD-TR/CarDD-TR-Image'
train_mask_path = 'CarDD/CarDD_SOD/CarDD-TR/CarDD-TR-Mask'
train_dataset = load_salient_object_dataset(train_image_path, train_mask_path) # Load images and masks
train_dataset = train_dataset.map(pad_transform)  # Pad to 1000,1000 square
train_dataset = train_dataset.map(resize_transform) # Resize to 256, 256
train_dataset = train_dataset.map(binary_mask_transform) # Convert non-binary masks to binary masks
train_dataset = SAMDataset(dataset = train_dataset, processor = processor) # Format using Transformers library

train_dataloader = DataLoader(train_dataset, batch_size = 8, shuffle = True)

In [None]:
# Load train dataset

test_image_path = 'CarDD/CarDD_SOD/CarDD-TE/CarDD-TE-Image'
test_mask_path = 'CarDD/CarDD_SOD/CarDD-TE/CarDD-TE-Mask'
test_dataset = load_salient_object_dataset(test_image_path, test_mask_path) # Load images and masks
test_dataset = test_dataset.map(pad_transform)  # Pad to 1000,1000 square
test_dataset = test_dataset.map(resize_transform) # Resize to 256, 256
test_dataset = test_dataset.map(binary_mask_transform) # Convert non-binary masks to binary masks
test_dataset = SAMDataset(dataset = test_dataset, processor = processor) # Format using Transformers library

test_dataloader = DataLoader(test_dataset, batch_size = 8, shuffle = True)

# Model

In [None]:
# Set up

# Load model
model = SamModel.from_pretrained("facebook/sam-vit-huge")
for name, param in model.named_parameters():

    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):

        param.requires_grad_(False)

# Optimizer
optimizer = Adam(model.mask_decoder.parameters(), lr = 1e-5, weight_decay = 0)

# Loss function
seg_loss = monai.losses.DiceCELoss(sigmoid = True, squared_pred = True, reduction = 'mean')

In [None]:
num_epochs = 50

model.to(device)
model.train()
for epoch in range(num_epochs):

    epoch_losses = []
    for batch in tqdm(train_dataloader):

        # forward pass
        outputs = model(pixel_values=batch["pixel_values"].to(device), input_boxes=batch["input_boxes"].to(device), multimask_output=False)

        # compute loss
        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch["ground_truth_mask"].float().to(device)
        loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

        # backward pass (compute gradients of parameters w.r.t. loss)
        optimizer.zero_grad()
        loss.backward()

        # optimize
        optimizer.step()
        epoch_losses.append(loss.item())

    # After training, switch to evaluation mode for validation
    model.eval()
    with torch.no_grad():  # No need to track gradients during validation

        val_losses = []

        for batch in tqdm(test_dataloader):

            # forward pass
            val_outputs = model(pixel_values = batch["pixel_values"].to(device), input_boxes = batch["input_boxes"].to(device), multimask_output = False)

            # compute loss
            val_predicted_masks = val_outputs.pred_masks.squeeze(1)
            ground_truth_masks = batch["ground_truth_mask"].float().to(device)
            val_loss = seg_loss(val_predicted_masks, ground_truth_masks.unsqueeze(1))

            val_losses.append(val_loss.item())

    model.train()

    # Calculate and print the average losses for this epoch
    train_loss = mean(epoch_losses)
    val_loss = mean(val_losses)

    print(f'EPOCH: {epoch}')
    print(f'Training Mean loss: {train_loss}')
    print(f'Validation Mean loss: {val_loss}')