# COCO Annotation Visualizer for Instance Segmentation

This Colab notebook is designed to visualize annotations from a single merged COCO JSON file containing instance segmentation annotations for multiple images. We will use Detectron2 to display images along with their corresponding segmentation masks, bounding boxes, and category labels.

In [None]:
# Install and RESTART the runtime.
!git clone 'https://github.com/facebookresearch/detectron2'
!pip install 'git+https://github.com/facebookresearch/detectron2.git'

In [None]:
#@title Imports

from google.colab import drive
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances
from detectron2.utils.visualizer import Visualizer
from google.colab.patches import cv2_imshow
import json
import random
import os
import cv2
from collections import Counter

In [None]:
# Connect to google drive in case you have data there.
drive.mount('/content/gdrive')
try:
  !ln -s /content/gdrive/My\ Drive/ /mydrive
  print('Successful')
except Exception as e:
  print(e)
  print('Not successful')

In [None]:
#@title Utils

def unregister_dataset(dataset_name: str) -> None:
    """
    Removes the specified dataset from the MetadataCatalog and DatasetCatalog.

    Args:
      dataset_name: Name of the dataset to be removed.
    """
    if dataset_name in MetadataCatalog.list():
        MetadataCatalog.pop(dataset_name)

    if dataset_name in DatasetCatalog.list():
        DatasetCatalog.remove(dataset_name)


def read_coco_json(json_path: str) -> dict:
    """
    Reads a COCO JSON file and returns its contents as a dictionary.

    Args:
        json_path: Path to the COCO JSON file.

    Returns:
        Parsed JSON data as a dictionary.
    """
    with open(json_path) as file:
        return json.load(file)


def filter_dataset_by_category(dataset_dicts: list, category_id: int) -> list:
    """
    Filters a dataset to only include images containing a specific category.

    Args:
        dataset_dicts: List of dataset dictionaries containing image and annotation data.
        category_id: Category ID to filter annotations by.

    Returns:
        A new dataset dictionary containing only images with the specified category.
    """
    filtered_dataset_dicts = []

    for data in dataset_dicts:
        # Filter annotations by the given category_id
        filtered_annotations = [ann for ann in data['annotations'] if ann['category_id'] == category_id]

        # If there are annotations with the specified category_id, add the image data to filtered dataset
        if filtered_annotations:
            # Create a copy of the original dictionary to modify safely
            filtered_data = data.copy()
            # Update annotations to only include those with the specified category_id
            filtered_data['annotations'] = filtered_annotations
            filtered_dataset_dicts.append(filtered_data)

    return filtered_dataset_dicts


def get_object_counts_per_category(data):
    """
    Counts the number of objects for each category in a JSON file, including category ID,
    and sorts the output by ascending category ID.

    Args:
      data: Parsed JSON data.

    Returns:
      A list of tuples where each tuple contains (category_name, category_id, count),
            sorted by ascending category_id.
    """
    # Map category IDs to names
    category_id_to_name = {category['id']: category['name'] for category in data.get('categories', [])}

    # Count objects by category ID
    counts_by_id = Counter(anno['category_id'] for anno in data.get('annotations', []))

    # Translate category IDs to names and include ID in the output
    counts_by_name_and_id = [
        (cat_id, category_id_to_name[cat_id], count)
        for cat_id, count in counts_by_id.items()
    ]

    # Sort by category ID
    counts_by_name_and_id.sort(key=lambda x: x[0])

    return counts_by_name_and_id

In [None]:
# Provide the path to images and corresponding merged coco json file.
images_dir = '/mydrive/circularnet/Client_ManualAnnotation_data/Benjamin/batch1/images/' #@param {type:"string"}
coco_json = '/mydrive/circularnet/Client_ManualAnnotation_data/Benjamin/batch1/annotations/bottle_vs_non-bottle.json' #@param {type:"string"}
data = read_coco_json(coco_json)

## Visualize randoom images of all categories

In [None]:
# Number of images to visualize randomly.
num_of_images = 25   #@param {type:"integer"}

# Unregister the previous dataset
unregister_dataset("dataset")

register_coco_instances("dataset", {}, coco_json, images_dir)
metadata = MetadataCatalog.get("dataset")
dataset_dicts = DatasetCatalog.get("dataset")

