# Automatically Generating Document Masks

In [None]:
import colorsys
import random
import os

from segment_anything import SamAutomaticMaskGenerator, sam_model_registry

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

In [None]:
sam_checkpoint = "../../models/sam_models/sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
_ = sam.to(device=device)

In [None]:
def generate_random_color(hue):
    """
    Generate a random pastel color using the golden ratio.

    This function takes an initial hue, increments it by the golden ratio conjugate, 
    and generates a new color in the HSL color space.

    Parameters
    ----------
    hue : float
        The initial hue value for generating the color. It should be a value in the range [0, 1).

    Returns
    -------
    hue : float
        The updated hue value.
    color : tuple
        The generated pastel color in the RGB color space, represented as a tuple of three floats in the range [0, 1).
    """
    golden_ratio_conjugate = 0.618033988749895
    hue += golden_ratio_conjugate
    hue %= 1
    
    h, s, l = hue, 0.5, 0.8

    return hue, colorsys.hls_to_rgb(h, l, s)


def show_anns(anns, alpha=0.5, mask_upscale_factor=1.0):
    """
    Visualize annotations with different colors on a plot.

    This function sorts the provided annotations by area in descending order, then visualizes 
    each annotation with a different color on the current plot.

    Parameters
    ----------
    anns : list
        A list of annotation dictionaries. Each dictionary should have a 'segmentation' key and an 'area' key. 
        The value of 'segmentation' should be a numpy array representing the mask of the annotation, and the 
        value of 'area' should be a float representing the area of the annotation.
    alpha : float, optional
        The transparency level for the colors, represented as a float in the range [0, 1]. The default is 0.5.
    mask_upscale_factor : float, optional
        The scale factor to increase the mask size by. The default is 1.0 (i.e. no scaling).

    Returns
    -------
    None
    """
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    
    width = int(sorted_anns[0]["segmentation"].shape[0] * mask_upscale_factor)
    height = int(sorted_anns[0]["segmentation"].shape[1] * mask_upscale_factor)

    img = np.ones((width, height, 4))
    img[:,:,3] = 0
    
    hue = random.random()
    for ann in sorted_anns:
        if (mask_upscale_factor != 1.0):
            m = ann["segmentation"]
            m = cv2.resize(m.astype(np.uint8), (height, width), cv2.INTER_NEAREST) > 0
        else:
            m = ann["segmentation"]
        hue, rgb_color = generate_random_color(hue)
        color_mask = np.concatenate([rgb_color, [alpha]])
        img[m] = color_mask
    ax.imshow(img)

## Testing Image Segmentation

In [None]:
def load_image(image_path):
    """
    Load an image from a given path, convert it to RGB.

    Parameters
    ----------
    image_path : str
        A string representing the path of the image file to be loaded.
    
    Returns
    -------
    resized_image : ndarray
        An RGB image.
    """
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
       
    return image


def preprocess_image(image, kernel_size=1, downscale_factor=1.0):
    """
    Apply Gaussian blur to the given image and resize it to `1/downscale_factor` times the width and height
    of the original size.

    This function applies a Gaussian blur to the image with a kernel size of (`kernel_size`, `kernel_size`) 
    and resizes it to `1/rescale_factor` of the original size. 

    Parameters
    ----------
    image : ndarray
        An RGB image.
    kernel_size : int, optional
        The size of the gaussian blur kernel to apply to the image.
    downscale_factor : float, optional
        The scale factor to reduce each side of the image by. For exampe, a
        `downscale_factor` of 2.0 would decrease each size by a factor of 2.0. 
        Defaults to 1.0

    Returns
    -------
    pre_processed_image : ndarray
        The resized image after applying Gaussian blur with a kernel size of (kernel_size, kernel_size).
    """
    height, width = image.shape[:2]
    new_height = int(height / downscale_factor)
    new_width = int(width / downscale_factor)
    resized_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
    
    blurred_image = cv2.GaussianBlur(resized_image, (kernel_size, kernel_size), 0)
    pre_processed_image = blurred_image
    
    return pre_processed_image


In [None]:
def show_image_masks(image, masks, alpha=0.7, mask_upscale_factor=1.0):
    """
    Display an image overlaid with masks.

    This function displays a given image in a new figure with given masks overlaid on it.
    The masks are visualized using different colors with a set transparency level.

    Parameters
    ----------
    image : ndarray
        The base image to be displayed. This should be a numpy array of shape 
        (height, width, 3) and the values should be in the range [0, 1] for floats or [0, 255] for integers.
    masks : list
        A list of annotation dictionaries. Each dictionary should have a 'segmentation' key and 
        an 'area' key. The value of 'segmentation' should be a numpy array representing the mask 
        of the annotation, and the value of 'area' should be a float representing the area of the annotation.
    alpha : float, optional
        The alpha (transparency) value to display the masks with over the image. Defaults to 0.7
    mask_upscale_factor : float, optional
        The scale factor to increase the mask size by. The default is 1.0 (i.e. no scaling).
        
    Returns
    -------
    None
    """
    plt.figure(figsize=(20,20))
    plt.imshow(image)
    show_anns(masks, alpha=alpha, mask_upscale_factor=mask_upscale_factor)
    plt.axis("off")
    plt.show() 

In [None]:
def generate_and_show_masks_from_image_path(image_path, sam_mask_generator, blur_kernel_size, processed_image_downscale_factor):
    """
    Generate and display masks for a given image using a 
    specified mask generator. 

    Parameters
    ----------
    image_path : str
        A string representing the path of the image file to be loaded.
    sam_mask_generator : object
        An instance of a mask generator class.
    blur_kernel_size : int
        Size of the kernel to apply blur with.
    processed_image_downscale_factor : float
        How much to downscale the image passed to the NN.
        
    Returns
    -------
    None
    """
    image = load_image(image_path)
    processed_image = preprocess_image(image, kernel_size=blur_kernel_size, downscale_factor=processed_image_downscale_factor)
    
    masks = sam_mask_generator.generate(processed_image)
    
    show_image_masks(image, masks, mask_upscale_factor=processed_image_downscale_factor)

In [None]:
# tuning settings
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=8,
    points_per_batch=64,
    pred_iou_thresh=0.98,
    stability_score_thresh=0.98,
    min_mask_region_area=128 * 128
)

In [None]:
pth = "/projects/RUSTOW/htr_deskewing_image_dataset/needs_deskewing/ENA 1178/000046/ENA_1178_046_r.tif"
generate_and_show_masks_from_image_path(pth, mask_generator, 1, 2.0)    