In [None]:
import time
import json

import gc
import torch
import numpy as np
import cv2
from torch.nn import functional as F
from os import listdir, makedirs, getcwd
from os.path import join, exists, isfile, isdir, basename
import os
from ipywidgets import interact, widgets, FileUpload
from IPython.display import display, clear_output
from matplotlib import patches as patches
from matplotlib import pyplot as plt
from copy import deepcopy

def show_mask(mask, ax, random_color=False, alpha=0.95):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([alpha])], axis=0)
    else:
        color = np.array([251/255, 252/255, 30/255, alpha])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


class BboxPromptDemo:
    def __init__(self, model, directory_path):
        self.model = model
        self.model.eval()
        self.directory_path = directory_path
        self.image_files = self.list_images(self.directory_path)
        self.current_image_index = 0
        self.image = None
        self.image_path = None
        self.image_embeddings = None
        self.img_size = None
        self.gt = None
        self.currently_selecting = False
        self.x0, self.y0, self.x1, self.y1 = 0., 0., 0., 0.
        self.rect = None
        self.fig, self.axes = None, None
        self.segs = []
        self.timestamps = {
            "first_click": None,
            "second_click": None,
            "third_click": None,
            "first_clear_clicked": None,
            "second_clear_clicked": None,
            "first_save_clicked": None
        }
        self.clear_click_count = 0

    def list_images(self, directory_path):
        files = os.listdir(directory_path)
        return sorted([f for f in files if f.endswith(('.png', '.jpg', '.jpeg'))])
           
    def load_image(self):
        """Load an image from a file and display it."""
        if self.image_files:
            image_path = os.path.join(self.directory_path, self.image_files[self.current_image_index])
            self.image = plt.imread(image_path)
            self.show_image()
        else:
            print("No images found in the directory.")
    
    def show_image(self):
        """Show the image with interactive polygon drawing capabilities."""
        if self.image is not None:
            clear_output(wait=True)  # Clear previous outputs
            self.fig, self.ax = plt.subplots()
            self.ax.imshow(self.image)
            self.ax.axis('off')

            # Get the filename from the path and set it as the title of the plot
            filename = os.path.basename(self.image_files[self.current_image_index])
            self.ax.set_title(filename)  # Set the image filename as the title

            self.fig.canvas.mpl_connect('button_press_event', self.on_press)
            self.add_buttons()  # Add navigation and action buttons
        else:
            print("No image loaded.")
            
    def _show(self, fig_size=5, random_color=True, alpha=0.65):
        assert self.image is not None, "Please set image first."

        self.fig, self.axes = plt.subplots(1, 1, figsize=(fig_size, fig_size))
        self.fig.canvas.header_visible = False
        self.fig.canvas.footer_visible = False
        self.fig.canvas.toolbar_visible = False
        self.fig.canvas.resizable = False

        plt.tight_layout()
        self.axes.imshow(self.image)
        self.axes.axis('off')
    
    def add_buttons(self):
        """Add navigation and action buttons below the figure."""
        previous_button = widgets.Button(description="Previous")
        next_button = widgets.Button(description="Next")
        save_button = widgets.Button(description="Save")
        clear_button = widgets.Button(description="Clear")
        end_button = widgets.Button(description="End")
        previous_button.on_click(self.on_previous_clicked)
        next_button.on_click(self.on_next_clicked)
        save_button.on_click(self.on_save_clicked)
        clear_button.on_click(self.on_clear_clicked)
        end_button.on_click(self.on_end_clicked)
        button_box = widgets.HBox([previous_button, next_button, save_button, clear_button, end_button])
        display(button_box)

    def on_previous_clicked(self, b):
        """Go to the previous image."""
        if self.current_image_index > 0:
            self.current_image_index -= 1
            self.load_image()

    def on_next_clicked(self, b):
        """Go to the next image."""
        self.clear_count = 0
        if self.current_image_index < len(self.image_files) - 1:
            self.current_image_index += 1
            self.load_image()

    def __on_press(event):
        if event.inaxes == self.axes:
            if self.timestamps["first_click"] is None:
                self.timestamps["first_click"] = time.time()
            elif self.timestamps["second_click"] is None:
                self.timestamps["second_click"] = time.time()
            elif self.timestamps["third_click"] is None:
                self.timestamps["third_click"] = time.time()
        if event.inaxes == self.axes:
            self.x0 = float(event.xdata) 
            self.y0 = float(event.ydata)
            self.currently_selecting = True
            self.rect = plt.Rectangle(
                (self.x0, self.y0),
                1,1, linestyle="--",
                edgecolor="crimson", fill=False
            )
            self.axes.add_patch(self.rect)
            self.rect.set_visible(False)

    def __on_release(event):
        if event.inaxes == self.axes:
            if self.currently_selecting:
                self.x1 = float(event.xdata)
                self.y1 = float(event.ydata)
                self.fig.canvas.draw_idle()
                self.currently_selecting = False
                self.rect.set_visible(False)
                self.axes.patches[0].remove()
                x_min = min(self.x0, self.x1)
                x_max = max(self.x0, self.x1)
                y_min = min(self.y0, self.y1)
                y_max = max(self.y0, self.y1)
                bbox = np.array([x_min, y_min, x_max, y_max])
                with torch.no_grad():
                    seg = self._infer(bbox)
                    torch.cuda.empty_cache()
                show_mask(seg, self.axes, random_color=random_color, alpha=alpha)
                self.segs.append(deepcopy(seg))
                del seg
                self.rect = None
                gc.collect()

    def __on_motion(event):
        if event.inaxes == self.axes:
            if self.currently_selecting:
                self.x1 = float(event.xdata)
                self.y1 = float(event.ydata)
                #add rectangle for selection here
                self.rect.set_visible(True)
                xlim = np.sort([self.x0, self.x1])
                ylim = np.sort([self.y0, self.y1])
                self.rect.set_xy((xlim[0],ylim[0] ) )
                rect_width = np.diff(xlim)[0]
                self.rect.set_width(rect_width)
                rect_height = np.diff(ylim)[0]
                self.rect.set_height(rect_height)

    clear_button = widgets.Button(description="clear")
        
    def __on_clear_button_clicked(b):
        self.clear_click_count += 1
        if self.timestamps["first_clear_clicked"] is None:
            self.timestamps["first_clear_clicked"] = time.time()
        elif self.timestamps["second_clear_clicked"] is None:
            self.timestamps["second_clear_clicked"] = time.time()
        for i in range(len(self.axes.images)):
            self.axes.images[0].remove()
        self.axes.clear()
        self.axes.axis('off')
        self.axes.imshow(self.image)
        if len(self.axes.patches) > 0:
            self.axes.patches[0].remove()
        self.segs = []
        self.fig.canvas.draw_idle()
        clear_button.on_click(__on_clear_button_clicked)

    save_button = widgets.Button(description="save")
        
    def __on_save_button_clicked(b):
        self.timestamps["save_clicked"] = time.time()
        # Prepare data to save, including the clear click count
        data_to_save = {
        "timestamps": self.timestamps,
        "clear_click_count": self.clear_click_count
        }
        self.data_store[filename].append(data_to_save)
        print(f"Data collected for {filename}. Total entries: {len(self.data_store[filename])}")

    plt.savefig("seg_result.png", bbox_inches='tight', pad_inches=0)
    if len(self.segs) > 0:
        save_seg = np.zeros_like(self.segs[0])
        for i, seg in enumerate(self.segs, start=1):
            save_seg[seg > 0] = i
        cv2.imwrite("segs.png", save_seg)
        print(f"Segmentation result saved to {getcwd()}")

    display(clear_button)
    clear_button.on_click(__on_clear_button_clicked)

    self.fig.canvas.mpl_connect('button_press_event', __on_press)
    self.fig.canvas.mpl_connect('motion_notify_event', __on_motion)
    self.fig.canvas.mpl_connect('button_release_event', __on_release)

    plt.show()

    display(save_button)
    save_button.on_click(__on_save_button_clicked)

    def on_end_clicked(self, b):
        """Save all collected data to a JSON file when ending the session."""
        file_path = os.path.join(self.directory_path, "all_data.json")
        with open(file_path, 'w') as f:
            json.dump(self.data_store, f, indent=4)
        print(f"All data saved to {file_path}. Session ended.")
    
    def show(self, image_path, fig_size=5, random_color=True, alpha=0.65):
        self.set_image_path(image_path)
        self._show(fig_size=fig_size, random_color=random_color, alpha=alpha)

    def set_image_path(self, image_path):
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        self._set_image(image)
    
    def _set_image(self, image):
        self.image = image
        self.img_size = image.shape[:2]
        image_preprocess = self._preprocess_image(image)
        with torch.no_grad():
            self.image_embeddings = self.model.image_encoder(image_preprocess)

    def _preprocess_image(self, image):
        img_resize = cv2.resize(
            image,
            (1024, 1024),
            interpolation=cv2.INTER_CUBIC
        )
        # Resizing
        img_resize = (img_resize - img_resize.min()) / np.clip(img_resize.max() - img_resize.min(), a_min=1e-8, a_max=None) # normalize to [0, 1], (H, W, 3
        # convert the shape to (3, H, W)
        assert np.max(img_resize)<=1.0 and np.min(img_resize)>=0.0, 'image should be normalized to [0, 1]'
        img_tensor = torch.tensor(img_resize).float().permute(2, 0, 1).unsqueeze(0).to(self.model.device)

        return img_tensor
    
    @torch.no_grad()
    def _infer(self, bbox):
        ori_H, ori_W = self.img_size
        scale_to_1024 = 1024 / np.array([ori_W, ori_H, ori_W, ori_H])
        bbox_1024 = bbox * scale_to_1024
        bbox_torch = torch.as_tensor(bbox_1024, dtype=torch.float).unsqueeze(0).to(self.model.device)
        if len(bbox_torch.shape) == 2:
            bbox_torch = bbox_torch.unsqueeze(1)
    
        sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
            points=None,
            boxes=bbox_torch,
            masks=None,
        )
        low_res_logits, _ = self.model.mask_decoder(
            image_embeddings = self.image_embeddings, # (B, 256, 64, 64)
            image_pe = self.model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
            sparse_prompt_embeddings = sparse_embeddings, # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
            multimask_output=False,
            )

        low_res_pred = torch.sigmoid(low_res_logits)  # (1, 1, 256, 256)

        low_res_pred = F.interpolate(
            low_res_pred,
            size=self.img_size,
            mode="bilinear",
            align_corners=False,
        )  # (1, 1, gt.shape)
        low_res_pred = low_res_pred.squeeze().cpu().numpy()  # (256, 256)
        medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
        return medsam_seg

