# Waste identification with instance segmentation in PyTorch

This Colab notebook demonstrates an end-to-end pipeline for object detection, feature extraction, object tracking, and data aggregation using the Mask R-CNN model from Detectron2.
Key Steps in the Notebook:



*   Object Detection and segmentation – Detect objects in a set of images using Mask R-CNN.
*   Feature Extraction & Tracking – Extract object features and track them across multiple frames to eliminate duplicate counts.
*   Color Detection – Identify the color of each detected object.
*   Postprocessing – Aggregate tracking results and apply filtering to reduce false positives and false negatives.
*   Save detection and tracking results.










To finish this task, a proper path for the trained model and images need to be provided. The path to the labels on which the models are trained is in the waste_identification_ml directory inside the Tensorflow Model Garden repository.

This notebook will output 3 folders and 1 csv file :


*   **prediction_folder** : Will contain prediction results with bbox and masks.
*   **tracking** : Will contain tracking visualization.
*   **cropped_objects** : Will contain category level detected objects.
*   **count.csv** : Will contain the individual counts of each category.






## Install Detectron2 and RESTART the runtime

In [None]:
!git clone 'https://github.com/facebookresearch/detectron2'
!pip install 'git+https://github.com/facebookresearch/detectron2.git'

In [None]:
#@title Imports and Setup

!pip install -q supervision trackpy openpyxl==3.1.2

import sys
import tensorflow as tf
import csv
from typing import Any, TypedDict, Callable
import cv2
import logging
import numpy as np
import matplotlib.pyplot as plt
import glob
import natsort
import tqdm
import os
from PIL import Image
from scipy import ndimage
import pandas as pd
import skimage
import datetime
import trackpy as tp
import shutil
import supervision as sv

# Detectron2 Utilities
import torch
import detectron2
from detectron2.utils.logger import setup_logger
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.structures import Instances, Boxes
from detectron2.data.catalog import Metadata
from detectron2.utils.visualizer import Visualizer
setup_logger()


logging.disable(logging.WARNING)

%matplotlib inline

In [None]:
# Connect to Google drive if your data is stored there.
from google.colab import drive
drive.mount('/content/gdrive')

try:
  !ln -s /content/gdrive/My\ Drive/ /mydrive
  print('Successful')
except Exception as e:
  print(e)
  print('Not successful')

In [None]:
# # Connect to GCP bucket if your data is store there and copy them locally.
# !gcloud init
# gsutil -m cp -r gs://input .

To visualize the images with the proper detected boxes and segmentation masks, we will use the TensorFlow Object Detection API. To install it we will clone the repo.



In [None]:
# Clone the tensorflow models repository.
!git clone --depth 1 https://github.com/tensorflow/models 2>/dev/null

In [None]:
sys.path.append('models/official/projects/waste_identification_ml/model_inference/')
import color_and_property_extractor

In [None]:
#@title Utilities

_PROPERTIES = (
    'area',
    'bbox',
    'convex_area',
    'bbox_area',
    'major_axis_length',
    'minor_axis_length',
    'eccentricity',
    'centroid',
    'label',
    'mean_intensity',
    'max_intensity',
    'min_intensity',
    'perimeter'
)

