## Import packages

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
from skimage.measure import regionprops, regionprops_table
from keras.utils import load_img
from keras.saving import load_model
from importlib import reload
import segmenteverygrain as seg
import interactions as segi
from segment_anything import sam_model_registry, SamPredictor
from tqdm import trange, tqdm
%matplotlib qt

## Load models

In [None]:
model = load_model("seg_model.keras", custom_objects={'weighted_crossentropy': seg.weighted_crossentropy})

# the SAM model checkpoints can be downloaded from: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
sam = sam_model_registry["default"](checkpoint="../sam_vit_h_4b8939.pth")

## Run segmentation

Grains are supposed to be well defined in the image; e.g., if a grain consists of only a few pixels, it is unlikely to be detected.

The segmentation can take a few minutes even for medium-sized images. Images with ~2000 pixels along their largest dimension are a good start and allow the user to get an idea about how well the segmentation works.

If you have a much larger image, see the section **"Run segmentation on large image"** at the end of the notebook. Running the `predict_large_image` function takes a lot longer (e.g., several hours), but it is possible to analyze very large images with tens of thousands of grains.

Image used below is available from [here](https://github.com/zsylvester/segmenteverygrain/blob/main/torrey_pines_beach_image.jpeg).

In [None]:
# replace this with the path to your image:
fname = '../torrey_pines_beach_image.jpeg'

image = np.array(load_img(fname))
image_pred = seg.predict_image(image, model, I=256)

# decreasing the 'dbs_max_dist' parameter results in more SAM prompts (and longer processing times):
labels, coords = seg.label_grains(image, image_pred, dbs_max_dist=20.0) # Unet prediction

Use the figure created in the next cell to check the quality of the Unet labeling (sometimes it doesn't work at all) and the distribution of SAM prompts (= black dots). If the Unet prediction is of poor quality, it is a good idea to create some training data and fine tune the base model so that it works better with the images of interest.

In [None]:
fig, ax = plt.subplots(figsize=(15,10))
ax.imshow(image_pred)
plt.scatter(np.array(coords)[:,0], np.array(coords)[:,1], c='k')
plt.xticks([])
plt.yticks([]);

In [None]:
# SAM segmentation, using the point prompts from the Unet:
all_grains, labels, mask_all, grain_data, fig, ax = seg.sam_segmentation(sam, image, image_pred, 
            coords, labels, min_area=400.0, plot_image=True, remove_edge_grains=False, remove_large_objects=False)

### Interactive editing

The editing interface itself is defined in segmentanything.interactions.

Navigation within the interface is described in the [matplotlib documentation](https://matplotlib.org/stable/users/explain/figure/interactive.html#interactive-navigation). Additional controls are:

- `Left click`: Select/unselect existing grain or place foreground prompt for grain detection
- `Shift + left click/drag`: Create or adjust box prompt for grain detection
- `Right click`: Place background prompt for grain detection
- `Middle click`: Display measurement information about the indicated grain
- `Middle click + drag`: Measure scale bar to calibrate pixels per meter
- `Control`: Hold to temporarily hide selected grains
- `Escape`: Remove all prompts and unselect all grains
- `c`: Use selection box and/or foreground/background prompts to detect a grain
- `d`: Delete selected (highlighted) grains
- `m`: Merge selected grains (must be touching)
- `z`: Delete the most recently-created grain

`px_per_m`: The ratio of pixels to meters, if known. This will be overwritten if a scale bar is measured in the interface using middle click & drag.

`scale_m`: The length in meters of a reference object. Once the reference object is measured using middle click & drag, size/area values will be converted to meters. The diagonal of the selection box will be taken to represent `scale_m` meters.

`image_max_size` (y, x): Images larger than this in either dimension will be downscaled for display. Operations like grain detection will still be performed on the full image, but the display will not be able to zoom in at full quality. This is a performance optimization. Reduce this size for better performance, increase this size for better visual quality when zoomed.

`image_alpha`: Set this to a value lower than 1 to apply a fade effect to the background image.

In [None]:
# Create Grain objects from detected polygons, providing easy measurement and display methods
grains = segi.polygons_to_grains(all_grains, image=image)

# Prepare predictor for detecting new grains
predictor = SamPredictor(sam)
predictor.set_image(image)

In [None]:
# Display interactive interface
plot = segi.GrainPlot(
    grains,
    image = image, 
    predictor = predictor,
    figsize = (12, 8),              # in
    px_per_m = 3390.,               # px/m
    scale_m = 0.5,                  # m
    # image_max_size = (240, 320),  # px
    # image_alpha = 1.
)

# Turn on interactive features (grain editing, etc)
plot.activate()

### Results

Once manual editing is complete, close the plot and run the following two cells.

These results are saved to the location specified in `out_fn`:
- Grain shapes, for use elsewhere (geojson)
- Image with colorized grains and major/minor axes drawn in (jpg)
- Summary data, presenting measurements for each detected grain (csv)
- Summary histogram, representing major/minor axes of detected grains (jpg)
- Mask representations of the detected grains, in both computer-readable (png, 0-1) and human-readable (jpg, 0-255) formats

In [None]:
# Turn off interactive features
plot.deactivate()

# Draw the major and minor axes onto each grain
plot.draw_axes()

# Retrieve unit conversion factor
px_per_m = plot.px_per_m

In [None]:
# Save results
out_fn = '../torrey_pines'
# Grain shapes
segi.save_grains(out_fn + '_grains.geojson', grains)
# Grain image
plot.savefig(out_fn + '_grains.jpg')
# Summary data
summary = segi.save_summary(
    out_fn + '_summary.csv', grains, px_per_m=px_per_m)
summary.head()
# Summary histogram
segi.save_histogram(out_fn + '_summary.jpg', summary=summary)
# Training mask
segi.save_mask(out_fn + '_mask.png', grains, image, scale=False)  # For training
segi.save_mask(out_fn + '_mask2.jpg', grains, image, scale=True)  # For viewing

## Run segmentation on large image (new!)
In this case 'fname' points to an image that is larger than a few megapixels and has thousands of grains.
The 'predict_large_image' function breaks the input image into smaller patches and it runs the segmentation process on each patch.

The image used below (from [Mair et al., 2022, Earth Surface Dynamics](https://esurf.copernicus.org/articles/10/953/2022/)) is available [here](https://github.com/zsylvester/segmenteverygrain/blob/main/mair_et_al_L2_DJI_0382_image.jpg).

In [None]:
from PIL import Image
Image.MAX_IMAGE_PIXELS = None # needed if working with very large images
fname = "mair_et_al_L2_DJI_0382_image.jpg"
all_grains, image_pred, all_coords = seg.predict_large_image(fname, model, sam, min_area=400.0, patch_size=2000, overlap=200)

In [None]:
# plot results
image = np.array(load_img(fname))
fig, ax = plt.subplots(figsize=(15,10))
plt.xticks([])
plt.yticks([])
seg.plot_image_w_colorful_grains(image, all_grains, ax, cmap='Paired')
plt.axis('equal')
plt.xlim([0, np.shape(image)[1]])
plt.ylim([np.shape(image)[0], 0]);

In [None]:
# this is a faster way of deleting false positives (because it avoids highlighting and deleting the 'bad' grains)
grain_inds = []
cid1 = fig.canvas.mpl_connect('button_press_event', lambda event: seg.onclick2(event, all_grains, grain_inds, ax=ax, select_only=True))

In [None]:
# delete polygons from 'all_grains'
grain_inds = np.unique(grain_inds)
grain_inds = sorted(grain_inds, reverse=True)
for ind in tqdm(grain_inds):
    all_grains.remove(all_grains[ind])

After plotting the results, you will want to use the functions for deleting, merging, and adding grains (see above), before saving the results (same workflow as for a small image).

### Finetuning the base model

In [None]:
# patchify images and masks
input_dir = "./Masks_and_images/" # the input directory should contain files with 'image' and 'mask' in their filenames
patch_dir = "./New_project/" # a directory called "Patches" will be created here
image_dir, mask_dir = seg.patchify_training_data(input_dir, patch_dir)

In [None]:
# create training, validation, and test datasets
train_dataset, val_dataset, test_dataset = seg.create_train_val_test_data(image_dir, mask_dir, augmentation=True)

In [None]:
# load base model weights and train the model with the new data
model = seg.create_and_train_model(train_dataset, val_dataset, test_dataset, model_file='seg_model.keras', epochs=100)

In [None]:
# save finetuned model as new model (this then can be loaded using "model = load_model("new_model.keras", custom_objects={'weighted_crossentropy': seg.weighted_crossentropy})"
model.save('new_model.keras')