## Image Segmentation and Clustering

This notebook demonstrates a pipeline for segmenting an image using the Segment Anything Model (SAM), filtering the resulting masks, and then clustering them based on color. It utilizes a pre-trained SAM model to generate segmentation masks, a custom filtering algorithm to handle overlaps, and K-Means clustering to group the masks. Additionally, it leverages the OpenAI API to determine the optimal number of clusters from the elbow method distortions.

### Prerequisites

Before running this notebook, ensure you have installed all the necessary dependencies. You can do this by running the following command in your terminal:

```
pip install -r requirements.txt
```

In [None]:
# Import necessary libraries
import cv2
import torch

# Check if a CUDA-enabled GPU is available and print its name
print("CUDA available:", torch.cuda.is_available())
print("Device name:", torch.cuda.get_device_name(0))

# Set the computation device to CUDA if available, otherwise use the CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

### Load the SAM Model
This cell initializes the Segment Anything Model (SAM) from a specified checkpoint and model type. The model is then moved to the selected computation device (GPU or CPU).

In [None]:
# Import the Segment Anything Model (SAM) components
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry

# Define the path to the model checkpoint and the model type
checkpoint = "/checkpoints/sam_vit_h_4b8939.pth"
model_type = "vit_h"

# Register and initialize the SAM model with the specified checkpoint
sam = sam_model_registry[model_type](checkpoint=checkpoint)
# Move the model to the selected device (GPU/CPU)
sam.to(device)

# Create an automatic mask generator from the SAM model
mask_generator = SamAutomaticMaskGenerator(sam)

### Configure File Paths
This cell sets up the file paths for the input image and defines a naming convention for the output files based on the model type.

In [None]:
import os

# Define the path to the input image
image_path = "/image/test.png"
# Extract the base name of the image file without the extension
image_name = os.path.splitext(os.path.basename(image_path))[0]

# Create a mapping for model types to filename suffixes
suffix_map = {
 "vit_h": "_h",
 "vit_l": "_l",
 "vit_b": "_b"
}
# Get the appropriate suffix for the current model type
model_suffix = suffix_map.get(model_type, "")

# Define the filenames for the output pickle and segmented image files
pkl_filename = f"data_{image_name}{model_suffix}.pkl"
image_output_filename = f"segmented_{image_name}{model_suffix}"

### Generate Image Masks
This cell loads the input image, converts it from BGR to RGB color space, and then uses the `mask_generator` to create segmentation masks. The time taken for mask generation is also measured and printed.

In [None]:
import time
import numpy as np

# Read the input image using OpenCV
image = cv2.imread(image_path)
# Convert the image from BGR (OpenCV's default) to RGB
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Start the timer
start_time = time.time()
# Generate masks for the input image
masks = mask_generator.generate(image_rgb)
# Stop the timer
end_time = time.time()
# Calculate the elapsed time
elapsed_time = end_time - start_time

# Print the number of masks generated and the time taken
print(f"Number of masks generated: {len(masks)}")
print(f"Time taken to generate masks: {elapsed_time:.2f} seconds")

### Visualization Function
The `show_anns` function is defined here to visualize the generated masks. It overlays the masks on a blank canvas with random colors for differentiation.

In [None]:
import matplotlib.pyplot as plt

# Function to display the annotations (masks)
def show_anns(anns):
 # Return if there are no annotations
 if len(anns) == 0:
  return
 # Sort annotations by area in descending order
 sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
 ax = plt.gca()
 ax.set_autoscale_on(False)

 # Create a blank image with an alpha channel
 img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
 img[:, :, 3] = 0

 # Iterate through the sorted annotations and draw them on the image
 for ann in sorted_anns:
  m = ann['segmentation']
  # Generate a random color with an alpha value of 1
  color_mask = np.concatenate([np.random.random(3), [1]])
  # Apply the color to the mask area
  img[m] = color_mask

 # Display the image with the annotations
 ax.imshow(img)

### Display the Segmented Image
This cell displays the original RGB image with the generated masks overlaid on top.

