Segmenting Cars using Meta AI's Segment Anyting Model (SAM)

First, we are trying to trying to bound the car object in the image. We are achieving this using yolov4 model. If the car is detected than we are asking the SAM to segment only the object that is in the bounding box. If it can't detect the car, then it means that the image is a close shot of the car. And, in case of that kind of images the segmentation is not a problem. So, we are giving those kind of images directly to the SAM without any bounding boxes. After all of these processes, we are storing those images to a file in our google drive.


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!nvidia-smi

In [None]:
import os
HOME = os.getcwd()
print("HOME:", HOME)

Install Segment Anything Model (SAM) and other dependencies


In [None]:
!pip install -q 'git+https://github.com/facebookresearch/segment-anything.git'

In [None]:
!pip install -q jupyter_bbox_widget roboflow dataclasses-json supervision==0.23.0

Download SAM weights

In [None]:
!mkdir -p {HOME}/weights
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P {HOME}/weights

In [None]:
CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))

In [None]:
import cv2
import os
import numpy as np
import supervision as sv
import torch
from tqdm import tqdm
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

In [None]:
pwd

In [None]:
cd drive

In [None]:
cd MyDrive

In [None]:
cd dnn_model

In [None]:
pwd

In [None]:
#copy the above directory
#paste it into cv2.dnn.readNet() below

In [None]:
class VehicleDetector:

    def __init__(self):
        # Load Network
        net = cv2.dnn.readNet("/content/drive/.shortcut-targets-by-id/15CA35YuW2XOA8fXZRHR5CS8xCcUzCXo3/dnn_model/yolov4.weights", "/content/drive/.shortcut-targets-by-id/15CA35YuW2XOA8fXZRHR5CS8xCcUzCXo3/dnn_model/yolov4.cfg")
        self.model = cv2.dnn_DetectionModel(net)
        self.model.setInputParams(size=(832, 832), scale=1 / 255)


        # Allow classes containing Vehicles only
        self.classes_allowed = [2, 3, 5, 6, 7]


    def detect_vehicles(self, img):
        # Detect Objects
        vehicles_boxes = []
        scoress=[]
        class_ids, scores, boxes = self.model.detect(img, nmsThreshold=0.4)

        # if there is no prediction (len(scores)==0), then just return the empty list
        # else just return the box that has the highest score

        if len(scores) == 0:

          return []

        else:

          for class_id, score, box in zip(class_ids, scores, boxes):
              #if score < 0.5:
                  # Skip detection with low confidence
                  #continue

              if class_id in self.classes_allowed:
                  vehicles_boxes.append(box)
                  scoress.append(score)

        max_index=np.argmax(scoress)
        vehicle_box=vehicles_boxes[max_index]

        return vehicle_box



In [None]:
# Instantiating the car detection model
vd=VehicleDetector()

In [None]:
#LOADING THE SAM MODEL
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
mask_generator = SamAutomaticMaskGenerator(sam)

In [None]:
mask_predictor = SamPredictor(sam)

In [None]:
HASARLI_dir="/content/drive/MyDrive/DS542_FINAL_PROJECT_DATASET/HASARLI_BLURRED"

HASARLI=[file for file in os.listdir(HASARLI_dir) if file.lower().endswith(('.png', '.jpg', '.jpeg'))]

In [None]:


for image_path in tqdm(HASARLI):

  IMAGE_PATH="/content/drive/MyDrive/DS542_FINAL_PROJECT_DATASET/HASARLI_BLURRED/"+image_path
  #IMAGE_PATH="/content/drive/MyDrive/DENEME/HASARLI/"+image_path

  image_bgr = cv2.imread(IMAGE_PATH)
  #image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

  vehicle_boxes=vd.detect_vehicles(image_bgr)

  if len(vehicle_boxes)==0:
    #direct image seg
    sam_result = mask_generator.generate(image_bgr)
    biggest_seg=sorted(sam_result, key=lambda x: x['area'], reverse=True)[0]
    mask=biggest_seg['segmentation']

  else:
    x,y,w,h=vehicle_boxes
    box=np.array([x,y,x+w, y+h])

    mask_predictor.set_image(image_bgr)

    masks, scores, logits = mask_predictor.predict(
        box=box,
        multimask_output=True
    )

    detections = sv.Detections(
        xyxy=sv.mask_to_xyxy(masks=masks),
        mask=masks
    )

    detections = detections[detections.area == np.max(detections.area)]

    mask=detections.mask[0]

  segmented_image = image_bgr.copy()
  segmented_image[~mask] = (0, 0, 0)

  output_path = "/content/drive/MyDrive/segment_results/HASARLI_segmented/" + image_path
  cv2.imwrite(output_path, segmented_image)




In [None]:
HASARSIZ_dir="/content/drive/MyDrive/DS542_FINAL_PROJECT_DATASET/HASARSIZ_BLURRED"

HASARSIZ=[file for file in os.listdir(HASARSIZ_dir) if file.lower().endswith(('.png', '.jpg', '.jpeg'))]

In [None]:


for image_path in tqdm(HASARSIZ):

  IMAGE_PATH="/content/drive/MyDrive/DS542_FINAL_PROJECT_DATASET/HASARSIZ_BLURRED/"+image_path
  #IMAGE_PATH="/content/drive/MyDrive/DENEME/HASARLI/"+image_path

  image_bgr = cv2.imread(IMAGE_PATH)
  #image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

  vehicle_boxes=vd.detect_vehicles(image_bgr)

  if len(vehicle_boxes)==0:
    #direct image seg
    sam_result = mask_generator.generate(image_bgr)
    biggest_seg=sorted(sam_result, key=lambda x: x['area'], reverse=True)[0]
    mask=biggest_seg['segmentation']

  else:
    x,y,w,h=vehicle_boxes
    box=np.array([x,y,x+w, y+h])

    mask_predictor.set_image(image_bgr)

    masks, scores, logits = mask_predictor.predict(
        box=box,
        multimask_output=True
    )

    detections = sv.Detections(
        xyxy=sv.mask_to_xyxy(masks=masks),
        mask=masks
    )

    detections = detections[detections.area == np.max(detections.area)]

    mask=detections.mask[0]

  segmented_image = image_bgr.copy()
  segmented_image[~mask] = (0, 0, 0)

  output_path = "/content/drive/MyDrive/segment_results/HASARSIZ_segmented/" + image_path
  cv2.imwrite(output_path, segmented_image)


