In [None]:
# Check if we're running in Colab
import os
IN_COLAB = 'COLAB_GPU' in os.environ

if IN_COLAB:
    print("🔥 Running in Google Colab - GPU enabled!")
else:
    print("💻 Running locally")

# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")


In [None]:
# Install/upgrade PyTorch and torchvision to required versions
%pip install torch>=2.5.1 torchvision>=0.20.1 --upgrade --quiet

# Install other required packages
%pip install numpy>=1.24.4 tqdm>=4.66.1 hydra-core>=1.3.2 iopath>=0.1.10 pillow>=9.4.0 --quiet

# Install additional packages for notebooks
%pip install matplotlib>=3.9.1 opencv-python>=4.7.0 eva-decord>=0.6.1 --quiet

# Install packages for video processing
%pip install av imageio-ffmpeg --quiet

print("✅ All dependencies installed successfully!")


In [None]:
# Clone the SAM2 repository
import os
import subprocess
import shutil

# Remove existing sam2 directory if it exists
if os.path.exists('sam2'):
    shutil.rmtree('sam2')

# Clone the repository
result = subprocess.run(['git', 'clone', 'https://github.com/facebookresearch/sam2.git'], 
                       capture_output=True, text=True)
if result.returncode != 0:
    print(f"Error cloning repository: {result.stderr}")
    exit(1)

# Change to the sam2 directory
os.chdir('sam2')

print("✅ SAM2 repository cloned successfully!")
print(f"📂 Current directory: {os.getcwd()}")


In [None]:
# Install SAM2 in development mode
# Skip CUDA extension build in Colab to avoid issues
os.environ['SAM2_BUILD_CUDA'] = '0'

%pip install -e ".[notebooks]" --quiet

print("✅ SAM2 installed successfully!")
print("⚠️  Note: CUDA extension disabled for Colab compatibility")


In [None]:
# Download model checkpoints
import os
import requests
from tqdm import tqdm

# Create checkpoints directory
os.makedirs('checkpoints', exist_ok=True)

# Model checkpoints URLs (SAM 2.1)
CHECKPOINTS = {
    'sam2.1_hiera_tiny.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt',
    'sam2.1_hiera_small.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt',
    'sam2.1_hiera_base_plus.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt',
    'sam2.1_hiera_large.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt'
}

def download_checkpoint(filename, url):
    """Download a model checkpoint with progress bar."""
    filepath = f'checkpoints/{filename}'
    if os.path.exists(filepath):
        print(f"✅ {filename} already exists")
        return
    
    print(f"📥 Downloading {filename}...")
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    
    with open(filepath, 'wb') as f, tqdm(
        desc=filename,
        total=total_size,
        unit='B',
        unit_scale=True,
        unit_divisor=1024,
    ) as pbar:
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:
                f.write(chunk)
                pbar.update(len(chunk))

# Download the small model by default (good balance of speed and accuracy)
download_checkpoint('sam2.1_hiera_small.pt', CHECKPOINTS['sam2.1_hiera_small.pt'])

# Optionally download other models (uncomment as needed)
# download_checkpoint('sam2.1_hiera_tiny.pt', CHECKPOINTS['sam2.1_hiera_tiny.pt'])
# download_checkpoint('sam2.1_hiera_base_plus.pt', CHECKPOINTS['sam2.1_hiera_base_plus.pt'])
# download_checkpoint('sam2.1_hiera_large.pt', CHECKPOINTS['sam2.1_hiera_large.pt'])

print("\\n✅ Model checkpoints ready!")
print(f"📂 Checkpoints directory: {os.path.abspath('checkpoints')}")
print("📝 Downloaded models:")
for f in os.listdir('checkpoints'):
    if f.endswith('.pt'):
        size = os.path.getsize(f'checkpoints/{f}') / (1024**2)
        print(f"  - {f} ({size:.1f} MB)")


In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import requests
from io import BytesIO
import warnings
warnings.filterwarnings('ignore')

# SAM2 imports
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

print("✅ All libraries imported successfully!")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🚀 CUDA available: {torch.cuda.is_available()}")


In [None]:
# Model configuration
MODEL_NAME = "sam2.1_hiera_small"  # Change this to use different model sizes
CHECKPOINT_PATH = f"checkpoints/{MODEL_NAME}.pt"
CONFIG_PATH = f"configs/sam2.1/{MODEL_NAME.replace('sam2.1_', 'sam2.1_')}.yaml"

# Check if files exist
if not os.path.exists(CHECKPOINT_PATH):
    print(f"❌ Checkpoint not found: {CHECKPOINT_PATH}")
    print("Please download the checkpoint first using the cell above.")
else:
    print(f"✅ Checkpoint found: {CHECKPOINT_PATH}")

if not os.path.exists(CONFIG_PATH):
    print(f"❌ Config not found: {CONFIG_PATH}")
else:
    print(f"✅ Config found: {CONFIG_PATH}")

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🔧 Using device: {device}")

# Set up model for image prediction
sam2_model = build_sam2(CONFIG_PATH, CHECKPOINT_PATH, device=device)
image_predictor = SAM2ImagePredictor(sam2_model)

# Set up automatic mask generator
mask_generator = SAM2AutomaticMaskGenerator(sam2_model)

print(f"✅ SAM2 model ({MODEL_NAME}) loaded successfully!")
print(f"📊 Model parameters: {sum(p.numel() for p in sam2_model.parameters()) / 1e6:.1f}M")


In [None]:
def load_image_from_url(url):
    """Load an image from a URL."""
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    return np.array(image)

def load_image_from_path(path):
    """Load an image from a local path."""
    image = Image.open(path)
    return np.array(image)

