# YOLO-E Object Extraction with IoU-based NMS Demo

This notebook demonstrates how to use YOLO-E to extract individual objects from images with Non-Maximum Suppression (NMS) based on IoU threshold.

## 1. Setup and Imports

In [None]:
import os
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from ultralytics import YOLOE

# Set GPU device
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

print("Setup complete!")

## 2. Load Model and Image

In [None]:
# Configuration
IMAGE_PATH = "path/to/your/image.jpg"  # Change this to your image path
MODEL_PATH = "yolo11l-seg.pt"  # YOLO-E segmentation model
CONFIDENCE_THRESHOLD = 0.2

# Load model
print(f"Loading model: {MODEL_PATH}")
model = YOLOE(MODEL_PATH)

# Load image
print(f"Loading image: {IMAGE_PATH}")
image = Image.open(IMAGE_PATH).convert('RGB')
print(f"Image size: {image.size}")

# Display original image
plt.figure(figsize=(10, 8))
plt.imshow(image)
plt.title('Original Image')
plt.axis('off')
plt.show()

## 3. Run YOLO-E Segmentation

In [None]:
# Run prediction
print("Running YOLO-E segmentation...")
results = model.predict(image, conf=CONFIDENCE_THRESHOLD, verbose=False)

# Check results
if results and results[0].boxes:
    print(f"Detected {len(results[0].boxes)} objects")
    
    if hasattr(results[0], 'masks') and results[0].masks is not None:
        print(f"Segmentation masks: {len(results[0].masks.data)}")
    else:
        print("Warning: No segmentation masks found")
else:
    print("No objects detected")

## 4. Define Helper Functions

In [None]:
def calculate_iou(box1, box2):
    """
    Calculate Intersection over Union (IoU) between two bounding boxes.
    """
    x1_1, y1_1, x2_1, y2_1 = box1
    x1_2, y1_2, x2_2, y2_2 = box2
    
    # Intersection
    x1_inter = max(x1_1, x1_2)
    y1_inter = max(y1_1, y1_2)
    x2_inter = min(x2_1, x2_2)
    y2_inter = min(y2_1, y2_2)
    
    if x2_inter <= x1_inter or y2_inter <= y1_inter:
        return 0.0
    
    intersection_area = (x2_inter - x1_inter) * (y2_inter - y1_inter)
    
    # Union
    area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
    area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
    union_area = area1 + area2 - intersection_area
    
    return intersection_area / union_area if union_area > 0 else 0.0


def extract_objects_with_nms(results, original_image, padding=10, iou_threshold=0.5):
    """
    Extract segmented objects with NMS filtering.
    """
    if not (hasattr(results[0], 'masks') and results[0].masks is not None):
        print("No segmentation masks found.")
        return [], []
    
    img_array = np.array(original_image)
    if img_array.shape[-1] == 4:
        img_array = img_array[:, :, :3]
    
    masks = results[0].masks.data.cpu().numpy()
    boxes = results[0].boxes.xyxy.cpu().numpy()
    conf = results[0].boxes.conf.cpu().numpy()
    cls = results[0].boxes.cls.cpu().numpy() if results[0].boxes.cls is not None else np.zeros(len(boxes))
    
    # Sort by confidence
    sorted_indices = np.argsort(conf)[::-1]
    
    # Apply NMS
    keep_indices = []
    for i in sorted_indices:
        current_box = boxes[i]
        should_keep = True
        
        for kept_idx in keep_indices:
            kept_box = boxes[kept_idx]
            iou = calculate_iou(current_box, kept_box)
            if iou > iou_threshold:
                should_keep = False
                break
        
        if should_keep:
            keep_indices.append(i)
    
    print(f"Before NMS: {len(boxes)} objects")
    print(f"After NMS: {len(keep_indices)} objects (IoU threshold: {iou_threshold})")
    print(f"Removed: {len(boxes) - len(keep_indices)} overlapping objects")
    
    # Extract objects
    extracted_objects = []
    object_info = []
    
    for idx, original_idx in enumerate(keep_indices):
        mask = masks[original_idx]
        box = boxes[original_idx]
        confidence = conf[original_idx]
        class_id = cls[original_idx]
        
        # Resize mask
        h, w = img_array.shape[:2]
        resized_mask = cv2.resize(mask, (w, h))
        
        # Apply padding
        x1, y1, x2, y2 = box.astype(int)
        x1 = max(0, x1 - padding)
        y1 = max(0, y1 - padding)
        x2 = min(w, x2 + padding)
        y2 = min(h, y2 + padding)
        
        # Crop and apply mask
        cropped_img = img_array[y1:y2, x1:x2]
        cropped_mask = resized_mask[y1:y2, x1:x2]
        
        masked_obj = cropped_img.copy()
        n_channels = masked_obj.shape[-1]
        
        if len(cropped_mask.shape) == 2:
            cropped_mask_expanded = np.stack([cropped_mask] * n_channels, axis=-1)
        else:
            cropped_mask_expanded = cropped_mask
        
        # Background to white
        masked_obj[cropped_mask_expanded < 0.5] = 255
        
        pil_obj = Image.fromarray(masked_obj.astype(np.uint8))
        
        extracted_objects.append(pil_obj)
        object_info.append({
            'index': idx + 1,
            'confidence': float(confidence),
            'size': pil_obj.size,
            'bbox': (x1, y1, x2, y2)
        })
    
    return extracted_objects, object_info

print("Helper functions defined!")

## 5. Extract Objects with NMS

In [None]:
# Extract objects
IOU_THRESHOLD = 0.5
PADDING = 20

extracted_objects, object_info = extract_objects_with_nms(
    results, image,
    padding=PADDING,
    iou_threshold=IOU_THRESHOLD
)

print(f"\nExtracted {len(extracted_objects)} objects after NMS filtering")

## 6. Visualize Extracted Objects

In [None]:
if extracted_objects:
    n_objects = len(extracted_objects)
    cols = min(4, n_objects)
    rows = (n_objects + cols - 1) // cols
    
    fig = plt.figure(figsize=(16, 4 * rows))
    
    for i in range(n_objects):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(extracted_objects[i])
        plt.title(f"Object {object_info[i]['index']}\n"
                 f"Confidence: {object_info[i]['confidence']:.2f}",
                 fontsize=10)
        plt.axis('off')
    
    plt.suptitle(f'Extracted Objects ({n_objects} total)', fontsize=14)
    plt.tight_layout()
    plt.show()
else:
    print("No objects to visualize.")

## 7. Print Object Details

In [None]:
print("\nObject Details:")
print("=" * 60)
for info in object_info:
    print(f"Object {info['index']}:")
    print(f"  Confidence: {info['confidence']:.3f}")
    print(f"  Size: {info['size']}")
    print(f"  Bounding Box: {info['bbox']}")
    print()

## 8. Save Results (Optional)

In [None]:
from pathlib import Path

# Create output directory
output_dir = Path('output')
output_dir.mkdir(exist_ok=True)

# Save individual objects
objects_dir = output_dir / 'objects'
objects_dir.mkdir(exist_ok=True)

for obj, info in zip(extracted_objects, object_info):
    filename = f"object_{info['index']:03d}_conf{info['confidence']:.2f}.png"
    filepath = objects_dir / filename
    obj.save(filepath)
    print(f"Saved: {filepath}")

print(f"\nAll objects saved to: {objects_dir}")