### Introduction

This notebook contains an example of a workflow for performing exploratory data analysis to investigate a dataset that we want to use to train an object detection model. The methods then use transfer learning to tune the ```sd_resnet50_v1_fpn_640x640_coco17_tpu-8``` model so that it can detect the new class of object defined in our dataset. I describe the model selected for tuning in great detail later in this notebook.

In this case, our training data set contains histology images with annotations that label cells of interest. The original images are not publicly available so they had to be omitted from this notebook.

For our training purposes, we heavily revise the methods described in the following tutorial [[Tutorial](https://github.com/tensorflow/models/blob/master/research/object_detection/colab_tutorials/eager_few_shot_od_training_tf2_colab.ipynb)] for the purposes of detecting multiple objects instead of the single object detection described in the instructions.

As you will see in the section "Summary of Data Quality" our training data is significantly flawed, which inhibits out ability to achieve a high level of model accuracy. A important part of model training is to ensure that high quality data is provided, and this notebook is a good example of the "garbage in, garbage out" paradigm popular in Data Science. In creating an enterprise solution, this notebook would prove that higher quality data is necessary and I would work with a scientist to obtain a better dataset.

### Summary of Data Quality
- The raw data is unsuprisingly messy:
  - Annotation Data Review:
    - Bounding Boxes:
      - Bounding boxes are inconsistent and not precise in how they define the objects (ROIs).
      - Some boxes are missing, which may be making it difficult for the model to train to identify cells.
    - Inconsistent Data Labelling:
      - Some annotations are labelled as 'cell', however other annotations are missing this label.
        - Need clear understanding of what is labelled so we can properly train our model if multiple classes exist.
  - Image Data Review:
    - Image Capture:
      - Images are under-exposed, which reduces the range of image intensity values available for resolving different structures during training.
    - Cell Consistency:
      - There is high variability in both the illumination and sharpness of cell imaging. We see very blurry cells and cells that are in focus.
      - Some cells observed with internal structures, while others were uniform in appearance.
    - Aberrant Data:
      - Random, blurred circles exist within some images (Example: See 37.png)
        - I'm not sure what this is, but it looks like some sort of post-processing, or possibly optical aberration from the sample preparation.
      - Black, straight or curly lines exist in some images.
        - Straight lines may be splitting of tissue samples.
        - Curly lines have less obvious explaination. It would be good to understand what these are to know if this has any affect on whether the data needs to be discarded.



## 1. Image Preprocessing


In [None]:
! pip install matplotlib opencv-python

import os
import json
import cv2
import numpy as np

### 1.1 Loading Images
Images are loaded into the notebook for preprocessing and analysis.


### Evaluate Data Provided by Scientist and Establish Areas of Improvement

#### Here through manual observation of our images, we can tell that some of the data provided has been artificially simulated, where in some cases an image was rotated to increase the size of available training data.


In [None]:
image_folder = "/content/drive/MyDrive/Colab Notebooks/raw_imgs"
annotation_folder = "/content/drive/MyDrive/Colab Notebooks/raw_anno"

paired_data = []
unpaired_data = []

# We assume the image file has a corresponding json
for img_file in os.listdir(image_folder):
    # Here we get the name without an extension
    base_name = os.path.splitext(img_file)[0]

    img_path = os.path.join(image_folder, img_file)
    # We pair it with the .json extension to get the path for the annotation file
    anno_path = os.path.join(annotation_folder, base_name + '.json')

    # Check if annotation exists for the image
    if os.path.exists(anno_path):
        image = cv2.imread(img_path)
        with open(anno_path, 'r') as f:
            annotation = json.load(f)
        # make a tuple to pair the data
        paired_data.append((image, annotation))

    else:
      # collect a list of annotations that were not paired with images
      unpaired_data.appen(anno_path)

if not unpaired_data:
  print("All annotations present and loaded!")


Example Output:


```
All annotations present and loaded!
```




### 1.2 Channel Comparison and Redundancy Removal
Comparison of image channels is performed to identify and remove redundant data. An important thing to note is that the package we used to import the data actually took single channel data and duplicated it across three channels by default, so this step is redundant. However, if you were working with a unfamiliar package to load your data or had images with multichannel data, this step would be necessary to verify you are not working with redundant data.


In [None]:
#check to see whether multiple color channels present
paired_data[0][0].shape

# Here we suspect its an RGB image because it has three channels

Example Output:


```
(512, 512, 3)
```



In [None]:
import matplotlib.pyplot as plt
import numpy as np

def inspect_channels(image):
    if image.ndim == 3 and image.shape[-1] == 3:  # Check if the image has 3 channels
        red, green, blue = image[:, :, 0], image[:, :, 1], image[:, :, 2]
        # Create a composite image
        composite_image = np.stack((red, green, blue), axis=-1)
    else:
        print("Image does not have 3 channels!")
        return

    # Display each channel side by side and the composite image
    plt.figure(figsize=(20, 5))

    plt.subplot(1, 4, 1)
    plt.imshow(red, cmap='gray')
    plt.title('Red Channel')

    plt.subplot(1, 4, 2)
    plt.imshow(green, cmap='gray')
    plt.title('Green Channel')

    plt.subplot(1, 4, 3)
    plt.imshow(blue, cmap='gray')
    plt.title('Blue Channel')

    plt.subplot(1, 4, 4)
    plt.imshow(composite_image)
    plt.title('Composite Image')

    plt.tight_layout()
    plt.show()

    # Compare pixel values between channels
    rg_diff = np.abs(red - green)
    rb_diff = np.abs(red - blue)
    gb_diff = np.abs(green - blue)

    print("R-G Channel Mean Difference:", np.mean(rg_diff))
    print("R-B Channel Mean Difference:", np.mean(rb_diff))
    print("G-B Channel Mean Difference:", np.mean(gb_diff))

    if np.sum([np.mean(rg_diff), np.mean(rb_diff), np.mean(gb_diff)]) == 0:
        print("All channels are the same")

# check to see if there are differences across channels
inspect_channels(paired_data[5][0])


Example Output:
> Displays image of the red, green, and blue channel and then a composite image of the overlay of all three channels

```
R-G Channel Mean Difference: 0.0
R-B Channel Mean Difference: 0.0
G-B Channel Mean Difference: 0.0
All channels are the same
```



In [None]:
# Because all channels are the same, we discard the extra channels to clean up the data a little bit
for i in range(len(paired_data)):
    paired_data[i] = (paired_data[i][0][:,:,0:1], paired_data[i][1])

# we confirm only one channel is left
paired_data[0][0].shape

Example Output:


```
(512, 512, 1)
```



### 1.3 Image Data Format Confirmation
The format of the images is confirmed to be `float32` with intensity ranges from 0 to 256.


### 1.4 Intensity Histogram and Exposure Analysis
A histogram is created for all images to analyze the exposure levels, highlighting any underexposure issues.

Here we generate some stats to consider the quality of the images
The histogram of the batches shows that images were generally underexposed and the scientist could increase exposure conditions to improve the range of available pixel intensity values for training.
Manual review of the images shows some black splotches that should be manually removed from a future dataset. We also see some blurry spots that may also be poor quality data.

In [None]:
import numpy as np

# Image Preprocessing
def validate_image(image, target_dtype=np.float32, target_range=(0, 255)):
    """Ensure the image is in the correct format and range."""
    image = np.asarray(image, dtype=target_dtype)  # Convert to the target data type
    min_val, max_val = np.min(image), np.max(image)  # Check min and max values
    if min_val < target_range[0] or max_val > target_range[1]:
      print(f"Image out of range for expected values: {min_val} < {target_range[0]}, {max_val} > {target_range[1]}")
    return image

# Single Image Histogram
def histogram_for_single_image(image, bins=256, range=(0, 255)):
    """Prepare histogram for a single image."""
    flattened_image = image.flatten()  # Flatten the image
    hist_values, bin_edges = np.histogram(flattened_image, bins=bins, range=range)
    return hist_values, bin_edges

# Batch Histogram
def histogram_for_batch(images, bins=256, range=(0, 255)):
    """Prepare histogram for a batch of images."""
    total_hist_values = np.zeros(bins, dtype=np.float32)  # Initialize histogram values

    for image in images:
        hist_values, _ = histogram_for_single_image(image, bins=bins, range=range)
        total_hist_values += hist_values  # Accumulate histogram values

    return total_hist_values

def plot_histogram(hist_values, bin_edges, title='Histogram', xlabel='Pixel Intensity', ylabel='Frequency'):
    """Plot the histogram."""
    plt.bar(bin_edges[:-1], hist_values, width=np.diff(bin_edges)[0], align='edge')
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.show()

# Preprocess all images
validated_images = [validate_image(img) for img, _ in paired_data] # All images

# Generate histograms
for image in validated_images:
    hist_values, bin_edges = histogram_for_single_image(image)  # Histogram for individual image

# Generate batch histogram
hist_values_batch = histogram_for_batch(validated_images)
plot_histogram(hist_values_batch, bin_edges, title='Batch Histogram')  # Uncomment to visualize


Example Output:
> Displays a histogram of pixel intensity, with a peak on the left side of the graph close to zero, and tail on the right side of the data. This indicates most images are under-exposed.

### 1.5 Cumulative Distribution Function (CDF) Calculation
The CDF is calculated using histogram data to aid in adjusting pixel values for better exposure. Here we use the cumulative distribution function to equalize pixel intensity values. In our case we equalize according to the CDF calculated for an entire batch rather than using the CDF from each image. This is to ensure greater consistency but **requires** batch processing of images

#### Why we do this:
In statistical terms, the CDF is a function that indicates the probability of a variable taking a value less than or equal to a certain level. When applied to image processing, the CDF maps the intensity levels of pixels in an image to a cumulative probability. This function shows for any given intensity level in the image, the probability of finding a pixel with intensity less than or equal to that level. The histogram of an image shows the frequency of each intensity level. For histogram equalization, we first normalize this histogram to a probability distribution. Then, the CDF is calculated by taking the cumulative sum of these probabilities. The resulting CDF is used to remap the intensity levels of the image's pixels, which redistributes the pixel intensities across the available range more evenly.

Applying the CDF in histogram equalization achieves the following:
- **Enhances Contrast**: It spreads out the most frequent intensity values, which enhances the global contrast of the image.
- **Equalizes Intensities**: The CDF ensures that each intensity level is used in the image, maximizing the use of available intensity levels.
- **Improves Visibility**: Details hidden in dark or bright regions become more visible after equalization.
- **Facilitates Better Analysis**: For further image processing tasks such as segmentation or object detection, equalized images often yield better results.




### 1.6 Pre- and Post-Equalization Comparison
A manual comparison is made between images before and after the application of histogram equalization.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2

def histogram_equalization(image, bins=256):
    """Perform histogram equalization on a single image."""
    # Calculate the histogram
    hist, bin_edges = np.histogram(image.flatten(), bins, density=False)
    # Calculate the CDF
    cdf = hist.cumsum()
    cdf = 255 * cdf / cdf[-1]  # Normalize

    # Use linear interpolation of the CDF to find new pixel values
    image_equalized = np.interp(image.flatten(), bin_edges[:-1], cdf)
    return image_equalized.reshape(image.shape)

# We will perform this later
# def resize_image(image, target_size=(640, 640), interpolation=cv2.INTER_LANCZOS4):
#     """Resize images to the target size and add a channel dimension if needed."""
#     resized_image = cv2.resize(image, target_size, interpolation=interpolation)
#     if len(resized_image.shape) == 2:  # if only height and width, no channels
#         resized_image = resized_image[:, :, np.newaxis]  # add a new axis for channels
#     return resized_image

def display_images(original_images, equalized_images):
    """Display original and equalized images side by side."""
    fig, axes = plt.subplots(2, len(original_images), figsize=(15, 5))

    for i in range(len(original_images)):
        axes[0, i].imshow(original_images[i], cmap='gray')
        axes[0, i].set_title(f'Original Image {i+1}')
        axes[0, i].axis('off')

        axes[1, i].imshow(equalized_images[i], cmap='gray')
        axes[1, i].set_title(f'Equalized Image {i+1}')
        axes[1, i].axis('off')

    plt.tight_layout()
    plt.show()

# Extract just the images from paired_data
images = [img for img, _ in paired_data]

# Equalize all images in the batch individually
equalized_images = [histogram_equalization(img) for img in images]

# Randomly select three images and their equalized counterparts for display
random_indices = np.random.choice(len(images), min(3, len(images)), replace=False)
selected_original_images = [images[i] for i in random_indices]
selected_equalized_images = [equalized_images[i] for i in random_indices]

display_images(selected_original_images, selected_equalized_images)


Expected Output:
> This displays a series of images sampled from the dataset. The top row is a set of original images which are dark and underexposed. The bottom shows equalized images that are brighter than the originals so that many features are more visible.

## 2. ROI Annotation Preprocessing


### 2.1 ROI Category Consistency Check
Text related to ROI formatting is examined to ensure consistent categorization across annotations. Here I check my assumptions about the format of the data.


In [None]:

def get_unique_values_for_roi_format(paired_data):
    # Initialize a dictionary where keys are indices and values are sets of unique values at that index
    unique_values = {}

    # Iterate over paired data to populate unique values dictionary
    for _, annotation in paired_data:
        roi_format = annotation['information']['roi_format']
        for idx, value in enumerate(roi_format):
            if idx not in unique_values:
                unique_values[idx] = set()
            unique_values[idx].add(value)

    return unique_values

unique_values_dict = get_unique_values_for_roi_format(paired_data)

# Display unique values for each index
for idx, unique_vals in unique_values_dict.items():
    print(f"Index {idx} has unique values: {', '.join(unique_vals)}")


Expected Output:

```
Index 0 has unique values: type
Index 1 has unique values: left
Index 2 has unique values: top
Index 3 has unique values: width
Index 4 has unique values: height
Index 5 has unique values: class
```



### 2.2 Initial ROI Annotation Conversion
Conversion from the JSON structure to a list of bounding box dimensions is performed for each ROI annotation. We assume that the tracing of annotations starts at the top left corner of a box, and then draws the box to the bottom right corner. If this is the case, then we assume that

In [None]:
def reformat_annotations(pairs):
    reformatted_pairs = []

    for image, annotation in pairs:
        rois = annotation['rois']
        roi_format = annotation['information']['roi_format']

        # Initialize a dictionary with keys from roi_format and empty list as values
        formatted_annotation = {key: [] for key in roi_format}

        # Fill the lists with corresponding roi values
        for roi in rois:
            for i, key in enumerate(roi_format):
                formatted_annotation[key].append(roi[i])

        # Add the reformatted annotation along with its image to the new list
        reformatted_pairs.append((image, formatted_annotation))

    return reformatted_pairs

### 2.3 Bounding Box Validity Check
Bounding boxes that fall outside the known image dimensions are flagged and removed to clean the data.


In [None]:
def reformat_clean_annotations(pairs, clean=False):
    reformatted_pairs = []

    for index, (image, annotation) in enumerate(pairs):
        rois = annotation['rois']
        roi_format = annotation['information']['roi_format']
        image_height, image_width = image.shape[:2]  # Where 'image' is a NumPy array

        # We will only use the keys that are related to the ROI dimensions
        keys = ['type', 'left', 'top', 'width', 'height']
        formatted_annotation = {key: [] for key in keys}

        # Check each annotation for validity
        for i, roi in enumerate(rois):
            # Extract ROI data, ignoring the 'class' key if it's present.
            # We only have one class in our data so we can leave this off until
            # we set up the tensor data for the model.
            type_, left, top, width, height = roi[:5]

            if clean:
              # Check if the annotation is impossible
              # In theory we assume that annotations that are less than the
              # zero, zero cordinates of the top left are out of bounds. We also
              # assume that our box width and height can not exceed the size of
              # the image.
              if left < 0 or left >= image_width or top < 0 or top >= image_height:
                  continue  # Skip the impossible annotation

            # Append the valid ROI data to the formatted_annotation
            for j, key in enumerate(keys):
                formatted_annotation[key].append(roi[j])

        # Add the filtered and reformatted annotations to the list
        reformatted_pairs.append((image, formatted_annotation))

    return reformatted_pairs


# Use the function on cleaned_pairs
equalized_pairs = [(equalized_images[i], annotation) for i, (_, annotation) in enumerate(paired_data)]
temp_pairs = reformat_annotations(equalized_pairs)
cleaned_pairs = reformat_clean_annotations(equalized_pairs)

print(f"Length of first pair without removing out-of-range values: {len(temp_pairs[0][1]['left'])}")
print(f"Length of first pair with out-of-range values removed: {len(cleaned_pairs[0][1]['left'])}")

print(f"First pair without removing out-of-range values: {temp_pairs[0][1]}")
print(f"First pair with out-of-range values removed: {cleaned_pairs[0][1]}")

Expected Output
> In this case nothing is changed, we are just checking what is present in the data and seeing the impact of our function.

```
Length of first pair without removing out-of-range values: 12
Length of first pair with out-of-range values removed: 12
First pair without removing out-of-range values: {'type': [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], 'left': [417, 272, 484, 311, 243, 442, 503, 498, 291, 95, 168, 172], 'top': [415, 423, -1, 117, 288, 478, 246, 95, 194, 98, 139, 447], 'width': [14, 12, 16, 11, 17, 12, 8, 12, 16, 15, 13, 14], 'height': [11, 12, 10, 11, 10, 16, 13, 21, 26, 12, 20, 22]}
First pair with out-of-range values removed: {'type': [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], 'left': [417, 272, 484, 311, 243, 442, 503, 498, 291, 95, 168, 172], 'top': [415, 423, -1, 117, 288, 478, 246, 95, 194, 98, 139, 447], 'width': [14, 12, 16, 11, 17, 12, 8, 12, 16, 15, 13, 14], 'height': [11, 12, 10, 11, 10, 16, 13, 21, 26, 12, 20, 22]}
```



### 2.4 ROI 'Class' Assumptions and Corrections
The 'class' metric for ROI annotations is standardized, assuming all unlabelled objects are 'cells'. Here we discover that some annotations of the boxes are missing. In this case it becomes clear that some boxes are missing the "cell" label, so we fill in missing data with the label "cell" for each box drawn. There are no other label categories for the class, so this is only done for completeness.

In [None]:
import pandas as pd

# isolate annotations
annotations = [annotation for _, annotation in temp_pairs]

# Create DataFrame
df_annotations = pd.DataFrame(annotations)

print("roi_formats with missing values")
# Check for roi_formats with missing values across all columns
print(df_annotations.isnull().sum())

Expected Outputs:

```
roi_formats with missing values
type       0
left       0
top        0
width      0
height     0
class     39
dtype: int64
```



In [None]:
print("unique values for the 'class' roi_format")

# Replace NaN with empty lists
df_annotations['class'] = df_annotations['class'].apply(lambda x: x if isinstance(x, list) else [])

# Flatten all lists and remove empty fields
all_classes = [element for sublist in df_annotations['class'] if sublist for element in sublist]

# Get the unique values
unique_classes = list(set(all_classes))

# Print the list of unique classes. There was definitely an easier way to do
# this: (append all the class lists to a single list and then convert to set).
print(unique_classes)

Expected Output:

```
unique values for the 'class' roi_format
['cell']
```



### 2.5 ROI Format Conversion and Validation
ROI metrics are converted from the JSON format to the [ymin, xmin, ymax, xmax] format used by the model, and validated through manual inspection. The following link is useful for looking at various annotation methods, and technically the annotation method I adopt below is an inverted "albumentations" method, where the x and y coordinates are switched, however we still use normalized mins and maxes.

https://albumentations.ai/docs/getting_started/bounding_boxes_augmentation/#coco

In [None]:
import numpy as np

def convert_rois_to_cartesian(roi_pairs):
    image_array, roi_data = roi_pairs
    image_height, image_width, _ = image_array.shape

    # Create a list to hold the coordinates for bounding boxes
    box_coordinates = []
    for i in range(len(roi_data['top'])):
      # Convert to [y_min, x_min, y_max, x_max] and normalize
      y_min = roi_data['top'][i] / image_height
      x_min = roi_data['left'][i] / image_width
      y_max = (roi_data['top'][i] + roi_data['height'][i]) / image_height
      x_max = (roi_data['left'][i] + roi_data['width'][i]) / image_width
      # Append the box coordinates as a list
      box_coordinates.append([y_min, x_min, y_max, x_max])

    # Convert the list of box coordinates to a single numpy array
    # print(box_coordinates)
    numpy_boxes = np.array(box_coordinates, dtype=np.float32)
    # Return a new tuple with the image array and the numpy array of box coordinates
    return (image_array, numpy_boxes)

cartesian_pair = convert_rois_to_cartesian(cleaned_pairs[0])
cartesian_pair[1]

Expected Output:

```
array([[ 0.8105469 ,  0.8144531 ,  0.83203125,  0.8417969 ],
       [ 0.8261719 ,  0.53125   ,  0.8496094 ,  0.5546875 ],
       [-0.00195312,  0.9453125 ,  0.01757812,  0.9765625 ],
       [ 0.22851562,  0.6074219 ,  0.25      ,  0.62890625],
       [ 0.5625    ,  0.47460938,  0.58203125,  0.5078125 ],
       [ 0.93359375,  0.86328125,  0.96484375,  0.88671875],
       [ 0.48046875,  0.9824219 ,  0.5058594 ,  0.9980469 ],
       [ 0.18554688,  0.97265625,  0.2265625 ,  0.99609375],
       [ 0.37890625,  0.5683594 ,  0.4296875 ,  0.5996094 ],
       [ 0.19140625,  0.18554688,  0.21484375,  0.21484375],
       [ 0.27148438,  0.328125  ,  0.31054688,  0.35351562],
       [ 0.8730469 ,  0.3359375 ,  0.9160156 ,  0.36328125]],
      dtype=float32)
```



In [None]:
cleaned_pairs[0][0].shape

Expected Output:


```
(512, 512, 1)
```



In [None]:
from PIL import Image, ImageDraw, ImageFont

def draw_boxes_on_image(image_array, boxes, is_converted=False):
    # Image preprocessing to ensure it's in the correct format
    if image_array.dtype != np.uint8:
        image_array = (255 * (image_array - image_array.min()) / (image_array.max() - image_array.min())).astype(np.uint8)
    if image_array.shape[-1] == 1:
        image_array = image_array.squeeze(-1)
    if image_array.ndim == 2:
        image = Image.fromarray(image_array, 'L').convert('RGB')
    else:
        image = Image.fromarray(image_array)

    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()

    # Define the loop range based on whether boxes are converted or not
    loop_range = len(boxes) if is_converted else len(boxes['type'])

    for index in range(loop_range):
        # Determine the coordinates based on whether the boxes are converted
        if is_converted:
            box = boxes[index]  # Define the box variable here for the converted
            # case.
            # please note that because of the way the draw function expects
            # inputs, we need to modify the order of our box array to ensure
            # correct plotting.
            left, top, right, bottom = box[1] * image.width, box[0] * image.height, box[3] * image.width, box[2] * image.height
        else:
            left, top, width, height = boxes['left'][index], boxes['top'][index], boxes['width'][index], boxes['height'][index]
            right, bottom = left + width, top + height

        # Draw the rectangle with green color
        draw.rectangle([left, top, right, bottom], outline=(0, 255, 0), width=2)

        # Draw the index number above the box
        text_position = (left, top - 10 - 2)
        draw.text(text_position, str(index), fill=(0, 255, 0), font=font)

    return image

def validate_coordinate_conversion(roi_pair, cartesian_pair):
    # Extract the image and original ROIs from roi_pair
    image_array, rois = roi_pair

    # Extract the converted ROIs
    _, numpy_boxes = cartesian_pair

    # Draw the original ROIs on the image
    image_with_original_rois = draw_boxes_on_image(np.copy(image_array), rois, is_converted=False)

    # Draw the converted ROIs on the image
    image_with_converted_rois = draw_boxes_on_image(np.copy(image_array), numpy_boxes, is_converted=True)

    # Now, let's display both images side by side using matplotlib for comparison
    import matplotlib.pyplot as plt

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

    # Convert the PIL images back to numpy arrays for displaying
    image_with_original_rois_np = np.array(image_with_original_rois)
    image_with_converted_rois_np = np.array(image_with_converted_rois)

    # Display the images without cmap='gray' to ensure the green boxes are visible
    ax1.imshow(image_with_original_rois_np)
    ax1.set_title('Original ROI Annotation Format')

    ax2.imshow(image_with_converted_rois_np)
    ax2.set_title('Converted ROI Annotation Format')

    plt.show()

# Now we can compare the original and converted rois visually
validate_coordinate_conversion(cleaned_pairs[0], cartesian_pair)
cleaned_pairs[0][1]

Expected Output:
> This displays the same image twice, with annotation boxes plotted on each image. One image uses boxes that used the original annotation format as input. The other image uses boxes plotted using the new annotation format as input. The outcome is that the boxes are in the same locations on both images--ensuring that the conversion of annotation formats did not alter or skew the geometry of the boxes. This helps make sure we are using an annotation method that is compatible with our machine learning model. This allows us to avoid having to retrain the annotation head of the model, saving us time and money.

```
{'type': [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
 'left': [417, 272, 484, 311, 243, 442, 503, 498, 291, 95, 168, 172],
 'top': [415, 423, -1, 117, 288, 478, 246, 95, 194, 98, 139, 447],
 'width': [14, 12, 16, 11, 17, 12, 8, 12, 16, 15, 13, 14],
 'height': [11, 12, 10, 11, 10, 16, 13, 21, 26, 12, 20, 22]}
```



Seeing that converting from ROI to cartesian coordinates still results in the same plotted bounding boxes, we move forward and convert the rest of the annotations. We randomly sample an index to double check quality. We then convert our image back to RGB.

In [None]:
cartesian_pairs = [convert_rois_to_cartesian(pair) for pair in cleaned_pairs]
cartesian_pairs[50]

Expected Output:


```
(array([[[  6.43861771],
         [  2.35977374],
         [  0.78339438],
         ...,
         [242.6128975 ],
         [243.57959862],
         [222.61333466]],

        [[ 16.1225616 ],
         [  6.43861771],
         [  1.04764938],
         ...,
         [168.91843805],
         [232.33611514],
         [110.85475998]],

        [[ 47.35530853],
         [ 43.20379689],
         [ 26.05213165],
         ...,
         [176.28017876],
         [238.39163152],
         [110.85475998]],

        ...,

        [[ 16.1225616 ],
         [  6.43861771],
         [  6.43861771],
         ...,
         [187.81331623],
         [ 79.38983917],
         [150.06483078]],

        [[ 43.20379689],
         [ 10.13463897],
         [  5.01904138],
         ...,
         [150.06483078],
         [239.98952866],
         [246.46819522]],

        [[224.44064874],
         [ 76.22279718],
         [ 10.13463897],
         ...,
         [ 43.20379689],
         [196.8074304 ],
         [ 23.54496462]]]),
 array([[0.04296875, 0.8886719 , 0.08984375, 0.91796875],
        [0.19140625, 0.84375   , 0.22265625, 0.8730469 ],
        [0.03515625, 0.5       , 0.06835938, 0.5253906 ],
        [0.7519531 , 0.671875  , 0.79296875, 0.70703125],
        [0.06640625, 0.1796875 , 0.09570312, 0.20507812],
        [0.9140625 , 0.41210938, 0.9433594 , 0.44140625],
        [0.29882812, 0.51171875, 0.34179688, 0.55078125],
        [0.03320312, 0.32617188, 0.08984375, 0.36523438],
        [0.43359375, 0.1953125 , 0.45507812, 0.21679688],
        [0.40625   , 0.5957031 , 0.44335938, 0.62890625],
        [0.390625  , 0.36523438, 0.43359375, 0.40039062],
        [0.24414062, 0.8105469 , 0.28125   , 0.8417969 ],
        [0.7988281 , 0.25390625, 0.8339844 , 0.27734375],
        [0.19726562, 0.33398438, 0.25195312, 0.359375  ],
        [0.53515625, 0.12304688, 0.57421875, 0.1484375 ],
        [0.53125   , 0.921875  , 0.5703125 , 0.9550781 ],
        [0.171875  , 0.00976562, 0.20898438, 0.04296875],
        [0.46289062, 0.20703125, 0.5253906 , 0.25      ],
        [0.13476562, 0.38867188, 0.16796875, 0.41796875],
        [0.0078125 , 0.8105469 , 0.05273438, 0.8417969 ],
        [0.8417969 , 0.2421875 , 0.8828125 , 0.27929688],
        [0.13476562, 0.484375  , 0.16601562, 0.5097656 ],
        [0.82421875, 0.03710938, 0.8613281 , 0.07226562],
        [0.40039062, 0.17578125, 0.43164062, 0.19726562],
        [0.24023438, 0.171875  , 0.2734375 , 0.1953125 ],
        [0.01171875, 0.9277344 , 0.05078125, 0.95703125],
        [0.29101562, 0.5449219 , 0.3359375 , 0.5761719 ],
        [0.31054688, 0.609375  , 0.34375   , 0.63671875],
        [0.00390625, 0.23046875, 0.03320312, 0.25585938],
        [0.32617188, 0.6699219 , 0.34960938, 0.6875    ],
        [0.8847656 , 0.29882812, 0.9433594 , 0.32617188],
        [0.56640625, 0.6152344 , 0.5996094 , 0.640625  ],
        [0.31054688, 0.11523438, 0.36523438, 0.14453125],
        [0.46875   , 0.35742188, 0.5175781 , 0.38671875],
        [0.52734375, 0.04101562, 0.5625    , 0.07617188],
        [0.6777344 , 0.4921875 , 0.71875   , 0.5292969 ]], dtype=float32))
```



In [None]:
def convert_monochrome_to_rgb(image_array):
  if image_array.shape[-1] == 3:
    print("Already RGB!")
    return image_array

  # Check if the image is of type float and convert to uint8
  if image_array.dtype != np.uint8:
      image_array = (255 * (image_array - image_array.min()) / (image_array.max() - image_array.min())).astype(np.uint8)

  # If the image array has a shape of (H, W, 1), convert it to (H, W)
  if image_array.shape[-1] == 1:
      image_array = image_array.squeeze(-1)

  # Convert grayscale (H, W) to RGB (H, W, 3) by stacking the grayscale image along the new axis
  if image_array.ndim == 2:
      image_array = np.stack((image_array,)*3, axis=-1)

  return image_array

for i, (image_array, annotation) in enumerate(cartesian_pairs):
    rgb_image_array = convert_monochrome_to_rgb(image_array)
    cartesian_pairs[i] = (rgb_image_array, annotation)


## 3. Model Preparation and Data Augmentation

The following code is heavily modified from the original "Eager Few Shot Object Detection Colab" tutorial which is part of the Object Detection library of tensorflow. To highlight a few differences, that tutorial detects a single object within a an image, while I have adapted the use of this model to find multiple objects within each image. Additionally, I have changed the optimizer, added a validation metrics using IOU, and functionality such as training checkpoints, in addition to some tooling for evaluating model performance holistically.

https://github.com/tensorflow/models/blob/master/research/object_detection/colab_tutorials/eager_few_shot_od_training_tf2_colab.ipynb


### 3.1 Cloning TensorFlow Model Zoo and Setup
The TensorFlow model zoo is cloned, and the environment is set up using protobuf.


In [None]:
!pip install -U --pre tensorflow=="2.2.0"


In [None]:
import os
import pathlib

# Clone the tensorflow models repository if it doesn't already exist
if "models" in pathlib.Path.cwd().parts:
  while "models" in pathlib.Path.cwd().parts:
    os.chdir('..')
elif not pathlib.Path('models').exists():
  !git clone --depth 1 https://github.com/tensorflow/models

In [None]:
# The setup of TensorFlow Object Detection API and compilation of .proto files using protobuff
%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf2/setup.py .
python -m pip install .

In [None]:
import matplotlib
import matplotlib.pyplot as plt

import os
import random
import io
import imageio
import glob
import scipy.misc
import numpy as np
from six import BytesIO
from PIL import Image, ImageDraw, ImageFont
from IPython.display import display, Javascript
from IPython.display import Image as IPyImage

import tensorflow as tf

from object_detection.utils import label_map_util
from object_detection.utils import ops as utils_ops
from object_detection.utils import config_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.utils import colab_utils
from object_detection.builders import model_builder

%matplotlib inline

### 3.2 Define untility functions


In [None]:
def load_image_into_numpy_array(path):
  """Load an image from file into a numpy array.

  Puts image into numpy array to feed into tensorflow graph.
  Note that by convention we put it into a numpy array with shape
  (height, width, channels), where channels=3 for RGB.

  Args:
    path: a file path.

  Returns:
    uint8 numpy array with shape (img_height, img_width, 3)
  """
  img_data = tf.io.gfile.GFile(path, 'rb').read()
  image = Image.open(BytesIO(img_data))
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)

def plot_detections(image_np,
                    boxes,
                    classes,
                    scores,
                    category_index,
                    figsize=(12, 16),
                    image_name=None,
                    min_score_thresh=0.8):
  """Wrapper function to visualize detections.

  Args:
    image_np: uint8 numpy array with shape (img_height, img_width, 3)
    boxes: a numpy array of shape [N, 4]
    classes: a numpy array of shape [N]. Note that class indices are 1-based,
      and match the keys in the label map.
    scores: a numpy array of shape [N] or None.  If scores=None, then
      this function assumes that the boxes to be plotted are groundtruth
      boxes and plot all boxes as black with no classes or scores.
    category_index: a dict containing category dictionaries (each holding
      category index `id` and category name `name`) keyed by category indices.
    figsize: size for the figure.
    image_name: a name for the image file.
  """
  image_np_with_annotations = image_np.copy()
  viz_utils.visualize_boxes_and_labels_on_image_array(
      image_np_with_annotations,
      boxes,
      classes,
      scores,
      category_index,
      use_normalized_coordinates=True,
      min_score_thresh=min_score_thresh)
  if image_name:
    plt.imsave(image_name, image_np_with_annotations)
  else:
    plt.imshow(image_np_with_annotations)


### 3.3 Image Resizing for Model Input
Images are resized to 640x640 using `cv2.INTER_LANCZOS4` interpolation to match the model input requirements. When it comes to image processing, upscaling an image is a common task that involves increasing the resolution of the image. `cv2.INTER_LANCZOS4` is one of the interpolation methods provided by OpenCV, and it is especially beneficial for upscaling while retaining detail. `cv2.INTER_LANCZOS4` refers to the Lanczos interpolation over 8x8 pixel neighborhood. It is one of the most sophisticated and high-quality resampling algorithms provided by OpenCV for resizing images.




In [None]:
import cv2

def resize_image_and_boxes(image, boxes, target_size=(640, 640)):
    '''Resize the image and the bounding boxes. This a assumes that box cordinates are already normalized.'''

    # Resize image
    resized_image = cv2.resize(image, target_size, interpolation=cv2.INTER_LANCZOS4)

    # Add channel dimension if missing
    if len(resized_image.shape) == 2:
        resized_image = resized_image[..., np.newaxis]

    return resized_image, boxes

# Resize the images and boxes to use the 640x640 model input shape
resized_data = [resize_image_and_boxes(image, boxes) for image, boxes in cartesian_pairs]


### 3.4 Data Split into Training and Testing Sets
The dataset is divided into 70% training data and 30% testing data.


In [None]:
import numpy as np
def train_validation_split(data, train_ratio=0.70):
    """
    Splits the data into training and validation sets based on the specified training ratio.
    """
    np.random.shuffle(data)
    train_size = int(len(data) * train_ratio)
    training_data = data[:train_size]
    validation_data = data[train_size:]
    return training_data, validation_data

# Split the dataset
training_data, validation_data = train_validation_split(resized_data)
print(f"Training data size: {len(training_data)}")
print(f"Validation data size: {len(validation_data)}")

# Initialize lists to hold training and test data
train_images_np, train_gt_boxes = [], []
test_images_np, test_gt_boxes = [], []

# Load the training images and annotations
for images, annotations in training_data:
    train_images_np.append(images)
    train_gt_boxes.append(annotations)

# Load the validation images and annotations
for images, annotations in validation_data:
    test_images_np.append(images)
    test_gt_boxes.append(annotations)

### 3.5 Manual Inspection of Formatted Images
Formatted images undergo a manual inspection to ensure consistent preprocessing.

In [None]:
import matplotlib.pyplot as plt

plt.rcParams['axes.grid'] = False
plt.rcParams['xtick.labelsize'] = False
plt.rcParams['ytick.labelsize'] = False
plt.rcParams['xtick.top'] = False
plt.rcParams['xtick.bottom'] = False
plt.rcParams['ytick.left'] = False
plt.rcParams['ytick.right'] = False
plt.rcParams['figure.figsize'] = [14, 7]

print("Shape of resized image tensors:", train_images_np[0].shape)

for idx, train_image_np in enumerate(train_images_np[0:6]):
  plt.subplot(2, 3, idx+1)
  plt.imshow(train_image_np)
print("Sample of training images")
plt.show()

print("Sample of annotation data for the first image, aka 'ground truth boxes'")
print(train_gt_boxes[0])


### 3.6 Training Data Box Statistics Generation
Statistics for training data bounding boxes are generated for reference against test data.

#### Average Distance Between Boxes
- The average Euclidean distance between the centers of all pairs of bounding boxes within each image. This statistic provides insight into the average spatial distribution of objects in an image.

#### Count per Image
- The total number of bounding boxes detected in each image. This count helps understand the average density of detectable objects per image.

#### Size Statistics
- `size_mean`: The mean size of bounding boxes, provided as (width, height). It indicates the average dimensions of the detected objects.
- `size_median`: The median size of bounding boxes, which can be less sensitive to outliers than the mean.
- `size_std`: The standard deviation of the sizes of bounding boxes, indicating the variability in the size of detected objects.

#### Aspect Ratio Statistics
- `aspect_ratio_mean`: The mean of the aspect ratios (height divided by width) of all bounding boxes, which gives an average shape of objects.
- `aspect_ratio_median`: The median of the aspect ratios, again providing a measure less affected by outliers.
- `aspect_ratio_std`: The standard deviation of the aspect ratios, showing how much the shape of detected objects varies.

#### Position Statistics
- `position_mean`: The mean position of bounding box centers, given as (y, x). This reflects the average location where objects tend to be centered in the image.
- `position_std`: The standard deviation of the positions of bounding box centers, indicating how spread out the objects are in the image.

#### Edge Proximity Statistics
- `edge_proximity_mean`: The mean minimum distance of each bounding box to the nearest edge of the image. A lower number suggests objects are often closer to the edges.
- `edge_proximity_std`: The standard deviation of the distances to the edges, showing the variability in how close objects tend to be to the edge.


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def calculate_statistics(train_gt_boxes):
    # Assumes train_gt_boxes is a list of Numpy arrays with the shape (N, 4)
    # Each sub-array contains bounding boxes in the form [ymin, xmin, ymax, xmax]
    all_distances = []
    all_sizes = []
    all_aspect_ratios = []
    all_positions = []
    all_edge_proximities = []

    for boxes in train_gt_boxes:
        # Calculate center positions
        centers = np.c_[ (boxes[:, 2] + boxes[:, 0]) / 2,
                         (boxes[:, 3] + boxes[:, 1]) / 2 ]
        all_positions.append(centers)

        # Calculate sizes (width, height)
        sizes = np.c_[ boxes[:, 3] - boxes[:, 1],
                       boxes[:, 2] - boxes[:, 0] ]
        all_sizes.append(sizes)

        # Calculate aspect ratios
        aspect_ratios = sizes[:, 1] / sizes[:, 0]
        all_aspect_ratios.extend(aspect_ratios)

        # Calculate edge proximity
        edge_proximity = np.minimum.reduce([boxes[:, 1], boxes[:, 0], 1 - boxes[:, 3], 1 - boxes[:, 2]])
        all_edge_proximities.extend(edge_proximity)

        # Calculate distances between each pair of boxes
        if len(centers) > 1:
            distances = np.sqrt(np.sum((centers[:, np.newaxis, :] - centers[np.newaxis, :, :]) ** 2, axis=2))
            # We only want the upper triangle without the diagonal, since the matrix is symmetric and the diagonal is 0
            triu_indices = np.triu_indices_from(distances, k=1)
            all_distances.extend(distances[triu_indices])

    # Convert to Numpy arrays for easier manipulation
    all_distances = np.array(all_distances)
    all_sizes = np.concatenate(all_sizes, axis=0)
    all_aspect_ratios = np.array(all_aspect_ratios)
    all_edge_proximities = np.array(all_edge_proximities)

    # Calculate the statistics
    stats = {
        'average_distance_between_boxes': np.mean(all_distances) if len(all_distances) else None,
        'box_count_per_image': [len(boxes) for boxes in train_gt_boxes],
        'box_size_mean': np.mean(all_sizes, axis=0),
        'box_size_median': np.median(all_sizes, axis=0),
        'box_size_std': np.std(all_sizes, axis=0),
        'box_aspect_ratio_mean': np.mean(all_aspect_ratios),
        'box_aspect_ratio_median': np.median(all_aspect_ratios),
        'box_aspect_ratio_std': np.std(all_aspect_ratios),
        'box_position_mean': np.mean(np.concatenate(all_positions, axis=0), axis=0),
        'box_position_std': np.std(np.concatenate(all_positions, axis=0), axis=0),
        'box_edge_proximity_mean': np.mean(all_edge_proximities),
        'box_edge_proximity_std': np.std(all_edge_proximities)
    }

    return stats

def plot_count_histogram(counts):
    '''Function to plot a histogram with counts per bin and number of occurrences per bin'''
    # There's an issue with how this notebook is rendering numbers, likely
    # because of some of the tensorflow visualization tools modifying HTML/CSS
    # of the notebook. I included this plot as a general reference, but a production
    # implementation should be able to render axis numbers correctly.

    plt.figure(figsize=(10, 6))
    n, bins, patches = plt.hist(counts, bins=20, color='skyblue', edgecolor='black')

    # Set the title and labels
    plt.title('Count of Bounding Boxes per Image')
    plt.xlabel('Count')
    plt.ylabel('Frequency')

    # Set the ticks to be at the edges of the bins
    bin_centers = 0.5 * (bins[1:] + bins[:-1])
    plt.xticks(bin_centers, [f'{bin:.2f}' for bin in bin_centers])
    plt.yticks(range(int(max(n) + 1)))

    # Show the plot with a grid
    plt.show()

print("Statistics based upon normalized coordinates.")
stats = calculate_statistics(train_gt_boxes)
for key, value in stats.items():
  print(f"{key}: {value}")
plot_count_histogram(stats['box_count_per_image'])

## 4. Model Configuration and Training


### 4.1 Creation of Single Class Category Index
A category index for the single class present in the images is created.


### 4.2 Training Data Conversion to Tensors
The training images and ground truth boxes are converted into tensors with one-hot encoding applied to annotations.


In [None]:
# By convention, our non-background classes start counting at 1.  Given
# that we will be predicting just one class, we will therefore assign it a
# `class id` of 1.
cell_class_id = 1
category_index = {cell_class_id: {'id': cell_class_id, 'name': 'cell'}}
num_classes = 1

# Convert class labels to one-hot; convert everything to tensors.
# The `label_id_offset` here shifts all classes by a certain number of indices;
# we do this here so that the model receives one-hot labels where non-background
# classes start counting at the zeroth index.  This is ordinarily just handled
# automatically in our training binaries, but we need to reproduce it here.
label_id_offset = 1
train_image_tensors = []

# lists containing the one-hot encoded classes and ground truth boxes
gt_classes_one_hot_tensors = []
gt_box_tensors = []
for (train_image_np, gt_box_np) in zip(train_images_np, train_gt_boxes):

  # convert training image to tensor, add batch dimension, and add to list
  train_image_tensors.append(tf.expand_dims(tf.convert_to_tensor(
      train_image_np, dtype=tf.float32), axis=0))

  # convert numpy array to tensor, then add to list
  gt_box_tensors.append(tf.convert_to_tensor(gt_box_np, dtype=tf.float32))

  # apply offset to to have zero-indexed ground truth classes
  zero_indexed_groundtruth_classes = tf.convert_to_tensor(
      np.ones(shape=[gt_box_np.shape[0]], dtype=np.int32) - label_id_offset)

  # do one-hot encoding to ground truth classes
  gt_classes_one_hot_tensors.append(tf.one_hot(
      zero_indexed_groundtruth_classes, num_classes))
print('Done prepping data.')

### 4.3 Annotation Plotting, Inspection, and Optional Augmentation
Annotations are plotted for manual inspection to ensure correct object alignment. We are also visualizing the data in an un-augmented state with the option to later introduce augmentations after this section of code. In many cases we can see how the annotations are quite "fuzzy" and do not closely gate the object. We also see many objects that look like cells that are not annotated--with little distinction of why they were excluded. I believe this confounds the learning of the model.

In [None]:
plt.figure(figsize=(30, 15))
indices = np.random.choice(len(train_images_np), 6, replace=False)
print(f"Random indices selected: {indices}")
temp_images = [train_images_np[i] for i in indices]
temp_train_boxes = [train_gt_boxes[i] for i in indices]

# Now you can loop over the temporary lists to display the images
for idx in range(len(temp_images)):
    plt.subplot(2, 3, idx + 1)

    # The number of boxes for the current image
    num_boxes = temp_train_boxes[idx].shape[0]

    # Create an array of ones with a length equal to the number of boxes
    # This assumes each box is to be given a score of 1.0 (or 100%)
    dummy_scores = np.ones(shape=[num_boxes], dtype=np.float32)

    plot_detections(
        temp_images[idx],
        temp_train_boxes[idx],
        np.ones(shape=[num_boxes], dtype=np.int32),  # class IDs for each box
        dummy_scores,  # scores for each box
        category_index)
plt.show()


Expected Output:
> This displays six images with annotations rendered for each image. The annotations are boxes over the cells in the images. Each box has a object label and a accuracy score above it. In this case everything says "cell: 100%"



```
Random indices selected: [57 33  6 11  4 51]
```



#### Optional Image Augmentation
Here we define data augmentation methods and visualize augmented data to ensure bounding boxes align correctly (This code is a work in progress and not currently working. Part of the reason I didn't prioritize it is because of our input data quality adding significant issues to training)

In [None]:
# # Function to flip an image horizontally using TensorFlow
# def tf_flip_image_horizontal(image):
#     return tf.image.flip_left_right(image)

# # Function to rotate an image by 90 degrees clockwise using TensorFlow
# def tf_rotate_image_90(image):
#     return tf.image.rot90(image, k=3)  # k=3 for 90 degrees clockwise rotation

# # Function to adjust bounding boxes for horizontal flip using TensorFlow
# def tf_adjust_boxes_for_flip(boxes, image_width):
#     new_boxes = tf.stack([boxes[:, 0], image_width - boxes[:, 3],
#                           boxes[:, 2], image_width - boxes[:, 1]], axis=1)
#     return new_boxes

# # Function to adjust bounding boxes for 90 degree rotation using TensorFlow
# def tf_adjust_boxes_for_rotation(boxes, image_height):
#     new_boxes = tf.stack([boxes[:, 1], image_height - boxes[:, 2],
#                           boxes[:, 3], image_height - boxes[:, 0]], axis=1)
#     return new_boxes

# # Define a function to apply augmentations
# def apply_augmentations(image, boxes):
#     # Apply horizontal flip
#     flipped_image = tf_flip_image_horizontal(image)
#     flipped_boxes = tf_adjust_boxes_for_flip(boxes, image.shape[1])

#     # Apply 90 degree rotation
#     rotated_image = tf_rotate_image_90(flipped_image)
#     rotated_boxes = tf_adjust_boxes_for_rotation(flipped_boxes, image.shape[0])

#     return rotated_image, rotated_boxes

# # Now augment and visualize some images
# plt.figure(figsize=(30, 15))
# indices = np.random.choice(len(train_images_np), 6, replace=False)

# for idx, image_idx in enumerate(indices):
#     image_np = train_images_np[image_idx]
#     gt_boxes = train_gt_boxes[image_idx]

#     # Apply augmentations
#     augmented_image_np, augmented_boxes = apply_augmentations(image_np, gt_boxes)

#     # Convert to numpy arrays if they are tensors
#     if isinstance(augmented_image_np, tf.Tensor):
#         augmented_image_np = augmented_image_np.numpy()
#     if isinstance(augmented_boxes, tf.Tensor):
#         augmented_boxes = augmented_boxes.numpy()

#     # Normalize box coordinates
#     image_height, image_width, _ = augmented_image_np.shape
#     normalized_boxes = augmented_boxes.copy()
#     normalized_boxes[:, [0, 2]] /= image_height  # y coordinates
#     normalized_boxes[:, [1, 3]] /= image_width   # x coordinates

#     # Check if the boxes are normalized between 0 and 1
#     assert normalized_boxes.min() >= 0.0 and normalized_boxes.max() <= 1.0, \
#         "Box coordinates are not normalized properly."

#     plt.subplot(2, 3, idx + 1)

#     num_boxes = len(normalized_boxes)
#     dummy_scores = np.ones(shape=[num_boxes], dtype=np.float32)  # Dummy scores for visualization

#     # Visualize the augmented image and boxes
#     plot_detections(
#         augmented_image_np,
#         normalized_boxes,
#         np.ones(shape=[num_boxes], dtype=np.int32),  # class IDs for each box
#         dummy_scores,  # scores for each box
#         category_index)

# plt.show()


### 4.4 Model Checkpoint Download and Configuration
The `sd_resnet50_v1_fpn_640x640_coco17_tpu-8` checkpoint is downloaded, and the model configuration file is modified.
#### SSD with ResNet-50 and FPN Model Overview

The `ssd_resnet50_v1_fpn_640x640_coco17_tpu-8` configuration defines an object detection model that synergizes various powerful components to produce an efficient and scalable architecture. This text delves into the individual components of the model and explicates the unique benefits they confer to the object detection domain.

#### Model Architecture Components

##### SSD (Single Shot Multibox Detector):
- Employs a singular neural network to accomplish the dual tasks of object localization and classification in one forward pass, thus enhancing processing speed.

##### ResNet-50 Backbone:
- Acts as the foundational feature extraction component of the model.
- ResNet-50 employs a series of 'residual connections' which effectively allow layers to skip over one another, facilitating the training of deeper networks by circumventing the vanishing gradient dilemma.

##### FPN (Feature Pyramid Network):
- Enhances the SSD framework by constructing a 'multi-scale feature hierarchy' wherein each pyramid level represents a distinct scale, bolstering the detection of objects of disparate sizes.

##### Input Size (640x640):
- The model accepts input images of 640x640 pixels, offering a compromise between capturing image detail and computational tractability.Given our 512x512 image dimensions, this is a good match for our needs.

##### COCO Dataset:
- The COCO dataset encompasses a diverse array of object categories, serving as one of the benchmarks in object detection.

##### Optimized for TPU-8:
- Configuration is specialized for optimal performance on Google's Tensor Processing Units (TPUs) with 8 cores, enabling swift training and inference.

#### Benefits of the Model

##### Speed:
- Intrinsic to the SSD architecture, speed is a hallmark, rendering the model apt for real-time detection scenarios.

##### Accuracy:
- ResNet-50's track record of accurate feature extraction is well-established.
- FPN contributes significantly to the model's precision by improving object detection across scales.

##### Scalability:
- TPU compatibility ensures the model scales efficiently, facilitating the processing of expansive datasets.

##### Versatility:
- The model's proficiency in identifying a broad spectrum of objects is fortified by its comprehensive training on the COCO dataset.

##### Balance of Performance and Resources:
- The chosen input size and the inherent efficiency of the SSD method allow for an optimal balance between performance and computational resource expenditure.

#### Detailed Architecture of ResNet-50

ResNet-50 is a variant of the Residual Network architecture that contains 50 layers. The key innovation in ResNet is the introduction of 'residual connections' which skip one or more layers. Traditional neural networks might suffer from the vanishing gradient problem as they become deeper, making training very deep networks challenging. Residual connections combat this by allowing gradients to flow through the network more easily. Each residual block in a ResNet contains two paths: one where the input is passed through weights (and non-linearities), and another where the input bypasses this transformation and is added to the output of the weighted path. This architecture encourages the network to learn residual functions, which are modifications to the identity mapping, rather than learning unreferenced functions each time.

#### Understanding Multi-Scale Feature Hierarchy

A multi-scale feature hierarchy within the FPN constructs a pyramidal structure where each level of the pyramid corresponds to features extracted at a different scale. This setup enables the model to detect objects at various sizes. In practice, lower levels of the pyramid have higher resolution features which are good for detecting smaller objects, while higher levels have coarser features suitable for identifying larger objects. This multi-scale approach is crucial for handling the inherent scale variability of objects within real-world images.


In [None]:
# Download the checkpoint and put it into models/research/object_detection/test_data/

!wget http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
!tar -xf ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
!mv ssd_resnet50_v1_fpn_640x640_coco17_tpu-8/checkpoint models/research/object_detection/test_data/

Expected Output:


```
--2023-11-09 19:39:49--  http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
Resolving download.tensorflow.org (download.tensorflow.org)... 142.251.12.207, 172.217.194.207, 172.253.118.207, ...
Connecting to download.tensorflow.org (download.tensorflow.org)|142.251.12.207|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 244817203 (233M) [application/x-tar]
Saving to: ‘ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz.1’

ssd_resnet50_v1_fpn 100%[===================>] 233.48M  20.9MB/s    in 12s     

2023-11-09 19:40:01 (19.5 MB/s) - ‘ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz.1’ saved [244817203/244817203]
```



### 4.5 Model Weight Initialization and Freezing
The model is initialized with fake data to prepare for training, with weights frozen except for the heads to enable fine-tuning.


In [None]:
tf.keras.backend.clear_session()

print('Building model and restoring weights for fine-tuning...', flush=True)
num_classes = 1
pipeline_config = 'models/research/object_detection/configs/tf2/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.config'
checkpoint_path = 'models/research/object_detection/test_data/checkpoint/ckpt-0'

# Load pipeline config and build a detection model.
#
# Since we are working off of a COCO architecture which predicts 90
# class slots by default, we override the `num_classes` field here to be just
# one (for our new cell class).
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
model_config.ssd.num_classes = num_classes
model_config.ssd.freeze_batchnorm = True
detection_model = model_builder.build(
      model_config=model_config, is_training=True)

# Set up object-based checkpoint restore --- RetinaNet has two prediction
# `heads` --- one for classification, the other for box regression.  We will
# restore the box regression head but initialize the classification head
# from scratch (we show the omission below by commenting out the line that
# we would add if we wanted to restore both heads)
fake_box_predictor = tf.compat.v2.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
    # _prediction_heads=detection_model._box_predictor._prediction_heads,
    #    (i.e., the classification head that we *will not* restore)
    _box_prediction_head=detection_model._box_predictor._box_prediction_head,
    )
fake_model = tf.compat.v2.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,
          _box_predictor=fake_box_predictor)
ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()