def show_mask(mask, ax, random_color=False, alpha=0.5):
    """Display a mask on the given axes."""
    if random_color:
        color = np.concatenate([np.random.random(3), [alpha]])
    else:
        color = np.array([30/255, 144/255, 255/255, alpha])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=200):
    """Display points on the given axes."""
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    """Display a bounding box on the given axes."""
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

def show_anns(anns, ax):
    """Display automatic mask annotations."""
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    
    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.5]])
        img[m] = color_mask
    ax.imshow(img)

def download_sample_image(url, filename):
    """Download a sample image for testing."""
    if not os.path.exists(filename):
        print(f"📥 Downloading sample image: {filename}")
        response = requests.get(url)
        with open(filename, 'wb') as f:
            f.write(response.content)
        print(f"✅ Sample image downloaded: {filename}")
    else:
        print(f"✅ Sample image already exists: {filename}")

print("✅ Helper functions defined!")


In [None]:
# Download sample images for demonstration
SAMPLE_IMAGES = {
    'truck.jpg': 'https://raw.githubusercontent.com/facebookresearch/sam2/main/notebooks/images/truck.jpg',
    'groceries.jpg': 'https://raw.githubusercontent.com/facebookresearch/sam2/main/notebooks/images/groceries.jpg',
}

# Create images directory
os.makedirs('images', exist_ok=True)

# Download sample images
for filename, url in SAMPLE_IMAGES.items():
    download_sample_image(url, f'images/{filename}')

print("\\n📂 Available sample images:")
for f in os.listdir('images'):
    if f.lower().endswith(('.jpg', '.jpeg', '.png')):
        print(f"  - {f}")


In [None]:
# Load and segment an image with point prompts
image_path = 'images/truck.jpg'
image = load_image_from_path(image_path)

# Set the image in the predictor
with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
    image_predictor.set_image(image)

# Define prompts (x, y coordinates)
input_points = np.array([[500, 375]])  # Click on truck
input_labels = np.array([1])  # 1 = positive click, 0 = negative click

# Predict masks
with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
    masks, scores, logits = image_predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        multimask_output=True,
    )

# Display results
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

# Original image
axes[0].imshow(image)
axes[0].set_title('Original Image')
axes[0].axis('off')

# Show masks
for i, (mask, score) in enumerate(zip(masks, scores)):
    axes[i+1].imshow(image)
    show_mask(mask, axes[i+1])
    show_points(input_points, input_labels, axes[i+1])
    axes[i+1].set_title(f'Mask {i+1} (Score: {score:.3f})')
    axes[i+1].axis('off')

plt.tight_layout()
plt.show()

print(f"✅ Generated {len(masks)} masks with scores: {scores}")
print(f"🎯 Best mask index: {np.argmax(scores)}")


In [None]:
# Segment with bounding box prompt
image_path = 'images/groceries.jpg'
image = load_image_from_path(image_path)

# Set the image in the predictor
with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
    image_predictor.set_image(image)

# Define bounding box [x1, y1, x2, y2]
input_box = np.array([425, 600, 700, 875])  # Bounding box around an object

# Predict masks
with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
    masks, scores, logits = image_predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_box[None, :],
        multimask_output=False,
    )

# Display results
fig, axes = plt.subplots(1, 2, figsize=(15, 7))

# Original image with bounding box
axes[0].imshow(image)
show_box(input_box, axes[0])
axes[0].set_title('Original Image with Bounding Box')
axes[0].axis('off')

# Segmented result
axes[1].imshow(image)
show_mask(masks[0], axes[1])
show_box(input_box, axes[1])
axes[1].set_title(f'Segmentation Result (Score: {scores[0]:.3f})')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print(f"✅ Generated mask with score: {scores[0]:.3f}")


In [None]:
# Combine multiple prompts for better segmentation
image_path = 'images/truck.jpg'
image = load_image_from_path(image_path)

# Set the image in the predictor
with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
    image_predictor.set_image(image)

# Define multiple prompts
input_points = np.array([[500, 375], [600, 400]])  # Multiple positive clicks
input_labels = np.array([1, 1])  # Both positive
input_box = np.array([425, 300, 700, 500])  # Bounding box

# Predict masks
with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
    masks, scores, logits = image_predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        box=input_box[None, :],
        multimask_output=False,
    )

# Display results
fig, axes = plt.subplots(1, 2, figsize=(15, 7))

# Original image with prompts
axes[0].imshow(image)
show_points(input_points, input_labels, axes[0])
show_box(input_box, axes[0])
axes[0].set_title('Original Image with Combined Prompts')
axes[0].axis('off')

# Segmented result
axes[1].imshow(image)
show_mask(masks[0], axes[1])
show_points(input_points, input_labels, axes[1])
show_box(input_box, axes[1])
axes[1].set_title(f'Combined Prompts Result (Score: {scores[0]:.3f})')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print(f"✅ Generated mask with combined prompts, score: {scores[0]:.3f}")


In [None]:
# Automatic mask generation
image_path = 'images/groceries.jpg'
image = load_image_from_path(image_path)

print("🔄 Generating automatic masks... This may take a moment.")

# Generate masks automatically
with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
    masks = mask_generator.generate(image)

# Display results
fig, axes = plt.subplots(1, 2, figsize=(20, 10))

# Original image
axes[0].imshow(image)
axes[0].set_title('Original Image')
axes[0].axis('off')

# Automatic segmentation
axes[1].imshow(image)
show_anns(masks, axes[1])
axes[1].set_title(f'Automatic Segmentation ({len(masks)} masks)')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print(f"✅ Generated {len(masks)} automatic masks")
print(f"📊 Mask areas: {[m['area'] for m in masks[:5]][:5]}...")  # Show first 5 areas
