In [2]:
"""
This notebook is used to do an initial pass at the wing segmentation in the card images.
It does this by:
 1. Matching the image histogram to a reference image to help with thresholding
 2. Blurring the image
 3. Thresholding the image
 4. Finding the large contour that is closest to the center of the image (we assume this is the wing)
 5. Using a smaller blur and thresholding the contour to that (so we get the find edges more precisely)
 6. Saving the mask and segmentations

Once this notebook was done running, it was necessary to go through the images, check if the segmentation worked, and 
manually correct many of the segmentations. These manual corrections are saved in '2_card_mask_manual_corrections'
and combined with the masks from this notebook into the folder '2_final_masks'
"""

"\nThis notebook is used to do an initial pass at the wing segmentation in the card images.\nIt does this by:\n 1. Matching the image histogram to a reference image to help with thresholding\n 2. Blurring the image\n 3. Thresholding the image\n 4. Finding the large contour that is closest to the center of the image (we assume this is the wing)\n 5. Using a smaller blur and thresholding the contour to that (so we get the find edges more precisely)\n 6. Saving the mask and segmentations\n\nOnce this notebook was done running, it was necessary to go through the images, check if the segmentation worked, and \nmanually correct many of the segmentations. These manual corrections are saved in '2_card_mask_manual_corrections'\nand combined with the masks from this notebook into the folder '2_final_masks'\n"

In [1]:
import os
import cv2
import json
import shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from glob import glob
from tqdm import tqdm
from pathlib import Path
from copy import deepcopy

from scipy import ndimage
from skimage.measure import label

from numpy import linalg
from scipy.sparse.linalg import eigs

from skimage.exposure import match_histograms

from utils import is_wing_facing_up, segment_contour

In [5]:
#for fp in glob('../2_live_bees/2_card_segs/*'):
#    os.remove(fp)
#for fp in glob('../2_live_bees/2_card_masks/*'):
#    os.remove(fp)
#for fp in glob('../2_live_bees/2_card_segs_and_orig/*'):
#    os.remove(fp)

In [13]:
DEBUG = False

scale_factor = 2
blur_kernel_size = 11
img_fps = sorted(glob('../2_live_bees/1_cards/*'))
np.random.shuffle(img_fps)
reference = cv2.imread('../2_live_bees/reference_card_2024_06_13_h05b78.png')
reference = cv2.imread('../2_live_bees/reference_card_2024_06_06_h05bee62.png')
subset = []


