In [None]:
# 🔧 Check GPU availability and system info
import torch
import subprocess
import os

print("🖥️  System Information:")
print(f"   Python version: {torch.__version__}")
print(f"   PyTorch version: {torch.__version__}")
print(f"   CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"   GPU device: {torch.cuda.get_device_name(0)}")
    print(f"   GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print("   ✅ GPU is ready for SAM2!")
else:
    print("   ⚠️  No GPU detected. Please enable GPU runtime for optimal performance.")
    print("   Go to Runtime → Change runtime type → Select GPU")

# Check available disk space
result = subprocess.run(['df', '-h', '/content'], capture_output=True, text=True)
print(f"\n💾 Available disk space:")
print(result.stdout.split('\n')[1])


In [None]:
# 📥 Clone the SAM2 demo repository
import os

if os.path.exists('/content/sam2-demo'):
    print("🔄 Repository already exists, pulling latest changes...")
    %cd /content/sam2-demo
    !git pull origin main
else:
    print("📥 Cloning SAM2 demo repository...")
    %cd /content
    !git clone https://github.com/sarptandoven/sam2-demo.git
    %cd sam2-demo

print("✅ Repository ready!")
!pwd


In [None]:
# 📦 Install system dependencies
print("📦 Installing system dependencies...")

# Update package list and install required packages
!apt-get update -qq
!apt-get install -y -qq \
    ffmpeg \
    libsm6 \
    libxext6 \
    libxrender-dev \
    libglib2.0-0 \
    libgl1-mesa-glx \
    curl

# Install Node.js and Yarn for frontend
print("📦 Installing Node.js and Yarn...")
!curl -fsSL https://deb.nodesource.com/setup_18.x | bash -
!apt-get install -y nodejs
!npm install -g yarn

# Verify installations
!echo "Node version: $(node --version)"
!echo "Yarn version: $(yarn --version)"
!echo "FFmpeg version: $(ffmpeg -version | head -1)"

print("✅ System dependencies installed!")


In [None]:
# 🐍 Install Python dependencies
print("🐍 Installing Python dependencies...")

# Install core dependencies
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q \
    opencv-python \
    Pillow \
    numpy \
    matplotlib \
    imagesize \
    dataclasses-json \
    pycocotools \
    av \
    flask \
    flask-cors \
    strawberry-graphql \
    pydantic \
    python-multipart

# Install SAM2 package
print("🔧 Installing SAM2 package...")
!pip install -q -e .

print("✅ Python dependencies installed!")

# Verify PyTorch CUDA
import torch
print(f"🔥 PyTorch CUDA available: {torch.cuda.is_available()}")


In [None]:
# 🤖 Download SAM2 model checkpoints
print("🤖 Downloading SAM2 model checkpoints...")

import os
os.makedirs('/content/sam2-demo/checkpoints', exist_ok=True)

# Download model checkpoints
%cd /content/sam2-demo/checkpoints
!bash download_ckpts.sh

# Verify downloads
print("\n📋 Downloaded checkpoints:")
!ls -lh *.pt

print("✅ Model checkpoints ready!")


In [None]:
# 🎨 Install frontend dependencies
print("🎨 Installing frontend dependencies...")

%cd /content/sam2-demo/demo/frontend
!yarn install --legacy-peer-deps

print("✅ Frontend dependencies installed!")


In [None]:
# 🔧 Setup environment variables and configuration
import os

# Set environment variables for GPU usage
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['APP_ROOT'] = '/content/sam2-demo'
os.environ['DATA_PATH'] = '/content/sam2-demo/demo/data'
os.environ['API_URL'] = 'http://localhost:7263'
os.environ['MODEL_SIZE'] = 'small'  # Use small model for faster inference
os.environ['DEFAULT_VIDEO_PATH'] = 'gallery/01_dog.mp4'
os.environ['PYTHONPATH'] = '/content/sam2-demo:' + os.environ.get('PYTHONPATH', '')

print("🔧 Environment configured:")
print(f"   - App root: {os.environ['APP_ROOT']}")
print(f"   - Model size: {os.environ['MODEL_SIZE']}")
print(f"   - API URL: {os.environ['API_URL']}")
print(f"   - CUDA device: {os.environ.get('CUDA_VISIBLE_DEVICES', 'default')}")

%cd /content/sam2-demo
print("✅ Environment ready!")


In [None]:
# 🌐 Install Cloudflare Tunnel for public access
print("🌐 Installing Cloudflare Tunnel for public access...")

# Download and install cloudflared
!wget -q https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb
!dpkg -i cloudflared-linux-amd64.deb

print("✅ Cloudflare Tunnel installed!")


In [None]:
# 🚀 Start the backend server
import subprocess
import time
import os

print("🚀 Starting SAM2 backend server...")

# Change to backend directory
backend_dir = '/content/sam2-demo/demo/backend/server'
os.chdir(backend_dir)

# Start backend server in background
backend_cmd = ['python', 'app.py']

# Set environment for backend
backend_env = os.environ.copy()
backend_env.update({
    'PYTORCH_ENABLE_MPS_FALLBACK': '1',
    'APP_ROOT': '/content/sam2-demo',
    'DATA_PATH': '/content/sam2-demo/demo/data',
    'API_URL': 'http://localhost:7263',
    'MODEL_SIZE': 'small',
    'DEFAULT_VIDEO_PATH': 'gallery/01_dog.mp4',
    'PYTHONPATH': '/content/sam2-demo'
})

# Start backend process
backend_process = subprocess.Popen(
    backend_cmd,
    env=backend_env,
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
    text=True,
    bufsize=1,
    universal_newlines=True
)

print("⏳ Waiting for backend to initialize...")

# Monitor backend startup
for i in range(60):  # Wait up to 60 seconds
    if backend_process.poll() is not None:
        print("❌ Backend process exited unexpectedly")
        break
    
    try:
        import requests
        response = requests.get('http://localhost:7263/graphql', timeout=2)
        if response.status_code == 405:  # GraphQL expects POST, 405 = Method Not Allowed is fine
            print("✅ Backend server is running on port 7263")
            break
    except:
        pass
    
    time.sleep(1)
    if i % 10 == 0:
        print(f"   ... still waiting ({i}s)")
else:
    print("⚠️  Backend startup timeout, but continuing...")

print("🎯 Backend process started (PID: {})".format(backend_process.pid))


In [None]:
# 🎨 Start the frontend server
import subprocess
import time
import os

print("🎨 Starting frontend server...")

# Change to frontend directory
frontend_dir = '/content/sam2-demo/demo/frontend'
os.chdir(frontend_dir)

# Start frontend server in background
frontend_cmd = ['yarn', 'dev', '--host', '0.0.0.0', '--port', '5173']

frontend_process = subprocess.Popen(
    frontend_cmd,
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
    text=True,
    bufsize=1,
    universal_newlines=True
)

print("⏳ Waiting for frontend to start...")

# Monitor frontend startup
for i in range(30):  # Wait up to 30 seconds
    if frontend_process.poll() is not None:
        print("❌ Frontend process exited unexpectedly")
        break
    
    try:
        import requests
        response = requests.get('http://localhost:5173', timeout=2)
        if response.status_code == 200:
            print("✅ Frontend server is running on port 5173")
            break
    except:
        pass
    
    time.sleep(1)
    if i % 5 == 0:
        print(f"   ... still waiting ({i}s)")
else:
    print("⚠️  Frontend startup timeout, but continuing...")

print("🎯 Frontend process started (PID: {})".format(frontend_process.pid))


In [None]:
# 🌐 Create public tunnel and get URL
import subprocess
import time
import re
import threading

print("🌐 Creating public tunnel to frontend...")

# Start cloudflared tunnel
tunnel_cmd = ['cloudflared', 'tunnel', '--url', 'http://localhost:5173']

tunnel_process = subprocess.Popen(
    tunnel_cmd,
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
    text=True,
    bufsize=1,
    universal_newlines=True
)

# Function to read tunnel output and extract URL
public_url = None

def read_tunnel_output():
    global public_url
    for line in iter(tunnel_process.stdout.readline, ''):
        if 'trycloudflare.com' in line:
            match = re.search(r'https://[\w-]+\.trycloudflare\.com', line)
            if match:
                public_url = match.group(0)
                break

# Start reading tunnel output in background
output_thread = threading.Thread(target=read_tunnel_output)
output_thread.start()

# Wait for URL
print("⏳ Waiting for public URL...")
for i in range(20):
    if public_url:
        break
    time.sleep(1)
    if i % 5 == 0:
        print(f"   ... still waiting ({i}s)")

if public_url:
    print(f"\n🎉 SAM2 Demo is ready!")
    print(f"\n🔗 Public URL: {public_url}")
    print(f"\n📱 Click the link above to access the demo")
    print(f"\n⚡ Features available:")
    print(f"   - Interactive video segmentation")
    print(f"   - Real-time object tracking")
    print(f"   - Video upload (up to 100MB)")
    print(f"   - Video effects and export")
    print(f"   - GPU-accelerated inference")
else:
    print("❌ Failed to get public URL. You can still access locally at http://localhost:5173")

print(f"\n🎯 Tunnel process started (PID: {tunnel_process.pid})")


In [None]:
# 📊 Monitor services and show status
import requests
import time
import torch

def check_service_status():
    print("📊 Service Status Check:")
    
    # Check backend
    try:
        response = requests.get('http://localhost:7263/graphql', timeout=5)
        backend_status = "✅ Running" if response.status_code in [200, 405] else "❌ Error"
    except:
        backend_status = "❌ Not responding"
    
    # Check frontend
    try:
        response = requests.get('http://localhost:5173', timeout=5)
        frontend_status = "✅ Running" if response.status_code == 200 else "❌ Error"
    except:
        frontend_status = "❌ Not responding"
    
    # GPU info
    if torch.cuda.is_available():
        gpu_memory = torch.cuda.memory_allocated() / 1e9
        gpu_status = f"✅ GPU Memory: {gpu_memory:.1f}GB used"
    else:
        gpu_status = "⚠️  No GPU available"
    
    print(f"   🖥️  Backend (GraphQL): {backend_status}")
    print(f"   🎨 Frontend (React): {frontend_status}")
    print(f"   🔥 {gpu_status}")
    
    if 'public_url' in globals() and public_url:
        print(f"   🌐 Public URL: {public_url}")
    
    return backend_status, frontend_status

# Initial status check
check_service_status()

print("\n💡 Tips:")
print("   - If services show as not responding, wait a moment and run this cell again")
print("   - The demo works best with videos under 10 seconds and 100MB")
print("   - Try the gallery videos first to test the interface")
print("   - Use the 'Add Object' button to start segmenting objects in videos")

print("\n🔄 Services will run until this notebook session ends.")
print("📝 Re-run this cell anytime to check service status.")


In [None]:
# Check if we're running in Colab
import os
IN_COLAB = 'COLAB_GPU' in os.environ

if IN_COLAB:
    print("🔥 Running in Google Colab - GPU enabled!")
else:
    print("💻 Running locally")

# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")


In [None]:
# Install/upgrade PyTorch and torchvision to required versions
%pip install torch>=2.5.1 torchvision>=0.20.1 --upgrade --quiet

# Install other required packages
%pip install numpy>=1.24.4 tqdm>=4.66.1 hydra-core>=1.3.2 iopath>=0.1.10 pillow>=9.4.0 --quiet

# Install additional packages for notebooks
%pip install matplotlib>=3.9.1 opencv-python>=4.7.0 eva-decord>=0.6.1 --quiet

# Install packages for video processing
%pip install av imageio-ffmpeg --quiet

print("✅ All dependencies installed successfully!")


In [None]:
# Clone the SAM2 repository
import os
import subprocess
import shutil

# Remove existing sam2 directory if it exists
if os.path.exists('sam2'):
    shutil.rmtree('sam2')

# Clone the repository
result = subprocess.run(['git', 'clone', 'https://github.com/facebookresearch/sam2.git'], 
                       capture_output=True, text=True)
if result.returncode != 0:
    print(f"Error cloning repository: {result.stderr}")
    exit(1)

# Change to the sam2 directory
os.chdir('sam2')

print("✅ SAM2 repository cloned successfully!")
print(f"📂 Current directory: {os.getcwd()}")


In [None]:
# Install SAM2 in development mode
# Skip CUDA extension build in Colab to avoid issues
os.environ['SAM2_BUILD_CUDA'] = '0'

%pip install -e ".[notebooks]" --quiet

print("✅ SAM2 installed successfully!")
print("⚠️  Note: CUDA extension disabled for Colab compatibility")


In [None]:
# Download model checkpoints
import os
import requests
from tqdm import tqdm

# Create checkpoints directory
os.makedirs('checkpoints', exist_ok=True)

# Model checkpoints URLs (SAM 2.1)
CHECKPOINTS = {
    'sam2.1_hiera_tiny.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt',
    'sam2.1_hiera_small.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt',
    'sam2.1_hiera_base_plus.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt',
    'sam2.1_hiera_large.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt'
}

def download_checkpoint(filename, url):
    """Download a model checkpoint with progress bar."""
    filepath = f'checkpoints/{filename}'
    if os.path.exists(filepath):
        print(f"✅ {filename} already exists")
        return
    
    print(f"📥 Downloading {filename}...")
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    
    with open(filepath, 'wb') as f, tqdm(
        desc=filename,
        total=total_size,
        unit='B',
        unit_scale=True,
        unit_divisor=1024,
    ) as pbar:
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:
                f.write(chunk)
                pbar.update(len(chunk))

# Download the small model by default (good balance of speed and accuracy)
download_checkpoint('sam2.1_hiera_small.pt', CHECKPOINTS['sam2.1_hiera_small.pt'])

# Optionally download other models (uncomment as needed)
# download_checkpoint('sam2.1_hiera_tiny.pt', CHECKPOINTS['sam2.1_hiera_tiny.pt'])
# download_checkpoint('sam2.1_hiera_base_plus.pt', CHECKPOINTS['sam2.1_hiera_base_plus.pt'])
# download_checkpoint('sam2.1_hiera_large.pt', CHECKPOINTS['sam2.1_hiera_large.pt'])

print("\\n✅ Model checkpoints ready!")
print(f"📂 Checkpoints directory: {os.path.abspath('checkpoints')}")
print("📝 Downloaded models:")
for f in os.listdir('checkpoints'):
    if f.endswith('.pt'):
        size = os.path.getsize(f'checkpoints/{f}') / (1024**2)
        print(f"  - {f} ({size:.1f} MB)")


In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import requests
from io import BytesIO
import warnings
warnings.filterwarnings('ignore')

# SAM2 imports
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

print("✅ All libraries imported successfully!")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🚀 CUDA available: {torch.cuda.is_available()}")


In [None]:
# Model configuration
MODEL_NAME = "sam2.1_hiera_small"  # Change this to use different model sizes
CHECKPOINT_PATH = f"checkpoints/{MODEL_NAME}.pt"
CONFIG_PATH = f"configs/sam2.1/{MODEL_NAME.replace('sam2.1_', 'sam2.1_')}.yaml"

# Check if files exist
if not os.path.exists(CHECKPOINT_PATH):
    print(f"❌ Checkpoint not found: {CHECKPOINT_PATH}")
    print("Please download the checkpoint first using the cell above.")
else:
    print(f"✅ Checkpoint found: {CHECKPOINT_PATH}")

if not os.path.exists(CONFIG_PATH):
    print(f"❌ Config not found: {CONFIG_PATH}")
else:
    print(f"✅ Config found: {CONFIG_PATH}")

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🔧 Using device: {device}")

# Set up model for image prediction
sam2_model = build_sam2(CONFIG_PATH, CHECKPOINT_PATH, device=device)
image_predictor = SAM2ImagePredictor(sam2_model)

# Set up automatic mask generator
mask_generator = SAM2AutomaticMaskGenerator(sam2_model)

print(f"✅ SAM2 model ({MODEL_NAME}) loaded successfully!")
print(f"📊 Model parameters: {sum(p.numel() for p in sam2_model.parameters()) / 1e6:.1f}M")


In [None]:
def load_image_from_url(url):
    """Load an image from a URL."""
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    return np.array(image)

def load_image_from_path(path):
    """Load an image from a local path."""
    image = Image.open(path)
    return np.array(image)

def show_mask(mask, ax, random_color=False, alpha=0.5):
    """Display a mask on the given axes."""
    if random_color:
        color = np.concatenate([np.random.random(3), [alpha]])
    else:
        color = np.array([30/255, 144/255, 255/255, alpha])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=200):
    """Display points on the given axes."""
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    """Display a bounding box on the given axes."""
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

def show_anns(anns, ax):
    """Display automatic mask annotations."""
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    
    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.5]])
        img[m] = color_mask
    ax.imshow(img)

