In [1]:
import argparse
import os
import sys
from pathlib import Path
import json
import matplotlib.pyplot as plt
%matplotlib inline

import imageio
import numpy as np
from tqdm import tqdm
try:
    import renderpy
except ImportError:
    print("renderpy not installed. Please install renderpy from https://github.com/liu115/renderpy")
    sys.exit(1)

In [2]:
from utils.colmap import read_model, write_model, Image
from scene_release import ScannetppScene_Release
from utils.utils import run_command, load_yaml_munch, load_json, read_txt_list

In [3]:
p = argparse.ArgumentParser()
p.add_argument("config_file", help="Path to config file", default="/home/kumaraditya/scannetpp/common/configs/render.yml", nargs="?")
args = p.parse_args([])

print(f"Config file: {args.config_file}")


Config file: /home/kumaraditya/scannetpp/common/configs/render.yml


In [4]:
cfg = load_yaml_munch(args.config_file)

# get the scenes to process
if cfg.get("scene_ids"):
    scene_ids = cfg.scene_ids
elif cfg.get("splits"):
    scene_ids = []
    for split in cfg.splits:
        split_path = Path(cfg.data_root) / "splits" / f"{split}.txt"
        scene_ids += read_txt_list(split_path)

output_dir = cfg.get("output_dir")
if output_dir is None:
    # default to data folder in data_root
    output_dir = Path(cfg.data_root) / "data"
output_dir = Path(output_dir)

render_devices = []
if cfg.get("render_dslr", False):
    render_devices.append("dslr")
    raise Exception("This code is has not been tested with the DSLR data.")
if cfg.get("render_iphone", False):
    render_devices.append("iphone")

In [5]:
scene_id = scene_ids[0]
print(scene_id)
print(render_devices)

c0f5742640
['iphone']


In [6]:
from render_crops_utils import vert_to_obj_lookup, CropHeap, crop_rgb_mask, plot_grid_images

scene = ScannetppScene_Release(scene_id, data_root=Path(cfg.data_root) / "data")
render_engine = renderpy.Render()
render_engine.setupMesh(str(scene.scan_mesh_path))

# Load annotations
segments_anno = json.load(open(scene.scan_anno_json_path, "r"))
n_objects = len(segments_anno["segGroups"])
instance_colors = np.random.randint(low=0, high=256, size=(n_objects + 1, 3), dtype=np.uint8)
instance_colors[0] = 255 # White bg
vert_to_obj = vert_to_obj_lookup(segments_anno)

# Crop heaps
crop_heaps = dict()
for obj in segments_anno["segGroups"]:
    crop_heaps[obj["id"]] = dict()
    crop_heaps[obj["id"]]["label"] = obj["label"]
    crop_heaps[obj["id"]]["heap"] = CropHeap(max_size=4)

# Background class is 0
assert 0 not in crop_heaps
crop_heaps[0] = dict()
crop_heaps[0]["label"] = "BACKGROUND"
crop_heaps[0]["heap"] = CropHeap(max_size=4)


