In [2]:
import numpy as np
from skimage.measure import label, regionprops
import matplotlib.pyplot as plt

# Function to perform instance segmentation
def instance_segmentation(mask):
    # Label connected components in the mask
    labeled_mask = label(mask)
    
    # Get properties of labeled regions
    regions = regionprops(labeled_mask)
    
    # Create an empty array to store individual instances
    instance_masks = np.zeros_like(mask)
    
    # Iterate over regions and extract individual instances
    instance_count = 0
    for region in regions:
        # Extract coordinates of the bounding box
        minr, minc, maxr, maxc = region.bbox
        
        # Extract individual instance and assign a unique label
        instance_mask = (labeled_mask[minr:maxr, minc:maxc] == region.label).astype(np.uint8)
        instance_masks[minr:maxr, minc:maxc] += instance_mask * (instance_count + 1)
        instance_count += 1
    
    return instance_masks, instance_count

# Example usage
# Choose an example semantic segmentation mask from the predictions
example_mask = preds[0].squeeze() > 0.5  # Assuming threshold of 0.5 for binary mask

# Perform instance segmentation
instance_masks, instance_count = instance_segmentation(example_mask)

# Plot original mask and instance segmentation masks
plt.figure(figsize=(10, 5))
plt.subplot(1, instance_count + 1, 1)
plt.imshow(example_mask, cmap='gray')
plt.title('Original Mask')
plt.axis('off')

for i in range(instance_count):
    plt.subplot(1, instance_count + 1, i + 2)
    plt.imshow(instance_masks == (i + 1), cmap='jet')  # Each instance has a unique color
    plt.title(f'Instance {i + 1}')
    plt.axis('off')

plt.show()
