## Import packages

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
from skimage.measure import regionprops, regionprops_table
from keras.utils import load_img
from keras.saving import load_model
from importlib import reload
import segmenteverygrain as seg
import interactions as segi
from segment_anything import sam_model_registry, SamPredictor
from tqdm import trange, tqdm
%matplotlib qt

## Load models

In [None]:
model = load_model("seg_model.keras", custom_objects={'weighted_crossentropy': seg.weighted_crossentropy})

# the SAM model checkpoints can be downloaded from: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
sam = sam_model_registry["default"](checkpoint="sam_vit_h_4b8939.pth")

2025-03-21 13:42:01.239706: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


## Run segmentation

Grains are supposed to be well defined in the image; e.g., if a grain consists of only a few pixels, it is unlikely to be detected.

The segmentation can take a few minutes even for medium-sized images. Images with ~2000 pixels along their largest dimension are a good start and allow the user to get an idea about how well the segmentation works.

If you have a much larger image, see the section **"Run segmentation on large image"** at the end of the notebook. Running the `predict_large_image` function takes a lot longer (e.g., several hours), but it is possible to analyze very large images with tens of thousands of grains.

Image used below is available from [here](https://github.com/zsylvester/segmenteverygrain/blob/main/torrey_pines_beach_image.jpeg).

In [None]:
reload(seg)
# replace this with the path to your image:
fname = 'torrey_pines_beach_image.jpeg'

image = np.array(load_img(fname))
image_pred = seg.predict_image(image, model, I=256)

# decreasing the 'dbs_max_dist' parameter results in more SAM prompts (and longer processing times):
labels, coords = seg.label_grains(image, image_pred, dbs_max_dist=20.0) # Unet prediction

segmenting image tiles...


100%|██████████| 7/7 [00:02<00:00,  2.56it/s]
100%|██████████| 6/6 [00:02<00:00,  2.91it/s]


Use the figure created in the next cell to check the quality of the Unet labeling (sometimes it doesn't work at all) and the distribution of SAM prompts (= black dots). If the Unet prediction is of poor quality, it is a good idea to create some training data and fine tune the base model so that it works better with the images of interest.

In [4]:
fig, ax = plt.subplots(figsize=(15,10))
ax.imshow(image_pred)
plt.scatter(np.array(coords)[:,0], np.array(coords)[:,1], c='k')
plt.xticks([])
plt.yticks([]);

In [5]:
# SAM segmentation, using the point prompts from the Unet:
all_grains, labels, mask_all, grain_data, fig, ax = seg.sam_segmentation(sam, image, image_pred, 
            coords, labels, min_area=400.0, plot_image=True, remove_edge_grains=False, remove_large_objects=False)

creating masks using SAM...


100%|██████████| 1041/1041 [01:09<00:00, 15.04it/s]


finding overlapping polygons...


1041it [00:06, 160.51it/s]


finding best polygons...


100%|██████████| 394/394 [00:12<00:00, 31.77it/s]


creating labeled image...


100%|██████████| 467/467 [00:02<00:00, 162.70it/s]


## Edit grains in segmentation result

The editing interface will open in a separate window to allow fullscreen and other functionality. Navigation within the interface is described in the [matplotlib documentation](https://matplotlib.org/stable/users/explain/figure/interactive.html#interactive-navigation).

Additional controls:
- `Left click`: Select/unselect existing grain or place foreground prompt for grain detection
- `Shift + left click/drag`: Create or adjust box prompt for grain detection
- `Right click`: Place background prompt for grain detection
- `Middle click`: Display measurement information about the indicated grain
- `Middle click + drag`: Measure scale bar to calibrate pixels per meter
- `Control`: Hold to temporarily hide selected grains
- `Escape`: Remove all prompts and unselect all grains
- `c`: Use selection box and/or foreground/background prompts to detect a grain
- `d`: Delete selected (highlighted) grains
- `m`: Merge selected grains (must be touching)
- `z`: Delete the most recently-created grain

Plot creation parameters:
- `px_per_m`: The ratio of pixels to meters, if known. This will be overwritten if a scale bar is measured in the interface using middle click & drag.
- `scale_m`: The length in meters of a scale bar in the image. Measure it with middle click & drag to calibrate size/area measurements.

In [13]:
# Get image predictor
predictor = SamPredictor(sam)
predictor.set_image(image)

# Create Grain objects from detected polygons
grains = segi.polygons_to_grains(all_grains)

# Display interactive interface
plot = segi.GrainPlot(
    grains, 
    image = image, 
    predictor = predictor,
    blit = True,
    figsize = (12, 8),              # in
    px_per_m = 1.0,                 # px/m
    scale_m = 1.0                   # m
)
plot.activate()

100%|██████████| 467/467 [00:06<00:00, 76.69it/s]


In [28]:
plot.deactivate()

grain_data = segi.get_summary(plot.grains, plot.px_per_m)
grain_data.head()

Unnamed: 0,area,centroid-0,centroid-1,major_axis_length,minor_axis_length,orientation,perimeter,max_intensity-0,max_intensity-1,max_intensity-2,mean_intensity-0,mean_intensity-1,mean_intensity-2,min_intensity-0,min_intensity-1,min_intensity-2
0,0.069863,863.404234,25.738911,0.487285,0.239275,1.087517,1.294373,253.0,255.0,255.0,170.155242,175.738407,171.784778,0.0,0.0,0.0
0,0.016163,151.172113,916.464052,0.217807,0.111859,0.797045,0.583475,116.0,130.0,141.0,30.259259,34.215686,36.211329,0.0,0.0,0.0
0,0.051905,552.653324,1255.13772,0.279542,0.251729,0.009142,0.93952,235.0,233.0,210.0,141.476255,137.125509,122.672999,4.0,4.0,4.0
0,0.032608,589.158747,901.592873,0.315568,0.142672,1.395407,0.797102,234.0,218.0,195.0,76.356371,63.087473,51.423326,3.0,0.0,0.0
0,0.025389,690.399445,1242.565881,0.447397,0.11025,1.136451,0.995632,167.0,166.0,138.0,25.797503,33.572816,37.941748,0.0,0.0,0.0


In [29]:
grain_data.to_csv(fname[:-4]+'.csv') # save grain data to CSV file

In [30]:
# plot histogram of grain axis lengths
# note that input data needs to be in milimeters
fig, ax = seg.plot_histogram_of_axis_lengths(grain_data['major_axis_length'], grain_data['minor_axis_length'], binsize=0.2, xlimits=[0.25, 4])

## Save mask and grain labels to PNG files

In [29]:
dirname = '/Users/zoltan/Dropbox/Segmentation/images/'
# write grayscale mask to PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_mask.png', mask_all)
# Save the image as a PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_image.png', cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

True

## Run segmentation on large image (new!)
In this case 'fname' points to an image that is larger than a few megapixels and has thousands of grains.
The 'predict_large_image' function breaks the input image into smaller patches and it runs the segmentation process on each patch.

The image used below (from [Mair et al., 2022, Earth Surface Dynamics](https://esurf.copernicus.org/articles/10/953/2022/)) is available [here](https://github.com/zsylvester/segmenteverygrain/blob/main/mair_et_al_L2_DJI_0382_image.jpg).

In [9]:
from PIL import Image
Image.MAX_IMAGE_PIXELS = None # needed if working with very large images
fname = "mair_et_al_L2_DJI_0382_image.jpg"
all_grains, image_pred, all_coords = seg.predict_large_image(fname, model, sam, min_area=400.0, patch_size=2000, overlap=200)

segmenting image tiles...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:07<00:00,  1.24it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.25it/s]


creating masks using SAM...


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 3197/3197 [03:56<00:00, 13.51it/s]


finding overlapping polygons...


2887it [00:06, 463.54it/s]


finding best polygons...


100%|████████████████████████████████████████████████████████████████████████████████████████████████| 969/969 [00:09<00:00, 106.43it/s]


creating labeled image...
processed patch #1 out of 6 patches
segmenting image tiles...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:07<00:00,  1.21it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.19it/s]


creating masks using SAM...


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2575/2575 [03:08<00:00, 13.66it/s]


finding overlapping polygons...


2282it [00:08, 278.10it/s]


finding best polygons...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [00:12<00:00, 53.85it/s]


creating labeled image...
processed patch #2 out of 6 patches
segmenting image tiles...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:07<00:00,  1.20it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.17it/s]


creating masks using SAM...


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2140/2140 [02:24<00:00, 14.78it/s]


finding overlapping polygons...


1894it [00:05, 332.08it/s]


finding best polygons...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 595/595 [00:07<00:00, 80.87it/s]


creating labeled image...
processed patch #3 out of 6 patches
segmenting image tiles...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:07<00:00,  1.25it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.26it/s]


creating masks using SAM...


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 3564/3564 [03:59<00:00, 14.87it/s]


finding overlapping polygons...


3312it [00:04, 729.47it/s]


finding best polygons...


100%|██████████████████████████████████████████████████████████████████████████████████████████████| 1157/1157 [00:06<00:00, 190.14it/s]


creating labeled image...
processed patch #4 out of 6 patches
segmenting image tiles...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:07<00:00,  1.21it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.24it/s]