# Run model through a dummy image so that variables are created
image, shapes = detection_model.preprocess(tf.zeros([1, 640, 640, 3]))
prediction_dict = detection_model.predict(image, shapes)
_ = detection_model.postprocess(prediction_dict, shapes)
print('Weights restored!')

## 5. Model Training and Evaluation


In [None]:
def compute_iou(groundtruth_boxes, detection_boxes):
    """
    Computes the IoU between ground truth and detection boxes.

    Args:
        groundtruth_boxes: a Tensor of shape [num_gt_boxes, 4]
        detection_boxes: a Tensor of shape [1, num_detections, 4]
    """
    # Expand the ground truth boxes tensor to match the shape of detection boxes
    gt_boxes = tf.expand_dims(groundtruth_boxes, axis=1)  # Now [num_gt_boxes, 1, 4]
    # Remove the batch dimension from detection boxes since it's not needed here
    detection_boxes = tf.squeeze(detection_boxes, axis=0)  # Now [num_detections, 4]

    # Calculate intersection areas
    gt_ymin, gt_xmin, gt_ymax, gt_xmax = tf.split(gt_boxes, 4, axis=2)
    d_ymin, d_xmin, d_ymax, d_xmax = tf.split(detection_boxes, 4, axis=1)

    inter_xmin = tf.maximum(gt_xmin, d_xmin)
    inter_ymin = tf.maximum(gt_ymin, d_ymin)
    inter_xmax = tf.minimum(gt_xmax, d_xmax)
    inter_ymax = tf.minimum(gt_ymax, d_ymax)

    inter_area = tf.maximum(inter_xmax - inter_xmin, 0) * tf.maximum(inter_ymax - inter_ymin, 0)

    # Calculate union areas
    gt_area = (gt_xmax - gt_xmin) * (gt_ymax - gt_ymin)
    d_area = (d_xmax - d_xmin) * (d_ymax - d_ymin)

    union_area = gt_area + d_area - inter_area

    # Calculate IoU
    iou = inter_area / union_area

    # Squeeze to remove the extra dimension
    iou = tf.squeeze(iou, axis=2)

    return iou


