# Segment Every Grain

A SAM-based model for instance segmentation of images of grains

<a target="_blank" href="https://colab.research.google.com/github/zsylvester/segmenteverygrain/blob/main/Segment_every_grain_colab.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## Install and import packages

### Set up interactive figure backend in Colab

In [None]:
# this is needed to make figures in Colab interactive
!pip install ipympl
exit(0) # this restarts the runtime after installing ipympl -- otherwise you get an error when switching the matplotlib backend to ipympl

In [1]:
from google.colab import output

output.enable_custom_widget_manager()

In [2]:
%matplotlib ipympl

### Install the other dependencies

In [4]:
import torch
import torchvision

print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
import sys

!{sys.executable} -m pip install opencv-python matplotlib
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
!{sys.executable} -m pip install segmenteverygrain
pip install rtree
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

--2025-06-17 14:35:48--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 3.169.149.41, 3.169.149.5, 3.169.149.36, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.169.149.41|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2564550879 (2.4G) [binary/octet-stream]
Saving to: ‘sam_vit_h_4b8939.pth’


2025-06-17 14:35:56 (304 MB/s) - ‘sam_vit_h_4b8939.pth’ saved [2564550879/2564550879]



In [5]:
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
device = "cuda"
model_type = "default"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

In [7]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
from skimage import measure
from keras.utils import load_img
from keras.saving import load_model
from importlib import reload
import segmenteverygrain as seg
from tqdm import trange

## Download Unet model weights and create Unet model

In [8]:
!wget "https://raw.githubusercontent.com/zsylvester/segmenteverygrain/main/models/seg_model.keras"

--2025-06-17 14:37:02--  https://raw.githubusercontent.com/zsylvester/segmenteverygrain/main/models/seg_model.keras
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 26139262 (25M) [application/octet-stream]
Saving to: ‘seg_model.keras’


2025-06-17 14:37:02 (152 MB/s) - ‘seg_model.keras’ saved [26139262/26139262]



In [9]:
# UNET model
unet = load_model(
    "seg_model.keras",
    custom_objects={"weighted_crossentropy": seg.weighted_crossentropy},
)

## Run segmentation

Grains are supposed to be well defined in the image; e.g., if a grain consists of only a few pxiels, it is unlikely to be detected.

The segmentation can take a few minutes even for medium-sized images, so do not start with large images (downsample them if necessary).

In [10]:
# get example image
!wget "https://raw.githubusercontent.com/zsylvester/segmenteverygrain/main/examples/barton_creek/barton_creek_image.jpg"

--2025-06-17 14:38:21--  https://raw.githubusercontent.com/zsylvester/segmenteverygrain/main/examples/barton_creek/barton_creek_image.jpg
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1451152 (1.4M) [image/jpeg]
Saving to: ‘barton_creek_image.jpg’


2025-06-17 14:38:21 (20.2 MB/s) - ‘barton_creek_image.jpg’ saved [1451152/1451152]



In [11]:
# replace this with the path to your image:
fname = "barton_creek_image.jpg"
image = np.array(load_img(fname))
image_pred = seg.predict_image(image, unet, I=256)

# decreasing the 'dbs_max_dist' parameter results in more SAM prompts
# (and longer processing times):
labels, coords = seg.label_grains(image, image_pred, dbs_max_dist=20.0)

segmenting image tiles...


100%|██████████| 9/9 [00:08<00:00,  1.04it/s]
100%|██████████| 8/8 [00:05<00:00,  1.45it/s]


In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(image_pred)
plt.scatter(np.array(coords)[:, 0], np.array(coords)[:, 1], c="k")
plt.xticks([])
plt.yticks([]);

In [None]:
# SAM segmentation, using the point prompts from the Unet:
all_grains, labels, mask_all, grain_data, fig, ax = seg.sam_segmentation(
    sam,
    image,
    image_pred,
    coords,
    labels,
    min_area=400.0,
    plot_image=True,
    remove_edge_grains=False,
    remove_large_objects=False,
)

## Delete or merge grains in segmentation result
* click on the grain that you want to remove and press the 'x' key
* click on two grains that you want to merge and press the 'm' key (they have to be the last two grains you clicked on)

In [None]:
grain_inds = []
cid1 = fig.canvas.mpl_connect(
    "button_press_event",
    lambda event: seg.onclick2(event, all_grains, grain_inds, ax=ax),
)
cid2 = fig.canvas.mpl_connect(
    "key_press_event",
    lambda event: seg.onpress2(event, all_grains, grain_inds, fig=fig, ax=ax),
)

Run this cell if you do not want to delete / merge existing grains anymore; it is a good idea to do this before moving on to the next step.

In [None]:
fig.canvas.mpl_disconnect(cid1)
fig.canvas.mpl_disconnect(cid2)

In [None]:
# Use this function to update the 'labels' array after deleting and merging grains
# (the 'all_grains' list is updated when doing the deletion and merging):
all_grains, labels, mask_all = seg.get_grains_from_patches(ax, image)

