## Import packages

Note that `rasterio` and `geopandas` need to be installed in the current Python environment for this notebook to work.

In [1]:
import cv2
from keras.utils import load_img
from keras.saving import load_model
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from segment_anything import sam_model_registry, SamPredictor
from skimage.measure import regionprops, regionprops_table
from tqdm import trange, tqdm

import segmenteverygrain as seg

%matplotlib qt

## Load models

In [2]:
# UNET model
unet = load_model(
    "./models/seg_model.keras",
    custom_objects={"weighted_crossentropy": seg.weighted_crossentropy},
)

# SAM checkpoints. Download from:
# https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
sam = sam_model_registry["default"](checkpoint="./models/sam_vit_h_4b8939.pth")

2025-04-28 13:47:34.938869: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2 Max
2025-04-28 13:47:34.938895: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 96.00 GB
2025-04-28 13:47:34.938899: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 36.00 GB
2025-04-28 13:47:34.938921: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-04-28 13:47:34.938937: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


## Run segmentation

Grains are supposed to be well defined in the image; e.g., if a grain consists of only a few pixels, it is unlikely to be detected.

The segmentation can take a few minutes even for medium-sized images, so do not start with large images (downsample them if necessary). Images with ~2000 pixels along their largest dimension are a good start.

In [4]:
fname = "./examples/RI_T01_Grid_65/RI_T01_Grid_65.tif"
image = np.array(load_img(fname))
image_pred = seg.predict_image(image, unet, I=256)

# decreasing the 'dbs_max_dist' parameter results in more SAM prompts
# (and longer processing times):
labels, coords = seg.label_grains(image, image_pred, dbs_max_dist=10.0)

# SAM segmentation, using the point prompts from the Unet:
all_grains, labels, mask_all, grain_data, fig, ax = seg.sam_segmentation(
    sam,
    image,
    image_pred,
    coords,
    labels,
    min_area=50.0,
    plot_image=True,
    remove_edge_grains=False,
    remove_large_objects=False,
)

segmenting image tiles...


  0%|          | 0/6 [00:00<?, ?it/s]2025-04-28 13:50:52.832881: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.
100%|██████████| 6/6 [00:02<00:00,  2.87it/s]
100%|██████████| 5/5 [00:01<00:00,  4.30it/s]


creating masks using SAM...


100%|██████████| 1301/1301 [00:51<00:00, 25.35it/s]


finding overlapping polygons...


1293it [00:03, 332.53it/s]


finding best polygons...


100%|██████████| 393/393 [00:08<00:00, 46.36it/s]


creating labeled image...


100%|██████████| 428/428 [00:01<00:00, 360.95it/s]


## Delete or merge grains in segmentation result
* click on the grain that you want to remove and press the 'x' key
* click on two grains that you want to merge and press the 'm' key (they have to be the last two grains you clicked on)
* press the 'g' key to hide the grain masks (so that you can see the original image better); press the 'g' key again to show the grain masks

In [5]:
grain_inds = []
cid1 = fig.canvas.mpl_connect(
    "button_press_event",
    lambda event: seg.onclick2(event, all_grains, grain_inds, ax=ax),
)
cid2 = fig.canvas.mpl_connect(
    "key_press_event",
    lambda event: seg.onpress2(event, all_grains, grain_inds, fig=fig, ax=ax),
)

Run this cell if you do not want to delete / merge existing grains anymore; it is a good idea to do this before moving on to the next step.

In [6]:
fig.canvas.mpl_disconnect(cid1)
fig.canvas.mpl_disconnect(cid2)

Use this function to update the 'all_grains' list after deleting and merging grains:

In [7]:
all_grains, labels, mask_all = seg.get_grains_from_patches(ax, image)

100%|██████████| 426/426 [00:00<00:00, 2234.33it/s]
426it [00:00, 3445.41it/s]


Plot the updated set of grains:

In [8]:
fig, ax = plt.subplots(figsize=(15, 10))
ax.imshow(image)
plt.xticks([])
plt.yticks([])
seg.plot_image_w_colorful_grains(image, all_grains, ax, cmap="Paired")
seg.plot_grain_axes_and_centroids(all_grains, labels, ax, linewidth=1, markersize=10)
plt.xlim([0, np.shape(image)[1]])
plt.ylim([np.shape(image)[0], 0]);

100%|██████████| 426/426 [00:01<00:00, 371.36it/s]


## Add new grains using the Segment Anything Model

* click on unsegmented grain that you want to add
* press the 'x' key if you want to delete the last grain you added
* press the 'm' key if you want to merge the last two grains that you added
* right click outside the grain (but inside the most recent mask) if you want to restrict the grain to a smaller mask - this adds a background prompt

In [9]:
predictor = SamPredictor(sam)
predictor.set_image(image)  # this can take a while
coords = []
cid3 = fig.canvas.mpl_connect(
    "button_press_event", lambda event: seg.onclick(event, ax, coords, image, predictor)
)
cid4 = fig.canvas.mpl_connect(
    "key_press_event", lambda event: seg.onpress(event, ax, fig)
)

In [10]:
fig.canvas.mpl_disconnect(cid3)
fig.canvas.mpl_disconnect(cid4)

