diff --git a/examples/python/arkitscenes/.gitignore b/examples/python/arkitscenes/.gitignore new file mode 100644 index 000000000000..d1ac3a94954e --- /dev/null +++ b/examples/python/arkitscenes/.gitignore @@ -0,0 +1 @@ +dataset/** diff --git a/examples/python/arkitscenes/download_dataset.py b/examples/python/arkitscenes/download_dataset.py new file mode 100644 index 000000000000..d2f5e2cbe62a --- /dev/null +++ b/examples/python/arkitscenes/download_dataset.py @@ -0,0 +1,321 @@ +# Copied from https://github.com/apple/ARKitScenes/blob/main/download_data.py +# Licensing information: https://github.com/apple/ARKitScenes/blob/main/LICENSE +import math +import os +import subprocess +from pathlib import Path +from typing import Final, List, Optional + +import pandas as pd + +ARkitscense_url = "https://docs-assets.developer.apple.com/ml-research/datasets/arkitscenes/v1" +TRAINING: Final = "Training" +VALIDATION: Final = "Validation" +HIGRES_DEPTH_ASSET_NAME: Final = "highres_depth" +POINT_CLOUDS_FOLDER: Final = "laser_scanner_point_clouds" + +AVAILABLE_RECORDINGS: Final = ["48458663", "42444949", "41069046", "41125722", "41125763", "42446167"] +DATASET_DIR: Final = Path(os.path.dirname(__file__)) / "dataset" + +default_raw_dataset_assets = [ + "mov", + "annotation", + "mesh", + "confidence", + "highres_depth", + "lowres_depth", + "lowres_wide.traj", + "lowres_wide", + "lowres_wide_intrinsics", + "ultrawide", + "ultrawide_intrinsics", + "vga_wide", + "vga_wide_intrinsics", +] + +missing_3dod_assets_video_ids = [ + "47334522", + "47334523", + "42897421", + "45261582", + "47333152", + "47333155", + "48458535", + "48018733", + "47429677", + "48458541", + "42897848", + "47895482", + "47333960", + "47430089", + "42899148", + "42897612", + "42899153", + "42446164", + "48018149", + "47332198", + "47334515", + "45663223", + "45663226", + "45663227", +] + + +def raw_files(video_id: str, assets: List[str], metadata: pd.DataFrame) -> List[str]: + file_names = [] + for asset in assets: + if HIGRES_DEPTH_ASSET_NAME == asset: + in_upsampling = metadata.loc[metadata["video_id"] == float(video_id), ["is_in_upsampling"]].iat[0, 0] + if not in_upsampling: + print(f"Skipping asset {asset} for video_id {video_id} - Video not in upsampling dataset") + continue # highres_depth asset only available for video ids from upsampling dataset + + if asset in [ + "confidence", + "highres_depth", + "lowres_depth", + "lowres_wide", + "lowres_wide_intrinsics", + "ultrawide", + "ultrawide_intrinsics", + "wide", + "wide_intrinsics", + "vga_wide", + "vga_wide_intrinsics", + ]: + file_names.append(asset + ".zip") + elif asset == "mov": + file_names.append(f"{video_id}.mov") + elif asset == "mesh": + if video_id not in missing_3dod_assets_video_ids: + file_names.append(f"{video_id}_3dod_mesh.ply") + elif asset == "annotation": + if video_id not in missing_3dod_assets_video_ids: + file_names.append(f"{video_id}_3dod_annotation.json") + elif asset == "lowres_wide.traj": + if video_id not in missing_3dod_assets_video_ids: + file_names.append("lowres_wide.traj") + else: + raise Exception(f"No asset = {asset} in raw dataset") + return file_names + + +def download_file(url: str, file_name: str, dst: Path) -> bool: + os.makedirs(dst, exist_ok=True) + filepath = os.path.join(dst, file_name) + + if not os.path.isfile(filepath): + command = f"curl {url} -o {file_name}.tmp --fail" + print(f"Downloading file {filepath}") + try: + subprocess.check_call(command, shell=True, cwd=dst) + except Exception as error: + print(f"Error downloading {url}, error: {error}") + return False + os.rename(filepath + ".tmp", filepath) + else: + print(f"WARNING: skipping download of existing file: {filepath}") + return True + + +def unzip_file(file_name: str, dst: Path, keep_zip: bool = True) -> bool: + filepath = os.path.join(dst, file_name) + print(f"Unzipping zip file {filepath}") + command = f"unzip -oq {filepath} -d {dst}" + try: + subprocess.check_call(command, shell=True) + except Exception as error: + print(f"Error unzipping {filepath}, error: {error}") + return False + if not keep_zip: + os.remove(filepath) + return True + + +def download_laser_scanner_point_clouds_for_video(video_id: str, metadata: pd.DataFrame, download_dir: Path) -> None: + video_metadata = metadata.loc[metadata["video_id"] == float(video_id)] + visit_id = video_metadata["visit_id"].iat[0] + has_laser_scanner_point_clouds = video_metadata["has_laser_scanner_point_clouds"].iat[0] + + if not has_laser_scanner_point_clouds: + print(f"Warning: Laser scanner point clouds for video {video_id} are not available") + return + + if math.isnan(visit_id) or not visit_id.is_integer(): + print(f"Warning: Downloading laser scanner point clouds for video {video_id} failed - Bad visit id {visit_id}") + return + + visit_id = int(visit_id) # Expecting an 8 digit integer + laser_scanner_point_clouds_ids = laser_scanner_point_clouds_for_visit_id(visit_id, download_dir) + + for point_cloud_id in laser_scanner_point_clouds_ids: + download_laser_scanner_point_clouds(point_cloud_id, visit_id, download_dir) + + +def laser_scanner_point_clouds_for_visit_id(visit_id: int, download_dir: Path) -> List[str]: + point_cloud_to_visit_id_mapping_filename = "laser_scanner_point_clouds_mapping.csv" + if not os.path.exists(point_cloud_to_visit_id_mapping_filename): + point_cloud_to_visit_id_mapping_url = ( + f"{ARkitscense_url}/raw/laser_scanner_point_clouds/{point_cloud_to_visit_id_mapping_filename}" + ) + if not download_file( + point_cloud_to_visit_id_mapping_url, + point_cloud_to_visit_id_mapping_filename, + download_dir, + ): + print( + f"Error downloading point cloud for visit_id {visit_id} at location " + f"{point_cloud_to_visit_id_mapping_url}" + ) + return [] + + point_cloud_to_visit_id_mapping_filepath = os.path.join(download_dir, point_cloud_to_visit_id_mapping_filename) + point_cloud_to_visit_id_mapping = pd.read_csv(point_cloud_to_visit_id_mapping_filepath) + point_cloud_ids = point_cloud_to_visit_id_mapping.loc[ + point_cloud_to_visit_id_mapping["visit_id"] == visit_id, + ["laser_scanner_point_clouds_id"], + ] + point_cloud_ids_list = [scan_id[0] for scan_id in point_cloud_ids.values] + + return point_cloud_ids_list + + +def download_laser_scanner_point_clouds(laser_scanner_point_cloud_id: str, visit_id: int, download_dir: Path) -> None: + laser_scanner_point_clouds_folder_path = download_dir / POINT_CLOUDS_FOLDER / str(visit_id) + os.makedirs(laser_scanner_point_clouds_folder_path, exist_ok=True) + + for extension in [".ply", "_pose.txt"]: + filename = f"{laser_scanner_point_cloud_id}{extension}" + filepath = os.path.join(laser_scanner_point_clouds_folder_path, filename) + if os.path.exists(filepath): + return + file_url = f"{ARkitscense_url}/raw/laser_scanner_point_clouds/{visit_id}/{filename}" + download_file(file_url, filename, laser_scanner_point_clouds_folder_path) + + +def get_metadata(dataset: str, download_dir: Path) -> pd.DataFrame: + filename = "metadata.csv" + url = f"{ARkitscense_url}/threedod/{filename}" if "3dod" == dataset else f"{ARkitscense_url}/{dataset}/{filename}" + dst_folder = download_dir / dataset + dst_file = dst_folder / filename + + if not download_file(url, filename, dst_folder): + return + + metadata = pd.read_csv(dst_file) + return metadata + + +def download_data( + dataset: str, + video_ids: List[str], + dataset_splits: List[str], + download_dir: Path, + keep_zip: bool, + raw_dataset_assets: Optional[List[str]] = None, + should_download_laser_scanner_point_cloud: bool = False, +) -> None: + """ + Downloads data from the specified dataset and video IDs to the given download directory. + + Args: + ---- + dataset: the name of the dataset to download from (raw, 3dod, or upsampling) + video_ids: the list of video IDs to download data for + dataset_splits: the list of splits for each video ID (train, validation, or test) + download_dir: the directory to download data to + keep_zip: whether to keep the downloaded zip files after extracting them + raw_dataset_assets: a list of asset types to download from the raw dataset, if dataset is "raw" + should_download_laser_scanner_point_cloud: whether to download the laser scanner point cloud data, if available + + Returns: None + """ + metadata = get_metadata(dataset, download_dir) + if None is metadata: + print(f"Error retrieving metadata for dataset {dataset}") + return + + for video_id in sorted(set(video_ids)): + split = dataset_splits[video_ids.index(video_id)] + dst_dir = download_dir / dataset / split + if dataset == "raw": + url_prefix = "" + file_names = [] + if not raw_dataset_assets: + print(f"Warning: No raw assets given for video id {video_id}") + else: + dst_dir = dst_dir / str(video_id) + url_prefix = f"{ARkitscense_url}/raw/{split}/{video_id}" + "/{}" + file_names = raw_files(video_id, raw_dataset_assets, metadata) + elif dataset == "3dod": + url_prefix = f"{ARkitscense_url}/threedod/{split}" + "/{}" + file_names = [ + f"{video_id}.zip", + ] + elif dataset == "upsampling": + url_prefix = f"{ARkitscense_url}/upsampling/{split}" + "/{}" + file_names = [ + f"{video_id}.zip", + ] + else: + raise Exception(f"No such dataset = {dataset}") + + if should_download_laser_scanner_point_cloud and dataset == "raw": + # Point clouds only available for the raw dataset + download_laser_scanner_point_clouds_for_video(video_id, metadata, download_dir) + + for file_name in file_names: + dst_path = os.path.join(dst_dir, file_name) + url = url_prefix.format(file_name) + + if not file_name.endswith(".zip") or not os.path.isdir(dst_path[: -len(".zip")]): + download_file(url, dst_path, dst_dir) + else: + print(f"WARNING: skipping download of existing zip file: {dst_path}") + if file_name.endswith(".zip") and os.path.isfile(dst_path): + unzip_file(file_name, dst_dir, keep_zip) + + +def ensure_recording_downloaded(video_id: str, include_highres: bool) -> Path: + """Only downloads from validation set.""" + data_path = DATASET_DIR / "raw" / "Validation" / video_id + assets_to_download = [ + "lowres_wide", + "lowres_depth", + "lowres_wide_intrinsics", + "lowres_wide.traj", + "annotation", + "mesh", + ] + if include_highres: + assets_to_download.extend(["highres_depth", "wide", "wide_intrinsics"]) + download_data( + dataset="raw", + video_ids=[video_id], + dataset_splits=[VALIDATION], + download_dir=DATASET_DIR, + keep_zip=False, + raw_dataset_assets=assets_to_download, + should_download_laser_scanner_point_cloud=False, + ) + return data_path + + +def ensure_recording_available(video_id: str, include_highres: bool) -> Path: + """ + Returns the path to the recording for a given video_id. + + Args: + video_id (str): Identifier for the recording. + + Returns + ------- + Path: Path object representing the path to the recording. + + Raises + ------ + AssertionError: If the recording path does not exist. + """ + recording_path = ensure_recording_downloaded(video_id, include_highres) + assert recording_path.exists(), f"Recording path {recording_path} does not exist." + return recording_path # Return the path to the recording diff --git a/examples/python/arkitscenes/main.py b/examples/python/arkitscenes/main.py new file mode 100755 index 000000000000..5b2c62027022 --- /dev/null +++ b/examples/python/arkitscenes/main.py @@ -0,0 +1,450 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import numpy.typing as npt +import rerun as rr +import trimesh +from download_dataset import AVAILABLE_RECORDINGS, ensure_recording_available +from scipy.spatial.transform import Rotation as R +from tqdm import tqdm + +# hack for now since dataset does not provide orientation information, only known after initial visual inspection +ORIENTATION = { + "48458663": "landscape", + "42444949": "portrait", + "41069046": "portrait", + "41125722": "portrait", + "41125763": "portrait", + "42446167": "portrait", +} +assert len(ORIENTATION) == len(AVAILABLE_RECORDINGS) +assert set(ORIENTATION.keys()) == set(AVAILABLE_RECORDINGS) + + +def load_json(js_path: Path) -> Dict[str, Any]: + with open(js_path, "r") as f: + json_data = json.load(f) # type: Dict[str, Any] + return json_data + + +def log_annotated_bboxes( + annotation: Dict[str, Any] +) -> Tuple[npt.NDArray[np.float64], List[str], List[Tuple[int, int, int, int]]]: + """ + Logs annotated oriented bounding boxes to Rerun. + + We currently calculate and return the 3D bounding boxes keypoints, labels, and colors for each object to log them in + each camera frame TODO(pablovela5620): Once #1581 is resolved this can be removed. + + annotation json file + | |-- label: object name of bounding box + | |-- axesLengths[x, y, z]: size of the origin bounding-box before transforming + | |-- centroid[]: the translation matrix (1,3) of bounding-box + | |-- normalizedAxes[]: the rotation matrix (3,3) of bounding-box + """ + bbox_list = [] + bbox_labels = [] + num_objects = len(annotation["data"]) + # Generate a color per object that can be reused across both 3D obb and their 2D projections + # TODO(pablovela5620): Once #1581 or #1728 is resolved this can be removed + color_positions = np.linspace(0, 1, num_objects) + colormap = plt.cm.get_cmap("viridis") + color_array_float = [colormap(pos) for pos in color_positions] + color_list = [(int(r * 255), int(g * 255), int(b * 255), int(a * 255)) for r, g, b, a in color_array_float] + + for i, label_info in enumerate(annotation["data"]): + uid = label_info["uid"] + label = label_info["label"] + + # TODO(pablovela5620): half this value once #1701 is resolved + scale = np.array(label_info["segments"]["obbAligned"]["axesLengths"]).reshape(-1, 3)[0] + transform = np.array(label_info["segments"]["obbAligned"]["centroid"]).reshape(-1, 3)[0] + rotation = np.array(label_info["segments"]["obbAligned"]["normalizedAxes"]).reshape(3, 3) + + rot = R.from_matrix(rotation).inv() + + rr.log_obb( + f"world/annotations/box-{uid}-{label}", + half_size=scale, + position=transform, + rotation_q=rot.as_quat(), + label=label, + color=color_list[i], + timeless=True, + ) + + box3d = compute_box_3d(scale, transform, rotation) + bbox_list.append(box3d) + bbox_labels.append(label) + bboxes_3d = np.array(bbox_list) + return bboxes_3d, bbox_labels, color_list + + +def compute_box_3d( + scale: npt.NDArray[np.float64], transform: npt.NDArray[np.float64], rotation: npt.NDArray[np.float64] +) -> npt.NDArray[np.float64]: + """ + Given obb compute 3d keypoints of the box. + + TODO(pablovela5620): Once #1581 is resolved this can be removed + """ + scale = scale.tolist() + scales = [i / 2 for i in scale] + length, height, width = scales + center = np.reshape(transform, (-1, 3)) + center = center.reshape(3) + x_corners = [length, length, -length, -length, length, length, -length, -length] + y_corners = [height, -height, -height, height, height, -height, -height, height] + z_corners = [width, width, width, width, -width, -width, -width, -width] + corners_3d = np.dot(np.transpose(rotation), np.vstack([x_corners, y_corners, z_corners])) + + corners_3d[0, :] += center[0] + corners_3d[1, :] += center[1] + corners_3d[2, :] += center[2] + bbox3d_raw = np.transpose(corners_3d) + return bbox3d_raw + + +def log_line_segments( + entity_path: str, bboxes_2d_filtered: npt.NDArray[np.float64], color: Tuple[int, int, int, int], label: str +) -> None: + """ + Generates line segments for each object's bounding box in 2d. + + Box corner order that we return is of the format below: + 6 -------- 7 + /| /| + 5 -------- 4 . + | | | | + . 2 -------- 3 + |/ |/ + 1 -------- 0 + + TODO(pablovela5620): Once #1581 is resolved this can be removed + + :param bboxes_2d_filtered: + A numpy array of shape (8, 2), representing the filtered 2D keypoints of the 3D bounding boxes. + :return: A numpy array of shape (24, 2), representing the line segments for each object's bounding boxes. + Even and odd indices represent the start and end points of each line segment respectively. + """ + + # Calculate the centroid of the 2D keypoints + valid_points = bboxes_2d_filtered[~np.isnan(bboxes_2d_filtered).any(axis=1)] + + # log centroid and add label so that object label is visible in the 2d view + if valid_points.size > 0: + centroid = valid_points.mean(axis=0) + rr.log_point(f"{entity_path}/centroid", centroid, color=color, label=label) + else: + pass + + # fmt: off + segments = np.array([ + # bottom of bbox + bboxes_2d_filtered[0], bboxes_2d_filtered[1], + bboxes_2d_filtered[1], bboxes_2d_filtered[2], + bboxes_2d_filtered[2], bboxes_2d_filtered[3], + bboxes_2d_filtered[3], bboxes_2d_filtered[0], + + # top of bbox + bboxes_2d_filtered[4], bboxes_2d_filtered[5], + bboxes_2d_filtered[5], bboxes_2d_filtered[6], + bboxes_2d_filtered[6], bboxes_2d_filtered[7], + bboxes_2d_filtered[7], bboxes_2d_filtered[4], + + # sides of bbox + bboxes_2d_filtered[0], bboxes_2d_filtered[4], + bboxes_2d_filtered[1], bboxes_2d_filtered[5], + bboxes_2d_filtered[2], bboxes_2d_filtered[6], + bboxes_2d_filtered[3], bboxes_2d_filtered[7] + ], dtype=np.float32) + + rr.log_line_segments(entity_path, segments, color=color) + + +def project_3d_bboxes_to_2d_keypoints( + bboxes_3d: npt.NDArray[np.float64], + camera_from_world: Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]], + intrinsic: npt.NDArray[np.float64], + img_width: int, + img_height: int, +) -> npt.NDArray[np.float64]: + """ + Returns 2D keypoints of the 3D bounding box in the camera view. + + TODO(pablovela5620): Once #1581 is resolved this can be removed + Args: + bboxes_3d: (nObjects, 8, 3) containing the 3D bounding box keypoints in world frame. + camera_from_world: Tuple containing the camera translation and rotation_quaternion in world frame. + intrinsic: (3,3) containing the camera intrinsic matrix. + img_width: Width of the image. + img_height: Height of the image. + + Returns + ------- + bboxes_2d_filtered: + A numpy array of shape (nObjects, 8, 2), representing the 2D keypoints of the 3D bounding boxes. That + are within the image frame. + """ + + translation, rotation_q = camera_from_world + rotation = R.from_quat(rotation_q) + + # Transform 3D keypoints from world to camera frame + world_to_camera_rotation = rotation.as_matrix() + world_to_camera_translation = translation.reshape(3, 1) + # Tile translation to match bounding box shape, (nObjects, 1, 3) + world_to_camera_translation_tiled = np.tile(world_to_camera_translation.T, (bboxes_3d.shape[0], 1, 1)) + # Transform 3D bounding box keypoints from world to camera frame to filter out points behind the camera + camera_points = ( + np.einsum("ij,afj->afi", world_to_camera_rotation, bboxes_3d[..., :3]) + world_to_camera_translation_tiled + ) + # Check if the points are in front of the camera + depth_mask = camera_points[..., 2] > 0 + # convert to transformation matrix shape of (3, 4) + world_to_camera = np.hstack([world_to_camera_rotation, world_to_camera_translation]) + transformation_matrix = intrinsic @ world_to_camera + # add batch dimension to match bounding box shape, (nObjects, 3, 4) + transformation_matrix = np.tile(transformation_matrix, (bboxes_3d.shape[0], 1, 1)) + # bboxes_3d: [nObjects, 8, 3] -> [nObjects, 8, 4] to allow for batch projection + bboxes_3d = np.concatenate([bboxes_3d, np.ones((bboxes_3d.shape[0], bboxes_3d.shape[1], 1))], axis=-1) + # Apply depth mask to filter out points behind the camera + bboxes_3d[~depth_mask] = np.nan + # batch projection of points using einsum + bboxes_2d = np.einsum("vab,fnb->vfna", transformation_matrix, bboxes_3d) + bboxes_2d = bboxes_2d[..., :2] / bboxes_2d[..., 2:] + # nViews irrelevant, squeeze out + bboxes_2d = bboxes_2d[0] + + # Filter out keypoints that are not within the frame + mask_x = (bboxes_2d[:, :, 0] >= 0) & (bboxes_2d[:, :, 0] < img_width) + mask_y = (bboxes_2d[:, :, 1] >= 0) & (bboxes_2d[:, :, 1] < img_height) + mask = mask_x & mask_y + bboxes_2d_filtered = np.where(mask[..., np.newaxis], bboxes_2d, np.nan) + + return bboxes_2d_filtered + + +def log_camera( + intri_path: Path, + frame_id: str, + poses_from_traj: Dict[str, Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]], + entity_id: str, + bboxes: npt.NDArray[np.float64], + bbox_labels: List[str], + color_list: List[Tuple[int, int, int, int]], +) -> None: + """Logs camera transform and 3D bounding boxes in the image frame.""" + w, h, fx, fy, cx, cy = np.loadtxt(intri_path) + intrinsic = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) + camera_from_world = poses_from_traj[frame_id] + + # TODO(pablovela5620): Once #1581 is resolved this can be removed + # Project 3D bounding boxes into 2D image + bboxes_2d = project_3d_bboxes_to_2d_keypoints(bboxes, camera_from_world, intrinsic, img_width=w, img_height=h) + # clear previous centroid labels + rr.log_cleared(f"{entity_id}/bbox-2d-segments", recursive=True) + # Log line segments for each bounding box in the image + for i, (label, bbox_2d) in enumerate(zip(bbox_labels, bboxes_2d)): + log_line_segments(f"{entity_id}/bbox-2d-segments/{label}", bbox_2d.reshape(-1, 2), color_list[i], label) + + rr.log_rigid3( + # pathlib makes it easy to get the parent, but log_rigid requires a string + str(Path(entity_id).parent), + child_from_parent=camera_from_world, + xyz="RDF", # X=Right, Y=Down, Z=Forward + ) + rr.log_pinhole(f"{entity_id}", child_from_parent=intrinsic, width=w, height=h) + + +def read_camera_from_world(traj_string: str) -> Tuple[str, Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]]: + """ + Reads out camera_from_world transform from trajectory string. + + Args: + traj_string: A space-delimited file where each line represents a camera position at a particular timestamp. + The file has seven columns: + * Column 1: timestamp + * Columns 2-4: rotation (axis-angle representation in radians) + * Columns 5-7: translation (usually in meters) + + Returns + ------- + timestamp: float + timestamp in seconds + camera_from_world: tuple of two numpy arrays + A tuple containing a translation vector and a quaternion that represent the camera_from_world transform + + Raises + ------ + AssertionError: If the input string does not contain 7 tokens. + """ + tokens = traj_string.split() # Split the input string into tokens + assert len(tokens) == 7, f"Input string must have 7 tokens, but found {len(tokens)}." + ts: str = tokens[0] # Extract timestamp from the first token + + # Extract rotation from the second to fourth tokens + angle_axis = [float(tokens[1]), float(tokens[2]), float(tokens[3])] + rotation = R.from_rotvec(np.asarray(angle_axis)) + + # Extract translation from the fifth to seventh tokens + translation = np.asarray([float(tokens[4]), float(tokens[5]), float(tokens[6])]) + + # Create tuple in format log_rigid3 expects + camera_from_world = (translation, rotation.as_quat()) + + return (ts, camera_from_world) + + +def find_closest_frame_id(target_id: str, frame_ids: Dict[str, Any]) -> str: + """Finds the closest frame id to the target id.""" + target_value = float(target_id) + closest_id = min(frame_ids.keys(), key=lambda x: abs(float(x) - target_value)) + return closest_id + + +def log_arkit(recording_path: Path, include_highres: bool) -> None: + """ + Logs ARKit recording data using Rerun. + + Args: + recording_path (Path): The path to the ARKit recording. + + Returns + ------- + None + """ + video_id = recording_path.stem + lowres_image_dir = recording_path / "lowres_wide" + image_dir = recording_path / "wide" + lowres_depth_dir = recording_path / "lowres_depth" + depth_dir = recording_path / "highres_depth" + lowres_intrinsics_dir = recording_path / "lowres_wide_intrinsics" + intrinsics_dir = recording_path / "wide_intrinsics" + traj_path = recording_path / "lowres_wide.traj" + + # frame_ids are indexed by timestamps, you can see more info here + # https://github.com/apple/ARKitScenes/blob/main/threedod/README.md#data-organization-and-format-of-input-data + depth_filenames = [x.name for x in sorted(lowres_depth_dir.iterdir())] + lowres_frame_ids = [x.split(".png")[0].split("_")[1] for x in depth_filenames] + lowres_frame_ids.sort() + + # dict of timestamp to pose which is a tuple of translation and quaternion + camera_from_world_dict = {} + with open(traj_path, "r", encoding="utf-8") as f: + trajectory = f.readlines() + + for line in trajectory: + timestamp, camera_from_world = read_camera_from_world(line) + # round timestamp to 3 decimal places as seen in the original repo here + # https://github.com/apple/ARKitScenes/blob/e2e975128a0a9695ea56fa215fe76b4295241538/threedod/benchmark_scripts/utils/tenFpsDataLoader.py#L247 + timestamp = f"{round(float(timestamp), 3):.3f}" + camera_from_world_dict[timestamp] = camera_from_world + + rr.log_view_coordinates("world", up="+Z", right_handed=True, timeless=True) + ply_path = recording_path / f"{recording_path.stem}_3dod_mesh.ply" + print(f"Loading {ply_path}…") + assert os.path.isfile(ply_path), f"Failed to find {ply_path}" + + mesh_ply = trimesh.load(str(ply_path)) + rr.log_mesh( + "world/mesh", + positions=mesh_ply.vertices, + indices=mesh_ply.faces, + vertex_colors=mesh_ply.visual.vertex_colors, + timeless=True, + ) + + # load the obb annotations and log them in the world frame + bbox_annotations_path = recording_path / f"{recording_path.stem}_3dod_annotation.json" + annotation = load_json(bbox_annotations_path) + bboxes_3d, bbox_labels, colors_list = log_annotated_bboxes(annotation) + + lowres_posed_entity_id = "world/camera_posed_lowres/image_posed_lowres" + highres_entity_id = "world/camera_highres/image_highres" + + print("Processing frames…") + for frame_timestamp in tqdm(lowres_frame_ids): + # frame_id is equivalent to timestamp + rr.set_time_seconds("time", float(frame_timestamp)) + # load the lowres image and depth + bgr = cv2.imread(f"{lowres_image_dir}/{video_id}_{frame_timestamp}.png") + rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + depth = cv2.imread(f"{lowres_depth_dir}/{video_id}_{frame_timestamp}.png", cv2.IMREAD_ANYDEPTH) + + high_res_exists: bool = (image_dir / f"{video_id}_{frame_timestamp}.png").exists() and include_highres + + # Log the camera transforms: + if frame_timestamp in camera_from_world_dict: + lowres_intri_path = lowres_intrinsics_dir / f"{video_id}_{frame_timestamp}.pincam" + log_camera( + lowres_intri_path, + frame_timestamp, + camera_from_world_dict, + lowres_posed_entity_id, + bboxes_3d, + bbox_labels, + colors_list, + ) + + rr.log_image(f"{lowres_posed_entity_id}/rgb", rgb) + rr.log_depth_image(f"{lowres_posed_entity_id}/depth", depth, meter=1000) + + # log the high res camera + if high_res_exists: + rr.set_time_seconds("time high resolution", float(frame_timestamp)) + # only low res camera has a trajectory, high res does not so need to find the closest low res frame id + closest_lowres_frame_id = find_closest_frame_id(frame_timestamp, camera_from_world_dict) + highres_intri_path = intrinsics_dir / f"{video_id}_{frame_timestamp}.pincam" + log_camera( + highres_intri_path, + closest_lowres_frame_id, + camera_from_world_dict, + highres_entity_id, + bboxes_3d, + bbox_labels, + colors_list, + ) + + # load the highres image and depth if they exist + highres_bgr = cv2.imread(f"{image_dir}/{video_id}_{frame_timestamp}.png") + highres_depth = cv2.imread(f"{depth_dir}/{video_id}_{frame_timestamp}.png", cv2.IMREAD_ANYDEPTH) + + highres_rgb = cv2.cvtColor(highres_bgr, cv2.COLOR_BGR2RGB) + rr.log_image(f"{highres_entity_id}/rgb", highres_rgb) + rr.log_depth_image(f"{highres_entity_id}/depth", highres_depth, meter=1000) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Visualizes the ARKitScenes dataset using the Rerun SDK.") + parser.add_argument( + "--video-id", + type=str, + choices=AVAILABLE_RECORDINGS, + default=AVAILABLE_RECORDINGS[0], + help="Video ID of the ARKitScenes Dataset", + ) + parser.add_argument( + "--include-highres", + action="store_true", + help="Include the high resolution camera and depth images", + ) + rr.script_add_args(parser) + args = parser.parse_args() + + rr.script_setup(args, "arkitscenes") + recording_path = ensure_recording_available(args.video_id, args.include_highres) + log_arkit(recording_path, args.include_highres) + + rr.script_teardown(args) + + +if __name__ == "__main__": + main() diff --git a/examples/python/arkitscenes/requirements.txt b/examples/python/arkitscenes/requirements.txt new file mode 100644 index 000000000000..132f56a29b4d --- /dev/null +++ b/examples/python/arkitscenes/requirements.txt @@ -0,0 +1,7 @@ +rerun-sdk +numpy +pandas +opencv-python +tqdm +scipy +trimesh