## Import packages

In [231]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
from skimage import measure
from skimage.measure import regionprops_table
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import load_img
from importlib import reload
import segmenteverygrain as seg
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from tqdm import trange
%matplotlib qt

## Load models

In [334]:
model = seg.Unet()
model.compile(optimizer=Adam(), loss=seg.weighted_crossentropy, metrics=["accuracy"])
model.load_weights('./checkpoints/seg_model');

# the SAM model checkpoints can be downloaded from: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
sam = sam_model_registry["default"](checkpoint="/Users/zoltan/Dropbox/Segmentation/sam_vit_h_4b8939.pth")

## Run segmentation

Grains are supposed to be well defined in the image; e.g., if a grain consists of only a few pxiels, it is unlikely to be detected.

The segmentation can take a few minutes even for medium-sized images, so do not start with large images (downsample them if necessary).

In [916]:
reload(seg)
# fname = '../images/bucegi_conglomerate_1_image.png'
# fname = '../images/A003_20201229_103823_image.png'
fname = '../images/IMG_5208_image.png'

big_im = np.array(load_img(fname))
big_im_pred = seg.predict_big_image(big_im, model, I=256)
# decreasing the 'dbs_max_dist' parameter results in more SAM prompts:
labels, grains, coords = seg.label_grains(big_im, big_im_pred, dbs_max_dist=10.0)
all_grains, labels, mask_all, fig, ax = seg.sam_segmentation(sam, big_im, big_im_pred, coords, labels)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  3.72it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.66it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1502/1502 [01:23<00:00, 17.92it/s]
1476it [01:50, 13.37it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████|

Use this figure to check the distribution of SAM prompts (= black dots):

In [920]:
plt.figure()
plt.imshow(big_im_pred)
plt.scatter(coords[:,0], coords[:,1], c='k');

Plot results as a simple image with grains colored randomly:

In [918]:
fig, ax = seg.plot_image_w_colorful_grains(big_im, all_grains)

## Delete or merge grains in segmentation result
* click on the grain that you want to remove and press the 'x' key
* click on two grains that you want to merge and press the 'm' key (they have to be the last two grains you clicked on)

In [902]:
grain_inds = []
cid1 = fig.canvas.mpl_connect('button_press_event', 
                              lambda event: seg.onclick2(event, all_grains, grain_inds, ax=ax))
cid2 = fig.canvas.mpl_connect('key_press_event', 
                              lambda event: seg.onpress2(event, all_grains, grain_inds, fig=fig, ax=ax))

Run this cell if you do not want to delete / merge existing grains anymore; it is a good idea to do this before moving on to the next step.

In [258]:
fig.canvas.mpl_disconnect(cid1)
fig.canvas.mpl_disconnect(cid2)

In [204]:
reload(seg)
all_grains, labels, mask_all, fig, ax = seg.get_grains_from_patches(ax, big_im)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 715/715 [00:02<00:00, 269.26it/s]


## Add new grains using the Segment Anything Model

* click on unsegmented grain that you want to add
* press the 'x' key if you want to delete the last grain you added
* press the 'm' key if you want to merge the last two grains that you added
* right click outside the grain (but inside the most recent mask) if you want to restrict the grain to a smaller mask - this adds a background prompt

In [270]:
predictor = SamPredictor(sam)
predictor.set_image(big_im) # this can take a while
coords = []
cid3 = fig.canvas.mpl_connect('button_press_event', lambda event: seg.onclick(event, ax, coords, big_im, predictor))
cid4 = fig.canvas.mpl_connect('key_press_event', lambda event: seg.onpress(event, ax, fig))

After you are done with the deletion / addition of grain masks, run this cell to generate an updated set of grains:

In [275]:
all_grains, labels, mask_all, fig, ax = seg.get_grains_from_patches(ax, big_im)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 853/853 [00:19<00:00, 44.40it/s]


## Get grain size distribution

In [226]:
cid5 = fig.canvas.mpl_connect('button_press_event', lambda event: seg.click_for_scale(event, ax))

number of pixels: 491.92


In [228]:
n_of_units = 10 # centimeters in this case
units_per_pixel = n_of_units/dist

In [234]:
props = regionprops_table(labels.astype('int'), intensity_image = big_im, properties =\
        ('label', 'area', 'centroid', 'major_axis_length', 'minor_axis_length', 
         'orientation', 'perimeter', 'max_intensity', 'mean_intensity', 'min_intensity'))
grain_data = pd.DataFrame(props)
grain_data['major_axis_length'] = grain_data['major_axis_length'].values*units_per_pixel
grain_data['minor_axis_length'] = grain_data['minor_axis_length'].values*units_per_pixel
grain_data['perimeter'] = grain_data['perimeter'].values*units_per_pixel
grain_data['area'] = grain_data['area'].values*units_per_pixel**2

In [235]:
grain_data.head()

Unnamed: 0,label,area,centroid-0,centroid-1,major_axis_length,minor_axis_length,orientation,perimeter,max_intensity-0,max_intensity-1,max_intensity-2,mean_intensity-0,mean_intensity-1,mean_intensity-2,min_intensity-0,min_intensity-1,min_intensity-2
0,1,0.657837,15.05402,38.740578,1.166285,0.739201,1.534311,3.235238,255.0,255.0,255.0,181.621231,180.74309,182.645101,15.0,7.0,8.0
1,2,0.259085,11.588517,92.639553,0.617154,0.570253,-0.298606,2.038546,247.0,241.0,237.0,108.567783,93.953748,85.216906,0.0,0.0,0.0
2,3,0.544203,18.751708,121.830676,0.929578,0.762772,-0.41212,2.799937,255.0,255.0,255.0,156.492027,137.490509,124.29309,0.0,0.0,0.0
3,4,0.175616,18.4,150.910588,0.481739,0.478469,0.692243,1.574497,206.0,184.0,177.0,85.891765,66.447059,63.76,0.0,0.0,0.0
4,5,0.380984,18.08026,189.661605,0.93567,0.528288,-0.511963,2.511139,255.0,251.0,241.0,175.050976,150.254881,141.08026,0.0,0.0,0.0


In [240]:
plt.figure()
plt.hist(grain_data['major_axis_length'], 25)
plt.xlabel('major axis length (cm)')
plt.ylabel('count');

Text(0, 0.5, 'count')

## Save mask and grain labels to PNG files

In [276]:
dirname = '/Users/zoltan/Dropbox/Segmentation/images/'
# write grayscale mask to PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_mask.png', mask_all)
# Define a colormap using matplotlib
num_classes = len(all_grains)
cmap = plt.get_cmap('viridis', num_classes)
# Map each class label to a unique color using the colormap
vis_mask = cmap(labels.astype(np.uint16))[:,:,:3] * 255
vis_mask = vis_mask.astype(np.uint8)
# Save the mask as a PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_labels.png', vis_mask)
# Save the image as a PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_image.png', cv2.cvtColor(big_im, cv2.COLOR_BGR2RGB))

True