# Visualize images randomly.
for image in random.sample(dataset_dicts, num_of_images):
    print(os.path.basename(image["file_name"]))
    img = cv2.imread(image["file_name"])
    visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, scale=0.5)
    out = visualizer.draw_dataset_dict(image)
    cv2_imshow(out.get_image()[:, :, ::-1])
    print()

## Visualize by image name

In [None]:
image_name =  'google_benjamin_669154948um_d5bcad7c584849d7557e74704701be2cd9bd6326097bb2e7902f919dc5b6284e.jpeg' #@param {type:"string"}

# Find the image in the dataset
target_image = None
for d in dataset_dicts:
    if os.path.basename(d["file_name"]) == image_name:
        target_image = d
        break

# Visualize the target image if found
if target_image:
    print("Found Image:", os.path.basename(target_image["file_name"]))
    img = cv2.imread(target_image["file_name"])
    visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, scale=0.5)
    out = visualizer.draw_dataset_dict(target_image)
    cv2_imshow(out.get_image()[:, :, ::-1])  # Use cv2.imshow if running locally
else:
    print(f"Image '{image_name}' not found in the dataset.")

## Visualize random images of a single category

In [None]:
# Check the category IDs.
data['categories']

In [None]:
num_of_images = 10   #@param {type:"integer"}
category_id_to_filter = 1 #@param {type:"integer"}

unregister_dataset("filtered_dataset")
filtered_dataset_dicts = filter_dataset_by_category(dataset_dicts, category_id_to_filter - 1)
DatasetCatalog.register("filtered_dataset", lambda: filtered_dataset_dicts)
MetadataCatalog.get("filtered_dataset").set(thing_classes=MetadataCatalog.get("dataset").thing_classes)

for image in random.sample(filtered_dataset_dicts, num_of_images):
    print(os.path.basename(image["file_name"]))
    img = cv2.imread(image["file_name"])
    visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, scale=0.5)
    out = visualizer.draw_dataset_dict(image)
    cv2_imshow(out.get_image()[:, :, ::-1])
    print()

## Visualize images of all categories

In [None]:
# Get the image counts per category in order to display the minimum images.
category_counts = get_object_counts_per_category(data)
category_counts

In [None]:
max_samples_per_category = 6  #@param {type:"integer"}

# Define range of categories to visualize
CATEGORY_RANGE_START = 1  #@param {type:"integer"}
CATEGORY_RANGE_END = 1  #@param {type:"integer"}
CATEGORY_RANGE = (CATEGORY_RANGE_START, CATEGORY_RANGE_END)  # User-defined range

# Extract category names and image counts
category_names = [category[1] for category in category_counts]  # List of category names
image_counts_per_category = [category[2] for category in category_counts]  # Number of images per category

# Define max number of images to visualize per category
max_samples_per_category = 6

# Iterate over category indices and corresponding data
for category_index, (category_name, image_count) in enumerate(zip(category_names, image_counts_per_category), start=1):

    # Ensure we only process categories within the defined range
    if category_index < CATEGORY_RANGE[0] or category_index > CATEGORY_RANGE[1]:
        continue

    # Determine the number of images to sample
    num_images_to_sample = min(max_samples_per_category, image_count)

    print(f"Category ID: {category_index} | Name: {category_name} | Image Count: {image_count}\n")

    # Unregister previous dataset if it exists
    unregister_dataset("filtered_dataset")

    # Filter dataset for the given category
    filtered_annotations = filter_dataset_by_category(dataset_dicts, category_index - 1)

    # Register filtered dataset
    DatasetCatalog.register("filtered_dataset", lambda d=filtered_annotations: d)
    MetadataCatalog.get("filtered_dataset").set(thing_classes=MetadataCatalog.get("dataset").thing_classes)

    # Randomly sample images and visualize
    for annotation in random.sample(filtered_annotations, num_images_to_sample):
        print(os.path.basename(annotation["file_name"]))
        image = cv2.imread(annotation["file_name"])

        # Create visualizers for original and annotated images
        visualizer_original = Visualizer(image[:, :, ::-1], metadata=metadata, scale=0.4)
        original_output = visualizer_original.get_output()

        visualizer_annotated = Visualizer(image[:, :, ::-1], metadata=metadata, scale=0.4)
        annotated_output = visualizer_annotated.draw_dataset_dict(annotation)

        # Combine original and annotated images
        combined_image = cv2.hconcat([original_output.get_image()[:, :, ::-1], annotated_output.get_image()[:, :, ::-1]])

        # Display the combined image
        cv2_imshow(combined_image)
        print()
