## Import packages

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
from skimage import measure
from skimage.measure import regionprops, regionprops_table
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.preprocessing.image import load_img
from importlib import reload
import segmenteverygrain as seg
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from tqdm import trange
%matplotlib qt

## Load models

In [3]:
model = seg.Unet()
model.compile(optimizer=Adam(), loss=seg.weighted_crossentropy, metrics=["accuracy"])
model.load_weights('./checkpoints/seg_model') # replace this if you have a finetuned Unet model and want to use it

# the SAM model checkpoints can be downloaded from: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
sam = sam_model_registry["default"](checkpoint=r"C:\Users\juan_\Python\JSG_2024_Hackathon\segmenteverygrain\SAM\sam_vit_h_4b8939.pth")

## Run segmentation

Grains are supposed to be well defined in the image; e.g., if a grain consists of only a few pixels, it is unlikely to be detected.

The segmentation can take a few minutes even for medium-sized images. Images with ~2000 pixels along their largest dimension are a good start and allow the user to get an idea about how well the segmentation works.

If you have a much larger image, see the section **"Run segmentation on large image"** at the end of the notebook. Running the `predict_large_image` function takes a lot longer (e.g., several hours), but it is possible to analyze very large images with tens of thousands of grains.

Image used below is available from [here](https://github.com/zsylvester/segmenteverygrain/blob/main/torrey_pines_beach.jpeg).

In [4]:
# replace this with the path to your image:
fname = r"C:\Users\juan_\Python\JSG_2024_Hackathon\1973\wellsorted_lowerfinegrainsize.PNG"

image = np.array(load_img(fname))
image_pred = seg.predict_image(image, model, I=256)

# decreasing the 'dbs_max_dist' parameter results in more SAM prompts (and longer processing times):
labels, coords = seg.label_grains(image, image_pred, dbs_max_dist=20.0) # Unet prediction

segmenting image tiles...


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.21it/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.08it/s]


