# Annotation Quality Control using Gemini

Welcome to this Colab Notebook, designed to help you analyze and verify instance segmentation annotations stored in COCO JSON format. Accurate annotations are critical for training high-performance computer vision models, and this notebook provides a quality assurance pipeline to detect potential issues using Google’s Gemini AI.

Choose the Gemini model ID for this notebook from - [click here](https://ai.google.dev/gemini-api/docs/models/gemini)

In [None]:
!pip install --upgrade --quiet google-genai

In [None]:
#@title Imports

import sys
from google.colab import auth
from google import genai
from PIL import Image
import io
import os
import requests
from io import BytesIO
from google.cloud import storage
import csv
import subprocess
from typing import Any
import json
import cv2
import matplotlib.pyplot as plt
import numpy as np
import re

from google.genai.types import (
    FunctionDeclaration,
    GenerateContentConfig,
    GoogleSearch,
    Part,
    Retrieval,
    SafetySetting,
    Tool,
    VertexAISearch,
)

In [None]:
# Authenticate colab notebook.
if "google.colab" in sys.modules:
  auth.authenticate_user()

In [None]:
#@title Utils

def read_csv(file_path: str) -> list[str]:
  """Reads a CSV file and returns its contents as a list.

  This function reads the given CSV file, skips the header, and assumes
  there is only one column in the CSV. It returns the contents as a list of
  strings.

  Args:
      file_path: The path to the CSV file.

  Returns:
      The contents of the CSV file as a list of strings.
  """
  data_list = []
  with open(file_path, mode='r') as csvfile:
    reader = csv.reader(csvfile)
    for row in reader:
      if row:  # Ensure the row is not empty
       data_list.append(row[0])  # Assuming there is only one column in the CSV
  return data_list


def read_json(file_path: str) -> dict[str, Any]:
    """
    Reads a JSON file and returns its contents as a dictionary.

    Args:
        file_path: Path to the JSON file.

    Returns:
        The parsed JSON content.
    """
    try:
        with open(file_path, mode="r", encoding="utf-8") as json_file:
            return json.load(json_file)
    except FileNotFoundError as err:
        raise FileNotFoundError(f"File not found: {file_path}") from err
    except json.JSONDecodeError as err:
        raise json.JSONDecodeError(f"Invalid JSON format in file: {file_path}", doc=str(err.doc), pos=err.pos) from err


def convert_bbox_coco_to_xyxy(bbox: list) -> list:
    """Converts a COCO bounding box format.

    Convert [x, y, width, height] to [x1, y1, x2, y2] format.

    Args:
        bbox: A bounding box in COCO format [x, y, width, height].

    Returns:
        Converted bounding box in [x1, y1, x2, y2] format.
    """
    x1 = bbox[0]
    y1 = bbox[1]
    x2 = x1 + bbox[2]  # x1 + width
    y2 = y1 + bbox[3]  # y1 + height
    return [x1, y1, x2, y2]

## GCP Config

In [None]:
MODEL_ID = "gemini-2.0-flash-001" # @param {type: "string", placeholder: "[your-model-id]", isTemplate: true}
PROJECT_ID = "projectidgoeshere"  # @param {type: "string", placeholder: "[your-project-id]", isTemplate: true}
LOCATION = os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")

# Gemini 2.0 Client - authentication through GCP/Vertex AI
client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)

## Download the labels

In [None]:
url = (
    "https://raw.githubusercontent.com/tensorflow/models/refs/heads/master/"
    "official/projects/waste_identification_ml/pre_processing/config/data/45_labels.csv"
)

subprocess.run(["wget", url])

In [None]:
labels = read_csv('45_labels.csv')
labels_mapping = {i:j for i,j in enumerate(labels, start=1)}

## Prompt

## Download COCO JSON annotation file & Image

In [None]:
# Download sample image file.
image_url = (
    "https://raw.githubusercontent.com/tensorflow/models/refs/heads/master/"
    "official/projects/waste_identification_ml/pre_processing/config/"
    "sample_images/image_2.png"
)

subprocess.run(["wget", image_url])

In [None]:
# Download sample COCO JSON file.
json_url = (
    "https://raw.githubusercontent.com/tensorflow/models/refs/heads/master/"
    "official/projects/waste_identification_ml/pre_processing/config/"
    "sample_json/gemini_sample.json"
)

subprocess.run(["wget", json_url])

In [None]:
# Read COCO JSON file.
json_coco_data = read_json('/content/gemini_sample.json')

In [None]:
## Load an image.
image = cv2.imread('/content/image_2.png')
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.imshow(image_rgb)
plt.axis('off')
plt.show()

## Inference

In [None]:
prompt = """
<OBJECTIVE_AND_PERSONA>
You're a Object Annotation Data Quality Checker.

You will be given 1) an image of a object and 2) the annotation label for the object - please confirm if the annotation label is correct for the object.
If it's not correct, please 1) provide reasoning why it's not correct, and 2) provide the correct annotation label from this list of labels: {labels}

For context, the object images you will be given are from a waste pile or on a conveyor belt.

<OUTPUT_FORMAT>
Structure your output as a JSON like the following:
(
  "original_label": [insert original label here],
  "correct": [insert True or False here],
  "reasoning": [insert reasoning on why it's correct or not here],
  "correct_label": [insert correct label or N/A if already correct here]
)
""".format(labels=labels)

In [None]:
for annotation in json_coco_data['annotations']:
  annotated_label = labels_mapping[annotation['category_id']]

  # Convert bbox formatfrom x, y, width, height to x1, y1, x2, y2 format.
  x1, y1, x2, y2 = convert_bbox_coco_to_xyxy(annotation['bbox'])

  # Get the image of an object using bbox.
  cropped_image_rgb_coords = image_rgb[y1:y2, x1:x2]
  cropped_image = Image.fromarray(cropped_image_rgb_coords)
  cropped_image.thumbnail([256,256])


  print(f"Original Label: {annotated_label}")
  plt.imshow(cropped_image)
  plt.axis('off')
  plt.show()

  response = client.models.generate_content(
    model=MODEL_ID,
    contents=[
        cropped_image,
        prompt + "\nAnnotation Label: " + annotated_label
      ]
  )

  print(response.text)

In [None]:
# Extract generated text
if response.candidates:
    raw_text = response.candidates[0].content.parts[0].text  # Get the text output

    # Remove the triple backticks and language identifier (`json`)
    json_string = re.sub(r"```json\n|\n```", "", raw_text).strip()

    try:
        response_json = json.loads(json_string)  # Convert to dictionary
        print(response_json)  # Print parsed JSON

        # Access specific fields if needed
        original_label = response_json.get("original_label")
        correct = response_json.get("correct")
        reasoning = response_json.get("reasoning")
        correct_label = response_json.get("correct_label")

        print(f"Original Label: {original_label}")
        print(f"Correct: {correct}")
        print(f"Reasoning: {reasoning}")
        print(f"Correct Label: {correct_label}")

    except json.JSONDecodeError:
        print("Failed to decode JSON. Raw text:", json_string)
else:
    print("No candidates returned in the response.")