In [None]:
import rospy
import ros_numpy
from sensor_msgs.msg import Image as ImageMsg
from PIL import Image
from IPython.display import display
import matplotlib.pyplot as plt
import numpy as np

In [None]:
rospy.init_node('collect_images', anonymous=True)

In [None]:
rgb_message = rospy.wait_for_message("camera/color/image_raw", ImageMsg)
depth_message = rospy.wait_for_message("/camera/aligned_depth_to_color/image_raw", ImageMsg)
rgb_data = ros_numpy.numpify(rgb_message)
depth_data = ros_numpy.numpify(depth_message)

In [None]:
print(rgb_data.shape)
print(depth_data.shape)

In [None]:
rgb_image = Image.fromarray(rgb_data)
rgb_image

In [None]:
depth_image = Image.fromarray(depth_data)

In [None]:
rgb_image.save("data/lego_split/demo_start_rgb_left.png")
depth_image.save("data/lego_split/demo_start_depth_left.png")

In [None]:
from lang_sam import LangSAM

rgb_image = Image.open("data/lego_split/left_demo_start_rgb.png")
depth_image = Image.open("data/lego_split/left_demo_start_depth.png")

model = LangSAM()

In [None]:
def save_mask(mask_np, filename):
    mask_image = Image.fromarray((mask_np * 255).astype(np.uint8))
    mask_image.save(filename)

def display_image_with_masks(image, masks):
    num_masks = len(masks)

    fig, axes = plt.subplots(1, num_masks + 1, figsize=(15, 5))
    axes[0].imshow(image)
    axes[0].set_title("Original Image")
    axes[0].axis('off')

    for i, mask_np in enumerate(masks):
        axes[i+1].imshow(mask_np, cmap='gray')
        axes[i+1].set_title(f"Mask {i+1}")
        axes[i+1].axis('off')

    plt.tight_layout()
    plt.show()

def display_image_with_boxes(image, boxes, logits):
    fig, ax = plt.subplots()
    ax.imshow(image)
    ax.set_title("Image with Bounding Boxes")
    ax.axis('off')

    for box, logit in zip(boxes, logits):
        x_min, y_min, x_max, y_max = box
        confidence_score = round(logit.item(), 2)  # Convert logit to a scalar before rounding
        box_width = x_max - x_min
        box_height = y_max - y_min

        # Draw bounding box
        rect = plt.Rectangle((x_min, y_min), box_width, box_height, fill=False, edgecolor='red', linewidth=2)
        ax.add_patch(rect)

        # Add confidence score as text
        ax.text(x_min, y_min, f"Confidence: {confidence_score}", fontsize=8, color='red', verticalalignment='top')

    plt.show()

def print_bounding_boxes(boxes):
    print("Bounding Boxes:")
    for i, box in enumerate(boxes):
        print(f"Box {i+1}: {box}")

def print_detected_phrases(phrases):
    print("\nDetected Phrases:")
    for i, phrase in enumerate(phrases):
        print(f"Phrase {i+1}: {phrase}")

def print_logits(logits):
    print("\nConfidence:")
    for i, logit in enumerate(logits):
        print(f"Logit {i+1}: {logit}")

In [None]:
image_pil = rgb_image.copy()
text_prompt = "lego"
masks, boxes, phrases, logits = model.predict(rgb_image, text_prompt)

if len(masks) == 0:
    print(f"No objects of the '{text_prompt}' prompt detected in the image.")
else:
    # Convert masks to numpy arrays
    masks_np = [mask.squeeze().cpu().numpy() for mask in masks]

    # Display the original image and masks side by side
    display_image_with_masks(image_pil, masks_np)

    # Display the image with bounding boxes and confidence scores
    display_image_with_boxes(image_pil, boxes, logits)

    # Save the masks
    for i, mask_np in enumerate(masks_np):
        mask_path = f"demo_spray_mask_{i+1}.png"
        save_mask(mask_np, mask_path)

    # Print the bounding boxes, phrases, and logits
    print_bounding_boxes(boxes)
    print_detected_phrases(phrases)
    print_logits(logits)


: 