Use the figure created in the next cell to check the quality of the Unet labeling (sometimes it doesn't work at all) and the distribution of SAM prompts (= black dots). If the Unet prediction is of poor quality, it is a good idea to create some training data and fine tune the base model so that it works better with the images of interest.

In [5]:
plt.figure(figsize=(15,10))
plt.imshow(image_pred)
plt.scatter(np.array(coords)[:,0], np.array(coords)[:,1], c='k')
plt.xticks([])
plt.yticks([]);

In [6]:
# SAM segmentation, using the point prompts from the Unet:
all_grains, labels, mask_all, grain_data, fig, ax = seg.sam_segmentation(sam, image, image_pred, 
                                                coords, labels, min_area=400.0, plot_image=True, remove_edge_grains=False, remove_large_objects=False)

creating masks using SAM...


100%|████████████████████████████████████████████████████████████████████████████████| 350/350 [00:27<00:00, 12.60it/s]


finding overlapping polygons...


336it [00:00, 357.24it/s]


finding best polygons...


100%|███████████████████████████████████████████████████████████████████████████████| 121/121 [00:00<00:00, 127.70it/s]


creating labeled image...


100%|███████████████████████████████████████████████████████████████████████████████| 130/130 [00:00<00:00, 159.73it/s]


## Delete or merge grains in segmentation result
* click on the grain that you want to remove and press the 'x' key
* click on two grains that you want to merge and press the 'm' key (they have to be the last two grains you clicked on)
* press the 'g' key to hide the grain masks (so that you can see the original image better); press the 'g' key again to show the grain masks

In [7]:
grain_inds = []
cid1 = fig.canvas.mpl_connect('button_press_event', 
                              lambda event: seg.onclick2(event, all_grains, grain_inds, ax=ax))
cid2 = fig.canvas.mpl_connect('key_press_event', 
                              lambda event: seg.onpress2(event, all_grains, grain_inds, fig=fig, ax=ax))

Run this cell if you do not want to delete / merge existing grains anymore; it is a good idea to do this before moving on to the next step.

In [8]:
fig.canvas.mpl_disconnect(cid1)
fig.canvas.mpl_disconnect(cid2)

Use this function to update the 'labels' array after deleting and merging grains (the 'all_grains' list is updated when doing the deletion and merging):

In [9]:
all_grains, labels, mask_all, fig, ax = seg.get_grains_from_patches(ax, image)

100%|██████████████████████████████████████████████████████████████████████████████| 130/130 [00:00<00:00, 3204.36it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 130/130 [00:00<00:00, 203.96it/s]


Plot the updated set of grains:

In [10]:
fig, ax = plt.subplots(figsize=(15,10))
plt.xticks([])
plt.yticks([])
seg.plot_image_w_colorful_grains(image, all_grains, ax, cmap='Paired', plot_image=False)
seg.plot_grain_axes_and_centroids(all_grains, labels, ax, linewidth=1, markersize=10)
plt.xlim([0, np.shape(image)[1]])
plt.ylim([np.shape(image)[0], 0]);

100%|███████████████████████████████████████████████████████████████████████████████| 130/130 [00:00<00:00, 165.55it/s]


## Add new grains using the Segment Anything Model

* click on unsegmented grain that you want to add
* press the 'x' key if you want to delete the last grain you added
* press the 'm' key if you want to merge the last two grains that you added
* right click outside the grain (but inside the most recent mask) if you want to restrict the grain to a smaller mask - this adds a background prompt

In [11]:
predictor = SamPredictor(sam)
predictor.set_image(image) # this can take a while
coords = []
cid3 = fig.canvas.mpl_connect('button_press_event', lambda event: seg.onclick(event, ax, coords, image, predictor))
cid4 = fig.canvas.mpl_connect('key_press_event', lambda event: seg.onpress(event, ax, fig))

In [12]:
fig.canvas.mpl_disconnect(cid3)
fig.canvas.mpl_disconnect(cid4)

In [13]:
all_grains, labels, mask_all, fig, ax = seg.get_grains_from_patches(ax, image)

100%|██████████████████████████████████████████████████████████████████████████████| 130/130 [00:00<00:00, 2670.72it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 130/130 [00:00<00:00, 183.92it/s]


After you are done with the deletion / addition of grain masks, run this cell to generate an updated set of grains:

In [14]:
all_grains, labels, mask_all, fig, ax = seg.get_grains_from_patches(ax, image)

100%|██████████████████████████████████████████████████████████████████████████████| 130/130 [00:00<00:00, 8017.34it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 130/130 [00:00<00:00, 163.07it/s]


## Get grain size distribution

Run this cell and then click (left mouse button) on one end of the scale bar in the image and click (right mouse button) on the other end of the scale bar:

In [15]:
cid5 = fig.canvas.mpl_connect('button_press_event', lambda event: seg.click_for_scale(event, ax))

Use the length of the scale bar in pixels (it should be printed above) to get the scale of the image (in units / pixel):

In [16]:
n_of_units = 1000
units_per_pixel = n_of_units/1552.77 # length of scale bar in pixels

In [17]:
props = regionprops_table(labels.astype('int'), intensity_image = image, properties =\
        ('label', 'area', 'centroid', 'major_axis_length', 'minor_axis_length', 
         'orientation', 'perimeter', 'max_intensity', 'mean_intensity', 'min_intensity'))
grain_data = pd.DataFrame(props)
grain_data['major_axis_length'] = grain_data['major_axis_length'].values*units_per_pixel
grain_data['minor_axis_length'] = grain_data['minor_axis_length'].values*units_per_pixel
grain_data['perimeter'] = grain_data['perimeter'].values*units_per_pixel
grain_data['area'] = grain_data['area'].values*units_per_pixel**2

In [18]:
grain_data.head()

Unnamed: 0,label,area,centroid-0,centroid-1,major_axis_length,minor_axis_length,orientation,perimeter,max_intensity-0,max_intensity-1,max_intensity-2,mean_intensity-0,mean_intensity-1,mean_intensity-2,min_intensity-0,min_intensity-1,min_intensity-2
0,1,182.904473,18.868481,141.802721,18.116601,12.93946,1.493505,50.194894,255.0,255.0,255.0,213.113379,213.113379,213.113379,15.0,15.0,15.0
1,2,309.403031,124.817694,741.646113,24.76144,16.069528,0.115414,68.097732,255.0,255.0,255.0,227.486595,227.486595,227.486595,63.0,63.0,63.0
2,3,243.87263,249.421769,394.579932,22.468607,14.242538,-0.239939,60.655323,255.0,255.0,255.0,213.530612,213.530612,213.530612,34.0,34.0,34.0
3,4,329.311001,52.093199,58.955919,28.881017,15.033689,0.26298,74.915088,255.0,255.0,255.0,219.547859,219.547859,219.547859,65.0,65.0,65.0
4,5,184.97822,74.475336,99.488789,17.555486,13.524346,0.45682,50.572147,251.0,251.0,251.0,123.484305,123.484305,123.484305,15.0,15.0,15.0


In [19]:
grain_data.to_csv(fname[:-4]+'.csv') # save grain data to CSV file

In [20]:
# plot histogram of grain axis lengths
plt.figure()
plt.hist(grain_data['major_axis_length'], np.arange(0, 100, 1), alpha=0.5)
plt.hist(grain_data['minor_axis_length'], np.arange(0, 100, 1), alpha=0.5)
plt.xlim(0,100)
plt.xlabel('axis length (microns)')
plt.ylabel('count');

## Normalize area for Montecarlo simulation

In [31]:
def select_grains_weighted(data, n):
    """Selects n grains from a grain size .csv based on their area, weighted by probability.

    Args:
        data: Pandas DataFrame containing grain data with columns 'area'.
        n: Number of grains to select.

    Returns:
        Pandas DataFrame containing the selected grains.
    """

    # Calculate probabilities based on area
    data['probability'] = data['area'] / data['area'].sum()

    # Sample n grains based on probabilities
    selected_grains = data.sample(n, replace=False, weights='probability')

    return selected_grains

# Example usage
# Replace 'grain_size.csv' with your actual file path
data = pd.read_csv(r"C:\Users\juan_\Python\JSG_2024_Hackathon\1973\wellsorted_lowerfinegrainsize.csv")

# Select 100 grains based on area-weighted probability
selected_grains = select_grains_weighted(data, 30)

print(selected_grains)

     Unnamed: 0  label         area  centroid-0  centroid-1  \
128         128    129   237.236640  340.202797  437.895105   
32           32     33  1800.427021    3.791753  449.822391   
39           39     40   288.250813  128.650360  662.282014   
14           14     15   283.273821   84.759883   47.339678   
5             5      6   308.573532  157.071237   18.174731   
115         115    116   192.858458  443.791398  707.516129   
98           98     99   313.135775  324.728477  564.162914   
68           68     69   235.577643   39.720070  431.616197   
54           54     55   179.586478  422.006928  709.926097   
20           20     21   252.582367  307.916256  602.715928   
2             2      3   243.872630  249.421769  394.579932   
63           63     64   405.210136  554.254862  334.859775   
109         109    110   233.918645  393.154255  331.092199   
6             6      7   294.472054  199.301408   56.361972   
38           38     39   197.420701  118.037815  273.32

In [32]:
# plot histogram of grain axis lengths
plt.figure()
plt.hist(selected_grains['major_axis_length'], np.arange(0, 100, 1), alpha=0.5)
plt.hist(selected_grains['minor_axis_length'], np.arange(0, 100, 1), alpha=0.5)
plt.xlim(0,100)
plt.xlabel('axis length (microns)')
plt.ylabel('count');

## Run and plot Montecarlo Simulation

In [35]:
# Attempt 1

from scipy.stats import gaussian_kde

# Repeat the selection process 1000 times
num_iterations = 1000
major_axis_lengths = []
minor_axis_lengths = []

for _ in range(num_iterations):
    selected_grains = select_grains_weighted(data, 30)
    major_axis_lengths.extend(selected_grains['major_axis_length'])
    minor_axis_lengths.extend(selected_grains['minor_axis_length'])

# Calculate KDEs
kde_major = gaussian_kde(major_axis_lengths)
kde_minor = gaussian_kde(minor_axis_lengths)

# Plot KDEs
x_vals = np.linspace(0, 100, 1000)
plt.figure()
plt.plot(x_vals, kde_major(x_vals), label='Major Axis Length')
plt.plot(x_vals, kde_minor(x_vals), label='Minor Axis Length')
plt.xlim(0, 100)
plt.xlabel('Axis Length (microns)')
plt.ylabel('Density')
plt.legend()
plt.title('KDE of Grain Axis Lengths')
plt.show()

In [39]:
# Succesful attempt

from scipy.stats import gaussian_kde
from matplotlib.cm import get_cmap

# Repeat the selection process 100 times
num_iterations = 100
major_axis_lengths = []
minor_axis_lengths = []

for _ in range(num_iterations):
    selected_grains = select_grains_weighted(data, 30)
    major_axis_lengths.append(selected_grains['major_axis_length'])
    minor_axis_lengths.append(selected_grains['minor_axis_length'])

# Calculate KDEs
kde_major = [gaussian_kde(data) for data in major_axis_lengths]
kde_minor = [gaussian_kde(data) for data in minor_axis_lengths]

# Plot KDEs
x_vals = np.linspace(0, 100, 1000)
plt.figure()
cmap = get_cmap('tab20')  # Choose a colormap

for i in range(num_iterations):
    plt.plot(x_vals, kde_major[i](x_vals), label=f'Major Axis Length {i+1}', color=cmap(i/num_iterations))
    plt.plot(x_vals, kde_minor[i](x_vals), label=f'Minor Axis Length {i+1}', color=cmap(i/num_iterations))

plt.xlim(0, 100)
plt.xlabel('Axis Length (microns)')
plt.ylabel('Density')
#plt.legend()
plt.title('KDE of Grain Axis Lengths (Monte Carlo Simulations)')
plt.show()

  cmap = get_cmap('tab20')  # Choose a colormap


In [None]:
## Better figure

In [40]:
# Find the median KDE
median_major_kde = gaussian_kde(np.concatenate(major_axis_lengths))
median_minor_kde = gaussian_kde(np.concatenate(minor_axis_lengths))

# Find the minimum and maximum KDE values at each x-value
x_vals = np.linspace(0, 100, 1000)
major_kde_min = np.min([kde(x_vals) for kde in kde_major], axis=0)
major_kde_max = np.max([kde(x_vals) for kde in kde_major], axis=0)
minor_kde_min = np.min([kde(x_vals) for kde in kde_minor], axis=0)
minor_kde_max = np.max([kde(x_vals) for kde in kde_minor], axis=0)

# Plot KDEs with shaded area and median line
plt.figure()
plt.fill_between(x_vals, major_kde_min, major_kde_max, alpha=0.2, label='Major Axis Length')
plt.plot(x_vals, median_major_kde(x_vals), label='Major Axis Length Median')
plt.fill_between(x_vals, minor_kde_min, minor_kde_max, alpha=0.2, label='Minor Axis Length')
plt.plot(x_vals, median_minor_kde(x_vals), label='Minor Axis Length Median')

plt.xlim(0, 100)
plt.xlabel('Axis Length (microns)')
plt.ylabel('Density')
plt.legend()
plt.title('KDE of Grain Axis Lengths (Monte Carlo Simulations)')
plt.show()

## Save mask and grain labels to PNG files

In [23]:
dirname = r'C:\Users\juan_\Python\JSG_2024_Hackathon\segmenteverygrain\segmenteverygrain\results'
# write grayscale mask to PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_mask.png', mask_all)
# Save the image as a PNG file
cv2.imwrite(dirname + fname.split('/')[-1][:-4] + '_image.png', cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

True

## Run segmentation on large image (new!)
In this case 'fname' points to an image that is larger than a few megapixels and has thousands of grains.
The 'predict_large_image' function breaks the input image into smaller patches and it runs the segmentation process on each patch.

In [24]:
from PIL import Image
Image.MAX_IMAGE_PIXELS = None # needed if working with very large images
fname = r"C:\Users\juan_\Python\JSG_2024_Hackathon\1973\wellsorted_lowerfinegrainsize.PNG"
all_grains = seg.predict_large_image(fname, model, sam, min_area=400.0, patch_size=2000, overlap=200)

segmenting image tiles...


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.69it/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.78it/s]


creating masks using SAM...


100%|████████████████████████████████████████████████████████████████████████████████| 350/350 [00:27<00:00, 12.60it/s]


finding overlapping polygons...


320it [00:00, 596.72it/s]


finding best polygons...


100%|███████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 246.86it/s]


creating labeled image...
processed patch #1 out of 1 patches


128it [00:00, 2573.88it/s]
0it [00:00, ?it/s]


In [25]:
# plot results
image = np.array(load_img(fname))
fig, ax = plt.subplots(figsize=(15,10))
plt.xticks([])
plt.yticks([])
seg.plot_image_w_colorful_grains(image, all_grains, ax, cmap='Paired')
plt.axis('equal')
plt.xlim([0, np.shape(image)[1]])
plt.ylim([np.shape(image)[0], 0]);

100%|███████████████████████████████████████████████████████████████████████████████| 128/128 [00:00<00:00, 180.55it/s]


In [26]:
# this is a faster way of deleting false positives (because it avoids highlighting and deleting the 'bad' grains)
grain_inds = []
cid1 = fig.canvas.mpl_connect('button_press_event', lambda event: seg.onclick2(event, all_grains, grain_inds, ax=ax, select_only=True))

In [27]:
from tqdm import tqdm
# delete polygons from 'all_grains'
grain_inds = np.unique(grain_inds)
grain_inds = sorted(grain_inds, reverse=True)
for ind in tqdm(grain_inds):
    all_grains.remove(all_grains[ind])

0it [00:00, ?it/s]


After plotting the results, you will want to use the functions for deleting, merging, and adding grains (see above), before saving the results (same workflow as for a small image).