Image Segmentation Pipeline using Segment Anything Model (SAM)
This notebook provides an end-to-end pipeline for image segmentation using Meta AI's Segment Anything Model (SAM). It includes functions for loading the model, segmenting images, visualizing results, cropping and collaging segments, and monitoring system resources during processing.

In [1]:
#Import Libraries 
import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gc
from tqdm import tqdm
import psutil
import GPUtil
import threading
import time
import os


KeyboardInterrupt: 

```markdown
## Resource Monitoring

The next cell contains functions for monitoring system resources such as CPU, memory, and GPU usage. These functions will be used to track resource utilization during the image segmentation process.
```

In [2]:
#Resource Monitoring

# Function to get resource usage
def get_resource_usage():
    # Get CPU usage
    cpu_percent = psutil.cpu_percent()
    # Get memory usage
    memory = psutil.virtual_memory()
    memory_percent = memory.percent
    # Get GPU usage
    gpus = GPUtil.getGPUs()
    if gpus:
        gpu = gpus[0]
        gpu_percent = gpu.load * 100
        gpu_memory_used = gpu.memoryUsed
        gpu_memory_total = gpu.memoryTotal
        gpu_memory_percent = (gpu_memory_used / gpu_memory_total) * 100
    else:
        gpu_percent = 0
        gpu_memory_percent = 0
    return f"CPU:{cpu_percent:.1f}%, Mem:{memory_percent:.1f}%, GPU:{gpu_percent:.1f}%, GPU Mem:{gpu_memory_percent:.1f}%"

# Resource monitor function
def resource_monitor(pbar, stop_event, pbar_lock):
    while not stop_event.is_set():
        resource_usage = get_resource_usage()
        with pbar_lock:
            pbar.set_postfix_str(resource_usage)
        time.sleep(1)


```markdown
## Initialize Segment Anything Model (SAM)

In the next cell, we initialize the Segment Anything Model (SAM) using a pre-trained checkpoint. This model is used for automatic mask generation in image segmentation tasks. Below are the parameters used in the initialization function:

- **sam_checkpoint**: Path to the pre-trained SAM model checkpoint file.
- **model_type**: Type of the model architecture. In this case, it is "vit_h".
- **device**: The device on which the model will run. It uses "cuda" if a GPU is available, otherwise it falls back to "cpu".
- **points_per_side**: Number of points to sample per side of the image for mask generation.
- **pred_iou_thresh**: Threshold for the predicted Intersection over Union (IoU) score. Masks with IoU scores below this threshold are discarded.
- **stability_score_thresh**: Threshold for the stability score of the mask. Masks with stability scores below this threshold are discarded.
- **crop_n_layers**: Number of layers to crop from the image.
- **crop_n_points_downscale_factor**: Factor to downscale the number of points when cropping.
- **min_mask_region_area**: Minimum area (in pixels) for a mask region to be considered valid.
```

In [3]:
# Load the Segment Anything Model
def initialize_sam():
    sam_checkpoint = "c:\\Users\\Riley\\Desktop\\sam_vit_h_4b8939.pth"  # Update this path as needed
    model_type = "vit_h"
    device = "cuda" if torch.cuda.is_available() else "cpu"

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    return SamAutomaticMaskGenerator(
        sam,
        points_per_side=8,  # Number of points to sample per side of the image
        pred_iou_thresh=0.90,  # Threshold for the predicted Intersection over Union (IoU) score
        stability_score_thresh=0.95,  # Threshold for the stability score of the mask
        crop_n_layers=0,  # Number of layers to crop from the image
        crop_n_points_downscale_factor=2,  # Factor to downscale the number of points when cropping
        min_mask_region_area=5500,  # Minimum area (in pixels) for a mask region to be considered valid
    )


```markdown
## Output Full Segmented Image

The next cell contains a function to output the full segmented image. This function takes an image and its corresponding masks, combines all the masks into a single mask, applies this mask to the image, and saves the segmented image to the specified output path.
```

In [5]:
# Output the full segmented image
def output_full_segmented_image(image, masks, output_base_path):
    # Create an empty mask with the same dimensions as the image
    full_mask = np.zeros(image.shape[:2], dtype=np.uint8)

    # Combine all masks into the full mask
    for mask in masks:
        full_mask[mask['segmentation']] = 255

    # Apply the mask to the image
    segmented_image = cv2.bitwise_and(image, image, mask=full_mask)

    # Save the segmented image
    base_name, ext = os.path.splitext(output_base_path)
    output_path = f"{base_name}_segmented{ext}"
    cv2.imwrite(output_path, segmented_image)


