From cbedc341d1fe61faf463efe3e0678a1cc44cd958 Mon Sep 17 00:00:00 2001 From: Jeremy Leibs Date: Wed, 5 Apr 2023 22:14:40 +0200 Subject: [PATCH 1/8] New example code for facebook research segment anything --- examples/python/segment_anything/.gitignore | 1 + examples/python/segment_anything/main.py | 189 ++++++++++++++++++ .../python/segment_anything/requirements.txt | 7 + 3 files changed, 197 insertions(+) create mode 100644 examples/python/segment_anything/.gitignore create mode 100755 examples/python/segment_anything/main.py create mode 100644 examples/python/segment_anything/requirements.txt diff --git a/examples/python/segment_anything/.gitignore b/examples/python/segment_anything/.gitignore new file mode 100644 index 000000000000..0447b0d4ac3a --- /dev/null +++ b/examples/python/segment_anything/.gitignore @@ -0,0 +1 @@ +model/ diff --git a/examples/python/segment_anything/main.py b/examples/python/segment_anything/main.py new file mode 100755 index 000000000000..5a66e514d0c1 --- /dev/null +++ b/examples/python/segment_anything/main.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +""" +Example of using Rerun to log and visualize the output of segment-anything. + +See: [segment_anything](https://segment-anything.com/). + +Can be used to test mask-generation on one or more images. Images can be local file-paths +or remote urls. + +Exa: +``` +python main.py --device cuda https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/dog.jpg +``` + +""" + + +import argparse +import logging +import os +from pathlib import Path +from typing import Final +from urllib.parse import urlparse + +import cv2 +import numpy as np +import requests +import rerun as rr +import torch +import torchvision +from cv2 import Mat +from segment_anything import SamAutomaticMaskGenerator, sam_model_registry +from segment_anything.modeling import Sam +from tqdm import tqdm + +MODEL_DIR: Final = Path(os.path.dirname(__file__)) / "model" +MODEL_URLS: Final = { + "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", + "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", +} + + +def download_with_progress(url: str, dest: Path) -> None: + """Download file with tqdm progress bar.""" + chunk_size = 1024 * 1024 + resp = requests.get(url, stream=True) + total_size = int(resp.headers.get("content-length", 0)) + with open(dest, "wb") as dest_file: + with tqdm( + desc="Downloading model", total=total_size, unit="iB", unit_scale=True, unit_divisor=1024 + ) as progress: + for data in resp.iter_content(chunk_size): + dest_file.write(data) + progress.update(len(data)) + + +def get_downloaded_model_path(model_name: str) -> Path: + """Fetch the segment-anything model to a local cache directory.""" + model_url = MODEL_URLS[model_name] + + model_location = MODEL_DIR / model_url.split("/")[-1] + if not model_location.exists(): + os.makedirs(MODEL_DIR, exist_ok=True) + download_with_progress(model_url, model_location) + + return model_location + + +def create_sam(model: str, device: str) -> Sam: + """Load the segment-anything model, fetching the model-file as necessary.""" + model_path = get_downloaded_model_path(model) + + logging.info("PyTorch version: {}".format(torch.__version__)) + logging.info("Torchvision version: {}".format(torchvision.__version__)) + logging.info("CUDA is available: {}".format(torch.cuda.is_available())) + + logging.info("Building sam from: {}".format(model_path)) + sam = sam_model_registry[model](checkpoint=model_path) + return sam.to(device=device) + + +def run_segmentation(mask_generator: SamAutomaticMaskGenerator, image: Mat) -> None: + """Run segmentation on a single image.""" + rr.log_image("image", image) + + logging.info("Finding masks") + masks = mask_generator.generate(image) + + logging.info("Found {} masks".format(len(masks))) + + # Note: it is important to sort these masks by area from largest to smallest + # this is because the masks are overlapping and we want smaller masks to + # be drawn on top of larger masks. + # TODO(jleibs): we could instead draw each mask as a separate image layer, but the current layer-stacking + # does not produce great results. + masks_with_ids = list(enumerate(masks, start=1)) + masks_with_ids.sort(key=(lambda x: x[1]["area"]), reverse=True) # type: ignore[no-any-return] + + # Layer all of the masks together, using the id as class-id in the segmentation + segmentation_img = np.zeros((image.shape[0], image.shape[1])) + for id, m in masks_with_ids: + segmentation_img[m["segmentation"]] = id + rr.log_segmentation_image("image/masks", segmentation_img) + + mask_bbox = np.array([m["bbox"] for _, m in masks_with_ids]) + rr.log_rects("image/boxes", rects=mask_bbox, class_ids=[id for id, _ in masks_with_ids]) + + +def is_url(path: str) -> bool: + """Check if a path is a url or a local file.""" + try: + result = urlparse(path) + return all([result.scheme, result.netloc]) + except ValueError: + return False + + +def load_image(image_uri: str) -> Mat: + """Conditionally download an image from URL or load it from disk.""" + logging.info("Loading: {}".format(image_uri)) + if is_url(image_uri): + response = requests.get(image_uri) + response.raise_for_status() + image_data = np.asarray(bytearray(response.content), dtype="uint8") + image = cv2.imdecode(image_data, cv2.IMREAD_COLOR) + else: + image = cv2.imread(image_uri, cv2.IMREAD_COLOR) + + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Run the Facebook Research Segment Anything example.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--model", + action="store", + default="vit_b", + choices=MODEL_URLS.keys(), + help="Which model to use.", + ) + parser.add_argument( + "--device", + action="store", + default="cpu", + help="Which torch device to use, e.g. cpu or cuda. " + "(See: https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)", + ) + parser.add_argument( + "--points-per-batch", + action="store", + default=32, + type=int, + help="Points per batch. More points will run faster, but too many will exhaust GPU memory.", + ) + parser.add_argument("images", metavar="N", type=str, nargs="*", help="A list of images to process.") + + rr.script_add_args(parser) + args = parser.parse_args() + + rr.script_setup(args, "segment_anything") + logging.getLogger().addHandler(rr.LoggingHandler("logs")) + logging.getLogger().setLevel(logging.INFO) + + sam = create_sam(args.model, args.device) + + mask_config = {"points_per_batch": args.points_per_batch} + mask_generator = SamAutomaticMaskGenerator(sam, **mask_config) + + if len(args.images) == 0: + logging.info("No image provided. Using default.") + args.images = [ + "https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg" + ] + + for n, image_uri in enumerate(args.images): + rr.set_time_sequence("image", n) + image = load_image(image_uri) + run_segmentation(mask_generator, image) + + rr.script_teardown(args) + + +if __name__ == "__main__": + main() diff --git a/examples/python/segment_anything/requirements.txt b/examples/python/segment_anything/requirements.txt new file mode 100644 index 000000000000..cf534e554f48 --- /dev/null +++ b/examples/python/segment_anything/requirements.txt @@ -0,0 +1,7 @@ +numpy +torchvision +opencv-python +rerun-sdk +tqdm +git+https://github.com/facebookresearch/segment-anything.git + From 6995edaf5d86fceb2ba282c762d0db59f70cce16 Mon Sep 17 00:00:00 2001 From: Jeremy Leibs Date: Wed, 5 Apr 2023 22:26:44 +0200 Subject: [PATCH 2/8] Doc tweaks --- examples/python/segment_anything/main.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/python/segment_anything/main.py b/examples/python/segment_anything/main.py index 5a66e514d0c1..8bd869d0cf85 100755 --- a/examples/python/segment_anything/main.py +++ b/examples/python/segment_anything/main.py @@ -9,9 +9,14 @@ Exa: ``` -python main.py --device cuda https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/dog.jpg +# Run on a remote image: +python main.py https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/dog.jpg + +# Use cuda and a different model on a local image +python main.py --device cuda --model vit_h /path/to/my_image.jpg ``` + """ @@ -141,7 +146,7 @@ def main() -> None: action="store", default="vit_b", choices=MODEL_URLS.keys(), - help="Which model to use.", + help="Which model to use." "(See: https://github.com/facebookresearch/segment-anything#model-checkpoints)", ) parser.add_argument( "--device", From 90d8ed0e3a8ffca771fd2a182638ce6202262aef Mon Sep 17 00:00:00 2001 From: Jeremy Leibs Date: Wed, 5 Apr 2023 22:29:31 +0200 Subject: [PATCH 3/8] Remove blank lines --- examples/python/segment_anything/main.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/python/segment_anything/main.py b/examples/python/segment_anything/main.py index 8bd869d0cf85..74ea377cfd41 100755 --- a/examples/python/segment_anything/main.py +++ b/examples/python/segment_anything/main.py @@ -15,8 +15,6 @@ # Use cuda and a different model on a local image python main.py --device cuda --model vit_h /path/to/my_image.jpg ``` - - """ From d593ca4d4189745d23a256bdb8d1178122b5cbc4 Mon Sep 17 00:00:00 2001 From: Jeremy Leibs Date: Wed, 5 Apr 2023 22:35:22 +0200 Subject: [PATCH 4/8] Add segmentation workaround for users still on 0.4.0 --- examples/python/segment_anything/main.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/python/segment_anything/main.py b/examples/python/segment_anything/main.py index 74ea377cfd41..c6c496a2547d 100755 --- a/examples/python/segment_anything/main.py +++ b/examples/python/segment_anything/main.py @@ -100,6 +100,15 @@ def run_segmentation(mask_generator: SamAutomaticMaskGenerator, image: Mat) -> N masks_with_ids = list(enumerate(masks, start=1)) masks_with_ids.sort(key=(lambda x: x[1]["area"]), reverse=True) # type: ignore[no-any-return] + # Work-around for https://github.com/rerun-io/rerun/issues/1782 + # Make sure we have an AnnotationInfo present for every class-id used in this image + # TODO(jleibs): Remove when fix is released + rr.log_annotation_context( + "image", + [rr.AnnotationInfo(id) for id, _ in masks_with_ids], + timeless=False, + ) + # Layer all of the masks together, using the id as class-id in the segmentation segmentation_img = np.zeros((image.shape[0], image.shape[1])) for id, m in masks_with_ids: From 138f03d473929cb9039492ca7db640ac6162853f Mon Sep 17 00:00:00 2001 From: Jeremy Leibs Date: Wed, 5 Apr 2023 22:55:42 +0200 Subject: [PATCH 5/8] Fix requirements.txt --- examples/python/segment_anything/requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/python/segment_anything/requirements.txt b/examples/python/segment_anything/requirements.txt index cf534e554f48..b9dcedfb0b79 100644 --- a/examples/python/segment_anything/requirements.txt +++ b/examples/python/segment_anything/requirements.txt @@ -1,7 +1,9 @@ numpy +torch torchvision opencv-python +requests rerun-sdk tqdm -git+https://github.com/facebookresearch/segment-anything.git +-e git+https://github.com/facebookresearch/segment-anything.git#egg=segment-anything From 86cb9c1ad36814ed14a758a9faf3569f474fc102 Mon Sep 17 00:00:00 2001 From: Jeremy Leibs Date: Thu, 6 Apr 2023 17:54:37 +0200 Subject: [PATCH 6/8] Images should use class-id as label --- crates/re_viewer/src/ui/data_ui/image.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/re_viewer/src/ui/data_ui/image.rs b/crates/re_viewer/src/ui/data_ui/image.rs index bb3d5a483548..63d5a67130b3 100644 --- a/crates/re_viewer/src/ui/data_ui/image.rs +++ b/crates/re_viewer/src/ui/data_ui/image.rs @@ -366,7 +366,7 @@ pub fn show_zoomed_image_region( .class_description(Some(ClassId(u16_val))) .annotation_info() .label(None) - .unwrap_or_default(), + .unwrap_or_else(|| u16_val.to_string()) ); ui.end_row(); }; From 281791af81e3a83ca99d69ffd8c0964b6b43287f Mon Sep 17 00:00:00 2001 From: Jeremy Leibs Date: Thu, 6 Apr 2023 17:55:22 +0200 Subject: [PATCH 7/8] Add an alternative tensor-based view --- examples/python/segment_anything/main.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/python/segment_anything/main.py b/examples/python/segment_anything/main.py index c6c496a2547d..ad2069840f54 100755 --- a/examples/python/segment_anything/main.py +++ b/examples/python/segment_anything/main.py @@ -92,7 +92,15 @@ def run_segmentation(mask_generator: SamAutomaticMaskGenerator, image: Mat) -> N logging.info("Found {} masks".format(len(masks))) - # Note: it is important to sort these masks by area from largest to smallest + # Log all the masks stacked together as a tensor + # TODO(jleibs): Tensors with class-ids and annotation-coloring would make this much slicker + mask_tensor = ( + np.dstack([np.zeros((image.shape[0], image.shape[1]))] + [m["segmentation"] for m in masks]).astype("uint8") + * 128 + ) + rr.log_tensor("mask_tensor", mask_tensor) + + # Note: for stacking, it is important to sort these masks by area from largest to smallest # this is because the masks are overlapping and we want smaller masks to # be drawn on top of larger masks. # TODO(jleibs): we could instead draw each mask as a separate image layer, but the current layer-stacking @@ -113,6 +121,7 @@ def run_segmentation(mask_generator: SamAutomaticMaskGenerator, image: Mat) -> N segmentation_img = np.zeros((image.shape[0], image.shape[1])) for id, m in masks_with_ids: segmentation_img[m["segmentation"]] = id + rr.log_segmentation_image("image/masks", segmentation_img) mask_bbox = np.array([m["bbox"] for _, m in masks_with_ids]) From df5919308ae9cbebd9d31209066fe0d7c093b441 Mon Sep 17 00:00:00 2001 From: Jeremy Leibs Date: Thu, 6 Apr 2023 21:15:32 +0200 Subject: [PATCH 8/8] Sort requirements.txt --- examples/python/segment_anything/requirements.txt | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/python/segment_anything/requirements.txt b/examples/python/segment_anything/requirements.txt index b9dcedfb0b79..9c0dfad84fa8 100644 --- a/examples/python/segment_anything/requirements.txt +++ b/examples/python/segment_anything/requirements.txt @@ -1,9 +1,8 @@ +-e git+https://github.com/facebookresearch/segment-anything.git#egg=segment-anything numpy -torch -torchvision opencv-python requests rerun-sdk +torch +torchvision tqdm --e git+https://github.com/facebookresearch/segment-anything.git#egg=segment-anything -