In [None]:
# Display the original RGB image
plt.imshow(image_rgb)
# Overlay the generated masks on the image
show_anns(masks)
# Turn off the axis labels
plt.axis('off')
# The following line is commented out, but can be used to save the figure
# plt.savefig(image_output_filename, bbox_inches='tight', pad_inches=0)
# Show the plot
plt.show()

### Filter Overlapping Masks
The `filter_overlapping_masks` function is a crucial step to refine the segmentation results. It iteratively identifies overlapping masks, creates new masks from the overlapping regions, and removes the original masks if their non-overlapping area falls below a certain percentage. This helps to reduce redundancy and create more distinct segments.

In [None]:
# Function to filter masks by handling overlaps
def filter_overlapping_masks(masks, min_percentage=0.2):
 filtered_masks = masks.copy()
 updated = True

 # Loop until no more updates are made
 while updated:
  updated = False
  print(f"Current number of masks: {len(filtered_masks)}")

  new_masks = []
  masks_to_remove = set()

  # Iterate through pairs of masks to check for overlaps
  for i in range(len(filtered_masks)):
   if filtered_masks[i] is None:
    continue

   mask_i = filtered_masks[i]
   original_area_i = np.sum(mask_i)

   for j in range(i + 1, len(filtered_masks)):
    if filtered_masks[j] is None:
     continue

    mask_j = filtered_masks[j]
    original_area_j = np.sum(mask_j)

    # Find the overlapping region
    overlap = mask_i & mask_j
    overlap_area = np.sum(overlap)

    # Skip if the overlap is not significant for either mask
    if (overlap_area / original_area_i < min_percentage) and (overlap_area / original_area_j < min_percentage):
     continue

    # Add the overlap as a new mask if it exists
    if overlap_area > 0:
     new_masks.append(overlap)

    # Update the original masks by removing the overlap
    mask_i_updated = mask_i & ~overlap
    mask_j_updated = mask_j & ~overlap

    updated_area_i = np.sum(mask_i_updated)
    updated_area_j = np.sum(mask_j_updated)

    print(f"Processing masks {i} and {j}")
    print(f"Overlap area: {overlap_area}")
    print(f"Mask {i} area after removing overlap: {updated_area_i}")
    print(f"Mask {j} area after removing overlap: {updated_area_j}")

    # If the remaining area of a mask is too small, mark it for removal
    if updated_area_i / original_area_i >= min_percentage:
     filtered_masks[i] = mask_i_updated
    else:
     filtered_masks[i] = None
     masks_to_remove.add(i)

    if updated_area_j / original_area_j >= min_percentage:
     filtered_masks[j] = mask_j_updated
    else:
     filtered_masks[j] = None
     masks_to_remove.add(j)

    updated = True

  # Remove the marked masks
  filtered_masks = [filtered_masks[k] for k in range(len(filtered_masks)) if k not in masks_to_remove]

  # Add the new overlap masks to the list
  filtered_masks.extend(new_masks)

  # If new masks were added, continue the loop
  if new_masks:
   updated = True

  return filtered_masks

### Apply the Filtering Algorithm
This cell extracts the boolean segmentation masks from the `masks` dictionary and applies the `filter_overlapping_masks` function to them.

In [None]:
# Extract the 'segmentation' boolean arrays from the list of mask dictionaries
masks_list = [item['segmentation'] for item in masks if 'segmentation' in item]
# Apply the filtering function to handle overlapping masks
filtered_masks = filter_overlapping_masks(masks_list)

### Set Up OpenAI API
This cell initializes the OpenAI client with your API key. This is necessary for programmatically determining the optimal number of clusters later on.

In [None]:
from openai import OpenAI

# NOTE: Replace 'your_api_key' with your actual OpenAI API key
api_key = 'your_api_key'
client = OpenAI(api_key=api_key)

### Caching Functions
To avoid redundant API calls to OpenAI, these functions provide a simple caching mechanism. `load_cache` reads previous results from a JSON file, and `save_cache` writes new results to it.

In [None]:
import json

# Function to load cached data from a file
def load_cache(cache_file):
 if not isinstance(cache_file, str):
  raise TypeError("Expected 'cache_file' to be a string")
 # Check if the cache file exists
 if os.path.exists(cache_file):
  try:
   # Open and load the JSON data
   with open(cache_file, 'r') as f:
    return json.load(f)
  except json.JSONDecodeError:
   print(f"Error reading {cache_file}. File may be corrupted.")
 # Return an empty dictionary if the file doesn't exist or is invalid
 return {}

