## Import packages

In [1]:
import cv2
from keras.utils import load_img
from keras.saving import load_model
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from segment_anything import sam_model_registry, SamPredictor
from skimage.measure import regionprops, regionprops_table
from tqdm import trange, tqdm

import segmenteverygrain as seg

%matplotlib qt

2025-04-21 16:59:23.565887: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745279963.577414   32361 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745279963.581013   32361 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-21 16:59:23.592886: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Load models

In [2]:
# UNET model
unet = load_model(
    "./models/seg_model.keras",
    custom_objects={'weighted_crossentropy': seg.weighted_crossentropy})

# SAM checkpoints. Download from:
# https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
sam = sam_model_registry["default"](
    checkpoint="./models/sam_vit_h_4b8939.pth")

2025-04-21 16:59:28.761006: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


## Run segmentation

Grains are supposed to be well defined in the image; e.g., if a grain consists of only a few pixels, it is unlikely to be detected.

The segmentation can take a few minutes even for medium-sized images. Images with ~2000 pixels along their largest dimension are a good start and allow the user to get an idea about how well the segmentation works.

If you have a much larger image, see the section **"Run segmentation on large image"** at the end of the notebook. Running the `predict_large_image` function takes a lot longer (e.g., several hours), but it is possible to analyze very large images with tens of thousands of grains.

