In [None]:
import numpy as np
from PIL import Image
import cv2
import torch
import matplotlib.pyplot as plt

# -----------------------------
# 0. Load image
# -----------------------------
imageloc = "images/plant.jpg"

# Load as RGB numpy array (H, W, 3)
image_pil = Image.open(imageloc).convert("RGB")
image = np.array(image_pil)

h, w, _ = image.shape

# -----------------------------
# 1. Center crop + HSV auto leaf point (your original logic)
# -----------------------------
crop_size = 1  # 1 == whole image; change to e.g. 0.2 for central 20%
h0 = int(h * (0.5 - crop_size / 2))
h1 = int(h * (0.5 + crop_size / 2))
w0 = int(w * (0.5 - crop_size / 2))
w1 = int(w * (0.5 + crop_size / 2))

center_crop = image[h0:h1, w0:w1, :]

center_crop_bgr = cv2.cvtColor(center_crop, cv2.COLOR_RGB2BGR)
hsv = cv2.cvtColor(center_crop_bgr, cv2.COLOR_BGR2HSV)

# Typical green range in HSV (tweak as needed)
mask_green = (
    (hsv[..., 0] > 35) & (hsv[..., 0] < 85) &
    (hsv[..., 1] > 50) & (hsv[..., 2] > 50)
)

candidates = np.argwhere(mask_green)
plantloc = None

if candidates.shape[0]:
    # Find the green pixel closest to the center of the crop
    center_y, center_x = np.array(mask_green.shape) // 2
    dists = np.sum((candidates - [center_y, center_x]) ** 2, axis=1)
    chosen = candidates[np.argmin(dists)]
    y_rel, x_rel = chosen
    y_abs = y_rel + h0
    x_abs = x_rel + w0
    plantloc = [[x_abs, y_abs]]
    print("Auto-selected leaf point (HSV):", plantloc)
else:
    print("No green region found in center crop.")
    # You can fall back to a manual point here if you want:
    # plantloc = [[w // 2, h // 2]]

# Optional: visualize HSV mask overlay for debugging
overlay = center_crop.copy()
overlay[mask_green] = [255, 0, 0]  # mark detected plant-ish regions in red

# plt.figure(figsize=(10, 5))
# plt.subplot(1, 2, 1)
# plt.imshow(center_crop)
# plt.title("Center crop")
# plt.axis("off")
# plt.subplot(1, 2, 2)
# plt.imshow(overlay)
# plt.title("Green mask overlay")
# plt.axis("off")
# plt.show()

# -----------------------------
# 2. SAM 3: load model + processor
# -----------------------------
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor  # name from SAM3 README

# Choose device (your teammate should have CUDA)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Build SAM3 image model and processor
# NOTE: this will internally load checkpoints from Hugging Face;
# they must have run `hf auth login` and requested SAM3 access.
model = build_sam3_image_model().to(device)
processor = Sam3Processor(model)

# -----------------------------
# 3A. Use SAM3 with a TEXT prompt ("plant leaves")
#     (simpler to wire; SAM3 will segment all instances of that concept)
# -----------------------------
inference_state = processor.set_image(image_pil)

# You can tune this phrase based on your images:
text_prompt = "green plant leaves"

output = processor.set_text_prompt(
    state=inference_state,
    prompt=text_prompt,
)

# SAM3 returns instance masks + boxes + scores
masks = output["masks"]   # shape: (N, H, W) or similar
boxes = output["boxes"]   # (N, 4)
scores = output["scores"] # (N,)

# -----------------------------
# 3B. Convert masks to numpy and build union mask
# -----------------------------
if isinstance(masks, torch.Tensor):
    masks_np = masks.detach().cpu().numpy()
else:
    masks_np = np.asarray(masks)

print("SAM3 returned", masks_np.shape[0], "instance masks for prompt:", text_prompt)

# Optionally filter by score threshold
score_thresh = 0.5
keep = scores >= score_thresh
if isinstance(keep, torch.Tensor):
    keep = keep.cpu().numpy().astype(bool)

if keep.sum() == 0:
    print(f"No masks above score {score_thresh}. "
          f"Using all masks anyway.")
    keep[:] = True

masks_kept = masks_np[keep]

# Union over instances â†’ single plant mask
union_mask = masks_kept.astype(bool).any(axis=0)  # (H, W) boolean

num_segmented_pixels = int(union_mask.sum())
print("Number of SAM3 plant pixels (union of instances):", num_segmented_pixels)

# -----------------------------
# 4. Visualize original vs SAM3 segmentation
# -----------------------------
segmented_image = np.zeros_like(image)
segmented_image[union_mask] = image[union_mask]

fig, axs = plt.subplots(1, 2, figsize=(14, 7))

axs[0].imshow(image)
axs[0].set_title("Original image")
axs[0].axis("off")

axs[1].imshow(segmented_image)
axs[1].set_title(f"SAM3 segmentation ({num_segmented_pixels} px)")
axs[1].axis("off")

plt.tight_layout()
plt.show()