def download_sample_image(url, filename):
    """Download a sample image for testing."""
    if not os.path.exists(filename):
        print(f"📥 Downloading sample image: {filename}")
        response = requests.get(url)
        with open(filename, 'wb') as f:
            f.write(response.content)
        print(f"✅ Sample image downloaded: {filename}")
    else:
        print(f"✅ Sample image already exists: {filename}")

print("✅ Helper functions defined!")


In [None]:
# Download sample images for demonstration
SAMPLE_IMAGES = {
    'truck.jpg': 'https://raw.githubusercontent.com/facebookresearch/sam2/main/notebooks/images/truck.jpg',
    'groceries.jpg': 'https://raw.githubusercontent.com/facebookresearch/sam2/main/notebooks/images/groceries.jpg',
}

# Create images directory
os.makedirs('images', exist_ok=True)

# Download sample images
for filename, url in SAMPLE_IMAGES.items():
    download_sample_image(url, f'images/{filename}')

print("\\n📂 Available sample images:")
for f in os.listdir('images'):
    if f.lower().endswith(('.jpg', '.jpeg', '.png')):
        print(f"  - {f}")


In [None]:
# Load and segment an image with point prompts
image_path = 'images/truck.jpg'
image = load_image_from_path(image_path)

# Set the image in the predictor
with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
    image_predictor.set_image(image)