for device in render_devices:
    if device == "dslr":
        cameras, images, points3D = read_model(scene.dslr_colmap_dir, ".txt")
    else:
        cameras, images, points3D = read_model(scene.iphone_colmap_dir, ".txt")
    assert len(cameras) == 1, "Multiple cameras not supported"
    camera = next(iter(cameras.values()))

    fx, fy, cx, cy = camera.params[:4]
    params = camera.params[4:]
    camera_model = camera.model
    render_engine.setupCamera(
        camera.height, camera.width,
        fx, fy, cx, cy,
        camera_model,
        params,      # Distortion parameters np.array([k1, k2, k3, k4]) or np.array([k1, k2, p1, p2])
    )

    near = cfg.get("near", 0.05)
    far = cfg.get("far", 20.0)
    rgb_dir = Path(cfg.output_dir) / scene_id / device / "render_rgb"
    depth_dir = Path(cfg.output_dir) / scene_id / device / "render_depth"
    # crop_dir = Path(cfg.output_dir) / scene_id / device / "render_crops_kumar_w_sam"
    rgb_dir.mkdir(parents=True, exist_ok=True)
    depth_dir.mkdir(parents=True, exist_ok=True)
    # crop_dir.mkdir(parents=True, exist_ok=True)

    for _, image in tqdm(images.items(), f"Rendering object crops using {device} images"):
        world_to_camera = image.world_to_camera

        rgb_rendered, _, vert_indices = render_engine.renderAll(world_to_camera, near, far)

        iphone_rgb_path = Path(scene.iphone_rgb_dir) / image.name
        rgb = np.asarray(imageio.imread(iphone_rgb_path))

        vert_instance = vert_to_obj[vert_indices]
        pix_instance = vert_instance[:, :, 0] # Some triangles actually belong to different objects. I don't think it will matter for crops.

        # Visualize instances
        # instance_rgb = instance_colors[pix_instance]
        # imageio.imwrite(rgb_dir / image.name, instance_rgb)

        objs = np.unique(pix_instance)

        for obj in objs:
            mask = pix_instance == obj
            crop = crop_rgb_mask(rgb, rgb_rendered, mask, inflate_px=100)
            crop_heaps[obj]["heap"].push(crop)


        # instance_rgb = instance_rgb.astype(np.uint8)
        # # Make depth in mm and clip to fit 16-bit image
        # depth = (depth.astype(np.float32) * 1000).clip(0, 65535).astype(np.uint16)
        # depth_name = image.name.split(".")[0] + ".png"
        # imageio.imwrite(depth_dir / depth_name, depth)

Init EGL
Detected 5 devices
Using device 0
Using EGL version 1.5
OpenGL version: 4.6.0 NVIDIA 550.107.02
EGL version: 1.5
Loaded mesh:MeshData:
	Vertices:  917079
	Colors:    917079
	Normals:   0
	TexCoords: 0


Copy mesh to GPU: 917079 vertices, 1832073 faces
Setup the frame and render buffer


  rgb = np.asarray(imageio.imread(iphone_rgb_path))
Rendering object crops using iphone images: 100%|██████████| 586/586 [03:58<00:00,  2.46it/s]


In [7]:
# for id, entry in tqdm(crop_heaps.items(), f"Rendering image grids"):
#     heap = entry["heap"]
#     label = entry["label"]
#     if len(heap) and label.lower() not in [
#         "background",
#         "wall",
#         "floor",
#         "ceiling",
#         "split",
#         "remove",
#     ]:
#         crops = heap.get_sorted()
#         rgbs = [c.rgb for c in crops]
#         masks = [c.mask for c in crops]
#         scores = [c.score for c in crops]
#         plot_grid_images(
#             rgbs + masks, grid_width=len(rgbs), title=entry["label"]
#         )
#         plt.savefig(crop_dir / f"{str(id).zfill(5)}.jpg")
#         plt.close()

In [7]:
import torch
import os
import tempfile
import shutil
import cv2
import numpy as np
import logging
import random

from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor

# Set up logging
logging.basicConfig(
    level=logging.INFO,  # Logging level
    format="%(asctime)s - %(levelname)s - %(message)s",  # Log format
    filename="sam2_model.log",  # Log file path
    filemode="w"  # 'w' to overwrite the log file each time, 'a' to append
)

