## Import packages

In [1]:
from matplotlib import pyplot as plt
import keras
import segment_anything
import segmenteverygrain as seg
import segmenteverygrain.interactions as si
from tqdm import tqdm
from PIL import Image

%matplotlib qt

## Load models

In [2]:
# Load UNET model
unet = keras.saving.load_model(
    "./models/seg_model.keras",
    custom_objects={"weighted_crossentropy": seg.weighted_crossentropy},
)

# Download SAM model (only downloads it if it does not exist)
import os
if not os.path.exists("./models/sam_vit_h_4b8939.pth"):
    import urllib.request
    url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
    urllib.request.urlretrieve(url, "./models/sam_vit_h_4b8939.pth")

# Load SAM
fname = './models/sam_vit_h_4b8939.pth'
sam = segment_anything.sam_model_registry['default'](checkpoint=fname)
predictor = segment_anything.SamPredictor(sam)

2025-11-15 12:56:33.471446: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2 Max
2025-11-15 12:56:33.471479: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 96.00 GB
2025-11-15 12:56:33.471484: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 36.00 GB
2025-11-15 12:56:33.471508: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-11-15 12:56:33.471523: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


## Run segmentation

Grains are supposed to be well defined in the image; e.g., if a grain consists of only a few pixels, it is unlikely to be detected.

The segmentation can take a few minutes even for medium-sized images. Images with ~2000 pixels along their largest dimension are a good start and allow the user to get an idea about how well the segmentation works.