```markdown
## Main Pipeline for Processing a Folder of Images

The next cell contains the main pipeline function that processes a folder of images. It initializes the Segment Anything Model (SAM), generates segmentations for each image, visualizes and saves the segmentations, and optionally crops and collages the largest masks. The pipeline also includes resource monitoring to track CPU, memory, and GPU usage during processing.
```

In [6]:
# Main pipeline processing a folder of images
def main_pipeline(input_folder, output_folder):
    # Get list of image files in the input folder
    image_extensions = ('.jpg', '.jpeg', '.png')
    image_files = [f for f in os.listdir(input_folder) if f.lower().endswith(image_extensions)]
    num_images = len(image_files)

    if num_images == 0:
        print(f"No images found in {input_folder}.")
        return

    # Ensure the output folder exists
    os.makedirs(output_folder, exist_ok=True)

    pbar_lock = threading.Lock()
    with tqdm(total=num_images, desc='Processing Images', unit='image') as pbar:
        # Start resource monitor thread
        stop_event = threading.Event()
        monitor_thread = threading.Thread(target=resource_monitor, args=(pbar, stop_event, pbar_lock))
        monitor_thread.start()

        try:
            # Initialize SAM once
            with pbar_lock:
                pbar.set_description('Initializing SAM')
            mask_generator = initialize_sam()

            for image_file in image_files:
                image_path = os.path.join(input_folder, image_file)
                output_path = os.path.join(output_folder, image_file)

                with pbar_lock:
                    pbar.set_description(f'Processing {image_file}')
                    # print(f"Processing {image_file}")

                try:
                    masks, image = generate_segmentation(image_path, mask_generator)
                    visualize_and_save_segmentation(image, masks, output_path)
                    crop_and_collage_largest_masks(image, masks, output_path)
                except Exception as e:
                    print(f"Error processing {image_file}: {e}")
                finally:
                    # Clean up to free memory
                    torch.cuda.empty_cache()
                    gc.collect()

                with pbar_lock:
                    pbar.update(1)

            # Clean up SAM model after processing
            del mask_generator
            torch.cuda.empty_cache()
            gc.collect()
        finally:
            # Stop the resource monitor thread
            stop_event.set()
            monitor_thread.join()

# Define missing functions
def generate_segmentation(image_path, mask_generator):
    image = cv2.imread(image_path)
    masks = mask_generator.generate(image)
    return masks, image

def visualize_and_save_segmentation(image, masks, output_path):
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    for mask in masks:
        plt.contour(mask['segmentation'], colors='r')
    plt.axis('off')
    plt.savefig(output_path)
    plt.close()

def crop_and_collage_largest_masks(image, masks, output_path):
    # Placeholder function for cropping and collaging masks
    pass


In [7]:
# Initialize SAM and the mask generator
mask_generator = initialize_sam()


  state_dict = torch.load(f)


In [8]:
# Clean up to free memory
torch.cuda.empty_cache()
gc.collect()


60

In [9]:
#Optional Run with Resource Monitoring
if __name__ == "__main__":
    # Version checking
    if torch.cuda.is_available():
        print("Cuda Version:", torch.version.cuda)
        print("GPU Used:", torch.cuda.get_device_name(0))
        print("Current GPU Code Used:", torch.cuda.current_device())
        print("Number of GPUs installed:", torch.cuda.device_count())
    else:
        print("No GPU available")

    print("Starting...")
    input_folder = "c:\\Users\\Riley\\Desktop\\TestSet"  # Update this path as needed
    output_folder = "C:\\Users\\Riley\\Desktop\\SEGTESTINGFOLER3"  # Update this path as needed


    #Toggle if you want to test on one image or on a folder of images
    # image_path = "path_to_your_image.jpg"  # Update this path
    # output_path = "path_to_output_image.jpg"
    main_pipeline(input_folder, output_folder)


Cuda Version: 11.8
GPU Used: NVIDIA GeForce RTX 2060 SUPER
Current GPU Code Used: 0
Number of GPUs installed: 1
Starting...


Processing 0019_B%2010%200279821.jpg: 100%|██████████| 10/10 [04:05<00:00, 24.51s/image, CPU:19.8%, Mem:52.4%, GPU:9.0%, GPU Mem:66.8%] 
