In [None]:
# Installing Segment Anything to local machine

# pip install 'git+https://github.com/facebookresearch/segment-anything.git'

In [None]:
# Downloading model weights for segmentation

# !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth

In [None]:
# Essential imports

import os
import numpy as np
from __future__ import annotations
from pathlib import Path
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import cv2
import torch
from tqdm import tqdm

from matplotlib import pyplot as plt

In [None]:
# Building the Segmentation pipeline

# config
in_dir = '/content/input_images'
out_dir = '/content/segment_masks'
sam_model = "vit_l"
sam_check = "/content/sam_vit_l_0b3195.pth"
device = "cuda"
transparency = 0.3
max_masks = 300

# list of random colors
colors = []
for i in range(max_masks):
    colors.append(np.random.random((3)))

def draw_segmentation(anns):
    if len(anns) == 0:
        return
    h, w = anns[0]['segmentation'].shape
    image = np.zeros((h, w, 3), dtype=np.float64)
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    no_masks = min(len(sorted_anns), max_masks)
    for i in range(no_masks):
        # true/false segmentation
        seg = sorted_anns[i]['segmentation']

        # set segmentation to a random color
        image[seg] = colors[i]
    return image

def process_image(img_path, out_path, mask_generator):
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # mask generator wants the default uint8 image
    masks = mask_generator.generate(image)

    # convert to float64
    image = image.astype(np.float64) / 255
    seg = draw_segmentation(masks)

    # add segmentation image on top of original image
    image = seg

    # convert back to uint8 for display/save
    image = (255 * image).astype(np.uint8)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    # cv2.imshow("my img", image)
    # cv2.waitKey(-1)
    cv2.imwrite(out_path, image)

if __name__ == "__main__":
    # make sure output dir exists
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    # load SAM model + create mask generator
    sam = sam_model_registry[sam_model](checkpoint=sam_check)
    sam.to(device=device)
    mask_generator = SamAutomaticMaskGenerator(sam)

    # process input directory
    for img in tqdm(os.listdir(in_dir)):

        # change extension of output image to .png
        out_img = Path(img).stem + ".png"
        out_img = os.path.join(out_dir, out_img)

        # if we can read/decode this file as an image
        in_img = os.path.join(in_dir, img)
        if cv2.haveImageReader(in_img):
            process_image(in_img, out_img, mask_generator)

In [None]:
# Function for generating annotation (segment) masks on sample image

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

In [None]:
# Initialising the SAM model

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_TYPE = "vit_l"

sam = sam_model_registry[MODEL_TYPE](checkpoint="/content/sam_vit_l_0b3195.pth")
sam.to(device=DEVICE)

In [None]:
# Segmenting sample images from the shipwreck dataset
# This is applied to images before and after applying CLAHE

mask_gen = SamAutomaticMaskGenerator(sam)

image_bgr = cv2.imread('/content/test/c1754s413.jpg')
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
result = mask_gen.generate(image_rgb)

In [None]:
# Displaying the original image and segmented image for comparison

fig = plt.figure(figsize = (50,50))
original = fig.add_subplot(2,2,1)
original.imshow(image_bgr)
plt.axis("off")
masked = fig.add_subplot(2,2,2)
masked.imshow(image_bgr)
show_anns(result)
plt.axis("off")
plt.show()