# Monocular 3D object tracking with SAM-3D Objects

This notebook requires a self-hosted inference server with a 32GB+ VRAM GPU. See the README for the recommended setup.

In [None]:
%load_ext autoreload
%autoreload 2
%pip install -r requirements.txt

In [None]:
API_URL = "http://localhost:9001"
API_KEY = "YOUR_API_KEY"

SEGMENTATION_MODEL_ID = "rfdetr-seg-preview"
SAM3_3D_MODEL_ID = "sam3-3d-objects"

In [None]:
from supervision.assets import download_assets, VideoAssets

INPUT_VIDEO_PATH = download_assets(VideoAssets.MILK_BOTTLING_PLANT)
# INPUT_VIDEO_PATH = download_assets(VideoAssets.VEHICLES)

# FPS to sample the input video
SAMPLE_FPS = 5
# Limit the number of frames to process
MAX_FRAMES = None

OUTPUT_DIR = "sam-3d-track"

In [None]:
import os
import shutil

if os.path.exists(OUTPUT_DIR):
    shutil.rmtree(OUTPUT_DIR)
os.makedirs(OUTPUT_DIR)

In [None]:
from inference_sdk import InferenceHTTPClient

client = InferenceHTTPClient(api_url=API_URL, api_key=API_KEY)

In [None]:
import numpy as np

import supervision as sv

def detect_and_track(image: np.ndarray, tracker: sv.ByteTrack) -> sv.Detections:
    result = client.infer(image, model_id=SEGMENTATION_MODEL_ID)
    detections = sv.Detections.from_inference(result)

    # remove low-confidence detections
    detections = detections[detections.confidence > 0.5]

    # update tracker and add tracking labels to detections
    tracker.update_with_detections(detections)
    # occasionally a -1 ID sneaks through
    detections = detections[detections.tracker_id != -1]

    return detections

In [None]:
def annotate(image: np.ndarray, detections: sv.Detections) -> np.ndarray:
    mask_annotator = sv.MaskAnnotator()
    label_annotator = sv.LabelAnnotator()
    trace_annotator = sv.TraceAnnotator()

    labels = [
        f"#{tracker_id} ({class_name})"
        for tracker_id, class_name in zip(detections.tracker_id, detections.data["class_name"])
    ]
    annotated = mask_annotator.annotate(scene=image.copy(), detections=detections)
    annotated = label_annotator.annotate(scene=annotated, detections=detections, labels=labels)
    annotated = trace_annotator.annotate(scene=annotated, detections=detections)

    return annotated

In [None]:
def get_3d_objects(image: np.ndarray, detections: sv.Detections) -> sv.Detections:
    # flatten polygons to the expected [x1 y1 x2 y2 ... xN yN] format
    mask_input = [
        np.array(sv.mask_to_polygons(mask)[0]).flatten().tolist()
        for mask in detections.mask
    ]

    sam3_3d_result = client.sam3_3d_infer(
        inference_input=image,
        mask_input=mask_input,
        model_id=SAM3_3D_MODEL_ID,
        # 'Fast' SAM-3D config
        output_meshes=False,
        output_scene=False,
        with_mesh_postprocess=False,
        with_texture_baking=False,
        use_distillations=True,
    )

    detections.data["sam3_3d"] = sam3_3d_result["objects"]

    return detections

In [None]:
from base64 import b64decode
from io import BytesIO

import torch
from pytorch3d.io import IO
from pytorch3d.transforms.rotation_conversions import quaternion_to_matrix

import rerun as rr