def convert_detections_to_instances(
    outputs: dict,
    image_size: tuple[int, int] = (1024, 1024),
    nms_threshold: float = 0.8,
    class_agnostic: bool = True
) -> dict[str, Instances]:
    """Convert Detectron2 model outputs to an Instances object with Non-Maximum Suppression (NMS) applied.

    Args:
        outputs: Detectron2 model output containing instance predictions.
        image_size: Image dimensions (height, width).
        nms_threshold: Non-Maximum Suppression (NMS) threshold.
        class_agnostic: Whether NMS should be applied in a class-agnostic manner.

    Returns:
        Reformatted Detectron2 output as {"instances": Instances}.
    """
    # Apply NMS and convert to supervision Detections format
    detections = (
        sv.Detections.from_detectron2(outputs)
        .with_nms(threshold=nms_threshold, class_agnostic=class_agnostic)
    )

    # Convert extracted values to PyTorch tensors
    bboxes = torch.tensor(detections.xyxy, dtype=torch.float32)
    scores = torch.tensor(detections.confidence, dtype=torch.float32)
    classes = torch.tensor(detections.class_id, dtype=torch.int64)

    # Create an Instances object
    output_instances = Instances(image_size)
    output_instances.set("pred_boxes", Boxes(bboxes))
    output_instances.set("scores", scores)
    output_instances.set("pred_classes", classes)

    # Add masks if available
    if detections.mask is not None:
        masks = torch.tensor(detections.mask, dtype=torch.uint8)
        output_instances.set("pred_masks", masks)

    return {"instances": output_instances}


def read_csv(file_path: str) -> list[str]:
  """Reads a CSV file and returns its contents as a list.

  This function reads the given CSV file, skips the header, and assumes
  there is only one column in the CSV. It returns the contents as a list of
  strings.

  Args:
      file_path: The path to the CSV file.

  Returns:
      The contents of the CSV file as a list of strings.
  """
  data_list = []
  with open(file_path, 'r') as csvfile:
    reader = csv.reader(csvfile)
    for row in reader:
      data_list.append(row[0])
  return data_list


def adjust_image_size(height: int, width: int, min_size: int) -> tuple[int, int]:
  """Adjust the image size to ensure both dimensions are at least 1024.

  Args:
    height: The height of the image.
    width: The width of the image.
    min_size: Minimum size of the image dimension needed.

  Returns:
    The adjusted height and width of the image.
  """
  if height < min_size or width < min_size:
    return height, width

  # Calculate the scale factor to ensure both dimensions remain at least 1024
  scale_factor = min(height / min_size, width / min_size)

  new_height = int(height / scale_factor)
  new_width = int(width / scale_factor)

  return new_height, new_width


def dilated_largest_component(mask: np.ndarray) -> np.ndarray:
    """Extracts the largest connected component and fills holes.

    Args:
        mask: Input binary mask (2D numpy array).

    Returns:
        Binary mask of the largest connected component.
    """
    mask = mask.astype(np.uint8)*255
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
    largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
    largest_component_mask = np.zeros(mask.shape, dtype="uint8")
    largest_component_mask[labels == largest_label] = 1
    largest_component_mask = ndimage.binary_fill_holes(largest_component_mask).astype(int)
    return largest_component_mask


def extract_properties(image, masks):
  """Extract properties of the mask.

  Args:
    image: Corresponding image of the mask.
    masks: The masks to extract properties from.

  Returns:
    The extracted properties.
  """
  list_of_df = []
  for mask in masks:
    mask = np.where(mask, 1, 0)
    df = pd.DataFrame(
        skimage.measure.regionprops_table(mask, intensity_image=image, properties=_PROPERTIES)
    )
    list_of_df.append(df)
  features = pd.concat(list_of_df, ignore_index=True)
  features.rename(
      columns={
          'centroid-0': 'y',
          'centroid-1': 'x',
          'bbox-0': 'bbox_0',
          'bbox-1': 'bbox_1',
          'bbox-2': 'bbox_2',
          'bbox-3': 'bbox_3',
      },
      inplace=True,
  )
  return features


