## Import packages

In [1]:
import keras
import matplotlib.pyplot as plt
import numpy as np
import segment_anything
import segmenteverygrain as seg
import segmenteverygrain.interactions as si
from tqdm import tqdm

%matplotlib qt

## Load models

In [2]:
# Load UNET model
unet = keras.saving.load_model(
    "./models/seg_model.keras",
    custom_objects={"weighted_crossentropy": seg.weighted_crossentropy},
)

# Download SAM model (only downloads it if it does not exist)
import os 
if not os.path.exists("./models/sam_vit_h_4b8939.pth"):
    import urllib.request
    url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
    urllib.request.urlretrieve(url, "./models/sam_vit_h_4b8939.pth")

# Load SAM
fname = './models/sam_vit_h_4b8939.pth'
sam = segment_anything.sam_model_registry['default'](checkpoint=fname)
predictor = segment_anything.SamPredictor(sam)

2025-10-07 13:50:48.601769: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2 Max
2025-10-07 13:50:48.601795: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 96.00 GB
2025-10-07 13:50:48.601800: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 36.00 GB
2025-10-07 13:50:48.601819: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-10-07 13:50:48.601834: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


## Run segmentation

Grains are supposed to be well defined in the image; e.g., if a grain consists of only a few pixels, it is unlikely to be detected.

The segmentation can take a few minutes even for medium-sized images. Images with ~2000 pixels along their largest dimension are a good start and allow the user to get an idea about how well the segmentation works.

If you have a much larger image, see the section **"Run segmentation on large image"** at the end of the notebook. Running the `predict_large_image` function takes a lot longer (e.g., several hours), but it is possible to analyze very large images with tens of thousands of grains.