def log_to_rerun(annotated: np.ndarray, detections: sv.Detections, tracker: sv.ByteTrack | None, index: int):
    rr.set_time("tick", sequence=index)
    rr.log("/camera/image", rr.Image(annotated, color_model="bgr"))

    # Coordinate transforms used in make_scene_glb
    z_to_y_up = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=torch.float)
    y_to_z_up = torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float)
    R_view = torch.tensor([[-1, 0, 0], [0, 0, -1], [0, -1, 0]], dtype=torch.float)

    # Clear removed tracks
    removed_track_ids = [track.external_track_id for track in tracker.removed_tracks]
    for tracker_id in removed_track_ids:
        obj_id = f"#{tracker_id}"
        print(f"Clear objects/{obj_id}")
        rr.log(f"objects/{obj_id}", rr.Clear(recursive=True))

    # Add or update active tracks
    for i in range(len(detections)):
        det = detections[i]
        obj_id = f"#{det.tracker_id}"
        if "sam3_3d" not in det.data:
            print(f"No 3D data available for {obj_id}")
            continue
        obj_sam3_3d = det.data["sam3_3d"][0]

        obj_ply = IO().load_pointcloud(BytesIO(b64decode(obj_sam3_3d["gaussian_ply"])))
        obj_pts = obj_ply.points_list()[0]
        obj_pts = obj_pts[::100, :]  # Keep 1% of points to speed up rendering
        obj_box_size = (obj_pts.amax(dim=0) - obj_pts.amin(dim=0))
        obj_rgb = sv.annotators.utils.resolve_color(sv.ColorPalette.DEFAULT, detections, i).as_rgb()

        metadata = obj_sam3_3d["metadata"]
        t = torch.tensor(metadata["translation"], dtype=torch.float)
        R = quaternion_to_matrix(torch.tensor(metadata["rotation"], dtype=torch.float))
        s = torch.tensor(metadata["scale"], dtype=torch.float)
        # 1. Z-up â†’ Y-up coordinate conversion (row-vector convention throughout SAM3D)
        # 2. PyTorch3D quaternion_to_matrix is column-vector (R @ v), but SAM3D uses it
        #    row-vector (v @ R), so pass R.T to Rerun's column-vector mat3x3
        # 3. R_view: global scene correction from make_scene_glb, applied in world space
        t = t @ z_to_y_up @ R_view
        R = R_view @ y_to_z_up @ R.T @ z_to_y_up

        rr.log(
            f"objects/{obj_id}",
            rr.Boxes3D(sizes=obj_box_size, colors=obj_rgb, labels=obj_id),
            rr.Transform3D(translation=t, mat3x3=R, scale=s),
        )
        rr.log(
            f"objects/{obj_id}/pts",
            rr.Points3D(positions=obj_pts, colors=obj_rgb),
        )

In [None]:
import os
import time

# Initialize 3D viewer
rr.init("sam-3d-track")
rr.save(os.path.join(OUTPUT_DIR, "rerun_log.rrd"))
rr.log("/", rr.ViewCoordinates.RIGHT_HAND_Y_UP, rr.TransformAxes3D(0.5), static=True)

# Initialize tracker
tracker = sv.ByteTrack(frame_rate=SAMPLE_FPS, lost_track_buffer=SAMPLE_FPS)

# Read and process the video
video_info = sv.VideoInfo.from_video_path(INPUT_VIDEO_PATH)
stride = int(video_info.fps / SAMPLE_FPS)
frames = sv.get_video_frames_generator(INPUT_VIDEO_PATH, stride=stride)
video_info.fps = SAMPLE_FPS

with sv.VideoSink(os.path.join(OUTPUT_DIR, "annotated.mp4"), video_info) as sink:
    start = time.perf_counter()

    for index, frame in enumerate(frames):
        if MAX_FRAMES and index > (MAX_FRAMES - 1):
            break

        frame_start = time.perf_counter()

        detections = detect_and_track(frame, tracker)
        annotated = annotate(frame, detections)
        sink.write_frame(annotated)

        detections = get_3d_objects(frame, detections)
        log_to_rerun(annotated, detections, tracker, index)

        elapsed = time.perf_counter() - frame_start
        print(f"Finished processing frame #{index} in {elapsed:.2f} sec")

    elapsed = time.perf_counter() - start
    print(f"Finished processing video in {elapsed:.2f} sec")

In [None]:
# You can also use the standalone viewer app
# rerun [OUTPUT_DIR]/rerun_log.rrd
rr.notebook_show()
rr.log_file_from_path(os.path.join(OUTPUT_DIR, "rerun_log.rrd"))