def validate_model_on_batch(detection_model, validation_data, iou_threshold=0.5):
    """
    Validate the model on a batch of validation data.
    """
    iou_scores = []
    for image_tensor, gt_boxes_tensor in validation_data:
        preprocessed_image, shapes = detection_model.preprocess(image_tensor)
        prediction_dict = detection_model.predict(preprocessed_image, shapes)
        detections = detection_model.postprocess(prediction_dict, shapes)

        # Compute IoU
        iou = compute_iou(gt_boxes_tensor, detections['detection_boxes'])

        # Calculate the accuracy based on IoU threshold
        correct_predictions = tf.cast(iou >= iou_threshold, tf.float32)
        accuracy = tf.reduce_mean(correct_predictions)

        iou_scores.append(accuracy)

    # Compute the mean IoU score across all validation data
    mean_iou = tf.reduce_mean(iou_scores)

    return mean_iou.numpy()  # Return the mean IoU as a numpy float

def prepare_validation_tensors(images, boxes):
    image_tensors = []
    box_tensors = []
    for image_np, box_np in zip(images, boxes):
        # Convert to tensor and check if the batch dimension is already present
        image_tensor = tf.convert_to_tensor(image_np, dtype=tf.float32)
        box_tensor = tf.convert_to_tensor(box_np, dtype=tf.float32)
        if image_tensor.shape[0] != 1:
            image_tensor = tf.expand_dims(image_tensor, axis=0)
        image_tensors.append(image_tensor)
        box_tensors.append(box_tensor)
    return image_tensors, box_tensors

