Create a `HOME` constant.

In [1]:
import os
HOME = os.getcwd()
print("HOME:", HOME)

HOME: /home/ec2-user/geoseg/segment-anything/notebooks


## Load Model

In [2]:
import torch

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_b"

In [4]:
import sys
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-rf_k9auj
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-rf_k9auj
  Resolved https://github.com/facebookresearch/segment-anything.git to commit 6fdee8f2727f4506cfbbe553e23b895e27956588
  Preparing metadata (setup.py) ... [?25ldone
[?25h

In [3]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

chkpt_path = '../weights/sam_vit_b_01ec64.pth'
sam = sam_model_registry[MODEL_TYPE](checkpoint=chkpt_path).to(device=DEVICE)

predictor = SamPredictor(sam)

## Upload Image(s)

In [33]:
import cv2
import os
import xml.etree.ElementTree as ET
from tqdm import tqdm

def extract_metadata(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    
    metadata = {}
    ns = {'img': 'http://pds.nasa.gov/pds4/img/v1'}
    for elem in root.findall('.//img:focal_length', ns):
        metadata['focal_length'] = elem.text
    for elem in root.findall('.//img:line_fov', ns):
        metadata['line_fov'] = elem.text
    for elem in root.findall('.//img:sample_fov', ns):
        metadata['sample_fov'] = elem.text
    for elem in root.findall('.//img:focus_distance', ns):
        metadata['focus_distance'] = elem.text
    
    return metadata

def process_folder(folder_path):
    imgs = []
    png_files = [f for f in os.listdir(folder_path) if f.endswith(".png")]
    for filename in tqdm(png_files, desc="Processing .png and .xml files"):
        img_path = os.path.join(folder_path, filename)
        xml_filename = filename.split(".")[0] + ".xml"
        xml_path = os.path.join(folder_path, xml_filename)
        if os.path.exists(img_path) and os.path.exists(xml_path):
            try:
                metadata = extract_metadata(xml_path)
                img = cv2.imread(img_path)
                if img is not None:
                    imgs.append({"filename": filename,"img": img, "meta": metadata})
            except:
                continue
        else:
            print("no img: " + img_path)
    return imgs

folder_path = "../pds/imgs"
imgs = process_folder(folder_path)


Processing .png and .xml files: 100%|██████████| 563/563 [00:31<00:00, 18.00it/s]


In [34]:
print(len(imgs))

563


## Automated Mask Generation

To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class. Set the path below to the SAM checkpoint. Running on CUDA and with the default model is recommended.

In [35]:
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.8,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=50,  # Requires open-cv to run post-processing
)

### Generate masks with SAM

In [48]:
import cv2
import json
import os
import numpy as np

# Create dataset directory if it doesn't exist
dataset_dir = "./dataset"
if not os.path.exists(dataset_dir):
    os.makedirs(dataset_dir)

# Process images and save results and metadata as JSON files
start_at = 0
end_at = 2
index = 0
for image in tqdm(imgs, desc="Processing images"):
    # if index < start_at:
    #     continue
    # if index > end_at:
    #     break
    # index += 1
    try:
        img_bgr = cv2.cvtColor(image['img'], cv2.COLOR_RGB2BGR)
        sam_results = mask_generator.generate(img_bgr)
        out_results = []
        for j in range(len(sam_results)):
            mask = (sam_results[j]['segmentation'] * 255).astype(np.uint8)
            contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 
            # Calculate polygon coordinates
            polygon = [cv2.approxPolyDP(contour, 1, True).reshape(-1, 2).tolist() for contour in contours]
            result = {
                'bbox': sam_results[j]['bbox'],
                'polygon': polygon,
                'area': sam_results[j]['area'],
                'stability_score': sam_results[j]['stability_score'],
                'point_coords': sam_results[j]['point_coords']
            }
            out_results.append(result)
        
        # Save image results and metadata as JSON files
        result_filename = os.path.join(dataset_dir, f"{image['filename'][:-4]}.json")
        with open(result_filename, 'w') as f:
            json.dump({"segments": out_results, "metadata": image['meta']}, f)
    except Exception as e:
        print(f"Error processing image {image['filename']}: {e}")
        continue


Processing images: 100%|██████████| 563/563 [2:03:18<00:00, 13.14s/it]  


In [46]:
print(imgs)

[{'filename': 'ZL0_0994_0755186846_738EBY_N0471434ZCAM07114_1100LMJ01.png', 'img': array([[[ 0,  0,  0],
        [ 0,  0,  0],
        [ 0,  0,  0],
        ...,
        [ 0,  0,  0],
        [ 0,  0,  0],
        [ 0,  0,  0]],

       [[ 0,  0,  0],
        [ 0,  0,  0],
        [ 0,  0,  0],
        ...,
        [ 0,  0,  0],
        [ 0,  0,  0],
        [ 0,  0,  0]],

       [[ 0,  0,  0],
        [ 0,  0,  0],
        [81, 77, 31],
        ...,
        [ 0,  2,  5],
        [ 0,  0,  0],
        [ 0,  0,  0]],

       ...,

       [[ 0,  0,  0],
        [ 0,  0,  0],
        [15, 15, 16],
        ...,
        [ 2,  2,  0],
        [ 0,  0,  0],
        [ 0,  0,  0]],

       [[ 0,  0,  0],
        [ 0,  0,  0],
        [ 0,  0,  0],
        ...,
        [ 0,  0,  0],
        [ 0,  0,  0],
        [ 0,  0,  0]],

       [[ 0,  0,  0],
        [ 0,  0,  0],
        [ 0,  0,  0],
        ...,
        [ 0,  0,  0],
        [ 0,  0,  0],
        [ 0,  0,  0]]], dtype=uint8), 'meta': 

### Output format

`SamAutomaticMaskGenerator` returns a `list` of masks, where each mask is a `dict` containing various information about the mask:

* `segmentation` - `[np.ndarray]` - the mask with `(W, H)` shape, and `bool` type
* `area` - `[int]` - the area of the mask in pixels
* `bbox` - `[List[int]]` - the boundary box of the mask in `xywh` format
* `predicted_iou` - `[float]` - the model's own prediction for the quality of the mask
* `point_coords` - `[List[List[float]]]` - the sampled input point that generated this mask
* `stability_score` - `[float]` - an additional measure of mask quality
* `crop_box` - `List[int]` - the crop of the image used to generate this mask in `xywh` format

### Results visualisation with Supervision

As of version `0.5.0` Supervision has native support for SAM.

In [15]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import json

save_path = '/home/sastrong/repos/m_rocks_dataset'

for i in range(len(results)):
    img = imgs[i].copy()
    img_num = "{:03}".format(i)
    entry = {"img_num": img_num, "bboxes": [], "polygons": []}
    bboxes = []
    polygons = []
    for j in range(len(results[i])):
        if (results[i][j]['area'] > 10000):
            continue

        bboxes.append(results[i][j]['bbox'])

        # Assuming the mask is a binary 2D numpy array
        mask = (results[i][j]['segmentation'] * 255).astype(np.uint8)

        # Find contours in the mask
        contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

        # Simplify contours to polygons
        polygons += [cv2.approxPolyDP(contour, 1, True).tolist() for contour in contours]

    entry["bboxes"] = bboxes
    entry["polygons"] = polygons

     # save the image to save_path/images (with the new name, i.e. 023.jpg)
    img_path = os.path.join(save_path, 'images', f'{img_num}.png')
    cv2.imwrite(img_path, img)

    # save the entry as a JSON file to save_path/labels
    json_path = os.path.join(save_path, 'labels', f'{img_num}.json')
    with open(json_path, 'w') as f:
        json.dump(entry, f)