## Main File for Project! All important functions are here

#### Jack Cells / Helpers

In [None]:
%pip install opencv-python
%pip install os
%pip install numpy
%pip install segment_anything
%pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118

import os
import cv2
import numpy as np
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator

In [None]:
def get_mask_for_point(image_path, point, model_type="default", checkpoint_path=None):
    try:
        # Load SAM model
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}")
        sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
        predictor = SamPredictor(sam)

        # Load/set img
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        predictor.set_image(image)

        # Generate mask
        input_point = np.array([point])
        input_label = np.array([1])  # For foreground
        masks, _, _ = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=False,  # Return only the best mask
        )
        return masks[0]  # Return the first mask (best mask)

    except Exception as e:
        print(f"Error: {e}")
        return None

In [None]:
# Example usage
image_path = "test.jpg"
checkpoint_path_small = "sam_vit_b_01ec64.pth"
point = (1000, 1000) 

mask = get_mask_for_point(image_path, point, model_type="vit_b", checkpoint_path=checkpoint_path_small)

if mask is not None:
    print("Mask shape:", mask.shape)
    # You can further process or display the mask here.  For example:
    import matplotlib.pyplot as plt
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    plt.figure()
    plt.imshow(image)
    plt.imshow(mask, cmap='gray', alpha=0.5)  # Overlay the mask
    plt.scatter(point[0], point[1], c='red', s=10) # added line
    plt.title('Image with Mask Overlay')
    plt.show()

else: print("Failed to generate mask.")