class SAM2VideoMaskModel:
    def __init__(self, sam2_checkpoint, model_cfg, num_points=5, device="cuda"):
        """
        Initialize SAM2 model and set device.
        """
        self.device = torch.device(device)
        if self.device.type == "cuda":
            torch.autocast("cuda", dtype=torch.bfloat16).__enter__()

        self.predictor = self._build_predictor(sam2_checkpoint, model_cfg)
        self._initialize_storage()
        self.num_points = num_points

    def _build_predictor(self, sam2_checkpoint, model_cfg):
        """
        Helper function to build the SAM2 video predictor.
        """
        return build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=self.device)

    def _initialize_storage(self):
        """
        Initializes temporary directory and storage for RGB images, masks, and related data.
        """
        self.temp_dir = tempfile.mkdtemp()
        logging.info(f"Temporary directory created at: {self.temp_dir}")

        # Placeholder for resized images, masks, padded_masks, rgb_padded, padding information, and scores
        self.rgb = []
        self.rgb_padded = []
        self.mask = []
        self.masks_padded = []
        self.masks_refined = []
        self.padding_info = []
        self.scores = []

    def pad_and_store(self, rgbs, masks, scores):
        """
        Pads the RGB images and masks to the size of the largest image and stores them.
        """
        if len(rgbs) != len(masks) or len(rgbs) != len(scores):
            raise ValueError("The number of RGB images, masks, and scores must match.")

        max_h, max_w = self._get_max_dimensions(rgbs)

        for idx, (rgb, mask, score) in enumerate(zip(rgbs, masks, scores)):
            padded_rgb, padded_mask, padding_info = self._pad_image_and_mask(rgb, mask, max_h, max_w)
            self._store_padded_data(idx, rgb, padded_rgb, mask, padded_mask, score, padding_info)

    def _get_max_dimensions(self, rgbs):
        """
        Get the maximum height and width among the provided RGB images.
        """
        max_h = max([rgb.shape[0] for rgb in rgbs])
        max_w = max([rgb.shape[1] for rgb in rgbs])
        return max_h, max_w

    def _pad_image_and_mask(self, rgb, mask, max_h, max_w):
        """
        Pads the RGB and mask to the provided max dimensions.
        """
        h, w, _ = rgb.shape
        pad_h, pad_w = max_h - h, max_w - w
        padding_info = ((0, pad_h), (0, pad_w))

        padded_rgb = np.pad(rgb, ((0, pad_h), (0, pad_w), (0, 0)), mode="constant", constant_values=0)
        padded_mask = np.pad(mask, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=0)

        return padded_rgb, padded_mask, padding_info

    def _store_padded_data(self, idx, rgb, padded_rgb, mask, padded_mask, score, padding_info):
        """
        Store padded data, and save the padded RGB image to the temp directory.
        """
        rgb_filename = os.path.join(self.temp_dir, f"{idx}.jpg")
        cv2.imwrite(rgb_filename, padded_rgb)

        self.rgb.append(rgb)
        self.rgb_padded.append(padded_rgb)
        self.mask.append(mask)
        self.masks_padded.append(padded_mask)
        self.scores.append(score)
        self.padding_info.append(padding_info)

    def set_state_and_refine_masks(self):
        """
        Set the state for the SAM2 predictor and refine masks.
        """
        inference_state = self._initialize_inference_state()
        points, labels, highest_score_idx = self._get_initial_prompts(points=True)

        self._refine_masks(inference_state, points, labels, highest_score_idx)
        self.predictor.reset_state(inference_state)

    def set_state_and_refine_masks_w_mask_prompt(self):
        """
        Set the state for the SAM2 predictor and refine masks using the mask with the highest score.
        """
        inference_state = self._initialize_inference_state()
        highest_score_mask, highest_score_idx = self._get_initial_prompts(mask=True)

        self._refine_masks_w_mask_prompt(inference_state, highest_score_mask, highest_score_idx)
        self.predictor.reset_state(inference_state)

    def set_state_and_refine_masks_w_manual_prompt(self, points, labels, frame_idx):
        """
        Set the state for the SAM2 predictor and refine masks using the provided points and labels.
        """
        inference_state = self._initialize_inference_state()

        self._refine_masks(inference_state, points, labels, frame_idx)
        self.predictor.reset_state(inference_state)

    def _initialize_inference_state(self):
        """
        Initialize the predictor's inference state.
        """
        return self.predictor.init_state(video_path=self.temp_dir)

    def _get_initial_prompts(self, points=False, mask=False):
        """
        Determine initial points based on the mask with the highest score.
        """
        highest_score_idx = np.argmax(self.scores)
        highest_score_mask = self.masks_padded[highest_score_idx]

        if highest_score_idx != 0:
            logging.info(f"Using mask with the highest score from frame {highest_score_idx} as the initial prompt.")

        if points:
            mask_indices = np.argwhere(highest_score_mask > 0)
            points = np.array([mask_indices[np.random.choice(len(mask_indices))] for _ in range(self.num_points)], dtype=np.float32)
            labels = np.ones(self.num_points, dtype=np.int32)

            return points, labels, highest_score_idx
        elif mask:
            return highest_score_mask, highest_score_idx
        else:
            raise ValueError("Either points or mask must be True.")

    def _refine_masks(self, inference_state, points, labels, highest_score_idx):
        """
        Refine the masks and propagate through frames.
        """
        _, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
            inference_state=inference_state,
            frame_idx=highest_score_idx,
            obj_id=1,
            points=points,
            labels=labels,
            box=None,
        )

        for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
            refined_mask = (out_mask_logits[0] > 0.0).cpu().numpy().squeeze()

            if np.sum(refined_mask) == 0:
                logging.warning(f"Refined mask is empty for frame {out_frame_idx}. Using the old mask.")
                refined_mask = self.masks_padded[out_frame_idx]

            self.masks_refined.append(refined_mask)

    def _refine_masks_w_mask_prompt(self, inference_state, highest_score_mask, highest_score_idx):
        """
        Refine the masks using the mask with the highest score and propagate through frames.
        """
        _, out_obj_ids, out_mask_logits = self.predictor.add_new_mask(
            inference_state=inference_state,
            frame_idx=highest_score_idx,
            obj_id=1,
            mask=highest_score_mask,
        )

        for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
            refined_mask = (out_mask_logits[0] > 0.0).cpu().numpy().squeeze()

            if np.sum(refined_mask) == 0:
                logging.warning(f"Refined mask is empty for frame {out_frame_idx}. Using the old mask.")
                refined_mask = self.masks_padded[out_frame_idx]

            self.masks_refined.append(refined_mask)

    def unpad_masks_to_original_size(self):
        """
        Remove padding from refined masks to restore them to their original size.
        """
        self.masks_refined = [self._unpad_mask(mask, frame_idx) for frame_idx, mask in enumerate(self.masks_refined)]

    def _unpad_mask(self, mask, frame_idx):
        """
        Unpad a single mask based on its padding information.
        """
        pad_h, pad_w = self.padding_info[frame_idx]

        # Unpad the mask based on padding info
        unpadded_mask = mask
        if pad_h[1] > 0:  # Check if there was padding at the bottom
            unpadded_mask = unpadded_mask[:-pad_h[1], :]
        if pad_w[1] > 0:  # Check if there was padding at the right
            unpadded_mask = unpadded_mask[:, :-pad_w[1]]

        return unpadded_mask

    def cleanup(self):
        """
        Clean up the temporary directory and clear stored data.
        """
        self._clear_temp_directory()
        self._clear_storage()

    def _clear_temp_directory(self):
        """
        Clears the contents of the temporary directory but leaves the directory itself.
        """
        if os.path.exists(self.temp_dir):
            for filename in os.listdir(self.temp_dir):
                file_path = os.path.join(self.temp_dir, filename)
                try:
                    if os.path.isfile(file_path) or os.path.islink(file_path):
                        os.unlink(file_path)  # Remove file or symbolic link
                    elif os.path.isdir(file_path):
                        shutil.rmtree(file_path)  # Remove subdirectory and its contents
                except Exception as e:
                    logging.error(f"Failed to delete {file_path}. Reason: {e}")
            
            logging.info(f"Contents of temporary directory {self.temp_dir} removed.")

    def _clear_storage(self):
        """
        Clears all stored data except for the SAM2 model.
        """
        self.rgb = []
        self.rgb_padded = []
        self.mask = []
        self.masks_padded = []
        self.scores = []
        self.padding_info = []
        self.masks_refined = []
        logging.info("Cleared all stored images, masks, scores, and padding information.")


  OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()


