## Import packages

In [28]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
from skimage import measure
from skimage.measure import regionprops, regionprops_table
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.preprocessing.image import load_img
from importlib import reload
import segmenteverygrain as seg
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from tqdm import trange
%matplotlib qt

## Load models

In [29]:
model = seg.Unet()
model.compile(optimizer=Adam(), loss=seg.weighted_crossentropy, metrics=["accuracy"])
model.load_weights('./checkpoints/seg_model');

# the SAM model checkpoints can be downloaded from: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
sam = sam_model_registry["default"](checkpoint="/Users/zoltan/Dropbox/Segmentation/sam_vit_h_4b8939.pth")

## Run segmentation

Grains are supposed to be well defined in the image; e.g., if a grain consists of only a few pixels, it is unlikely to be detected.

The segmentation can take a few minutes even for medium-sized images, so do not start with large images (downsample them if necessary). Images with ~2000 pixels along their largest dimension are a good start.

In [30]:
reload(seg)
# fname = '../images/bucegi_conglomerate_1_image.png'
# fname = '../images/A003_20201229_103823_image.png'
# fname = '../images/IMG_5208_image.png'
fname = '/Users/zoltan/Downloads/Pebbles_on_beach_at_Broulee_-NSW_-Australia-2Jan2009.jpg'
# fname = '/Users/zoltan/Downloads/vecteezy_stone-pebbles-on-river-bed_3366528.jpg'

big_im = np.array(load_img(fname))
big_im_pred = seg.predict_big_image(big_im, model, I=256)
# decreasing the 'dbs_max_dist' parameter results in more SAM prompts (and longer processing times):
labels, grains, coords = seg.label_grains(big_im, big_im_pred, dbs_max_dist=10.0)
all_grains, labels, mask_all, grain_data, fig, ax = seg.sam_segmentation(sam, big_im, big_im_pred, coords, labels, min_area=50.0)

  0%|                                                                                                     | 0/7 [00:00<?, ?it/s]2023-06-06 14:06:05.814693: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  6.80it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  8.75it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████| 950/950 [01:03<00:00, 14.91it/s]
