# Waste identification with instance segmentation in PyTorch

Welcome to the Instance Segmentation Colab! This notebook will take you through the steps of running an "out-of-the-box" Mask RCNN Instance Segmentation model on image from Detectron2.

To finish this task, a proper path for the model and a single image needs to be provided. The path to the labels on which the models are trained is in the waste_identification_ml directory inside the Tensorflow Model Garden repository. The label files are inferred automatically for the model.

## RESTART the colab notebook after installing packages of Detectron2.

In [None]:
# Clone the Detectron2 repository and install the required packages.
# Relax as installing packages might take a while.
!git clone 'https://github.com/facebookresearch/detectron2'
!pip install 'git+https://github.com/facebookresearch/detectron2.git'

In [None]:
# Install supervision package for the postprocessing of output results
# from Detectron2 Mask RCNN model.
!pip install -q supervision

## Clone the TF Model Garden repo where the waste identification project is located.

In [None]:
!git clone --depth 1 https://github.com/tensorflow/models 2>/dev/null

## Imports and Setup

In [None]:
# Third-Party Imports
import csv
import torch
import cv2
import matplotlib.pyplot as plt
import supervision as sv
from PIL import Image

# Detectron2 Imports
import detectron2
from detectron2.utils.logger import setup_logger
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.structures import Instances, Boxes
from detectron2.data.catalog import Metadata
from detectron2.utils.visualizer import Visualizer

# Setup Detectron2 Logger
setup_logger()

In [None]:
#@title Utilities


def convert_detections_to_instances(
    outputs: dict,
    image_size: tuple[int, int] = (1024, 1024),
    nms_threshold: float = 0.8,
    class_agnostic: bool = True
) -> dict[str, Instances]:
    """Convert Detectron2 model outputs to an Instances object with Non-Maximum Suppression (NMS) applied.

    Args:
        outputs: Detectron2 model output containing instance predictions.
        image_size: Image dimensions (height, width).
        nms_threshold: Non-Maximum Suppression (NMS) threshold.
        class_agnostic: Whether NMS should be applied in a class-agnostic manner.

    Returns:
        Reformatted Detectron2 output as {"instances": Instances}.
    """
    # Apply NMS and convert to supervision Detections format
    detections = (
        sv.Detections.from_detectron2(outputs)
        .with_nms(threshold=nms_threshold, class_agnostic=class_agnostic)
    )

    # Convert extracted values to PyTorch tensors
    bboxes = torch.tensor(detections.xyxy, dtype=torch.float32)
    scores = torch.tensor(detections.confidence, dtype=torch.float32)
    classes = torch.tensor(detections.class_id, dtype=torch.int64)

    # Create an Instances object
    output_instances = Instances(image_size)
    output_instances.set("pred_boxes", Boxes(bboxes))
    output_instances.set("scores", scores)
    output_instances.set("pred_classes", classes)

    # Add masks if available
    if detections.mask is not None:
        masks = torch.tensor(detections.mask, dtype=torch.uint8)
        output_instances.set("pred_masks", masks)

    return {"instances": output_instances}


def read_csv(file_path: str) -> list[str]:
  """Reads a CSV file and returns its contents as a list.

  This function reads the given CSV file, skips the header, and assumes
  there is only one column in the CSV. It returns the contents as a list of
  strings.

  Args:
      file_path: The path to the CSV file.

  Returns:
      The contents of the CSV file as a list of strings.
  """
  data_list = []
  with open(file_path, 'r') as csvfile:
    reader = csv.reader(csvfile)
    for row in reader:
      data_list.append(row[0])
  return data_list

## Import and load the labels.

In [None]:
LABELS_PATH = (
    'models/official/projects/waste_identification_ml/pre_processing/'
    'config/data/45_labels.csv'
)

labels = read_csv(LABELS_PATH)

my_metadata = Metadata()
my_metadata.set(thing_classes=labels)

## Import Detectron2 Mask RCNN model.

In [None]:
%%bash
wget https://storage.googleapis.com/tf_model_garden/vision/\
waste_identification_ml/Detectron2_Jan2025_1024_1024.zip

unzip Detectron2_Jan2025_1024_1024.zip > /dev/null 2>&1

## Load the model

In [None]:
# Initialize the Detectron2 configuration object
cfg = get_cfg()

# Load the model configuration from a YAML file.
cfg.merge_from_file("config.yaml")

# Set the confidence threshold.
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5

# Specify the path to the trained model weights.
cfg.MODEL.WEIGHTS = "model_final.pth"

# Create a predictor object using the configured model.
predictor = DefaultPredictor(cfg)

## Import and load an image

In [None]:
# Path to a sample image stored in the repo.
IMAGES_FOR_TEST = {
    'Image1': (
        'models/official/projects/waste_identification_ml/pre_processing/'
        'config/sample_images/image_2.png'
    )
}

# The model is trained on 1024 x 1024 image dimensions
HEIGHT = 1024
WIDTH = 1024

In [None]:
original_image = cv2.imread(IMAGES_FOR_TEST['Image1'])
original_height, original_width = original_image.shape[:2]

resized_image = cv2.resize(
    original_image,
    (WIDTH, HEIGHT),
    interpolation=cv2.INTER_AREA
)

## Perform prediction

In [None]:
outputs = predictor(resized_image)
outputs = convert_detections_to_instances(outputs)

## Visualize the results

In [None]:
# Extract the predicted instances
instances = outputs["instances"].to("cpu")

# Rescale bounding boxes back to the original image size
scale_x = original_width / WIDTH
scale_y = original_height / HEIGHT
instances.pred_boxes.scale(scale_x, scale_y)

In [None]:
# Resize masks to match the original image size
if instances.has("pred_masks"):
    pred_masks = instances.pred_masks.numpy()  # Convert to NumPy array
    resized_masks = []

    for mask in pred_masks:
        resized_mask = cv2.resize(
            mask.astype("uint8"),
            (original_width, original_height),
            interpolation=cv2.INTER_NEAREST
        )
        resized_masks.append(resized_mask)

    instances.pred_masks = torch.tensor(resized_masks, dtype=torch.uint8)

# Initialize the visualizer with the original image
visualizer = Visualizer(
    img_rgb=original_image,  # Use the original image
    metadata=my_metadata,  # Metadata containing class labels, colors, etc.
    scale=1  # Scale factor for visualization
)

# Draw predictions on the original image
visualized_image = visualizer.draw_instance_predictions(instances).get_image()

# Convert BGR to RGB for correct visualization in Matplotlib
visualized_image = visualized_image[:, :, ::-1]

# Display the final image with predictions overlaid on the original image
plt.figure(figsize=(20, 20))
plt.axis("off")
plt.imshow(visualized_image)
plt.show()