def get_image_creation_time(image_path):
  """
  Retrieves the creation time of an image, trying multiple methods.

  Args:
    image_path: The path to the image file.

  Returns:
    A string representing the creation time in the format "%Y-%m-%d %H:%M:%S" if
    found, otherwise returns "Creation time not found".
  """

  try:
    # 1. Try EXIF data (if available)
    image = Image.open(image_path)
    exif_data = image._getexif()
    if exif_data:
      datetime_tag_id = 36867  # Tag ID for "DateTimeOriginal"
      datetime_str = exif_data.get(datetime_tag_id)
      if datetime_str:
        datetime_obj = datetime.datetime.strptime(datetime_str, "%Y:%m:%d %H:%M:%S")
        return datetime_obj.strftime("%Y-%m-%d %H:%M:%S")

    # 2. Try file modification time (less accurate, but better than nothing)
    file_modified_time = os.path.getmtime(image_path)
    datetime_obj = datetime.datetime.fromtimestamp(file_modified_time)
    return datetime_obj.strftime("%Y-%m-%d %H:%M:%S")

  except FileNotFoundError:
    return "Image not found"
  except Exception as e:
    return f"Error: {e}"


def process_tracking_result(df):
    """Process the tracking result dataframe.

    Args:
      df: Dataframe to be aggregated.

    Returns:
      Processed dataframe.
    """
    # Get class information with the new include_groups parameter
    class_info = df.groupby('particle', as_index=False).apply(
        select_class_with_scores,
        include_groups=False
    )

    grouped = df.groupby('particle').agg({
        'source_name': 'first',
        'image_name': 'first',
        'detection_scores': 'max',
        'creation_time': 'first',
        'bbox_0': 'first',
        'bbox_1': 'first',
        'bbox_2': 'first',
        'bbox_3': 'first',
    }).reset_index()

    # Add class information
    grouped['detection_classes'] = class_info['class_id']
    grouped['detection_classes_names'] = class_info['class_name']

    return grouped

def select_class_with_scores(group):
    """
    Select class based on modal class, falling back to highest score for ties.
    Returns both class ID and class name.
    """
    # Get the value counts of classes
    class_counts = group['detection_classes'].value_counts()

    #print('class counts', class_counts)

    # If there's a clear winner (one mode), use it
    if len(class_counts) == 1 or class_counts.iloc[0] > class_counts.iloc[1]:
        class_id = group['detection_classes'].mode().iloc[0]
    else:
        # If there's a tie, look at highest score for each tied class
        tied_classes = class_counts[class_counts == class_counts.iloc[0]].index
        #print('tied classes', tied_classes)
        class_max_scores = {
            cls: group[group['detection_classes'] == cls]['detection_scores'].max()
            for cls in tied_classes
        }
        #print('class max scores', class_max_scores)
        class_id = max(class_max_scores.items(), key=lambda x: x[1])[0]

    # Get corresponding class name
    class_name = group[group['detection_classes'] == class_id]['detection_classes_names'].iloc[0]
    #print('winner', pd.Series({'class_id': class_id, 'class_name': class_name}))
    return pd.Series({'class_id': class_id, 'class_name': class_name})


def apply_tracking(df,
        search_range_x,
        search_range_y,
        memory):
  """Apply tracking to the dataframe.

  Args:
    df: The dataframe to apply tracking to.
    search_range_x: The search range of pixels for tracking along x axis.
    search_range_y: The search range of pixels for tracking along y axis.
    memory: The frames memory for tracking.

  Returns:
    The tracking result dataframe.
  """
  # Define the columns to link for tracking.
  # Additional features that can be used are 'area', 'label', 'color',
  # 'eccentricity', 'convex_area', 'mean_intensity-0', 'mean_intensity-1',
  # 'mean_intensity-2', 'max_intensity-0', 'max_intensity-1', 'max_intensity-2',
  # 'min_intensity-0',  'min_intensity-1', 'min_intensity-2'.
  tracking_columns = [
      'x',
      'y',
      'frame',
      'bbox_0',
      'bbox_1',
      'bbox_2',
      'bbox_3',
      'major_axis_length',
      'minor_axis_length',
      'perimeter',
  ]

  # Perform the tracking operation on the specified columns
  track_df = tp.link_df(df[tracking_columns], search_range=(search_range_y, search_range_x), memory=memory)

  # Copy the additional columns from the original dataframe
  additional_columns = [
      'source_name',
      'image_name',
      'detection_scores',
      'detection_classes_names',
      'detection_classes',
      'color',
      'creation_time'
  ]
  track_df[additional_columns] = df[additional_columns]

  track_df.drop(columns=['frame'], inplace=True)
  track_df.reset_index(drop=True, inplace=True)

  return track_df