950it [01:04, 14.78it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████| 231/231 [00:08<00:00, 28.46it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████| 247/247 [00:00<00:00, 595.13it/s]
100%|██████████████████████████████████████████



Use this figure to check the distribution of SAM prompts (= black dots):

In [27]:
plt.figure()
plt.imshow(big_im_pred)
plt.scatter(coords[:,0], coords[:,1], c='k');

## Delete or merge grains in segmentation result
* click on the grain that you want to remove and press the 'x' key
* click on two grains that you want to merge and press the 'm' key (they have to be the last two grains you clicked on)

In [547]:
grain_inds = []
cid1 = fig.canvas.mpl_connect('button_press_event', 
                              lambda event: seg.onclick2(event, all_grains, grain_inds, ax=ax))
cid2 = fig.canvas.mpl_connect('key_press_event', 
                              lambda event: seg.onpress2(event, all_grains, grain_inds, fig=fig, ax=ax))

Run this cell if you do not want to delete / merge existing grains anymore; it is a good idea to do this before moving on to the next step.

In [548]:
fig.canvas.mpl_disconnect(cid1)
fig.canvas.mpl_disconnect(cid2)

Use this function to update the 'all_grains' list after deleting and merging grains:

In [549]:
all_grains, labels, mask_all, fig, ax = seg.get_grains_from_patches(ax, big_im)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 706/706 [00:02<00:00, 285.11it/s]


Plot the updated set of grains:

In [558]:
fig, ax = plt.subplots(figsize=(15,10))
ax.imshow(big_im)
plt.xticks([])
plt.yticks([])
seg.plot_image_w_colorful_grains(big_im, all_grains, ax, cmap='Paired')
seg.plot_grain_axes_and_centroids(all_grains, labels, ax, linewidth=1, markersize=10)
plt.xlim([0, np.shape(big_im)[1]])
plt.ylim([np.shape(big_im)[0], 0]);

2023-06-04 17:47:51.949 python[76879:46896633] +[CATransaction synchronize] called within transaction
2023-06-04 17:48:42.897 python[76879:46896633] +[CATransaction synchronize] called within transaction
2023-06-04 17:48:46.649 python[76879:46896633] +[CATransaction synchronize] called within transaction
2023-06-04 17:48:47.761 python[76879:46896633] +[CATransaction synchronize] called within transaction
2023-06-04 17:48:48.031 python[76879:46896633] +[CATransaction synchronize] called within transaction


## Add new grains using the Segment Anything Model

* click on unsegmented grain that you want to add
* press the 'x' key if you want to delete the last grain you added
* press the 'm' key if you want to merge the last two grains that you added
* right click outside the grain (but inside the most recent mask) if you want to restrict the grain to a smaller mask - this adds a background prompt

In [483]:
predictor = SamPredictor(sam)
predictor.set_image(big_im) # this can take a while
coords = []
cid3 = fig.canvas.mpl_connect('button_press_event', lambda event: seg.onclick(event, ax, coords, big_im, predictor))
cid4 = fig.canvas.mpl_connect('key_press_event', lambda event: seg.onpress(event, ax, fig))

In [484]:
fig.canvas.mpl_disconnect(cid3)
fig.canvas.mpl_disconnect(cid4)

After you are done with the deletion / addition of grain masks, run this cell to generate an updated set of grains:

In [485]:
all_grains, labels, mask_all, fig, ax = seg.get_grains_from_patches(ax, big_im)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 443/443 [00:03<00:00, 121.37it/s]


## Get grain size distribution

Run this cell and then click (left mouse button) on one end of the scale bar in the image and click (right mouse button) on the other end of the scale bar:

In [24]:
cid5 = fig.canvas.mpl_connect('button_press_event', lambda event: seg.click_for_scale(event, ax))

number of pixels: 492.06


Use the length of the scale bar in pixels (it should be printed above) to get the scale of the image (in units / pixel):

In [27]:
n_of_units = 10 # centimeters in the case of 'IMG_5208_image.png'
units_per_pixel = n_of_units/492.06 # length of scale bar in pixels

In [28]:
props = regionprops_table(labels.astype('int'), intensity_image = big_im, properties =\
        ('label', 'area', 'centroid', 'major_axis_length', 'minor_axis_length', 
         'orientation', 'perimeter', 'max_intensity', 'mean_intensity', 'min_intensity'))
grain_data = pd.DataFrame(props)
grain_data['major_axis_length'] = grain_data['major_axis_length'].values*units_per_pixel
grain_data['minor_axis_length'] = grain_data['minor_axis_length'].values*units_per_pixel
grain_data['perimeter'] = grain_data['perimeter'].values*units_per_pixel
grain_data['area'] = grain_data['area'].values*units_per_pixel**2

In [29]:
grain_data.head()

Unnamed: 0,label,area,centroid-0,centroid-1,major_axis_length,minor_axis_length,orientation,perimeter,max_intensity-0,max_intensity-1,max_intensity-2,mean_intensity-0,mean_intensity-1,mean_intensity-2,min_intensity-0,min_intensity-1,min_intensity-2
0,1,0.131751,289.238245,4.905956,0.706249,0.253123,-0.010334,1.58313,255.0,241.0,215.0,132.141066,111.128527,91.896552,0.0,0.0,0.0
1,2,0.280849,1066.869118,12.223529,0.652708,0.553314,-0.401725,1.976483,255.0,252.0,215.0,150.177941,121.995588,100.126471,0.0,0.0,0.0
2,3,0.109861,4.240602,614.067669,0.629969,0.233494,-1.569466,1.45422,223.0,206.0,187.0,86.680451,72.661654,60.428571,0.0,0.0,0.0
3,4,0.277958,35.549777,73.271917,0.808445,0.456453,0.058363,2.136176,255.0,255.0,255.0,112.745914,101.066865,94.392273,0.0,0.0,0.0
4,5,0.107796,1016.425287,952.931034,0.519751,0.275739,-0.039216,1.308474,245.0,241.0,233.0,128.256705,110.823755,94.172414,0.0,0.0,0.0


In [30]:
plt.figure()
plt.hist(grain_data['major_axis_length'], 25)
plt.xlabel('major axis length (cm)')
plt.ylabel('count');

## Save mask and grain labels to PNG files

In [486]:
dirname = '/Users/zoltan/Dropbox/Segmentation/images/'
# write grayscale mask to PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_mask.png', mask_all)
# Define a colormap using matplotlib
num_classes = len(all_grains)
cmap = plt.get_cmap('viridis', num_classes)
# Map each class label to a unique color using the colormap
vis_mask = cmap(labels.astype(np.uint16))[:,:,:3] * 255
vis_mask = vis_mask.astype(np.uint8)
# Save the mask as a PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_labels.png', vis_mask)
# Save the image as a PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_image.png', cv2.cvtColor(big_im, cv2.COLOR_BGR2RGB))

True