# Function to save data to the cache file
def save_cache(cache_file, cache_data):
 with open(cache_file, 'w') as f:
  json.dump(cache_data, f)

### Get Optimal Clusters using OpenAI
This function, `get_optimal_clusters_from_openai`, takes a list of distortion values from the elbow method, sends them to the GPT-4 model, and asks for the optimal number of clusters. It uses the caching functions to store and retrieve results for the same set of distortions.

In [None]:
import re

# Function to get the optimal number of clusters from OpenAI's GPT-4
def get_optimal_clusters_from_openai(distortions, cache_file='cache.json'):
 print(f"cache_file type: {type(cache_file)}, value: {cache_file}")

 # Load existing cache
 cache_data = load_cache(cache_file)
 # Create a cache key from the list of distortions
 cache_key = str(distortions)

 # If the result is already in the cache, return it
 if cache_key in cache_data:
  print("Using cached result for distortions:", distortions)
  return cache_data[cache_key]

 # Format the distortions for the prompt
 distortions_str = ', '.join(map(str, distortions))
 # Create the prompt for the GPT-4 model
 prompt = (f"The following are distortions for different cluster numbers: {distortions_str}. "
   "Please suggest the optimal number of clusters based on the Elbow Method, responding with just the number.")

 # Make the API call to OpenAI
 response = client.chat.completions.create(
  model="gpt-4",
  messages=[{"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt}],
  max_tokens=10
 )

 # Get the response text
 response_text = response.choices[0].message.content

 print("GPT-4 response:", response_text)

 # Extract the number from the response
 match = re.search(r'\b\d+\b', response_text)
 optimal_clusters = int(match.group()) if match else 3

 # Save the new result to the cache
 cache_data[cache_key] = optimal_clusters
 save_cache(cache_file, cache_data)

 return optimal_clusters

### Cluster and Visualize Masks
This is the main function for clustering and visualization. It performs the following steps:
1. Calculates the average RGB color for each filtered mask.
2. Uses the elbow method to calculate distortion scores for a range of cluster numbers.
3. Calls `get_optimal_clusters_from_openai` to determine the best number of clusters.
4. Performs K-Means clustering on the RGB values with a fixed initial center.
5. Visualizes the results in a multi-panel plot showing the original image, filtered masks, and clustered masks.
6. Saves the final visualizations to specified output folders.

In [None]:
from sklearn.cluster import KMeans

# Main function to cluster masks and visualize the results
def cluster_and_visualize_masks(output_folder_img, output_folder_final, image_rgb, masks, filtered_masks, n_clusters=3, initial_center=None):
 rgb_values = []

 # Calculate the average RGB value for each mask
 for idx, mask in enumerate(filtered_masks):
  mask_pixels = np.where(mask)

  r_values = image_rgb[mask_pixels[0], mask_pixels[1], 0]
  g_values = image_rgb[mask_pixels[0], mask_pixels[1], 1]
  b_values = image_rgb[mask_pixels[0], mask_pixels[1], 2]

  avg_r = np.mean(r_values)
  avg_g = np.mean(g_values)
  avg_b = np.mean(b_values)

  rgb_values.append([avg_r, avg_g, avg_b])

 rgb_array = np.array(rgb_values)

 # Elbow method to calculate distortions with a fixed initial center
 distortions = []
 initial_center = list(map(int, image_name.split('_')[1].split(',')[:3]))

 K_range = range(2, min(11, len(rgb_array)) + 1)
 for k in K_range:
  # Find initial centers for the remaining clusters
  remaining_centers = KMeans(n_clusters=k - 1, init='k-means++', random_state=0).fit(rgb_array).cluster_centers_
  current_centers = np.vstack([initial_center, remaining_centers])

  # Perform KMeans with the fixed and found centers
  kmeans = KMeans(n_clusters=k, init=current_centers, n_init=1, random_state=0)
  kmeans.fit(rgb_array)
  distortions.append(kmeans.inertia_)

 # Get the optimal number of clusters from OpenAI
 n_clusters = min(get_optimal_clusters_from_openai(distortions), len(rgb_array))

 # Perform the final KMeans clustering with the optimal number of clusters
 remaining_centers = KMeans(n_clusters=n_clusters - 1, init='k-means++', random_state=0).fit(rgb_array).cluster_centers_
 final_centers = np.vstack([initial_center, remaining_centers])
 kmeans = KMeans(n_clusters=n_clusters, init=final_centers, n_init=1, random_state=0)

 kmeans.fit(rgb_array)
 labels = kmeans.labels_

 # Predefined colors for clusters
 cluster_colors = [
  [1, 0, 0, 0.35], [0, 1, 0, 0.35], [1, 1, 0, 0.35],
  [0.5, 0, 1, 0.35], [0, 0, 1, 0.35], [0, 1, 1, 0.35],
  [0.5, 0.5, 0.5, 0.35], [1, 0.5, 0, 0.35], [0.5, 0, 1, 0.35],
  [0, 0.5, 1, 0.35]
 ]

 # Helper function to show masks with random colors
 def show_anns(anns, ax):
  # ... (implementation in original code)

 # Helper function to show clustered masks with labels
 def show_anns_with_clusters_and_labels(anns, labels, rgb_array, ax):
  # ... (implementation in original code)

 # Helper function to show clustered masks
 def show_anns_with_clusters(anns, labels, rgb_array, ax, first_class_color=[0, 0, 1, 0.35]):
  # ... (implementation in original code)

 # Helper function to save each mask individually
 def visualize_each_mask_individually(anns, labels, rgb_array, output_folder):
  # ... (implementation in original code)

 # Create a 2x2 subplot for visualizations
 fig, axs = plt.subplots(2, 2, figsize=(15, 8))

 axs[0, 0].imshow(image_rgb)
 axs[0, 0].set_title('Original Image')
 axs[0, 0].axis('off')

 show_anns(filtered_masks, axs[0, 1])
 axs[0, 1].set_title('Filtered Masks')
 axs[0, 1].axis('off')

 show_anns_with_clusters(filtered_masks, labels, rgb_array, axs[1, 0])
 axs[1, 0].set_title('Clustered Masks')
 axs[1, 0].axis('off')

 show_anns_with_clusters_and_labels(filtered_masks, labels, rgb_array, axs[1, 1])
 axs[1, 1].set_title('Clustered Masks with Labels')
 axs[1, 1].axis('off')

 plt.tight_layout()

 # Save the main visualization figure
 suffix_map = {"vit_h": "_h", "vit_l": "_l", "vit_b": "_b"}
 model_suffix = suffix_map.get(model_type, "")
 image_filename = os.path.join(output_folder_img, f"data_{image_name}{model_suffix}.tiff")
 plt.savefig(image_filename, bbox_inches='tight', pad_inches=0, format="tiff", dpi=300)

 plt.show()

 # Save each mask as an individual image
 output_individual_masks = os.path.join(output_folder_img, f"{image_name}_individual_masks")
 visualize_each_mask_individually(filtered_masks, labels, rgb_array, output_individual_masks)

 return rgb_array, kmeans.cluster_centers_, labels

### Run the Pipeline
This cell sets the output directories, defines the initial cluster center based on the image filename, and then executes the entire clustering and visualization pipeline by calling the `cluster_and_visualize_masks` function.

In [None]:
# Define output folders for the final and intermediate images
output_folder_final = '/output'
output_folder_img = '/output_image'

# Extract the initial cluster center from the image filename
# This assumes the filename contains RGB values (e.g., 'image_255,0,0.png')
initial_center = list(map(int, image_name.split('_')[1].split(',')[:3]))

# Run the main clustering and visualization function
rgb_array, centers, labels = cluster_and_visualize_masks(output_folder_img, output_folder_final, image_rgb, masks, filtered_masks, image_name, model_type, initial_center)

### Print Results
This final cell prints the results of the clustering: the array of average RGB values for each mask, the coordinates of the final cluster centers, and the cluster label assigned to each mask.

In [None]:
# Print the final results from the clustering process
print("RGB values of masks:", rgb_array)
print("Cluster centers:", centers)
print("Labels for each mask:", labels)