In [8]:
class SAM2ImageMaskModel:
    def __init__(self, sam2_checkpoint, model_cfg, device="cuda", num_points=5, ransac_iterations=10):
        """
        Initialize SAM2 model, set device, and configure number of points and RANSAC iterations.
        """
        self.device = torch.device(device)
        self.num_points = num_points  # Number of points to sample from the mask
        self.ransac_iterations = ransac_iterations  # Number of RANSAC iterations

        if self.device.type == "cuda":
            torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
            if torch.cuda.get_device_properties(0).major >= 8:
                torch.backends.cuda.matmul.allow_tf32 = True
                torch.backends.cudnn.allow_tf32 = True

        self.predictor = self._build_predictor(sam2_checkpoint, model_cfg)
        self._initialize_storage()

    def _build_predictor(self, sam2_checkpoint, model_cfg):
        """
        Helper function to build the SAM2 image predictor.
        """
        sam2 = build_sam2(model_cfg, sam2_checkpoint, device=self.device)
        return SAM2ImagePredictor(sam2)

    def _initialize_storage(self):
        """
        Initializes storage for RGB images, masks, and related data.
        """
        self.rgb = []
        self.mask = []
        self.masks_refined = []
        self.crop_scores = []
        self.sam_scores = []

    def store_data(self, rgbs, masks, scores):
        """
        Stores the RGB images, masks, and scores.
        """
        if len(rgbs) != len(masks) or len(rgbs) != len(scores):
            raise ValueError("The number of RGB images, masks, and scores must match.")

        for idx, (rgb, mask, score) in enumerate(zip(rgbs, masks, scores)):
            self._store_data(idx, rgb, mask, score)

    def _store_data(self, idx, rgb, mask, score):
        """
        Store RGB, mask, and score data.
        """
        self.rgb.append(rgb)
        self.mask.append(mask)
        self.crop_scores.append(score)

    def _sample_points_from_mask(self, mask):
        """
        Sample n points from the provided mask where the mask is non-zero.
        """
        mask_indices = np.argwhere(mask > 0)  # Get non-zero mask points
        if len(mask_indices) == 0:
            raise ValueError("No valid mask points to sample from.")

        # Randomly sample 'n' points
        sampled_points = np.array(random.choices(mask_indices, k=self.num_points), dtype=np.float32)
        return sampled_points

    def _set_image_for_predictor(self, rgb):
        """
        Preprocesses the RGB image and sets it for the SAM predictor.
        """
        # # Convert to RGB format (OpenCV loads images in BGR by default)
        # rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)

        # Set the image for the SAM predictor
        self.predictor.set_image(rgb)

    def _predict_mask_with_points(self, rgb, points):
        """
        Use sampled points to prompt SAM for mask prediction and return mask and score.
        """
        labels = np.ones(len(points), dtype=np.int32)  # All positive labels (foreground)
        
        # Set the current RGB image for the predictor
        self._set_image_for_predictor(rgb)

        # Perform mask prediction using the points and labels
        masks, scores, _ = self.predictor.predict(
            point_coords=points,
            point_labels=labels,
            multimask_output=False
        )

        # Convert mask to bool type if it's not already
        if masks.dtype != bool:
            masks = masks.astype(bool)

        return masks, scores

    def _predict_mask_with_points_and_bbox(self, rgb, points, bbox):
        """
        Use sampled points to prompt SAM for mask prediction and return mask and score.
        """
        labels = np.ones(len(points), dtype=np.int32)  # All positive labels (foreground)
        
        # Set the current RGB image for the predictor
        self._set_image_for_predictor(rgb)

        # Perform mask prediction using the points and labels
        masks, scores, _ = self.predictor.predict(
            point_coords=points,
            point_labels=labels,
            box=bbox[None, :],
            multimask_output=False
        )

        # Convert mask to bool type if it's not already
        if masks.dtype != bool:
            masks = masks.astype(bool)

        return masks, scores

    def ransac_mask_selection(self):
        """
        Perform RANSAC-like sampling of points and select the best mask based on SAM score.
        """
        for idx, (rgb, mask) in enumerate(zip(self.rgb, self.mask)):
            best_score = -float('inf')
            best_mask = None

            for _ in range(self.ransac_iterations):
                try:
                    # Sample points from the mask
                    sampled_points = self._sample_points_from_mask(mask)
                    bbox = self._get_bounding_box_from_mask(mask)

                    # Get mask prediction and score from SAM
                    # predicted_mask, predicted_scores = self._predict_mask_with_points(rgb, sampled_points)
                    predicted_mask, predicted_scores = self._predict_mask_with_points_and_bbox(rgb, sampled_points, bbox)

                    predicted_mask = predicted_mask[0]  # Assuming single mask is returned
                        
                    # Choose the mask with the highest score
                    score = predicted_scores[0]  # Assuming single mask is returned
                    # score = self._mask_score_calculation(idx, predicted_mask, predicted_scores[0])
                    if score > best_score:
                        best_score = score
                        best_mask = predicted_mask
                except ValueError as e:
                    logging.warning(f"Skipping frame {idx} due to error: {e}")
                    continue
            
            # Store the best mask and score for the current frame
            self.masks_refined.append(best_mask)
            self.sam_scores.append(best_score)

    def _get_bounding_box_from_mask(self, mask):
        """
        Given a binary mask, return the bounding box in xyxy format.
        """

        # Find the indices where the mask is True (non-zero)
        rows, cols = np.where(mask)

        # If the mask is empty (no True values), return an empty bounding box
        if len(rows) == 0 or len(cols) == 0:
            return [0, 0, 0, 0]
        
        # Get the bounding box coordinates
        x_min = np.min(cols)
        y_min = np.min(rows)
        x_max = np.max(cols)
        y_max = np.max(rows)

        return np.array([x_min, y_min, x_max, y_max])

    def _mask_score_calculation(self, idx, refined_mask, sam_score):
        current_mask = self.mask[idx]

        current_area = np.sum(current_mask)
        refined_area = np.sum(refined_mask)
        areas_diff = abs(current_area - refined_area)

        areas_score = 1 / (1 + areas_diff)

        final_score = areas_score * sam_score
        return final_score

    def refine_masks(self):
        """
        Public method to trigger RANSAC-based mask refinement.
        """
        self.ransac_mask_selection()

    def cleanup(self):
        """
        Cleans up the stored data, including RGBs, masks, and scores, 
        while keeping the SAM model ready for further use.
        """
        # Clear all stored data (RGBs, masks, scores)
        self.rgb = []
        self.mask = []
        self.masks_refined = []
        self.crop_scores = []
        self.sam_scores = []

        logging.info("Cleared all stored images, masks, scores, and refined masks.")