# Set up forward + backward pass for a single train step.
def get_model_train_step_function(model, optimizer, vars_to_fine_tune):
  """Get a tf.function for training step."""

  # Use tf.function for a bit of speed.
  # Comment out the tf.function decorator if you want the inside of the
  # function to run eagerly.
  #@tf.function
  def train_step_fn(image_tensors,
                    groundtruth_boxes_list,
                    groundtruth_classes_list):
    """A single training iteration.

    Args:
      image_tensors: A list of [1, height, width, 3] Tensor of type tf.float32.
        Note that the height and width can vary across images, as they are
        reshaped within this function to be 640x640.
      groundtruth_boxes_list: A list of Tensors of shape [N_i, 4] with type
        tf.float32 representing groundtruth boxes for each image in the batch.
      groundtruth_classes_list: A list of Tensors of shape [N_i, num_classes]
        with type tf.float32 representing groundtruth boxes for each image in
        the batch.

    Returns:
      A scalar tensor representing the total loss for the input batch.
    """
    shapes = tf.constant(batch_size * [[640, 640, 3]], dtype=tf.int32)
    model.provide_groundtruth(
        groundtruth_boxes_list=groundtruth_boxes_list,
        groundtruth_classes_list=groundtruth_classes_list)
    with tf.GradientTape() as tape:
      preprocessed_images = tf.concat(
          [detection_model.preprocess(image_tensor)[0]
           for image_tensor in image_tensors], axis=0)
      prediction_dict = model.predict(preprocessed_images, shapes)
      losses_dict = model.loss(prediction_dict, shapes)
      total_loss = losses_dict['Loss/localization_loss'] + losses_dict['Loss/classification_loss']
      gradients = tape.gradient(total_loss, vars_to_fine_tune)
      optimizer.apply_gradients(zip(gradients, vars_to_fine_tune))
    return total_loss

  return train_step_fn