After you are done with the deletion / addition of grain masks, run this cell to generate an updated set of grains:

In [11]:
all_grains, labels, mask_all = seg.get_grains_from_patches(ax, image)

  0%|          | 0/429 [00:00<?, ?it/s]

100%|██████████| 429/429 [00:00<00:00, 2207.81it/s]
429it [00:00, 3357.38it/s]


## Save mask and image to PNG files

In [None]:
dirname = "./examples/RI_T01_Grid_65/"
# write grayscale mask to PNG file
cv2.imwrite(dirname + fname.split("/")[-1][:-4] + "_mask.png", mask_all)
# Save the image as a PNG file
cv2.imwrite(
    dirname + fname.split("/")[-1][:-4] + "_image.png",
    cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
)

True

## Convert polygon row, col coordinates to projected coordinates and save them to shapefile

In [12]:
import rasterio

dataset = rasterio.open(fname)

In [13]:
# convert polygon coordinates from row, col to UTM
from shapely.geometry import Polygon

projected_polys = []
for grain in all_grains:
    x, y = rasterio.transform.xy(
        dataset.transform, grain.exterior.xy[1], grain.exterior.xy[0]
    )
    poly = Polygon(np.vstack((x, y)).T)
    projected_polys.append(poly)

In [15]:
# create geopandas dataframe
import geopandas

gdf = geopandas.GeoDataFrame(projected_polys, columns=["geometry"])
gdf.head()

Unnamed: 0,geometry
0,"POLYGON ((339043.943 4686358.495, 339043.941 4..."
1,"POLYGON ((339043.824 4686358.479, 339043.822 4..."
2,"POLYGON ((339043.228 4686358.364, 339043.226 4..."
3,"POLYGON ((339043.768 4686358.191, 339043.767 4..."
4,"POLYGON ((339043.036 4686357.764, 339043.034 4..."


In [16]:
# create property dataframe from labeled image
props = regionprops_table(
    labels.astype("int"),
    intensity_image=image,
    properties=("label", "area", "centroid", "major_axis_length", "minor_axis_length"),
)
grain_data = pd.DataFrame(props)
grain_data["major_axis_length"] = grain_data["major_axis_length"].values
grain_data["minor_axis_length"] = grain_data["minor_axis_length"].values
grain_data["area"] = grain_data["area"].values
grain_data.head()

Unnamed: 0,label,area,centroid-0,centroid-1,major_axis_length,minor_axis_length
0,1,683.0,182.021962,1097.698389,37.846881,24.424525
1,2,676.0,189.627219,1041.39645,37.30994,23.467704
2,3,410.0,255.880488,719.34878,30.05785,17.937276
3,4,393.0,348.519084,1015.414758,29.249796,19.877529
4,5,354.0,587.09322,607.022599,25.084603,18.253541


In [17]:
# convert centroids from row, col to UTM and add them to geodataframe
centroid_x, centroid_y = rasterio.transform.xy(
    dataset.transform, grain_data["centroid-0"], grain_data["centroid-1"]
)
gdf["centroid_x"] = centroid_x
gdf["centroid_y"] = centroid_y

In [18]:
# convert grain axis lengths to UTM units
gdf["major_axis_length"] = grain_data["major_axis_length"] * dataset.transform[0]
gdf["minor_axis_length"] = grain_data["minor_axis_length"] * dataset.transform[0]
gdf.head()

Unnamed: 0,geometry,centroid_x,centroid_y,major_axis_length,minor_axis_length
0,"POLYGON ((339043.943 4686358.495, 339043.941 4...",339043.91715,4686359.0,0.068124,0.043964
1,"POLYGON ((339043.824 4686358.479, 339043.822 4...",339043.815807,4686359.0,0.067158,0.042242
2,"POLYGON ((339043.228 4686358.364, 339043.226 4...",339043.236121,4686358.0,0.054104,0.032287
3,"POLYGON ((339043.768 4686358.191, 339043.767 4...",339043.76904,4686358.0,0.05265,0.03578
4,"POLYGON ((339043.036 4686357.764, 339043.034 4...",339043.033934,4686358.0,0.045152,0.032856


In [30]:
# plot grain size distribution
# units need to be in mm!
fig, ax = seg.plot_histogram_of_axis_lengths(
    gdf["major_axis_length"] * 1000,
    gdf["minor_axis_length"] * 1000,
    binsize=0.25,
    xlimits=[8, 2 * 256],
)

In [19]:
# check if everything looks good
band1 = dataset.read(1)
band2 = dataset.read(2)
band3 = dataset.read(3)
plt.figure()
plt.imshow(
    np.stack((band1, band2, band3), axis=2),
    extent=[dataset.bounds[0], dataset.bounds[2], dataset.bounds[1], dataset.bounds[3]],
)
plt.scatter(gdf["centroid_x"], gdf["centroid_y"]);

In [20]:
# set geodataframe CRS
gdf.crs = dataset.crs

In [21]:
# write shapefile
gdf.to_file("./examples/RI_T01_Grid_65/projected_grains.shp")

  gdf.to_file("./examples/RI_T01_Grid_65/projected_grains.shp")
  ogr_write(
  ogr_write(


In [22]:
dataset.close()