In [None]:
# plot the updated set of grains
fig, ax = plt.subplots(figsize=(8, 6))
plt.xticks([])
plt.yticks([])
seg.plot_image_w_colorful_grains(image, all_grains, ax, cmap="Paired", plot_image=True)
seg.plot_grain_axes_and_centroids(all_grains, labels, ax, linewidth=1, markersize=10)
plt.xlim([0, np.shape(image)[1]])
plt.ylim([np.shape(image)[0], 0]);

## Add new grains using the Segment Anything Model

* click on unsegmented grain that you want to add
* press the 'x' key if you want to delete the last grain you added
* press the 'm' key if you want to merge the last two grains that you added
* right click outside the grain (but inside the most recent mask) if you want to restrict the grain to a smaller mask - this adds a background prompt

In [None]:
predictor = SamPredictor(sam)
predictor.set_image(image)  # this can take a while
coords = []
cid3 = fig.canvas.mpl_connect(
    "button_press_event", lambda event: seg.onclick(event, ax, coords, image, predictor)
)
cid4 = fig.canvas.mpl_connect(
    "key_press_event", lambda event: seg.onpress(event, ax, fig)
)

In [None]:
fig.canvas.mpl_disconnect(cid3)
fig.canvas.mpl_disconnect(cid4)

After you are done with the deletion / addition of grain masks, run this cell to generate an updated set of grains:

In [None]:
all_grains, labels, mask_all = seg.get_grains_from_patches(ax, image)

## Get grain size distribution

Run this cell and then click (left mouse button) on one end of the scale bar in the image and click (right mouse button) on the other end of the scale bar:

In [None]:
cid5 = fig.canvas.mpl_connect(
    "button_press_event", lambda event: seg.click_for_scale(event, ax)
)

Use the length of the scale bar in pixels (it should be printed below the image) to get the scale of the image (in units / pixel):

In [None]:
n_of_units = 10.0  # centimeters if using 'barton_creek_image.jpg'
units_per_pixel = n_of_units / 374.26  # length of scale bar in pixels

In [None]:
from skimage.measure import regionprops, regionprops_table

props = regionprops_table(
    labels.astype("int"),
    intensity_image=image,
    properties=(
        "label",
        "area",
        "centroid",
        "major_axis_length",
        "minor_axis_length",
        "orientation",
        "perimeter",
        "max_intensity",
        "mean_intensity",
        "min_intensity",
    ),
)
grain_data = pd.DataFrame(props)
grain_data["major_axis_length"] = (
    grain_data["major_axis_length"].values * units_per_pixel
)
grain_data["minor_axis_length"] = (
    grain_data["minor_axis_length"].values * units_per_pixel
)
grain_data["perimeter"] = grain_data["perimeter"].values * units_per_pixel
grain_data["area"] = grain_data["area"].values * units_per_pixel**2
grain_data.head()

In [None]:
grain_data.to_csv(fname[:-4] + ".csv")  # save grain data to CSV file

In [None]:
# plot histogram of grain axis lengths
# note that input data needs to be in milimeters!
# these limits are for 'barton_creek_image.jpg'
fig, ax = seg.plot_histogram_of_axis_lengths(
    grain_data["major_axis_length"] * 10,
    grain_data["minor_axis_length"] * 10,
    binsize=0.4,
    xlimits=[2, 128],
)

## Save mask and grain labels to PNG files

In [None]:
# write grayscale mask to PNG file
cv2.imwrite(fname.split("/")[-1][:-4] + "_mask.png", mask_all)
# Define a colormap using matplotlib
num_classes = len(all_grains)
cmap = plt.get_cmap("viridis", num_classes)
# Map each class label to a unique color using the colormap
vis_mask = cmap(labels.astype(np.uint16))[:, :, :3] * 255
vis_mask = vis_mask.astype(np.uint8)
# Save the mask as a PNG file
cv2.imwrite(fname.split("/")[-1][:-4] + "_labels.png", vis_mask)
# Save the image as a PNG file
cv2.imwrite(
    fname.split("/")[-1][:-4] + "_image.png", cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
)

## Run segmentation on large image

This function works with images that are larger than a few megapixels and have thousands of grains.
The 'predict_large_image' function breaks the input image into smaller patches and it runs the segmentation process on each patch.

The image used below (from [Mair et al., 2022, Earth Surface Dynamics](https://esurf.copernicus.org/articles/10/953/2022/)) is available [here](https://github.com/zsylvester/segmenteverygrain/blob/main/mair_et_al_L2_DJI_0382_image.jpg).

In [None]:
# get large example image
!wget "https://raw.githubusercontent.com/zsylvester/segmenteverygrain/main/examples/mair_et_al_L2_DJI_0382/mair_et_al_L2_DJI_0382_image.jpg"

In [None]:
from PIL import Image

Image.MAX_IMAGE_PIXELS = None  # needed if working with very large images
fname = "mair_et_al_L2_DJI_0382_image.jpg"
all_grains, image_pred, all_coords = seg.predict_large_image(
    fname, unet, sam, min_area=400.0, patch_size=2000, overlap=200
)

In [None]:
# plot results
image = np.array(load_img(fname))
fig, ax = plt.subplots(figsize=(10, 8))
plt.xticks([])
plt.yticks([])
seg.plot_image_w_colorful_grains(image, all_grains, ax, cmap="Paired")