# Define prompts (x, y coordinates)
input_points = np.array([[500, 375]])  # Click on truck
input_labels = np.array([1])  # 1 = positive click, 0 = negative click

# Predict masks
with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
    masks, scores, logits = image_predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        multimask_output=True,
    )

# Display results
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

# Original image
axes[0].imshow(image)
axes[0].set_title('Original Image')
axes[0].axis('off')

# Show masks
for i, (mask, score) in enumerate(zip(masks, scores)):
    axes[i+1].imshow(image)
    show_mask(mask, axes[i+1])
    show_points(input_points, input_labels, axes[i+1])
    axes[i+1].set_title(f'Mask {i+1} (Score: {score:.3f})')
    axes[i+1].axis('off')

plt.tight_layout()
plt.show()

print(f"✅ Generated {len(masks)} masks with scores: {scores}")
print(f"🎯 Best mask index: {np.argmax(scores)}")


In [None]:
# Segment with bounding box prompt
image_path = 'images/groceries.jpg'
image = load_image_from_path(image_path)

# Set the image in the predictor
with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
    image_predictor.set_image(image)

# Define bounding box [x1, y1, x2, y2]
input_box = np.array([425, 600, 700, 875])  # Bounding box around an object

# Predict masks
with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
    masks, scores, logits = image_predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_box[None, :],
        multimask_output=False,
    )