def resize_bbox(y1, x1, y2, x2, old_height, old_width, new_height, new_width):
    """Resize bounding box coordinates based on new image size.

    Args:
        y1, x1, y2, x2 (int/float): Original bounding box coordinates.
        old_height, old_width (int): Original image dimensions.
        new_height, new_width (int): New image dimensions.

    Returns:
        (new_y1, new_x1, new_y2, new_x2): Rescaled bounding box coordinates.
    """
    # Compute scale factors
    scale_x = new_width / old_width
    scale_y = new_height / old_height

    # Scale bounding box coordinates
    new_y1 = int(y1 * scale_y)
    new_x1 = int(x1 * scale_x)
    new_y2 = int(y2 * scale_y)
    new_x2 = int(x2 * scale_x)

    return new_y1, new_x1, new_y2, new_x2


## Import and load pre-trained models.

In [None]:
%%bash
wget https://storage.googleapis.com/tf_model_garden/vision/\
waste_identification_ml/Detectron2_Jan2025_1024_1024.zip

unzip Detectron2_Jan2025_1024_1024.zip > /dev/null 2>&1

## Import and load the labels

In [None]:
LABELS_PATH = (
    'models/official/projects/waste_identification_ml/pre_processing/'
    'config/data/45_labels.csv'
)

labels = read_csv(LABELS_PATH)

my_metadata = Metadata()
my_metadata.set(thing_classes=labels)

## Load all images

In [None]:
images_dir = "/mydrive/circularnet/TestData/input-test-05022025"
images = glob.glob(os.path.join(images_dir, "*"))

# Make sure that the files are sorted.
images = natsort.natsorted(images)
len(images)

In [None]:
# Prediction confidence score.
PREDICTION_THRESHOLD = 0.70

# The model is trained on 1024 x 1024 image dimensions.
HEIGHT = 1024
WIDTH = 1024

# Object Tracking parameters.
SEARCH_RANGE_X=150
SEARCH_RANGE_Y=20
MEMORY=1

# Create a folder for saving prediction results.
os.makedirs('prediction_folder', exist_ok=True)
prediction_folder = os.path.join(os.getcwd(), 'prediction_folder')

# Dimensions for tracking images.
HEIGHT_TRACKING = 300
WIDTH_TRACKING = 300

# Create a folder to troubleshoot tracking results.
os.makedirs('tracking', exist_ok=True)