for img_fp in tqdm(img_fps):
    if DEBUG:
        img_fp = '../2_live_bees/1_cards/2024_07_24_h34b11.png'#img_fps[5]
        #img_fp = img_fps[0]

    fn = img_fp.split('/')[-1]
    if len(subset):
        if fn not in subset:
            continue
    date = '_'.join(fn.split('_')[:-1])
    bee_id = fn.split('_')[-1].split('.')[0]


    threshold = 160

    
    img = cv2.imread(img_fp)

    matched = match_histograms(img, reference, channel_axis=-1)
    
    if DEBUG:
        plt.figure()
        plt.imshow(img)
        plt.title('img')
    if DEBUG:
        plt.figure()
        plt.imshow(matched)
        plt.title('histogram matched')
    
    crop_top = 300//scale_factor
    crop_bott = 300//scale_factor
    crop_right = 50//scale_factor
    crop_left = 50//scale_factor

    scaled_img = matched[::scale_factor,::scale_factor] # crude downsampling
    scaled_img = scaled_img[crop_top:scaled_img.shape[0] - crop_bott, crop_left:scaled_img.shape[1]-crop_right]

    if DEBUG:
        plt.figure()
        plt.imshow(scaled_img)
        plt.title('downscaled')
    y_start_hist = 1
    y_end_hist = scaled_img.shape[0] - 1

    pixels_below_thres = np.where(scaled_img.mean(axis=2).mean(axis=1) < 170)[0]
    
    while np.in1d(y_start_hist, pixels_below_thres):
        y_start_hist += 1
    while np.in1d(y_end_hist, pixels_below_thres):
        y_end_hist -= 1
        
    crop_top += y_start_hist
    crop_bott += scaled_img.shape[0] - y_end_hist
    scaled_img = scaled_img[y_start_hist:y_end_hist]
    
    blurred = cv2.blur(scaled_img, (blur_kernel_size, blur_kernel_size))

    if DEBUG:
        plt.figure()
        plt.imshow(blurred)
        plt.title('blurred')
    blurred_gray = blurred[:,:,0]#cv2.cvtColor(blurred, cv2.COLOR_BGR2GRAY)
    if '2024_06_06_h05bee50' in img_fp:
        blurred_gray[:,:10] = 255
    thres = blurred_gray < threshold

    if DEBUG:
        plt.figure()
        plt.imshow(thres)
        plt.title('thresholded')


    closed = ndimage.binary_closing(thres, iterations=3).astype('uint8')*255
    
    contours,hierarchy = cv2.findContours(closed, 1, 2)
    four_biggest_blobs = []
    i=0
    
    # encountered an error that it would detect dark edges sometimes instead, so just find a blob
    # with a reasonable aspect ratio
    area_sorted_indices = np.argsort([cv2.contourArea(x) for x in contours])
    while (len(four_biggest_blobs) < 4) and (i<10):
        if i >= len(area_sorted_indices):
            break
        blob_index = area_sorted_indices[-i]
        i+=1
        x,y,w,h = cv2.boundingRect(contours[blob_index])
        blob_area = cv2.contourArea(contours[blob_index])
    
        if w/h > 5 or h/w > 5:
            pass
        elif blob_area < 10000:
            pass
        else:
            four_biggest_blobs += [blob_index]


    # find blob closest to center
    min_dist = np.inf
    img_center = (thres.shape[0]//2, thres.shape[1]//2)
    closest_blob = np.nan
    for blob_idx in four_biggest_blobs:
        blob = contours[blob_idx]
        seg, mask = segment_contour(scaled_img, blob)
    
        
        mask_center = np.array(np.where(mask[:,:,0] > 0.5)).mean(axis=1)
        dist_to_center = ((mask_center - img_center)**2).mean()
        if dist_to_center < min_dist:
            min_dist = dist_to_center
            wing_center = mask_center*scale_factor
            closest_blob_mask = mask[:,:,0]
            closest_blob_seg = seg

    # decide if the wing is upside down based on vein locations
    mask_center = np.array(np.where(closest_blob_mask > 0.5)).mean(axis=1)
    veins = (closest_blob_seg[:,:,0] < 50).astype('uint8')*255
    if DEBUG:
        plt.figure()
        plt.imshow(veins)
        plt.title('veins')
    veins_center = np.array(np.where(veins > 0.5)).mean(axis=1)
    
    if mask_center[0] < veins_center[0]:
        flip = True
    else:
        flip = False


    closest_blob_mask = np.pad(closest_blob_mask,((crop_top, crop_bott),(crop_left, crop_right)))
    if flip:
        closest_blob_mask = np.flipud(closest_blob_mask)
        closest_blob_seg = np.flipud(closest_blob_seg)
        scaled_img = np.flipud(scaled_img)



    if DEBUG:
        plt.figure()
        plt.imshow(closest_blob_mask)
        plt.title('wing detected')
    
    original_mask = cv2.resize(closest_blob_mask, (img.shape[1],img.shape[0]))
    original_mask = (original_mask > 0.5).astype('uint8')
    if flip:
        original_mask = np.flipud(original_mask)
    
    if DEBUG:
        plt.figure()
        plt.imshow(original_mask)
        plt.title('mask in original orientation')
    
    seg = np.ones(img.shape)*255
    seg[np.where(original_mask > 0)] = img[np.where(original_mask>0)]

    # now we tighten the borders on the contour a bit
    a = cv2.cvtColor(seg.astype('uint8'), cv2.COLOR_BGR2GRAY)
    a = cv2.blur(a, (3,3))
    thres2 = (a < 220).astype('uint8')*255

    
    if DEBUG:
        plt.figure()
        plt.imshow(thres2)
        plt.title('Tightened Threshold')
    
    contours,hierarchy = cv2.findContours(thres2, 1, 2)
    area_sorted_indices = np.argsort([cv2.contourArea(x) for x in contours])
    biggest_contour_index = area_sorted_indices[-1]
    biggest_contour = contours[biggest_contour_index]
    seg2 = np.zeros(seg.shape[:2]).astype('uint8')
    seg2,mask2 = segment_contour(img, biggest_contour)

    if DEBUG:
        plt.figure(figsize=(20,20))
        plt.imshow(seg2)
    
    
    seg_and_orig = np.ones((img.shape[0], img.shape[1]*2, img.shape[2]))*255
    seg_and_orig[np.where(mask2 > 0)] = img[np.where(mask2>0)]
    seg_and_orig[:,img.shape[1]:] = img

    mask_fp = f'../2_live_bees/2_card_masks/{date}_{bee_id}.png'
    seg_fp = f'../2_live_bees/2_card_segs/{date}_{bee_id}.png'
    seg_and_orig_fp = f'../2_live_bees/2_card_segs_and_orig/{date}_{bee_id}.png'
    
    cv2.imwrite(mask_fp, mask2*255)
    cv2.imwrite(seg_fp, seg2)
    cv2.imwrite(seg_and_orig_fp, seg_and_orig)
    

    if DEBUG:
        break

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1193/1193 [00:46<00:00, 25.79it/s]


In [12]:
import shutil
for fn in seen:
    if fn not in problem_wings:
        old_fp = f'../2_live_bees/2_card_masks/{fn}'
        new_fn = fn.replace('_L.png','.png').replace('_R.png','.png')
        new_fp = f'../2_live_bees/2_card_masks_corrected/{new_fn}'
        #shutil.copyfile(old_fp, new_fp)

In [14]:
import shutil
for fn in seen:
    old_fp = f'../2_live_bees/2_card_masks/{fn}'
    new_fn = fn.replace('_L.png','.png').replace('_R.png','.png')
    new_fp = f'../2_live_bees/2_card_masks/{new_fn}'
    if os.path.exists(old_fp):
        pass#shutil.move(old_fp, new_fp)

In [2]:
# Create final masks folder based on manual annotation if available, or mask if not

img_fps = sorted(glob('../2_live_bees/1_cards/*'))


for img_fp in tqdm(img_fps):
    img_fp = Path(img_fp)
    fn = img_fp.name
    
    mask_fp = '../2_live_bees/2_card_masks/' + fn
    manual_mask_fp = '../2_live_bees/2_card_mask_manual_corrections/' + fn
    final_fp = '../2_live_bees/2_final_masks/' + fn

    if os.path.exists(manual_mask_fp):
        shutil.copyfile(manual_mask_fp, final_fp)
    else:
        shutil.copyfile(mask_fp, final_fp)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1193/1193 [00:00<00:00, 3264.09it/s]


In [3]:
# Remake segmentations including final masks

img_fps = sorted(glob('../2_live_bees/1_cards/*'))


for img_fp in tqdm(img_fps):
    img_fp = Path(img_fp)
    fn = img_fp.name
    
    mask_fp = '../2_live_bees/2_final_masks/' + fn
    seg_fp = '../2_live_bees/2_card_segs/' + fn
    seg_and_orig_fp = '../2_live_bees/2_card_segs_and_orig/' + fn

    img = cv2.imread(img_fp)
    mask = cv2.imread(mask_fp, cv2.IMREAD_GRAYSCALE)
    seg = np.ones(img.shape).astype('uint8')*255
    seg[np.where(mask > 0.5)] = img[np.where(mask > 0.5)]
    
    seg_and_orig = np.ones((img.shape[0], img.shape[1]*2, img.shape[2]))*255
    seg_and_orig[np.where(mask > 0)] = img[np.where(mask>0)]
    seg_and_orig[:,img.shape[1]:] = img

    
    cv2.imwrite(seg_and_orig_fp, seg_and_orig)
    cv2.imwrite(seg_fp, seg)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1193/1193 [14:50<00:00,  1.34it/s]