In [9]:
sam2_checkpoint = "/home/kumaraditya/checkpoints/sam2_hiera_large.pt"
sam2_model_cfg = "sam2_hiera_l.yaml"
sam2_video_model = SAM2VideoMaskModel(sam2_checkpoint, sam2_model_cfg)
sam2_img_model = SAM2ImageMaskModel(sam2_checkpoint, sam2_model_cfg, num_points=3, ransac_iterations=5)

In [10]:
refined_crops = dict()

crop_dir = Path(cfg.output_dir) / scene_id / device / "render_crops_sam2_img_video_2"
crop_dir.mkdir(parents=True, exist_ok=True)

for id, entry in tqdm(crop_heaps.items(), f"Rendering image grids"):
    heap = entry["heap"]
    label = entry["label"]
    if len(heap) and label.lower() not in [
        "background",
        "wall",
        "floor",
        "ceiling",
        "split",
        "remove",
    ]:
        
        logging.info(f"Processing object id: {id} with label: {label}")
        
        crops = heap.get_sorted()
        rgbs = [c.rgb for c in crops]
        masks = [c.mask for c in crops]
        scores = [c.score for c in crops]

        sam2_img_model.store_data(rgbs, masks, scores)
        sam2_img_model.refine_masks()
        img_refined_masks = sam2_img_model.masks_refined
        img_refined_scores = sam2_img_model.sam_scores
        sam2_img_model.cleanup()

        sorted_data = sorted(zip(img_refined_scores, rgbs, img_refined_masks), key=lambda x: x[0], reverse=True)
        img_refined_scores, rgbs, img_refined_masks = zip(*sorted_data)

        # Convert them back to lists
        img_refined_scores = list(img_refined_scores)
        rgbs = list(rgbs)
        img_refined_masks = list(img_refined_masks)

        sam2_video_model.pad_and_store(rgbs, img_refined_masks, img_refined_scores)
        sam2_video_model.set_state_and_refine_masks_w_mask_prompt()
        sam2_video_model.unpad_masks_to_original_size()
        video_refined_masks = sam2_video_model.masks_refined
        sam2_video_model.cleanup()

        refined_crops[id] = dict()
        refined_crops[id]["rgbs"] = rgbs
        refined_crops[id]["refined_masks"] = video_refined_masks
        refined_crops[id]["scores"] = img_refined_scores
        refined_crops[id]["label"] = label
        
        plot_grid_images(
            rgbs, video_refined_masks, grid_width=len(rgbs), title=entry["label"]
        )
        plt.savefig(crop_dir / f"{str(id).zfill(5)}.jpg")
        plt.close()

