In [1]:
import os
import sys
import torch
import torchvision
import json

import cv2

import numpy as np
import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageTk

In [2]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

In [3]:
torch.cuda.is_available()

True

In [4]:
HOME = "../"
#gros : sam_vit_h_4b8939.pth  petit : sam_vit_b_01ec64.pth
CHECKPOINT_PATH = os.path.join(HOME, "model", "sam_vit_b_01ec64.pth")  
print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))
MODEL_TYPE = "vit_b"

../model/sam_vit_b_01ec64.pth ; exist: True


In [128]:
class ImageLoader:
    def __init__(self, master):
        self.master = master
        self.master.geometry("1024x720")
        self.master.title("Image Loader")
        
        self.image = None
        self.segmented_image = None
        self.loaded_image = None
        
        self.my_dict = {(0,0,0):"RIEN",
                        (255,0,0):"item1",
                        (0,255,0):"item2",
                        (0,0,255):"item3"}
        
        self.display_str = ""
        for key, value in self.my_dict.items():
            self.display_str += f"{key}: {value}\n"
        
        
        # create a frame to hold the buttons
        button_frame = tk.Frame(self.master)
        button_frame.pack(padx=10, pady=10)
        
        # create a button to select an image file
        self.select_button = tk.Button(button_frame, text="Select Image", command=self.select_image)
        self.select_button.pack(side='left', padx=10, pady=10)
        
        # create a button to create new dict
        self.create_dict_button = tk.Button(button_frame, text="Create Dict", command=self.create_dict)
        self.create_dict_button.pack(side='left', padx=10, pady=10)
        
        # create a button to reset the dict to its initial state
        self.reset_dict_button = tk.Button(button_frame, text="Reset Dict", command=self.reset_dict)
        self.reset_dict_button.pack(side='left', padx=10, pady=10)
        
        # create a button to perform segmentation
        self.segment_button = tk.Button(button_frame, text="Perform Segmentation", command=self.segment_image)
        self.segment_button.pack(side='left', padx=10, pady=10)
        
        
        # create a frame to hold the dict label
        self.dict_frame = tk.Frame(self.master, borderwidth=2, relief="ridge")
        self.dict_frame.pack(side="left", padx=10, pady=10)
        
        # create a label to display the dict
        self.dict_label = tk.Label(self.dict_frame, text=self.display_str, font=("Arial", 12), fg="blue")
        self.dict_label.pack(padx=10, pady=10)


        # create a label to display the selected image
        self.image_label = tk.Label(self.master)
        self.image_label.pack(padx=10, pady=10)
    
    
    
    def update_display(self):
        # Clear the current display
        self.dict_label.config(text="")

        # Rebuild the display string
        self.display_str = ""
        for key, value in self.my_dict.items():
            self.display_str += f"{key}: {value}\n"

        # Update the label with the new display string
        self.dict_label.config(text=self.display_str)
        self.dict_label.pack()   
           
    def reset_dict(self):
        self.my_dict = {(0,0,0):"RIEN",
                        (255,0,0):"item1",
                        (0,255,0):"item2",
                        (0,0,255):"item3"}
        self.update_display()   
        
        
        
    def create_dict(self):
        # Create a dialog box to prompt the user for the RGB value and word
        dialog = tk.Toplevel(self.master)
        dialog.title("Add Color Word")
        
        tk.Label(dialog, text="Enter the RGB value:").grid(row=0, column=0)
        rgb_entry = tk.Entry(dialog)
        rgb_entry.grid(row=0, column=1)
        tk.Label(dialog, text="Enter the word:").grid(row=1, column=0)
        word_entry = tk.Entry(dialog)
        word_entry.grid(row=1, column=1)

        # When the user clicks OK, add the color and word to the dictionary
        def ok():
            rgb = tuple(int(x.strip()) for x in rgb_entry.get().split(','))
            word = word_entry.get()
            self.my_dict[rgb] = word
            dialog.destroy()
            
            # update the label that displays the color dictionary
            self.update_display()
            
        tk.Button(dialog, text="OK", command=ok).grid(row=2, column=1)
        
        

    def select_image(self):
        # open a file dialog to select an image file
        file_path = filedialog.askopenfilename(filetypes=[("Image files", "*.*")])
        print("Selected file:", file_path)
        self.file_path = file_path
        
        # load and resize the selected image file using openCV
        self.image = cv2.imread(self.file_path)
        self.image = cv2.resize(self.image, (512,512))
        
        # load and resize the selected image file using PIL
        image = Image.open(file_path)
        image = image.resize((512, 512))
    
        # convert the image to a Tkinter-compatible format
        tk_image = ImageTk.PhotoImage(image)
        
        # update the image label to display the selected image
        self.image_label.configure(image=tk_image)
        self.image_label.image = tk_image
        
        # save the image as an attribute of the ImageLoader instance
        self.loaded_image = image
        
        
        
    def segment_image(self):
        # check if an image is loaded
        if self.loaded_image is None:
            print("Please select an image first.")
            return
         
        # create a SAM model and mask generator
        sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH)
        mask_generator = SamAutomaticMaskGenerator(sam)
        
        # generate masks for the loaded image
        sam_result = mask_generator.generate(self.image)
        print(f"There is {len(sam_result)} masks detected in this image.")
        print(sam_result[0].keys())
        print(sam_result[0]['segmentation'].shape)

        
        # extract the segmentation masks
        masks = [result["segmentation"] for result in sam_result]
        
        # combine the masks into a single segmentation mask
        segmentation = np.zeros((512,512), dtype=np.uint8)
        for mask in masks:
            segmentation[mask] = 255

        # create a PIL Image from the segmentation mask
        segmentation_image = Image.fromarray(segmentation)
        
        # convert the image to a Tkinter-compatible format
        tk_segmentation = ImageTk.PhotoImage(segmentation_image)
        
        # update the image label to display the segmented image
        self.image_label.configure(image=tk_segmentation)
        self.image_label.image = tk_segmentation
        
        # save a reference to the segmentation result
        self.sam_result = sam_result

In [130]:
if __name__ == "__main__":
    # create the main window and start the event loop
    root = tk.Tk()
    app = ImageLoader(root)
    root.mainloop()

Selected file: ()


Exception in Tkinter callback
Traceback (most recent call last):
  File "/home/thmsguerin/anaconda3/envs/labelWithSAM/lib/python3.11/tkinter/__init__.py", line 1948, in __call__
    return self.func(*args)
           ^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_6099/164434093.py", line 112, in select_image
    self.image = cv2.imread(self.file_path)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Can't convert object to 'str' for 'filename'
