In [15]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
from skimage import measure
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import load_img
from importlib import reload
import segmenteverygrain as seg
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from tqdm import trange
%matplotlib qt

In [16]:
model = seg.Unet()
model.compile(optimizer=Adam(), loss=seg.weighted_crossentropy, metrics=["accuracy"])
model.load_weights('./checkpoints/seg_model')

# the SAM model checkpoints can be downloaded from: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
sam = sam_model_registry["default"](checkpoint="/Users/zoltan/Dropbox/Segmentation/sam_vit_h_4b8939.pth")

In [None]:
# fname = '../images/bucegi_conglomerate_1_image.png'
fname = '../images/A003_20201229_103823_image.png'

big_im = np.array(load_img(fname))
big_im_pred = seg.predict_big_image(big_im, model, I=256)
small_grain_threshold = 25 # objects that have an area smaller than this (in pixels) will be deleted
selem_size = 6 # increasing this parameter reduces the tendency to split grains
dilate_size = 3
splitting = True # you can turn grain splitting on or off. Usually it is better to set this to 'True'
labels, grain_data = seg.label_and_dilate_grains(big_im, big_im_pred, 
                                                small_grain_threshold, dilate_size, selem_size, splitting)
all_grains, labels, mask_all, fig, ax = seg.sam_segmentation(sam, big_im, big_im_pred, grain_data)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  5.91it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  6.29it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 2193.57it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 121/121 [00:00<00:00, 210.

In [29]:
# delete or merge grains that were segmented automatically
# click on grain that you want to remove and press the 'x' key
# click on two grains that you want to merge and press the 'm' key
grain_inds = []
cid1 = fig.canvas.mpl_connect('button_press_event', lambda event: seg.onclick2(event, all_grains, grain_inds, ax))
cid2 = fig.canvas.mpl_connect('key_press_event', lambda event: seg.onpress2(event, all_grains, grain_inds, fig, ax))

In [28]:
# run this cell if you do not want to delete / merge existing grains anymore
fig.canvas.mpl_disconnect(cid1)
fig.canvas.mpl_disconnect(cid2)

In [26]:
# add new grains using SAM
# click on grain that you want to add
# press the 'x' key if you want to delete it
# press the 'm' key if you want to merge the last two grains that you clicked
# right click outside the grain if you want to (try to) restrict the garin to a smaller mask
predictor = SamPredictor(sam)
predictor.set_image(big_im) # this can take a while
coords = []
cid3 = fig.canvas.mpl_connect('button_press_event', lambda event: seg.onclick(event, ax, coords, big_im, predictor))
cid4 = fig.canvas.mpl_connect('key_press_event', lambda event: seg.onpress(event, ax, fig))

In [30]:
# use this function to generate an updated set of grains
all_grains, labels, mask_all = seg.get_grains_from_patches(ax, big_im)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1782/1782 [00:42<00:00, 41.86it/s]


In [34]:
dirname = '/Users/zoltan/Dropbox/Segmentation/images/'
# write grayscale mask to PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_mask.png', mask_all)
# Define a colormap using matplotlib
num_classes = len(all_grains)
cmap = plt.get_cmap('viridis', num_classes)
# Map each class label to a unique color using the colormap
vis_mask = cmap(labels.astype(np.uint16))[:,:,:3] * 255
vis_mask = vis_mask.astype(np.uint8)
# Save the mask as a PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_labels.png', vis_mask)
# Save the image as a PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_image.png', cv2.cvtColor(big_im, cv2.COLOR_BGR2RGB))

True