frame loading (JPEG): 100%|██████████| 4/4 [00:00<00:00, 37.53it/s]

Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).
  pred_masks_gpu = fill_holes_in_mask_scores(
propagate in video: 100%|██████████| 4/4 [00:01<00:00,  3.87it/s]
frame loading (JPEG): 100%|██████████| 4/4 [00:00<00:00, 44.43it/s]it]

Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).
  pred_masks_gpu = fill_holes_in_mask_scores(
propagate in video: 100%|██████████| 4/4 [00:01<00:00,  3.93it/s]
frame loading

In [17]:
def save_refined_crops(refined_crops, output_dir, pad_length=3):
    # Create separate directories for RGBs, masks, and metadata
    rgb_dir = output_dir / "rgbs"
    mask_dir = output_dir / "masks"
    metadata_dir = output_dir / "metadata"
    
    os.makedirs(rgb_dir, exist_ok=True)
    os.makedirs(mask_dir, exist_ok=True)
    os.makedirs(metadata_dir, exist_ok=True)
    
    # Iterate through each object in the refined crops dictionary
    for obj_id, data in refined_crops.items():
        # Pad the obj_id with zeros up to the specified length
        padded_obj_id = str(obj_id).zfill(pad_length)

        # Save each RGB image and mask with a naming pattern padded_objid_cropid.ext
        for crop_id, (rgb, mask) in enumerate(zip(data["rgbs"], data["refined_masks"])):
            # Save the RGB image in the RGB directory
            rgb_filename = os.path.join(rgb_dir, f"{padded_obj_id}_{crop_id}_rgb.png")
            cv2.imwrite(rgb_filename, cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))  # Convert to BGR before saving

            # Save the Mask as a .npy file in the Masks directory
            mask_filename = os.path.join(mask_dir, f"{padded_obj_id}_{crop_id}_mask.npy")
            np.save(mask_filename, mask)

        # Convert scores to native Python types and round to 3 decimal places
        scores = [round(float(score), 4) for score in data["scores"]]

        # Save scores and label in the Metadata directory
        metadata = {
            "scores": scores,
            "label": data["label"]
        }
        metadata_filename = os.path.join(metadata_dir, f"{padded_obj_id}_metadata.json")
        with open(metadata_filename, 'w') as metadata_file:
            json.dump(metadata, metadata_file, indent=4)


In [18]:
refined_crops_data_dir = Path(cfg.output_dir) / scene_id / device / "refined_crops_data"
save_refined_crops(refined_crops, refined_crops_data_dir)