# Advanced Garment Extraction Pipeline using Segment Anything Model (SAM) and Segment Anything Model 2 (SAM 2)

## 1. Introduction

This notebook implements a state-of-the-art image segmentation pipeline for accurately extracting garment areas from sample images. It utilizes both the Segment Anything Model (SAM) and its successor, SAM2, developed by Facebook AI Research. The pipeline is coupled with an interactive Gradio interface for user-friendly mask generation.


## Table of Contents

1. [Setup and Installation](#setup-dependencies)
2. [Importing Required Libraries](#import-libraries)
3. [Model Loading](#model-loading)
4. [Core Functionality](#core-functionality)
5. [Image Processing Pipeline](#image-processing-pipeline)
6. [Gradio Interface](#gradio-interface)
7. [Execution](#execution)





In [1]:
# @markdown ## 2. Setup and Installation <a name="setup-dependencies"></a>
# @markdown - First, we install the necessary dependencies and download the pre-trained models.

!pip install -q gradio opencv-python matplotlib

# @markdown **Install SAM and SAM2**
!pip install -q 'git+https://github.com/facebookresearch/segment-anything.git'
!pip install -q 'git+https://github.com/facebookresearch/segment-anything-2.git'

# @markdown **Download model weights**
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
!wget -q https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.4/50.4 kB[0m [31m974.2 kB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.6/12.6 MB[0m [31m75.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m318.7/318.7 kB[0m [31m19.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.6/75.6 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m141.9/141.9 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m117.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.8/62.8 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
# @markdown ## 3. Import Required Libraries <a name="import-libraries"></a>

import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
import gradio as gr

# @markdown - SAM imports
from segment_anything import sam_model_registry, SamPredictor

# @markdown - SAM2 imports
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

  OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()


In [3]:
# @markdown ## 4. Model Loading Functions <a name="model-loading"></a>
# @markdown - We define functions to load both SAM and SAM2 models, allowing for flexibility in model selection.

def load_sam_model(checkpoint="sam_vit_h_4b8939.pth", model_type="vit_h", device="cuda"):
    """
    Load the Segment Anything Model (SAM).

    Args:
    checkpoint (str): Path to the SAM checkpoint file.
    model_type (str): Type of the SAM model (e.g., "vit_h", "vit_l", "vit_b").
    device (str): Device to load the model on ("cuda" or "cpu").

    Returns:
    SamPredictor: Loaded SAM predictor object.
    """
    sam = sam_model_registry[model_type](checkpoint=checkpoint)
    sam.to(device=device)
    return SamPredictor(sam)

def load_sam2_model(model_cfg="sam2_hiera_l.yaml", sam2_checkpoint="sam2_hiera_large.pt", device="cuda"):
    """
    Load the Segment Anything Model 2 (SAM2).

    Args:
    model_cfg (str): Path to the SAM2 model configuration file.
    sam2_checkpoint (str): Path to the SAM2 checkpoint file.
    device (str): Device to load the model on ("cuda" or "cpu").

    Returns:
    SAM2ImagePredictor: Loaded SAM2 predictor object.
    """
    sam2_model = build_sam2(config_file=model_cfg, ckpt_path=sam2_checkpoint, device=device)
    return SAM2ImagePredictor(sam2_model)

# @markdown **Initialize the SAM & SAM2 models**
sam_predictor = load_sam_model()
sam2_predictor = load_sam2_model()

In [4]:
# @markdown ## 5. Core Functionality <a name="core-functionality"></a>
# @markdown - Generate a mask for the selected region in the image.
# @markdown - Apply the generated mask to the input image.

def generate_mask(image, point_coords):
    """
    Generate a mask for the selected region in the image.

    Args:
    image (numpy.ndarray): Input image.
    point_coords (list): List of coordinates [x, y] selected by the user.

    Returns:
    numpy.ndarray: Generated binary mask.
    """
    global active_predictor
    active_predictor.set_image(image)
    input_point = np.array([point_coords])
    input_label = np.array([1])  # 1 indicates a foreground point

    masks, _, _ = active_predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=False,
    )
    return masks[0]

def mask_image(image, mask):
    """
    Apply the generated mask to the input image.

    Args:
    image (numpy.ndarray): Input image.
    mask (numpy.ndarray): Binary mask.

    Returns:
    numpy.ndarray: Masked image.
    """
    masked_image = image.copy()
    masked_image[mask == 0] = [0, 0, 0]  # Set background to black
    return masked_image.astype(np.uint8)

In [5]:
# @markdown ## 6. Image Processing Pipeline <a name="image-processing-pipeline"></a>
# @markdown - Process the input image and generate a masked output.
# @markdown - Ensure image is in the correct format (H, W, C) and uint8.
# @markdown - Convert mask to RGB for display.

def process_image(image, evt: gr.SelectData):
    """
    Process the input image and generate a masked output.

    Args:
    image (numpy.ndarray): Input image.
    evt (gr.SelectData): Event data containing selected coordinates.

    Returns:
    tuple: Tuple containing the original image, masked image, and mask.
    """
    if image is None:
        return None, None, None

    # Ensure image is in the correct format (H, W, C) and uint8
    if len(image.shape) == 2:
        image = np.stack([image] * 3, axis=-1)
    elif image.shape[2] == 4:
        image = image[:, :, :3]

    image = (image * 255).astype(np.uint8) if image.dtype == np.float32 else image.astype(np.uint8)

    point_coords = [evt.index[0], evt.index[1]]
    mask = generate_mask(image, point_coords)
    masked_image = mask_image(image, mask)

    # Convert mask to RGB for display
    mask_rgb = np.stack([mask] * 3, axis=-1).astype(np.uint8) * 255

    return image, masked_image, mask_rgb

In [6]:
# @markdown ## 7. Gradio Interface <a name="gradio-interface"></a>
# @markdown - Create a Gradio interface for user-friendly mask generation.

active_predictor = sam_predictor
def toggle_model(choice):
    global active_predictor
    active_predictor = sam_predictor if choice == "SAM" else sam2_predictor
    return f"Active Model: {choice}"

with gr.Blocks() as demo:
    gr.Markdown("# Advanced Garment Extraction using SAM and SAM2")
    gr.Markdown("Upload an image, select a model, and click on a region to generate a mask.")

    with gr.Row():
        input_image = gr.Image(label="Input Image", type="numpy")
        masked_output = gr.Image(label="Masked Output")
        mask_output = gr.Image(label="Generated Mask")

    model_choice = gr.Radio(["SAM", "SAM2"], label="Select Model", value="SAM")
    model_status = gr.Textbox(label="Active Model", value="Active Model: SAM")

    input_image.select(process_image, inputs=[input_image], outputs=[input_image, masked_output, mask_output])
    model_choice.change(toggle_model, inputs=[model_choice], outputs=[model_status])

In [7]:
# @markdown ## 8. Launch the Application <a name="execution"></a>
# @markdown - Run the Gradio application to start the mask generation process.

# Launch the Gradio application
demo.launch(debug=False)

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://345f8517668df6d33e.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




In [8]:
# @markdown ## 9. Conclusion
# @markdown - This advanced garment extraction pipeline leverages both SAM and SAM2 models, providing a flexible and powerful solution for image segmentation tasks. The interactive Gradio interface allows for easy model switching and real-time mask generation, making it an ideal tool for both research and practical applications in computer vision and fashion technology.