Image used below is available from [here](https://github.com/zsylvester/segmenteverygrain/blob/main/examples/barton_creek/barton_creek_image.jpg).

In [3]:
# replace this with the path to your image:
fname = "./examples/barton_creek/barton_creek_image.jpg"
image = si.load_image(fname)
image_pred = seg.predict_image(image, unet, I=256)

# decreasing the 'dbs_max_dist' parameter results in more SAM prompts
# (and longer processing times):
labels, coords = seg.label_grains(image, image_pred, dbs_max_dist=20.0)

segmenting image tiles...


  0%|          | 0/9 [00:00<?, ?it/s]2025-10-07 13:50:55.163861: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.
100%|██████████| 9/9 [00:03<00:00,  2.65it/s]
100%|██████████| 8/8 [00:02<00:00,  3.02it/s]


Use the figure created in the next cell to check the quality of the Unet labeling (sometimes it doesn't work at all) and the distribution of SAM prompts (= black dots). If the Unet prediction is of poor quality, it is a good idea to create some training data and fine tune the base model so that it works better with the images of interest.

In [4]:
fig, ax = plt.subplots(figsize=(15, 10))
ax.imshow(image_pred)
plt.scatter(np.array(coords)[:, 0], np.array(coords)[:, 1], c="k")
plt.xticks([])
plt.yticks([]);

In [5]:
# SAM segmentation, using the point prompts from the Unet:
all_grains, labels, mask_all, grain_data, fig, ax = seg.sam_segmentation(
    sam, image, image_pred, coords, labels,
    min_area=300.0,
    plot_image=True,
    remove_edge_grains=False,
    remove_large_objects=False,
)

creating masks using SAM...


100%|██████████| 1120/1120 [01:14<00:00, 15.11it/s]


finding overlapping polygons...


1112it [00:08, 136.36it/s]


finding best polygons...


100%|██████████| 309/309 [00:11<00:00, 27.48it/s]


creating labeled image...


100%|██████████| 340/340 [00:01<00:00, 190.24it/s]


In [6]:
# Save SAM image
out_fn = 'examples/auto_detection/barton_creek'
fig.savefig(out_fn + '_grains.jpg', bbox_inches='tight', pad_inches=0)
plt.close()

## Results

In [7]:
# Extract results
grains = si.polygons_to_grains(all_grains, image=image)
for g in tqdm(grains, desc='Measuring detected grains'):
    g.measure()

Measuring detected grains: 100%|██████████| 340/340 [00:00<00:00, 1324.39it/s]


The following results are then saved to the location specified in `out_fn`:
- Grain shapes, for use elsewhere (geojson)
- Summary data, presenting measurements for each detected grain (csv)
- Summary histogram, representing major/minor axes of detected grains (jpg)
- Mask representations of the detected grains, in both computer-readable (png, 0-1) and human-readable (jpg, 0-255) formats

In [8]:
# re-plot results if needed
fig, ax = plt.subplots(figsize=(15, 10))
plt.xticks([])
plt.yticks([])
seg.plot_image_w_colorful_grains(image, all_grains, ax, cmap="Paired", plot_image=True)
seg.plot_grain_axes_and_centroids(all_grains, labels, ax, linewidth=1, markersize=10)
plt.xlim([0, np.shape(image)[1]])
plt.ylim([np.shape(image)[0], 0]);

100%|██████████| 340/340 [00:01<00:00, 190.47it/s]


Run this cell and then click (left mouse button) on one end of the scale bar in the image and click (right mouse button) on the other end of the scale bar:

In [9]:
cid = fig.canvas.mpl_connect(
    "button_press_event", lambda event: seg.click_for_scale(event, ax)
)

number of pixels: 375.85


If `px_per_m` is 1, then the summary data and histogram will be in pixels. If the ratio of pixels to meters is known, set `px_per_m` in order to save them in meters.

In [10]:
# Save results
px_per_m = 3711.9    # 371.19 pixels / 10 cm (scale bar on photo)
out_fn = 'examples/auto_detection/barton_creek'
# Grain shapes
si.save_grains(out_fn + '_grains.geojson', grains)
# Summary data
summary = si.save_summary(out_fn + '_summary.csv', grains, px_per_m=px_per_m)
# Summary histogram
si.save_histogram(out_fn + '_summary.jpg', summary=summary)
# Training mask
si.save_mask(out_fn + '_mask.png', grains, image, scale=False)
si.save_mask(out_fn + '_mask2.jpg', grains, image, scale=True)

## Delete, merge, and add grains

Open and run the [`interactive_edit.ipynb`]('intercative_edit.ipynb') notebook to refine the results and generate training data. This new approach is more user friendly and faster than the previous implementation; thanks to [Dave Matthews](https://github.com/dirtbirb) for contributing the `segmenteverygrain.interactions` module.

## Run segmentation on large image
In this case 'fname' points to an image that is larger than a few megapixels and has thousands of grains.
The 'predict_large_image' function breaks the input image into smaller patches and it runs the segmentation process on each patch.

The image used below (from [Mair et al., 2022, Earth Surface Dynamics](https://esurf.copernicus.org/articles/10/953/2022/)) is available [here](https://github.com/zsylvester/segmenteverygrain/blob/main/examples/mair_et_al_L2_DJI_0382/mair_et_al_L2_DJI_0382_image.jpg).

In [20]:
from PIL import Image

Image.MAX_IMAGE_PIXELS = None  # needed if working with very large images
fname = "./examples/mair_et_al_L2_DJI_0382/mair_et_al_L2_DJI_0382_image.jpg"
all_grains, image_pred, all_coords = seg.predict_large_image(
    fname, unet, sam, min_area=400.0, patch_size=2000, overlap=200
)

segmenting image tiles...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:03<00:00,  2.40it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:02<00:00,  2.68it/s]


creating masks using SAM...


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2804/2804 [03:06<00:00, 15.02it/s]


finding overlapping polygons...


2537it [00:04, 570.08it/s]


finding best polygons...


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 1021/1021 [00:24<00:00, 42.12it/s]


creating labeled image...
processed patch #1 out of 6 patches
segmenting image tiles...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:03<00:00,  2.42it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:03<00:00,  2.36it/s]


creating masks using SAM...


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 1996/1996 [02:12<00:00, 15.09it/s]


finding overlapping polygons...


1766it [00:04, 402.55it/s]


finding best polygons...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 666/666 [00:14<00:00, 47.25it/s]


creating labeled image...
processed patch #2 out of 6 patches
segmenting image tiles...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:03<00:00,  2.53it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:03<00:00,  2.53it/s]


creating masks using SAM...


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 1757/1757 [01:56<00:00, 15.05it/s]


finding overlapping polygons...


1545it [00:04, 367.67it/s]


finding best polygons...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 581/581 [00:11<00:00, 51.63it/s]


creating labeled image...
processed patch #3 out of 6 patches
segmenting image tiles...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:03<00:00,  2.47it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:03<00:00,  2.48it/s]


creating masks using SAM...


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 3173/3173 [03:24<00:00, 15.53it/s]


finding overlapping polygons...


2941it [00:03, 840.91it/s] 


finding best polygons...


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:32<00:00, 37.95it/s]