Image used below is available from [here](https://github.com/zsylvester/segmenteverygrain/blob/main/torrey_pines_beach_image.jpeg).

In [None]:
# replace this with the path to your image:
fname = './examples/torrey_pines_beach/torrey_pines_beach_image.jpg'
image = np.array(load_img(fname))
image_pred = seg.predict_image(image, unet, I=256)

# decreasing the 'dbs_max_dist' parameter results in more SAM prompts
# (and longer processing times):
labels, coords = seg.label_grains(image, image_pred, dbs_max_dist=20.0)

segmenting image tiles...


100%|██████████| 7/7 [00:02<00:00,  2.77it/s]
100%|██████████| 6/6 [00:01<00:00,  3.12it/s]


Use the figure created in the next cell to check the quality of the Unet labeling (sometimes it doesn't work at all) and the distribution of SAM prompts (= black dots). If the Unet prediction is of poor quality, it is a good idea to create some training data and fine tune the base model so that it works better with the images of interest.

In [4]:
fig, ax = plt.subplots(figsize=(15,10))
ax.imshow(image_pred)
plt.scatter(np.array(coords)[:,0], np.array(coords)[:,1], c='k')
plt.xticks([])
plt.yticks([]);

In [None]:
# SAM segmentation, using the point prompts from the Unet:
all_grains, labels, mask_all, grain_data, fig, ax = seg.sam_segmentation(
    sam, image, image_pred, coords, labels, 
    min_area=400.0, plot_image=True, 
    remove_edge_grains=False, remove_large_objects=False)

creating masks using SAM...


100%|██████████| 1041/1041 [01:03<00:00, 16.37it/s]


finding overlapping polygons...


1041it [00:06, 160.79it/s]


finding best polygons...


100%|██████████| 394/394 [00:12<00:00, 31.68it/s]


creating labeled image...


100%|██████████| 467/467 [00:03<00:00, 152.18it/s]


## Delete or merge grains in segmentation result
* click on the grain that you want to remove and press the 'x' key
* click on two grains that you want to merge and press the 'm' key (they have to be the last two grains you clicked on)
* press the 'g' key to hide the grain masks (so that you can see the original image better); press the 'g' key again to show the grain masks

In [6]:
grain_inds = []
cid1 = fig.canvas.mpl_connect('button_press_event', 
    lambda event: seg.onclick2(event, all_grains, grain_inds, ax=ax))
cid2 = fig.canvas.mpl_connect('key_press_event', 
    lambda event: seg.onpress2(event, all_grains, grain_inds, fig=fig, ax=ax))

Run this cell if you do not want to delete / merge existing grains anymore; it is a good idea to do this before moving on to the next step.

In [7]:
fig.canvas.mpl_disconnect(cid1)
fig.canvas.mpl_disconnect(cid2)

Use this function to update the 'labels' array after deleting and merging grains (the 'all_grains' list is updated when doing the deletion and merging):

In [8]:
all_grains, labels, mask_all = seg.get_grains_from_patches(ax, image)

100%|██████████| 467/467 [00:00<00:00, 1381.18it/s]
467it [00:00, 643.95it/s]


Plot the updated set of grains:

In [None]:
fig, ax = plt.subplots(figsize=(15,10))
plt.xticks([])
plt.yticks([])
seg.plot_image_w_colorful_grains(
    image, all_grains, ax, cmap='Paired', plot_image=True)
seg.plot_grain_axes_and_centroids(
    all_grains, labels, ax, linewidth=1, markersize=10)
plt.xlim([0, np.shape(image)[1]])
plt.ylim([np.shape(image)[0], 0]);

100%|██████████| 467/467 [00:03<00:00, 153.36it/s]


## Add new grains using the Segment Anything Model

* click on unsegmented grain that you want to add
* press the 'x' key if you want to delete the last grain you added
* press the 'm' key if you want to merge the last two grains that you added
* right click outside the grain (but inside the most recent mask) if you want to restrict the grain to a smaller mask - this adds a background prompt

In [None]:
predictor = SamPredictor(sam)
predictor.set_image(image) # this can take a while
coords = []
cid3 = fig.canvas.mpl_connect('button_press_event',
    lambda event: seg.onclick(event, ax, coords, image, predictor))
cid4 = fig.canvas.mpl_connect('key_press_event',
    lambda event: seg.onpress(event, ax, fig))

In [11]:
fig.canvas.mpl_disconnect(cid3)
fig.canvas.mpl_disconnect(cid4)

After you are done with the deletion / addition of grain masks, run this cell to generate an updated set of grains:

In [12]:
all_grains, labels, mask_all = seg.get_grains_from_patches(ax, image)

100%|██████████| 467/467 [00:00<00:00, 1350.03it/s]
467it [00:00, 647.83it/s]


## Get grain size distribution

Run this cell and then click (left mouse button) on one end of the scale bar in the image and click (right mouse button) on the other end of the scale bar:

In [None]:
cid5 = fig.canvas.mpl_connect('button_press_event',
    lambda event: seg.click_for_scale(event, ax))

Use the length of the scale bar in pixels (it should be printed above) to get the scale of the image (in units / pixel):

In [None]:
n_of_units = 10.59
units_per_pixel = n_of_units / 507.96 # length of scale bar in pixels

In [15]:
props = regionprops_table(labels.astype('int'), intensity_image = image, properties =\
        ('label', 'area', 'centroid', 'major_axis_length', 'minor_axis_length', 
         'orientation', 'perimeter', 'max_intensity', 'mean_intensity', 'min_intensity'))
grain_data = pd.DataFrame(props)
grain_data['major_axis_length'] = grain_data['major_axis_length'].values*units_per_pixel
grain_data['minor_axis_length'] = grain_data['minor_axis_length'].values*units_per_pixel
grain_data['perimeter'] = grain_data['perimeter'].values*units_per_pixel
grain_data['area'] = grain_data['area'].values*units_per_pixel**2

In [16]:
grain_data.head()

Unnamed: 0,label,area,centroid-0,centroid-1,major_axis_length,minor_axis_length,orientation,perimeter,max_intensity-0,max_intensity-1,max_intensity-2,mean_intensity-0,mean_intensity-1,mean_intensity-2,min_intensity-0,min_intensity-1,min_intensity-2
0,1,0.85842,863.925063,25.706329,1.710283,0.838469,1.094589,4.589185,253.0,255.0,255.0,169.554937,175.094177,171.145823,0.0,0.0,0.0
1,2,0.205152,151.879237,916.902542,0.782615,0.394915,0.800688,2.096663,116.0,130.0,141.0,28.425847,32.326271,34.319915,0.0,0.0,0.0
2,3,0.641968,553.547055,1255.301963,0.986216,0.882518,-0.075191,3.313006,235.0,233.0,210.0,143.442112,138.888287,123.955315,4.0,4.0,4.0
3,4,0.401176,589.889491,901.600217,1.102045,0.503646,1.403738,2.817711,234.0,218.0,195.0,76.784399,63.483207,51.793066,3.0,2.0,0.0
4,5,0.289907,691.875562,1244.031484,1.532,0.388542,1.141661,3.467706,115.0,124.0,123.0,19.632684,28.26087,34.646177,0.0,0.0,0.0


In [17]:
grain_data.to_csv(fname[:-4]+'.csv') # save grain data to CSV file

In [None]:
# plot histogram of grain axis lengths
# note that input data needs to be in milimeters
fig, ax = seg.plot_histogram_of_axis_lengths(
    grain_data['major_axis_length'], grain_data['minor_axis_length'],
    binsize=0.2, xlimits=[0.25, 4])

## Save mask and grain labels to PNG files

In [None]:
dirname = './examples/torrey_pines_beach/'
# write grayscale mask to PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_mask.png', mask_all)
# Save the image as a PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_image.png',
    cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

[ WARN:0@203.635] global loadsave.cpp:848 imwrite_ Unsupported depth image for selected encoder is fallbacked to CV_8U.


True

## Run segmentation on large image (new!)
In this case 'fname' points to an image that is larger than a few megapixels and has thousands of grains.
The 'predict_large_image' function breaks the input image into smaller patches and it runs the segmentation process on each patch.

The image used below (from [Mair et al., 2022, Earth Surface Dynamics](https://esurf.copernicus.org/articles/10/953/2022/)) is available [here](https://github.com/zsylvester/segmenteverygrain/blob/main/mair_et_al_L2_DJI_0382_image.jpg).

In [22]:
from PIL import Image
Image.MAX_IMAGE_PIXELS = None # needed if working with very large images
fname = "./examples/mair_et_al_L2_DJI_0382/mair_et_al_L2_DJI_0382_image.jpg"
all_grains, image_pred, all_coords = seg.predict_large_image(
    fname, unet, sam, min_area=400.0, patch_size=2000, overlap=200)

segmenting image tiles...


100%|██████████| 9/9 [00:07<00:00,  1.18it/s]
100%|██████████| 8/8 [00:06<00:00,  1.17it/s]


creating masks using SAM...


100%|██████████| 2804/2804 [05:00<00:00,  9.33it/s]


finding overlapping polygons...


2537it [00:13, 188.53it/s]


finding best polygons...


100%|██████████| 1021/1021 [00:52<00:00, 19.34it/s]


creating labeled image...
processed patch #1 out of 6 patches
segmenting image tiles...


100%|██████████| 9/9 [00:07<00:00,  1.18it/s]
100%|██████████| 8/8 [00:06<00:00,  1.18it/s]


creating masks using SAM...


100%|██████████| 1996/1996 [03:36<00:00,  9.20it/s]


finding overlapping polygons...


1766it [00:13, 127.02it/s]


finding best polygons...


100%|██████████| 666/666 [00:29<00:00, 22.26it/s]


creating labeled image...
processed patch #2 out of 6 patches
segmenting image tiles...


100%|██████████| 9/9 [00:07<00:00,  1.14it/s]
100%|██████████| 8/8 [00:07<00:00,  1.14it/s]


creating masks using SAM...


100%|██████████| 1757/1757 [02:57<00:00,  9.90it/s]


finding overlapping polygons...


1545it [00:12, 122.48it/s]


finding best polygons...


100%|██████████| 581/581 [00:23<00:00, 25.10it/s]


creating labeled image...
processed patch #3 out of 6 patches
segmenting image tiles...


100%|██████████| 9/9 [00:07<00:00,  1.16it/s]
100%|██████████| 8/8 [00:06<00:00,  1.15it/s]


creating masks using SAM...


100%|██████████| 3173/3173 [05:18<00:00,  9.95it/s]


finding overlapping polygons...


2941it [00:10, 270.81it/s]


finding best polygons...


100%|██████████| 1250/1250 [01:10<00:00, 17.66it/s]


creating labeled image...
processed patch #4 out of 6 patches
segmenting image tiles...


100%|██████████| 9/9 [00:07<00:00,  1.14it/s]
100%|██████████| 8/8 [00:06<00:00,  1.17it/s]


creating masks using SAM...


100%|██████████| 2041/2041 [03:36<00:00,  9.45it/s]


finding overlapping polygons...


1811it [00:12, 142.33it/s]


finding best polygons...


100%|██████████| 693/693 [00:31<00:00, 21.85it/s]


creating labeled image...
processed patch #5 out of 6 patches
segmenting image tiles...


100%|██████████| 9/9 [00:07<00:00,  1.16it/s]
100%|██████████| 8/8 [00:06<00:00,  1.17it/s]


creating masks using SAM...


100%|██████████| 1662/1662 [02:38<00:00, 10.46it/s]


finding overlapping polygons...


1436it [00:13, 103.43it/s]


finding best polygons...


100%|██████████| 519/519 [00:22<00:00, 23.52it/s]


creating labeled image...
processed patch #6 out of 6 patches


4753it [00:05, 831.37it/s] 
100%|██████████| 331/331 [02:01<00:00,  2.72it/s]


In [23]:
# plot results
image = np.array(load_img(fname))
fig, ax = plt.subplots(figsize=(15,10))
plt.xticks([])
plt.yticks([])
seg.plot_image_w_colorful_grains(image, all_grains, ax, cmap='Paired')
plt.axis('equal')
plt.xlim([0, np.shape(image)[1]])
plt.ylim([np.shape(image)[0], 0]);

100%|██████████| 4399/4399 [00:28<00:00, 153.01it/s]


In [24]:
# this is a faster way of deleting false positives (because it avoids highlighting and deleting the 'bad' grains)
grain_inds = []
cid1 = fig.canvas.mpl_connect('button_press_event',
    lambda event: seg.onclick2(event, all_grains, grain_inds, ax=ax, select_only=True))

In [25]:
# delete polygons from 'all_grains'
grain_inds = np.unique(grain_inds)
grain_inds = sorted(grain_inds, reverse=True)
for ind in tqdm(grain_inds):
    all_grains.remove(all_grains[ind])

0it [00:00, ?it/s]


After plotting the results, you will want to use the functions for deleting, merging, and adding grains (see above), before saving the results (same workflow as for a small image).

### Finetuning the base model

In [None]:
# patchify images and masks
input_dir = "./examples/unet_training/Masks_and_images/" # the input directory should contain files with 'image' and 'mask' in their filenames
patch_dir = "./examples/unet_training/" # a directory called "Patches" will be created here
image_dir, mask_dir = seg.patchify_training_data(input_dir, patch_dir)

In [None]:
# create training, validation, and test datasets
train_dataset, val_dataset, test_dataset = seg.create_train_val_test_data(
    image_dir, mask_dir, augmentation=True)

In [None]:
# load base model weights and train the model with the new data
model = seg.create_and_train_model(train_dataset, val_dataset, test_dataset, model_file='./models/seg_model.keras', epochs=100)

In [None]:
# save finetuned model as new model (this then can be loaded using "model = load_model("new_model.keras", custom_objects={'weighted_crossentropy': seg.weighted_crossentropy})"
model.save('./examples/unet_training/new_model.keras')