In [1]:
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from segment_anything import SamPredictor, sam_model_registry

In [9]:
class InteractiveSegmentation:
    def __init__(self, sam_checkpoint='../models/sam_vit_h_4b8939.pth', model_type='vit_h'):
        self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
        self.sam.to('cuda' if torch.cuda.is_available() else 'cpu')
        self.predictor = SamPredictor(self.sam)
    
    def select_roi(self, image):
        roi_image = image.copy()
        cv2.namedWindow('Select Object')
        self.drawing = False
        self.roi_coords = []
        
        def mouse_callback(event, x, y, flags, param):
            if event == cv2.EVENT_LBUTTONDOWN:
                self.drawing = True
                self.roi_coords = [(x, y)]
            
            elif event == cv2.EVENT_MOUSEMOVE:
                if self.drawing:
                    temp_image = roi_image.copy()
                    cv2.rectangle(temp_image, self.roi_coords[0], (x, y), (0, 255, 0), 2)
                    cv2.imshow('Select Object', temp_image)
            
            elif event == cv2.EVENT_LBUTTONUP:
                self.drawing = False
                self.roi_coords.append((x, y))
                cv2.rectangle(roi_image, self.roi_coords[0], self.roi_coords[1], (0, 255, 0), 2)
                cv2.imshow('Select Object', roi_image)
        
        cv2.setMouseCallback('Select Object', mouse_callback)
        cv2.imshow('Select Object', roi_image)

        while True:
            key = cv2.waitKey(1) & 0xFF
            if key == 13:  # Enter key
                break
        
        cv2.destroyAllWindows()
        
        # Convert coordinates to bounding box
        x1, y1 = min(self.roi_coords[0][0], self.roi_coords[1][0]), min(self.roi_coords[0][1], self.roi_coords[1][1])
        x2, y2 = max(self.roi_coords[0][0], self.roi_coords[1][0]), max(self.roi_coords[0][1], self.roi_coords[1][1])
        
        return (x1, y1, x2, y2)
    
    def segment_object(self, image, bbox):

        self.predictor.set_image(image)
        input_box = np.array(bbox)[None, :]

        masks, _, _ = self.predictor.predict(
            point_coords=None,
            box=input_box,
            multimask_output=False
        )
        
        return masks[0]
    
    def visualize_segmentation(self, image, mask):

        overlay = image.copy()
        overlay[mask] = [0, 255, 0]  # Green mask
        alpha = 0.5
        result = cv2.addWeighted(image, 1 - alpha, overlay, alpha, 0)
        
        cv2.imshow('Segmentation Result', result)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    
    def remove_and_fill_object(self, image, mask):
        result = image.copy()
        result[mask] = [255, 255, 255]
        
        return result
    
    def process_image(self, image_path, remove_object=False):
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        bbox = self.select_roi(image)
        mask = self.segment_object(image, bbox)
        self.visualize_segmentation(image, mask)

        if remove_object:
            print('Removing object...')
            result = self.remove_and_fill_object(image, mask)
            
            # result
            cv2.imshow('Object Removed', cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
            cv2.waitKey(0)
            cv2.destroyAllWindows()
            
            return result
        
        return image

In [13]:
segmentor = InteractiveSegmentation()
segmentor.process_image('../data/test_images/landscape.jpg', remove_object=True)

array([[[145, 156, 186],
        [145, 156, 186],
        [146, 157, 187],
        ...,
        [113, 138, 179],
        [113, 138, 179],
        [114, 139, 179]],

       [[144, 155, 185],
        [144, 155, 185],
        [145, 156, 186],
        ...,
        [113, 138, 179],
        [113, 138, 179],
        [114, 139, 179]],

       [[143, 154, 184],
        [143, 154, 184],
        [144, 155, 185],
        ...,
        [114, 139, 180],
        [114, 139, 180],
        [113, 138, 179]],

       ...,

       [[ 78,  90,  52],
        [ 79,  91,  53],
        [ 78,  90,  52],
        ...,
        [ 94,  94,  56],
        [ 96,  97,  57],
        [ 90,  91,  51]],

       [[ 82,  94,  56],
        [ 82,  94,  56],
        [ 80,  92,  54],
        ...,
        [ 93,  93,  55],
        [ 95,  96,  56],
        [ 87,  88,  48]],

       [[ 76,  87,  47],
        [ 78,  89,  49],
        [ 79,  90,  50],
        ...,
        [ 92,  93,  53],
        [ 89,  90,  50],
        [ 85,  86,  46]]