In [1]:
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from segment_anything import SamPredictor, sam_model_registry
import os

In [2]:
class InteractiveSegmentation:
    def __init__(self, sam_checkpoint='../models/sam_vit_h_4b8939.pth', model_type='vit_h', output_dir='../data/saved_images'):
        self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
        self.sam.to('cuda' if torch.cuda.is_available() else 'cpu')
        self.predictor = SamPredictor(self.sam)
        self.output_dir = output_dir
        os.makedirs(self.output_dir, exist_ok=True)
    
    def save_image(self, image, filename):
        image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        full_path = os.path.join(self.output_dir, filename)
        cv2.imwrite(full_path, image_bgr)
        
        return full_path
    
    def select_roi(self, image):
        roi_image = image.copy()
        cv2.namedWindow('Select Object')
        self.drawing = False
        self.roi_coords = []
        
        def mouse_callback(event, x, y, flags, param):
            if event == cv2.EVENT_LBUTTONDOWN:
                self.drawing = True
                self.roi_coords = [(x, y)]
            
            elif event == cv2.EVENT_MOUSEMOVE:
                if self.drawing:
                    temp_image = roi_image.copy()
                    cv2.rectangle(temp_image, self.roi_coords[0], (x, y), (0, 255, 0), 2)
                    cv2.imshow('Select Object', temp_image)
            
            elif event == cv2.EVENT_LBUTTONUP:
                self.drawing = False
                self.roi_coords.append((x, y))
                cv2.rectangle(roi_image, self.roi_coords[0], self.roi_coords[1], (0, 255, 0), 2)
                cv2.imshow('Select Object', roi_image)
        
        cv2.setMouseCallback('Select Object', mouse_callback)
        cv2.imshow('Select Object', roi_image)
        
        # Wait for user to press 'Enter'
        while True:
            key = cv2.waitKey(1) & 0xFF
            if key == 13:  # Enter key
                break
        
        cv2.destroyAllWindows()
        
        # Convert coordinates to bounding box
        x1, y1 = min(self.roi_coords[0][0], self.roi_coords[1][0]), min(self.roi_coords[0][1], self.roi_coords[1][1])
        x2, y2 = max(self.roi_coords[0][0], self.roi_coords[1][0]), max(self.roi_coords[0][1], self.roi_coords[1][1])
        
        return (x1, y1, x2, y2)
    
    def segment_object(self, image, bbox):
        self.predictor.set_image(image)
        input_box = np.array(bbox)[None, :]

        masks, _, _ = self.predictor.predict(
            point_coords=None,
            box=input_box,
            multimask_output=False
        )
        
        return masks[0]
    
    def visualize_segmentation(self, image, mask):

        overlay = image.copy()
        overlay[mask] = [0, 255, 0]  # Green mask
        alpha = 0.5
        result = cv2.addWeighted(image, 1 - alpha, overlay, alpha, 0)
        
        cv2.imshow('Segmentation Result', result)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        
        return result
    
    def remove_and_fill_object(self, image, mask):
        result = image.copy()
        result[mask] = [255, 255, 255]
        
        return result
    
    def process_image(self, image_path, remove_object=False):

        filename = os.path.splitext(os.path.basename(image_path))[0]
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        bbox = self.select_roi(image)
        mask = self.segment_object(image, bbox)

        segmentation_vis = self.visualize_segmentation(image, mask)
        seg_save_path = self.save_image(segmentation_vis, f'{filename}_segmentation.png')
        print(f"Segmentation visualization saved to: {seg_save_path}")

        if remove_object:
            result = self.remove_and_fill_object(image, mask)
            
            # result
            cv2.imshow('Object Removed', cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
            cv2.waitKey(0)
            cv2.destroyAllWindows()
            
            # Save result
            remove_save_path = self.save_image(result, f'{filename}_removed.png')
            print(f"Image with object removed saved to: {remove_save_path}")
            
            return image, result
        
        return image, None

In [8]:
segmentor = InteractiveSegmentation()
segmentor.process_image('../data/test_images/tree.jpg', remove_object=True)

Segmentation visualization saved to: ../data/saved_images\tree_segmentation.png
Image with object removed saved to: ../data/saved_images\tree_removed.png


(array([[[  4,  44,  70],
         [  4,  44,  70],
         [  4,  44,  70],
         ...,
         [ 37,  86, 129],
         [ 36,  85, 128],
         [ 33,  81, 129]],
 
        [[  4,  44,  70],
         [  4,  44,  70],
         [  4,  44,  70],
         ...,
         [ 37,  86, 129],
         [ 36,  85, 128],
         [ 34,  82, 130]],
 
        [[  5,  45,  71],
         [  5,  45,  71],
         [  5,  45,  71],
         ...,
         [ 38,  87, 130],
         [ 37,  86, 129],
         [ 35,  83, 131]],
 
        ...,
 
        [[ 78,  31,  21],
         [ 85,  41,  30],
         [ 88,  44,  33],
         ...,
         [ 80,  64,  41],
         [ 63,  48,  25],
         [ 59,  47,  23]],
 
        [[ 84,  36,  24],
         [ 75,  29,  16],
         [ 84,  40,  27],
         ...,
         [ 84,  68,  43],
         [ 71,  55,  30],
         [ 75,  61,  34]],
 
        [[ 97,  48,  31],
         [ 93,  46,  30],
         [103,  57,  42],
         ...,
         [ 76,  62,  36],
  