# Display results
fig, axes = plt.subplots(1, 2, figsize=(15, 7))

# Original image with bounding box
axes[0].imshow(image)
show_box(input_box, axes[0])
axes[0].set_title('Original Image with Bounding Box')
axes[0].axis('off')

# Segmented result
axes[1].imshow(image)
show_mask(masks[0], axes[1])
show_box(input_box, axes[1])
axes[1].set_title(f'Segmentation Result (Score: {scores[0]:.3f})')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print(f"✅ Generated mask with score: {scores[0]:.3f}")


In [None]:
# Combine multiple prompts for better segmentation
image_path = 'images/truck.jpg'
image = load_image_from_path(image_path)

# Set the image in the predictor
with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
    image_predictor.set_image(image)

# Define multiple prompts
input_points = np.array([[500, 375], [600, 400]])  # Multiple positive clicks
input_labels = np.array([1, 1])  # Both positive
input_box = np.array([425, 300, 700, 500])  # Bounding box

# Predict masks
with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
    masks, scores, logits = image_predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        box=input_box[None, :],
        multimask_output=False,
    )

# Display results
fig, axes = plt.subplots(1, 2, figsize=(15, 7))

# Original image with prompts
axes[0].imshow(image)
show_points(input_points, input_labels, axes[0])
show_box(input_box, axes[0])
axes[0].set_title('Original Image with Combined Prompts')
axes[0].axis('off')

# Segmented result
axes[1].imshow(image)
show_mask(masks[0], axes[1])
show_points(input_points, input_labels, axes[1])
show_box(input_box, axes[1])
axes[1].set_title(f'Combined Prompts Result (Score: {scores[0]:.3f})')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print(f"✅ Generated mask with combined prompts, score: {scores[0]:.3f}")


In [None]:
# Automatic mask generation
image_path = 'images/groceries.jpg'
image = load_image_from_path(image_path)

print("🔄 Generating automatic masks... This may take a moment.")

# Generate masks automatically
with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
    masks = mask_generator.generate(image)

# Display results
fig, axes = plt.subplots(1, 2, figsize=(20, 10))

# Original image
axes[0].imshow(image)
axes[0].set_title('Original Image')
axes[0].axis('off')

# Automatic segmentation
axes[1].imshow(image)
show_anns(masks, axes[1])
axes[1].set_title(f'Automatic Segmentation ({len(masks)} masks)')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print(f"✅ Generated {len(masks)} automatic masks")
print(f"📊 Mask areas: {[m['area'] for m in masks[:5]][:5]}...")  # Show first 5 areas