### 5.1 Test Data Tensor Preparation
Test data is loaded as tensors to provide validation metrics during model training.


### 5.2 Loading Pre-existing Checkpoints
Checkpoints are loaded if available to continue training from a previous state.

### 5.3 Model Training and Performance Metrics
The training process includes forward and backward passes with calculation of loss and validation accuracy to monitor performance.


In [None]:
tf.keras.backend.set_learning_phase(True)

# Select variables in top layers to fine-tune.
trainable_variables = detection_model.trainable_variables
to_fine_tune = []
prefixes_to_train = [
  'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalBoxHead',
  'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalClassHead']
for var in trainable_variables:
  if any([var.name.startswith(prefix) for prefix in prefixes_to_train]):
    to_fine_tune.append(var)

optimizer = tf.keras.optimizers.Adam()
train_step_fn = get_model_train_step_function(
    detection_model, optimizer, to_fine_tune)

# Use the function to prepare the tensors
test_image_tensors, test_gt_box_tensors = prepare_validation_tensors(test_images_np, test_gt_boxes)
validation_tensors = list(zip(test_image_tensors, test_gt_box_tensors))

checkpoint_dir = './checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint = tf.train.Checkpoint(model=detection_model, optimizer=optimizer)
manager = tf.train.CheckpointManager(checkpoint, directory=checkpoint_dir, max_to_keep=3)