Image used below is available from [here](https://github.com/zsylvester/segmenteverygrain/blob/main/examples/barton_creek/barton_creek_image.jpg).

In [None]:
Image.MAX_IMAGE_PIXELS = None  # needed if working with very large images

# fname = "./examples/barton_creek/barton_creek_image.jpg"
# fname = "./examples/mair_et_al_L2_DJI_0382/mair_et_al_L2_DJI_0382_image_small.jpg" # use this file if you want to try a larger image

image = si.load_image(fname) # load image

all_grains, image_pred, all_coords = seg.predict_large_image(
    fname, unet, sam, 
    min_area=400.0, 
    patch_size=2000, 
    overlap=200, 
    remove_edge_grains=False
)

fig, ax = plt.subplots()
seg.plot_image_w_colorful_grains(image, all_grains, ax, cmap="tab20b", 
        plot_image=True, im_alpha=1.0)

segmenting image tiles...


100%|██████████| 9/9 [00:04<00:00,  1.99it/s]
100%|██████████| 8/8 [00:03<00:00,  2.16it/s]


creating masks using SAM...


100%|██████████| 1078/1078 [01:05<00:00, 16.53it/s]


finding overlapping polygons...


972it [00:04, 196.92it/s]


finding best polygons...


100%|██████████| 259/259 [00:10<00:00, 25.76it/s]


creating labeled image...
processed patch #1 out of 6 patches
segmenting image tiles...


100%|██████████| 9/9 [00:04<00:00,  2.21it/s]
100%|██████████| 8/8 [00:04<00:00,  1.75it/s]


creating masks using SAM...


100%|██████████| 1115/1115 [01:04<00:00, 17.27it/s]


finding overlapping polygons...


960it [00:04, 214.49it/s]


finding best polygons...


100%|██████████| 260/260 [00:09<00:00, 28.68it/s]


creating labeled image...
processed patch #2 out of 6 patches
segmenting image tiles...


100%|██████████| 7/7 [00:03<00:00,  2.18it/s]
100%|██████████| 6/6 [00:02<00:00,  2.21it/s]


creating masks using SAM...


100%|██████████| 835/835 [00:39<00:00, 21.19it/s]


finding overlapping polygons...


737it [00:03, 225.40it/s]


finding best polygons...


100%|██████████| 202/202 [00:06<00:00, 32.89it/s]


creating labeled image...
processed patch #3 out of 6 patches
segmenting image tiles...


100%|██████████| 9/9 [00:01<00:00,  5.62it/s]
100%|██████████| 8/8 [00:01<00:00,  5.70it/s]


creating masks using SAM...


100%|██████████| 515/515 [00:19<00:00, 25.80it/s]


finding overlapping polygons...


456it [00:01, 255.16it/s]


finding best polygons...


100%|██████████| 129/129 [00:03<00:00, 39.86it/s]


creating labeled image...
processed patch #4 out of 6 patches
segmenting image tiles...


100%|██████████| 9/9 [00:01<00:00,  5.71it/s]
100%|██████████| 8/8 [00:01<00:00,  5.82it/s]


creating masks using SAM...


100%|██████████| 485/485 [00:18<00:00, 26.29it/s]


finding overlapping polygons...


408it [00:01, 208.35it/s]


finding best polygons...


100%|██████████| 107/107 [00:03<00:00, 31.74it/s]


creating labeled image...
processed patch #5 out of 6 patches
segmenting image tiles...


100%|██████████| 7/7 [00:01<00:00,  5.67it/s]
100%|██████████| 6/6 [00:01<00:00,  5.78it/s]


creating masks using SAM...


100%|██████████| 372/372 [00:12<00:00, 29.99it/s]


finding overlapping polygons...


306it [00:01, 209.13it/s]


finding best polygons...


100%|██████████| 85/85 [00:02<00:00, 34.35it/s]


creating labeled image...
processed patch #6 out of 6 patches


1116it [00:00, 2052.15it/s]
100%|██████████| 66/66 [00:02<00:00, 23.22it/s]
100%|██████████| 1047/1047 [00:05<00:00, 176.81it/s]


In [223]:
import sys

# Clear all segmenteverygrain modules from cache
modules_to_remove = [key for key in sys.modules.keys() if key.startswith('segmenteverygrain')]
for module in modules_to_remove:
    del sys.modules[module]

# Now re-import
import segmenteverygrain as seg
import segmenteverygrain.interactions as si

## Results and interactive editing

In [231]:
# Extract results
grains = si.polygons_to_grains(all_grains, image=image)
for g in tqdm(grains, desc='Measuring detected grains'):
    g.measure()

Measuring detected grains: 100%|██████████| 1047/1047 [00:00<00:00, 1258.50it/s]


The editing interface itself is defined in `segmenteverygrain.interactions`.

Navigation within the interface is described in the [matplotlib documentation](https://matplotlib.org/stable/users/explain/figure/interactive.html#interactive-navigation). Additional controls are:

- `Left click` on existing grain: select/unselect 
- `Left click` in grain-free area: place foreground prompt for instant grain creation
- `Alt` + `Left click` in grain-free area: place foreground prompt for multi-prompt grain creation
- `Alt` + `Right click`: place background prompt for multi-prompt grain creation
- `Shift` (hold): enable scale bar drawing
- `Ctrl` (hold): temporarily hide selected grains
- `Esc`: Remove all prompts and unselect all grains
- `d`: Delete selected (highlighted) grains
- `m`: Merge selected grains (must be touching)
- `z`: Delete the most recently created grain

Hints for these controls are shown in the figure title bar.

Important parameters when calling `GrainPlot`:

- `px_per_m`: The ratio of pixels to meters, if known. This will be overwritten if a scale bar is measured in the interface using middle click & drag.
- `scale_m`: The length in meters of a reference object. Once the reference object is measured using left click & drag, size/area values will be converted to meters. The length of the line (shown as a red line) will be taken to represent `scale_m` meters.
- `image_max_size` (y, x): Images larger than this in either dimension will be downscaled for display. Operations like grain detection will still be performed on the full image, but the display will not be able to zoom in at full quality. This is a performance optimization. Reduce this size for better performance, increase this size for better visual quality when zoomed.
- `image_alpha`: Set this to a value lower than 1 to apply a fade effect to the background image.
- `color_palette`: Matplotlib colormap to be used when plotting the grain masks.
- `color_by`: Property to color grains by ('major_axis_length', 'minor_axis_length', 'area', 'perimeter', etc.)

In [232]:
# You only need to run this cell if you want to interactively edit the segmentation results
predictor.set_image(image)

# Display interactive interface
plot = si.GrainPlot(
    grains, 
    image = image, 
    predictor = predictor,
    blit = True,
    color_palette = 'tab20b',
    figsize = (12, 8),              # in
    scale_m = 0.1,                  # m
    color_by = None,
    px_per_m = 1856.6             # px/m; alternative to 'scale_m'; will be overwritten if scale bar is drawn on image, using 'scale_m'
)

plot.activate()

Measuring and drawing grains: 100%|██████████| 1047/1047 [00:05<00:00, 178.46it/s]


In [236]:
grains = plot.get_grains()

In [127]:
# Turn off interactive features
plot.deactivate()

# Draw the major and minor axes of each grain
plot.draw_axes()

In [108]:
# Retrieve unit conversion factor if scale bar selected in image
px_per_m = plot.px_per_m

# hist = si.get_histogram(grains, px_per_m)
summary = si.get_summary(grains, px_per_m)
hist = seg.plot_histogram_of_axis_lengths(
    summary['major_axis_length']*1000,
    summary['minor_axis_length']*1000,
    binsize=0.25,
    area=summary['area']
)

The following results are then saved to the location specified in `out_fn`:
- Grain shapes, for use elsewhere (geojson)
- Summary data, presenting measurements for each detected grain (csv)
- Summary histogram, representing major/minor axes of detected grains (jpg)
- Mask representations of the detected grains, in both computer-readable (png, 0-1) and human-readable (jpg, 0-255) formats

In [None]:
# Save results
out_fn = "/Users/zoltan/Documents/Segmentation/other/barton_creek_image_small" # filename
# Grain shapes
si.save_grains(out_fn + '_grains.geojson', grains)
# Grain image
plot.savefig(out_fn + '_grains.jpg')
# Summary data
summary = si.save_summary(
    out_fn + '_summary.csv', grains, px_per_m=plot.px_per_m)
# Summary histogram
si.save_histogram(out_fn + '_summary.jpg', summary=summary)
# Training mask
si.save_mask(out_fn + '_mask.png', grains, image, scale=False)
si.save_mask(out_fn + '_mask2.jpg', grains, image, scale=True)
summary.head()

Unnamed: 0,area,centroid-0,centroid-1,perimeter,orientation,major_axis_length,minor_axis_length,max_intensity-0,min_intensity-0,mean_intensity-0,max_intensity-1,min_intensity-1,mean_intensity-1,max_intensity-2,min_intensity-2,mean_intensity-2
0,0.000169,234.805875,336.768084,0.057834,-1.369626,0.01984,0.011519,149.0,0.0,77.413793,150.0,3.0,77.940109,154.0,6.0,77.484574
1,0.000748,494.690862,828.917383,0.136905,1.398136,0.048257,0.024631,254.0,3.0,182.829454,251.0,0.0,181.176966,236.0,0.0,173.810995
2,0.00015,9.242273,12.770522,0.050209,1.461306,0.015641,0.012975,79.0,7.0,45.590336,79.0,7.0,45.35084,79.0,9.0,47.378151
3,0.000175,12.417879,44.241802,0.054928,0.651942,0.019837,0.01176,197.0,0.0,90.673145,201.0,0.0,94.240283,212.0,2.0,102.113074
4,0.000226,23.688918,509.875523,0.067647,0.884231,0.023875,0.013615,208.0,0.0,95.851604,208.0,0.0,95.239305,206.0,0.0,93.859626


### Finetuning the base model

In [None]:
# patchify images and masks
input_dir = "./examples/unet_training/Masks_and_images/"  # the input directory should contain files with 'image' and 'mask' in their filenames
patch_dir = (
    "./examples/unet_training/"  # a directory called "Patches" will be created here
)
image_dir, mask_dir = seg.patchify_training_data(input_dir, patch_dir)

In [None]:
# create training, validation, and test datasets
train_dataset, val_dataset, test_dataset = seg.create_train_val_test_data(
    image_dir, mask_dir, augmentation=True
)

In [None]:
# load base model weights and train the model with the new data
model = seg.create_and_train_model(
    train_dataset,
    val_dataset,
    test_dataset,
    model_file="./models/seg_model.keras",
    epochs=100,
)

In [None]:
# save finetuned model as new model (this then can be loaded using "model = load_model("new_model.keras", custom_objects={'weighted_crossentropy': seg.weighted_crossentropy})"
model.save("./examples/unet_training/new_model.keras")