# Create a folder to save detected objects from Mask RCNN
# accpording to categories
output_dir = "cropped_objects"
os.makedirs(output_dir, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Initialize the Detectron2 configuration object
cfg = get_cfg()

# Load the model configuration from a YAML file.
cfg.merge_from_file("config.yaml")

# Set the confidence threshold.
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = PREDICTION_THRESHOLD

# Specify the path to the trained model weights.
cfg.MODEL.WEIGHTS = "model_final.pth"

cfg.MODEL.DEVICE = "cuda"

# Create a predictor object using the configured model.
predictor = DefaultPredictor(cfg)
predictor.model.to(device)  # Ensure the model is on GPU

## Perform inference

In [None]:
tracking_images = {}
features_set = []


for frame, image_path in tqdm.tqdm(enumerate(images, start=1)):
  original_image = cv2.imread(image_path)
  original_height, original_width = original_image.shape[:2]
  resized_image = cv2.resize(
    original_image,
    (WIDTH, HEIGHT),
    interpolation=cv2.INTER_AREA
  )

  # Perform inference.
  results = predictor(resized_image)

  # Implement class agnostic NMS.
  results = convert_detections_to_instances(results)


  # Extract the attributes from the prediction result.
  fields = results["instances"].to("cpu").get_fields()
  bboxes = fields["pred_boxes"].tensor.numpy().astype(int)
  if not len(bboxes):
      continue

  scores = fields["scores"].numpy()
  classes = fields["pred_classes"].numpy()
  masks = fields["pred_masks"].numpy()

  # Keep the predictions whose binary mask area > 4000.
  mask_areas = np.array([np.sum(i) for i in masks])
  valid_indices = mask_areas > 4000
  bboxes = bboxes[valid_indices]
  if not len(bboxes):
      continue

  scores = scores[valid_indices]
  classes = classes[valid_indices]
  masks = masks[valid_indices]

  # Adjust the image size to ensure both dimensions are at least 1024
  # for saving images with bbx and masks.
  height_plot, width_plot = adjust_image_size(
      original_image.shape[0], original_image.shape[1], 1024
  )
  image_plot = cv2.resize(
      original_image,
      (width_plot, height_plot),
      interpolation=cv2.INTER_AREA,
  )

  # Rescale bounding boxes
  scale_x = width_plot / WIDTH
  scale_y = height_plot / HEIGHT
  bboxes = (bboxes * [scale_x, scale_y, scale_x, scale_y]).astype(int)

  # Rescale masks
  if masks is not None:
    resized_masks = np.array([
        cv2.resize(mask.astype("uint8"), (width_plot, height_plot), interpolation=cv2.INTER_NEAREST)
        for mask in masks
    ])
  else:
      resized_masks = None

  # Convert predictions to Detectron2 visualization format
  pred_boxes = Boxes(torch.tensor(bboxes, dtype=torch.float32))
  pred_classes = torch.tensor(classes, dtype=torch.int64)
  pred_scores = torch.tensor(scores, dtype=torch.float32)

  predictions = {
      "pred_boxes": pred_boxes,
      "scores": pred_scores,
      "pred_classes": pred_classes,
  }

  if resized_masks is not None:
      predictions["pred_masks"] = torch.tensor(resized_masks, dtype=torch.uint8)

  # Save the prediction results as an image file with bbx and masks.
  visualizer = Visualizer(
    img_rgb=image_plot, metadata=my_metadata, scale=1
  )
  visualized_image = visualizer.draw_instance_predictions(
      Instances((height_plot, width_plot),
                **predictions
      )
  ).get_image()

  final_image = Image.fromarray(cv2.hconcat([image_plot[:,:,::-1], visualized_image[:, :, ::-1]]))
  final_image.save(f'prediction_folder/{os.path.basename(image_path)}')

  # Create object tracking data.
  tracking_image = cv2.resize(
      original_image,
      (WIDTH_TRACKING, HEIGHT_TRACKING),
      interpolation=cv2.INTER_AREA,
  )
  tracking_images[os.path.basename(image_path)] = tracking_image

  tracking_masks = np.array([
        cv2.resize(
            mask.astype("uint8"),
            (WIDTH_TRACKING, HEIGHT_TRACKING),
            interpolation=cv2.INTER_NEAREST
        ) for mask in masks
  ])
  # In case of connected masks, keep the biggest mask and fill the holes
  # in case of incomplete detections by Mask RCNN.
  tracking_masks = np.array([
      dilated_largest_component(i) for i in tracking_masks]
  )

  # Crop objects from an image using masks for color detection.
  cropped_objects = [
      np.where(np.expand_dims(i, -1), image_plot[:,:,::-1], 0)
      for i in resized_masks
  ]

  # Perform color detection using clustering approach.
  dominant_colors = [
        *map(
            color_and_property_extractor.find_dominant_color, cropped_objects
        )
  ]
  generic_color_names = color_and_property_extractor.get_generic_color_name(dominant_colors)

  # Extract features.
  features = extract_properties(
        tracking_image, tracking_masks
    )
  features["source_name"] = os.path.basename(os.path.dirname(image_path))
  features["image_name"] = os.path.basename(image_path)
  features["creation_time"] = get_image_creation_time(image_path)
  features["frame"] = frame
  features["detection_scores"] = list(scores)
  features["detection_classes"] = list(classes)
  features["detection_classes_names"] = [labels[i] for i in list(classes)]
  features["color"] = generic_color_names
  features_set.append(features)


if features_set:
  features_df = pd.concat(features_set, ignore_index=True)
  tracking_features = apply_tracking(
      features_df,
      search_range_x=SEARCH_RANGE_X,
      search_range_y=SEARCH_RANGE_Y,
      memory=MEMORY
  )
  agg_features = process_tracking_result(tracking_features)
  counts = agg_features.groupby("detection_classes_names").size()
  counts.to_frame().to_csv(os.path.join(os.getcwd(), "count.csv"))
  print(counts)

## Visualize Object Tracking  

In [None]:
CIRCLE_RADIUS =7
CIRCLE_THICKNESS = 3
font = cv2.FONT_HERSHEY_SIMPLEX
fontScale = 1
color = (255, 0, 0)
groups = tracking_features.groupby('image_name')

for name, group in groups:
  img = tracking_images[name].copy()
  for k in range(len(group)):
    cv2.circle(img,
               (int(group.iloc[k]['x']),int(group.iloc[k]['y'])),
               CIRCLE_RADIUS,
               (255,133,233),
               -1
    )
    cv2.putText(img,
                str(int(group.iloc[k]['particle'])),
                 (int(group.iloc[k]['x']), int(group.iloc[k]['y'])),
                font,
                fontScale,
                color,
                2,
                cv2.LINE_AA
    )

  cv2.imwrite(os.path.join('tracking',name), img)

## Visualize Predictions by Categories

In [None]:
if not agg_features.empty:
  for group_name, df in tqdm.tqdm(agg_features.groupby("detection_classes_names")):
    os.makedirs(f'{output_dir}/{group_name}', exist_ok=True)

    for row in df.itertuples(index=False):
      # Get the image
      image = cv2.imread(os.path.join(images_dir, row.image_name))
      image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
      new_h, new_w = image.shape[0], image.shape[1]

      # Get the bounding box and resize it
      y1, x1, y2, x2 = row.bbox_0, row.bbox_1, row.bbox_2, row.bbox_3
      new_bbox = resize_bbox(y1, x1, y2, x2, HEIGHT_TRACKING, WIDTH_TRACKING, new_h, new_w)

      # Include the score in the filename
      score = row.detection_scores if hasattr(row, 'detection_scores') else 0.0
      name = f'{os.path.splitext(row.image_name)[0]}_{row.particle}_{score:.2f}.png'

      # Save the cropped image
      cv2.imwrite(f'{output_dir}/{row.detection_classes_names}/{name}',
                 image[new_bbox[0]:new_bbox[2], new_bbox[1]:new_bbox[3]])


## Copying folders to my Google drive

In [None]:
destination_folder = '/mydrive/circularnet/TestModel'
os.makedirs(destination_folder, exist_ok=True)

# Function to safely copy directory, removing destination first if it exists
def copytree_replace(src, dst):
  if os.path.exists(dst):
    shutil.rmtree(dst)
  shutil.copytree(src, dst)

# Function to safely copy file, overwriting if it exists
def copy_replace(src, dst):
  if os.path.exists(dst):
    os.remove(dst)
  shutil.copy(src, dst)

copytree_replace(os.path.join(os.getcwd(), "prediction_folder"), os.path.join(destination_folder, "prediction_folder"))
copytree_replace(os.path.join(os.getcwd(), "cropped_objects"),os.path.join(destination_folder, "cropped_objects"))
copytree_replace(os.path.join(os.getcwd(), "tracking"),os.path.join(destination_folder, "tracking"))
copy_replace(os.path.join(os.getcwd(), "count.csv"),os.path.join(destination_folder, "count.csv"))