# Attempt to restore from the latest checkpoint
status = checkpoint.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

# These parameters can be tuned;
batch_size = 4
learning_rate = 0.01
num_batches = 100
iou_threshold = 0.75

print('Start fine-tuning!', flush=True)
for idx in range(num_batches):
  # Grab keys for a random subset of examples
  all_keys = list(range(len(train_images_np)))
  random.shuffle(all_keys)
  example_keys = all_keys[:batch_size]

  gt_boxes_list = [gt_box_tensors[key] for key in example_keys]
  gt_classes_list = [gt_classes_one_hot_tensors[key] for key in example_keys]
  image_tensors = [train_image_tensors[key] for key in example_keys]

  # Training step (forward pass + backwards pass)
  total_loss = train_step_fn(image_tensors, gt_boxes_list, gt_classes_list)
  if idx % 10 == 0:
    total_loss = train_step_fn(image_tensors, gt_boxes_list, gt_classes_list)
    print(f'Batch {idx} of {num_batches}, loss={total_loss.numpy():.4f}', flush=True)
  if idx % 50 == 0:
          save_path = manager.save()
          print("Saved checkpoint for batch {}: {}".format(idx, save_path))
          validation_accuracy = validate_model_on_batch(detection_model, validation_tensors, iou_threshold)
          print(f'Batch {idx}, Validation Accuracy: {validation_accuracy:.2%}', flush=True)