creating masks using SAM...


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2484/2484 [02:48<00:00, 14.74it/s]


finding overlapping polygons...


2242it [00:05, 390.88it/s]


finding best polygons...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 709/709 [00:07<00:00, 98.42it/s]


creating labeled image...
processed patch #5 out of 6 patches
segmenting image tiles...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:07<00:00,  1.19it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:06<00:00,  1.23it/s]


creating masks using SAM...


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 1978/1978 [02:06<00:00, 15.58it/s]


finding overlapping polygons...


1707it [00:06, 259.20it/s]


finding best polygons...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 509/509 [00:12<00:00, 41.57it/s]


creating labeled image...
processed patch #6 out of 6 patches


4946it [00:02, 2167.19it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 336/336 [00:00<00:00, 610.84it/s]


In [13]:
# plot results
image = np.array(load_img(fname))
fig, ax = plt.subplots(figsize=(15,10))
plt.xticks([])
plt.yticks([])
seg.plot_image_w_colorful_grains(image, all_grains, ax, cmap='Paired')
plt.axis('equal')
plt.xlim([0, np.shape(image)[1]])
plt.ylim([np.shape(image)[0], 0]);

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 4567/4567 [00:14<00:00, 322.67it/s]


In [11]:
# this is a faster way of deleting false positives (because it avoids highlighting and deleting the 'bad' grains)
grain_inds = []
cid1 = fig.canvas.mpl_connect('button_press_event', lambda event: seg.onclick2(event, all_grains, grain_inds, ax=ax, select_only=True))

In [12]:
# delete polygons from 'all_grains'
grain_inds = np.unique(grain_inds)
grain_inds = sorted(grain_inds, reverse=True)
for ind in tqdm(grain_inds):
    all_grains.remove(all_grains[ind])

100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:01<00:00, 13.45it/s]


After plotting the results, you will want to use the functions for deleting, merging, and adding grains (see above), before saving the results (same workflow as for a small image).

### Finetuning the base model

In [None]:
# patchify images and masks
input_dir = "./Masks_and_images/" # the input directory should contain files with 'image' and 'mask' in their filenames
patch_dir = "./New_project/" # a directory called "Patches" will be created here
image_dir, mask_dir = seg.patchify_training_data(input_dir, patch_dir)

In [None]:
# create training, validation, and test datasets
train_dataset, val_dataset, test_dataset = seg.create_train_val_test_data(image_dir, mask_dir, augmentation=True)

In [None]:
# load base model weights and train the model with the new data
model = seg.create_and_train_model(train_dataset, val_dataset, test_dataset, model_file='seg_model.keras', epochs=100)

In [None]:
# save finetuned model as new model (this then can be loaded using "model = load_model("new_model.keras", custom_objects={'weighted_crossentropy': seg.weighted_crossentropy})"
model.save('new_model.keras')