creating labeled image...
processed patch #4 out of 6 patches
segmenting image tiles...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:03<00:00,  2.33it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:03<00:00,  2.40it/s]


creating masks using SAM...


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2041/2041 [02:22<00:00, 14.37it/s]


finding overlapping polygons...


1811it [00:04, 439.28it/s]


finding best polygons...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 693/693 [00:15<00:00, 44.45it/s]


creating labeled image...
processed patch #5 out of 6 patches
segmenting image tiles...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:03<00:00,  2.43it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:03<00:00,  2.44it/s]


creating masks using SAM...


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 1662/1662 [01:46<00:00, 15.64it/s]


finding overlapping polygons...


1436it [00:04, 323.01it/s]


finding best polygons...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 519/519 [00:10<00:00, 47.35it/s]


creating labeled image...
processed patch #6 out of 6 patches


4753it [00:02, 2302.03it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 331/331 [00:54<00:00,  6.09it/s]


In [22]:
# plot results
image = si.load_image(fname)
fig, ax = plt.subplots(figsize=(15, 10))
plt.xticks([])
plt.yticks([])
seg.plot_image_w_colorful_grains(image, all_grains, ax, cmap="Paired")
plt.axis("equal")
plt.xlim([0, np.shape(image)[1]])
plt.ylim([np.shape(image)[0], 0]);

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 4399/4399 [00:13<00:00, 315.12it/s]


Run this cell and then click (left mouse button) on one end of the scale bar in the image and click (right mouse button) on the other end of the scale bar:

In [None]:
cid = fig.canvas.mpl_connect(
    "button_press_event", lambda event: seg.click_for_scale(event, ax)
)

In [25]:
# Extract results
grains = si.polygons_to_grains(all_grains, image=image)
for g in tqdm(grains, desc='Measuring detected grains'):
    g.measure()

Measuring detected grains: 100%|██████████████████████████████████████████████████████████████████| 4399/4399 [00:02<00:00, 2142.97it/s]


In [26]:
# Save results
px_per_m = 1812  # 181.2 pixels / 10 cm (scale bar on photo)
out_fn = 'examples/auto_detection/mair_et_al'
# Grain shapes
si.save_grains(out_fn + '_grains.geojson', grains)
# Summary data
summary = si.save_summary(out_fn + '_summary.csv', grains, px_per_m=px_per_m)
# Summary histogram
si.save_histogram(out_fn + '_summary.jpg', summary=summary)
# Training mask
si.save_mask(out_fn + '_mask.png', grains, image, scale=False)
si.save_mask(out_fn + '_mask2.jpg', grains, image, scale=True)

### Finetuning the base model

In [None]:
# patchify images and masks
input_dir = "./examples/unet_training/Masks_and_images/"  # the input directory should contain files with 'image' and 'mask' in their filenames
patch_dir = (
    "./examples/unet_training/"  # a directory called "Patches" will be created here
)
image_dir, mask_dir = seg.patchify_training_data(input_dir, patch_dir)

In [None]:
# create training, validation, and test datasets
train_dataset, val_dataset, test_dataset = seg.create_train_val_test_data(
    image_dir, mask_dir, augmentation=True
)

In [None]:
# load base model weights and train the model with the new data
model = seg.create_and_train_model(
    train_dataset,
    val_dataset,
    test_dataset,
    model_file="./models/seg_model.keras",
    epochs=100,
)

In [None]:
# save finetuned model as new model (this then can be loaded using "model = load_model("new_model.keras", custom_objects={'weighted_crossentropy': seg.weighted_crossentropy})"
model.save("./examples/unet_training/new_model.keras")