# Perform a final validation step after training is complete
print('Final validation on test data...', flush=True)
validation_accuracy = validate_model_on_batch(detection_model, validation_tensors, iou_threshold)
print(f'Final Validation Accuracy: {validation_accuracy:.2%}')

# Save the final checkpoint after training
save_path = manager.save()
print(f"Saved final checkpoint: {save_path}")

print('Done fine-tuning!')


### 5.4 Model Detection Inspection
Detections made by the model are manually plotted and inspected for accuracy and precision to assess the model's real-world performance.

In [None]:
import matplotlib.patches as patches

def plot_predictions_and_ground_truth(image, gt_boxes, pred_boxes, pred_scores, threshold=0.35):
    """
    Plots the unannotated and annotated images side by side.

    Args:
    image (ndarray): The image on which to plot the boxes.
    gt_boxes (ndarray): The ground truth boxes.
    pred_boxes (ndarray): The predicted boxes.
    pred_scores (ndarray): The confidence scores for the predicted boxes.
    threshold (float): The score threshold to consider for plotting predictions.
    """

    # Normalize image if necessary
    if image.dtype == np.float32 and image.max() > 1.0:
        image /= 255.0
    elif image.dtype == np.uint8:
        image = image.astype(np.float32) / 255.0

    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 9), sharex=True, sharey=True)

    # Display the unannotated image on the first subplot
    ax1.imshow(image)
    ax1.set_title('Unannotated Image')
    ax1.axis('off')  # Turn off the axis

    # Display the annotated image on the second subplot
    ax2.imshow(image)
    ax2.set_title('Annotated Image')

    # Plot ground truth boxes and predicted boxes on the annotated image
    for box in gt_boxes:
        ymin, xmin, ymax, xmax = box
        rect = patches.Rectangle((xmin*image.shape[1], ymin*image.shape[0]),
                                 (xmax-xmin)*image.shape[1],
                                 (ymax-ymin)*image.shape[0],
                                 linewidth=2,
                                 edgecolor='purple',
                                 facecolor='none',
                                 label='Ground Truth')
        ax2.add_patch(rect)

    for box, score in zip(pred_boxes, pred_scores):
        if score > threshold:
            ymin, xmin, ymax, xmax = box
            rect = patches.Rectangle((xmin*image.shape[1], ymin*image.shape[0]),
                                     (xmax-xmin)*image.shape[1],
                                     (ymax-ymin)*image.shape[0],
                                     linewidth=2,
                                     edgecolor='red',
                                     facecolor='none',
                                     label='Predicted')
            ax2.add_patch(rect)

    # Create legend for the annotated image
    handles, labels = ax2.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))  # Remove duplicates
    ax2.legend(by_label.values(), by_label.keys())

    plt.tight_layout()
    plt.show()


