In [1]:
import cv2
import tkinter as tk
import xml.etree.ElementTree as ET
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import supervision as sv
import torch
import io
from PIL import Image, ImageTk
import tkinter.filedialog as filedialog
import uuid



DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"
CHECKPOINT_PATH="sam_vit_h_4b8939.pth"

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH)

In [9]:

photo = None

class App:
    def __init__(self, root):
        self.root = root
        self.root.title("Segmentación de objetos")
        self.photo = None
        self.image = None
        self.annotated_image = None
        self.objects = {}

        self.root.geometry("800x600")

        self.create_widgets()

    def create_widgets(self):

        open_button = tk.Button(self.root, text="Abrir Imagen", command=self.load_image)
        open_button.pack()


        save_button = tk.Button(self.root, text="Guardar Coordenadas", command=self.save_coordinates)
        save_button.pack()


        self.label = tk.Label(self.root)
        self.label.pack()


        self.label.bind("<Button-1>", self.handle_click)

    def load_image(self):
        file_path = filedialog.askopenfilename()
        if file_path:

            self.image = cv2.imread(file_path)


            sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH)
            mask_generator = SamAutomaticMaskGenerator(sam)

            result = mask_generator.generate(self.image)
            mask_annotator = sv.MaskAnnotator()
            detections = sv.Detections.from_sam(result)
            self.annotated_image = mask_annotator.annotate(self.image.copy(), detections)

            img = Image.fromarray(cv2.cvtColor(self.annotated_image, cv2.COLOR_BGR2RGB))
            self.photo = ImageTk.PhotoImage(img)


            if self.photo:
                self.label.config(image=self.photo)
            else:
                print("Error: La imagen no se ha cargado correctamente.")

    def handle_click(self, event):
        x, y = event.x, event.y
        print("Coordenadas del clic:", x, y)
        color = self.get_pixel_color(x, y)


        connected_coordinates = self.find_connected_coordinates(x, y, color)


        object_id = str(uuid.uuid4())


        self.objects[object_id] = connected_coordinates


        for coord in connected_coordinates:
            cv2.circle(self.annotated_image, (coord[0], coord[1]), 3, (0, 255, 0), -1)


        img = Image.fromarray(cv2.cvtColor(self.annotated_image, cv2.COLOR_BGR2RGB))
        self.photo = ImageTk.PhotoImage(img)
        self.label.config(image=self.photo)

    def get_pixel_color(self, x, y):
        b, g, r = self.image[y, x]
        return (r, g, b)

    def find_connected_coordinates(self, x, y, target_color):
        color_threshold = 20


        connected_coordinates = set()
        connected_coordinates.add((x, y))


        to_explore = [(x, y)]

        while to_explore:
            current_x, current_y = to_explore.pop()


            neighbors = [(current_x - 1, current_y),
                        (current_x + 1, current_y),
                        (current_x, current_y - 1),
                        (current_x, current_y + 1)]

            for neighbor_x, neighbor_y in neighbors:
                if (0 <= neighbor_x < self.image.shape[1] and
                    0 <= neighbor_y < self.image.shape[0] and
                    self.get_pixel_color(neighbor_x, neighbor_y) == target_color and
                    (neighbor_x, neighbor_y) not in connected_coordinates):

                    connected_coordinates.add((neighbor_x, neighbor_y))
                    to_explore.append((neighbor_x, neighbor_y))

        return list(connected_coordinates)

    def save_coordinates(self):
        xml_file_path = "coordinates.xml"
        root = ET.Element("coordinates")

        for object_id, connected_coordinates in self.objects.items():
            object_elem = ET.SubElement(root, "object")
            object_elem.set("id", object_id)

            for coord in connected_coordinates:
                point = ET.SubElement(object_elem, "point")
                x_elem = ET.SubElement(point, "x")
                x_elem.text = str(coord[0])
                y_elem = ET.SubElement(point, "y")
                y_elem.text = str(coord[1])

        tree = ET.ElementTree(root)
        tree.write(xml_file_path)
        print(f"Coordenadas guardadas en {xml_file_path}")

if __name__ == "__main__":
    root = tk.Tk()
    app = App(root)
    root.mainloop()

Coordenadas del clic: 125 58
Coordenadas guardadas en coordinates.xml
Coordenadas del clic: 120 104
Coordenadas guardadas en coordinates.xml
Coordenadas del clic: 91 128
