<a href="https://colab.research.google.com/github/wyldescience/Cellpose-batch-segmentation-and-counts/blob/main/SAM_crop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Segment Anything Model**

Here I use the [meta](https://github.com/facebookresearch/segment-anything#model-checkpoints) SAM AI (trained on billions of images) that allows for incredible segmentation of images (even by a single click of an object see their [demo](https://segment-anything.com/demo) where you can upload your own image and cut out objects that are perfectly segmented in real time. In this instance I used this model to segment out the filter paper that contains the eggs of *Folsomia candida* to count reproductive output for a study of ageing and lifespan. The reason I wanted to crop out the small piece of filter paper from the background is that the petri dishes were reused between samples and often had stray eggs that I do not want to pick up in my egg counts during segmentation.
After installment of the appropriate packages, the first chunk of this script is for testing/ working with single images.
The second chunk is designed to run on batches of images and produces cropped images to an output folder of your choosing. In a number of cases, the first mask detected by the model is not the filter paper I want to retain and thus the region of interest ends up getting cropped out- the first chunk can be used and can alter the integer in the call: "mask = predictor[0]['segmentation']" until the correct cropping has occurred. The script also shuttles processed images to another folder of your choosing.

In [None]:
!pip install segment_anything
!pip install numpy
!pip install matplotlib
!pip install pillow

For running on an individual image file (handy for any that might not work properly when the first mask is not the filter paper (background i.e., edge of petri or other material in some cases makes the filter paper with eggs get cropped out). Try using second mask  "mask = predictor[0]['segmentation']" for problematic images. This cell will produce a plot of the processed and original image.

In [None]:
import cv2
import numpy as np
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from matplotlib import pyplot as plt
from google.colab.patches import cv2_imshow

MODEL_TYPE = "vit_h"
CHECKPOINT_PATH = "/content/drive/Othercomputers/ThinkPad/Desktop/opencv/sam_vit_h_4b8939.pth"
DEVICE = "cuda"

# Load the SAM model
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)

# Load the image
image_path = "/content/drive/Othercomputers/ThinkPad/Desktop/Folsomia candida/Data/egg count images/reprocess SAM/I1_F1_O20_CON_R5_10-09-23.jpg"
image = cv2.imread(image_path)

# Display the loaded image using cv2_imshow
cv2_imshow(image)
cv2.waitKey(0)
cv2.destroyAllWindows()

# Create a mask generator using SAM
mask_generator = SamAutomaticMaskGenerator(sam)

# Convert image to RGB format for mask generation
img_arr = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Convert image array to float32
# img_arr = img_arr.astype(np.float32) / 255.0  # Normalize pixel values to [0, 1] range

# Generate masks using the mask generator
predictor = mask_generator.generate(img_arr)

# Choose the first masks
mask = predictor[1]['segmentation']

# Remove background by turning it to white
img_arr[mask == False] = [0, 0, 0]

# Display the modified image using matplotlib
plt.imshow(img_arr)
plt.axis('off')
plt.show()


**Batch script to run on multiple image files and save output to folder**

In [None]:
import os
import cv2
import numpy as np
import shutil
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import torch
from PIL import Image
from torch.profiler import profile, record_function, ProfilerActivity

def clear_cuda_memory():
    torch.cuda.empty_cache()

# Set environment variable to help with memory issues on GPU
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'garbage_collection_threshold:0.6,max_split_size_mb:512'

MODEL_TYPE = "vit_b"
CHECKPOINT_PATH = "/content/drive/Othercomputers/ThinkPad/Desktop/Folsomia candida/Final Scripts/sam_vit_b_01ec64.pth"
DEVICE = "cuda"

# Input and output directories
input_dir = "/content/drive/Othercomputers/ThinkPad/Desktop/Folsomia candida/Data/egg count images/reprocess SAM"
output_dir = "/content/drive/Othercomputers/ThinkPad/Desktop/Folsomia candida/Data/egg count images/cropped"
processed_dir = "/content/drive/Othercomputers/ThinkPad/Desktop/Folsomia candida/Data/egg count images/processed"

# List all image files in the input directory
image_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.jpg', '.png', '.tif'))]

# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
os.makedirs(processed_dir, exist_ok=True)

# Process each image file
batch_size = 1  # Adjust the batch size based on your available memory
points_per_batch = 4  # Adjust the points_per_batch FIXED OOM Issue but a lot slower
for i in range(0, len(image_files), batch_size):
    # Load the SAM model for each batch
    sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
    sam.eval()

    # Clear GPU memory before processing a new batch
    clear_cuda_memory()

    for j in range(i, min(i + batch_size, len(image_files))):
        image_file = image_files[j]
        image_path = os.path.join(input_dir, image_file)
        image = cv2.imread(image_path)

        print(f"Processing image: {image_file}")

        # Load and process the image on CPU to reduce GPU memory usage
        img_arr = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Generate masks using the pre-created mask generator
        with torch.no_grad():
            with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
                with record_function("model_inference"):
                    mask_generator = SamAutomaticMaskGenerator(sam, points_per_batch=points_per_batch)
                    predictor = mask_generator.generate(img_arr)
                    mask = predictor[1]['segmentation']  # typically 0 but for erroneous crops try 1

            # Print the memory profile
            print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10))

        # Process the image and mask on CPU
        img_arr[mask == False] = [0, 0, 0]

        # Convert the image to PIL format
        pil_image = Image.fromarray(img_arr)

        # Set the DPI value to 300
        dpi_value = 300
        pil_image.info['dpi'] = (dpi_value, dpi_value)

        # Save the modified image in the output directory using PIL
        # Correct the output path to save in the specified folder
        output_path = os.path.join(output_dir, image_file)
        pil_image.save(output_path, dpi=(dpi_value, dpi_value))

        # Move the original image to the processed originals directory
        processed_image_path = os.path.join(processed_dir, image  _file)
        shutil.move(image_path, processed_image_path)

        # Free memory (No need to delete tensors since we're not using GPU)
        del img_arr, mask, predictor, mask_generator

    # Empty the CUDA cache after processing each batch
    torch.cuda.empty_cache()

print("Processing complete. Images saved in the output directory.")


Processing image: I5_F1_Y25_CON_R4_08-08-23_1.jpg
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        23.95%        1.576s       100.00%        6.581s        6.581s       0.000us         0.00%        4.176s        4.176s             1  
                                       aten::lift_fresh         0.01%     585.000us         0.01%     585.000us       0.210us       0.000us         0.00%    

In [None]:
import gc
import torch

torch.cuda.empty_cache()
gc.collect()

9034570

In [None]:
!nvidia-smi


Sun Jan 14 13:32:29 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   54C    P0              28W /  70W |   8437MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
import gc
def report_gpu():
   print(torch.cuda.list_gpu_processes())
   gc.collect()
   torch.cuda.empty_cache()

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()