In [None]:
import numpy as np
import tensorflow as tf

# Sample a random image to inspect
idx = np.random.choice(len(test_image_tensors))

# Select the first image tensor and its annotations for testing
test_image_tensor = validation_tensors[idx][0]
test_gt_box_tensor = validation_tensors[idx][1]

# Run detection on the single image tensor
# @tf.function # Uncomment to run eagerly
def detect(input_tensor):
    """Run detection on an input image tensor."""
    preprocessed_image, shapes = detection_model.preprocess(input_tensor)
    prediction_dict = detection_model.predict(preprocessed_image, shapes)
    detections = detection_model.postprocess(prediction_dict, shapes)
    return detections, prediction_dict

# Perform detection on the image
detections, prediction_dict = detect(test_image_tensor)

# Extract detection results
detection_boxes = detections['detection_boxes'][0].numpy()
detection_classes = detections['detection_classes'][0].numpy().astype(np.uint32)
detection_scores = detections['detection_scores'][0].numpy()

# # Print shapes and contents of the annotations and predictions for understanding
# print("Ground truth boxes shape:", test_gt_box_tensor.shape)
# print("Ground truth boxes content:", test_gt_box_tensor.numpy())

# print("Prediction dict keys:", prediction_dict.keys())
# print("Detections shape (boxes, classes, scores):",
#       detection_boxes.shape, detection_classes.shape, detection_scores.shape)
# print("Detections content (boxes):", detection_boxes)
# print("Detections content (classes):", detection_classes)
# print("Detections content (scores):", detection_scores)

# Validate the model
iou_threshold = 0.5
mean_iou = validate_model_on_batch(detection_model, validation_tensors, iou_threshold)
print(f"Mean IoU score on validation batch: {mean_iou:.4f}")

# Filter detections based on the score threshold
score_threshold = 0.20
selected_indices = np.where(detection_scores > score_threshold)[0]
selected_boxes = detection_boxes[selected_indices]
selected_scores = detection_scores[selected_indices]

# Squeeze the batch dimension from the image tensor before plotting
squeezed_image = np.squeeze(test_image_tensor.numpy())

plot_predictions_and_ground_truth(
    image=squeezed_image,
    gt_boxes=test_gt_box_tensor.numpy(),
    pred_boxes=selected_boxes,
    pred_scores=selected_scores,
    threshold=score_threshold
)

## Appendix

### A. Unit Tests

In [None]:
import tensorflow as tf
import unittest

# Function to flip an image horizontally using TensorFlow
def tf_flip_image_horizontal(image):
    return tf.image.flip_left_right(image)

# Function to rotate an image by 90 degrees clockwise using TensorFlow
def tf_rotate_image_90(image):
    return tf.image.rot90(image, k=3)  # k=3 for 90 degrees clockwise rotation

# Function to adjust bounding boxes for horizontal flip using TensorFlow
def tf_adjust_boxes_for_flip(boxes, image_width):
    new_boxes = tf.stack([boxes[:, 0], image_width - boxes[:, 3],
                          boxes[:, 2], image_width - boxes[:, 1]], axis=1)
    return new_boxes

# Function to adjust bounding boxes for 90 degree rotation using TensorFlow
def tf_adjust_boxes_for_rotation(boxes, image_height):
    new_boxes = tf.stack([boxes[:, 1], image_height - boxes[:, 2],
                          boxes[:, 3], image_height - boxes[:, 0]], axis=1)
    return new_boxes

# TensorFlow unit test class
class TFImageAugmentationTest(tf.test.TestCase):
    def test_image_augmentation(self):
        # Create a mock image and bounding boxes
        image = tf.zeros((640, 640, 3), dtype=tf.uint8)
        image = image + tf.cast(tf.stack([
            tf.pad(tf.ones((100, 100, 3), dtype=tf.uint8) * 255, [[100, 440], [100, 440], [0, 0]]),
            tf.pad(tf.ones((100, 100, 3), dtype=tf.uint8) * 255, [[400, 140], [300, 240], [0, 0]])
        ]), tf.uint8)
        boxes = tf.constant([[100, 100, 200, 200], [400, 300, 500, 400]], dtype=tf.float32)

        image_width, image_height = image.shape[1], image.shape[0]

        # Apply horizontal flip
        flipped_image = tf_flip_image_horizontal(image)
        flipped_boxes = tf_adjust_boxes_for_flip(boxes, image_width)

        # Apply 90 degree rotation
        rotated_image = tf_rotate_image_90(flipped_image)
        rotated_boxes = tf_adjust_boxes_for_rotation(flipped_boxes, image_height)

        # Check if the boxes still correctly identify the objects
        for box in rotated_boxes:
            ymin, xmin, ymax, xmax = tf.cast(box, tf.int32).numpy()
            object_area = rotated_image[ymin:ymax, xmin:xmax]
            # Assert that the object area is all white (255)
            self.assertAllEqual(tf.reduce_all(tf.equal(object_area, 255)), True)

        print("Test passed successfully!")

# Instantiate the test case and run the test method
test_case = TFImageAugmentationTest()
test_case.test_image_augmentation()