In [None]:
class MedSAMModel:
    def __init__(self):
        self.trained = False

    def train(self):
        """
        Simulate training the model.
        """
        print("Training model...")
        self.trained = True
        print("Model trained.")

    def eval(self):
        """
        Simulate putting the model into evaluation mode.
        """
        if not self.trained:
            print("Model is not trained yet. Please train the model before evaluating.")
        else:
            print("Model is now in evaluation mode.")

    def predict(self, input_data):
        """
        Simulate making a prediction with the model.
        
        Args:
        input_data (any): Input data for making predictions.

        Returns:
        any: Mock prediction result.
        """
        if not self.trained:
            print("Model is not trained yet. Cannot make predictions.")
            return None
        return "prediction_result"  # Return a mock prediction result.

# Example usage:
medsam_model = MedSAMModel()  # Create an instance of MedSAMModel
medsam_model.train()          # Train the model
medsam_model.eval()           # Switch to evaluation mode
prediction = medsam_model.predict("input data")  # Make a mock prediction
print("Prediction:", prediction)


In [1]:
%matplotlib widget
#from BboxPromptDemo import BboxPromptDemo  # Adjust the import according to your module's structure

# Initialize the demo with a model and directory path
bbox_prompt_demo = BboxPromptDemo(medsam_model, "Annotation Study")

# Load and display the first image
bbox_prompt_demo.load_image()

# Add interactive buttons to the session
bbox_prompt_demo.add_buttons()


NameError: name 'BboxPromptDemo' is not defined

In [15]:
import os
print("Current Working Directory:", os.getcwd())
# Change working directory if necessary
# os.chdir('path_to_your_directory')


Current Working Directory: C:\Users\m133326\Desktop\AI projects\SAM
