In [1]:
import os
import cv2
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from glob import glob
from scipy import ndimage
from copy import deepcopy

from utils import place_in_img, is_wing_facing_up

In [2]:
def get_center(mask):
    if len(mask.shape) == 2:
        return np.array(np.where(mask > 0.5)).mean(axis=1).astype('int')
    else:
        return np.array(np.where(mask.mean(axis=2) > 0.5)).mean(axis=1).astype('int')

In [3]:

#next_moving_mask = place_in_img(next_moving_mask, ref, next_pos, next_scale, next_angle)


In [4]:
def iou(mask, moving_mask):
    """
    Assumes fixed mask is binary (0s and 1s)
    and moving mask has values of 0, 1, and 2
    """
    #intersection = len(np.where(moving_mask + mask > 1.001)[0])
    #union = len(np.where(moving_mask > 0.5)[0]) + len(np.where(mask > 0.5)[0]) 
    i_mask = moving_mask * mask
    intersection = i_mask.sum()
    #union = next_moving_mask.sum() + mask.sum()
    
    #union = (np.bitwise_or(mask>0.1, moving_mask>0.1)*moving_mask).sum()
    union = (np.bitwise_or(mask>0.1, moving_mask>0.1).astype('int') + (moving_mask > 1.1).astype('int')).sum()

    return intersection / union

In [11]:
FIX_PROBLEM_WINGS = False
DEBUG = False

n_iter = 100000

scale_factor = 10
max_pos_step = 30
max_angle_step = 20
max_scale_step = 0.5

ref_fp = '../2_live_bees/type_mask_Hive01_Sheet_01_slide17_left.png'

mask_fps = sorted(glob('../2_live_bees/2_final_masks/*'))
np.random.shuffle(mask_fps)

if FIX_PROBLEM_WINGS:
    mask_fps = ['../2_live_bees/2_final_masks/'+ x for x in problem_wings]

for mask_fp in tqdm(mask_fps):
    if DEBUG:
        mask_fp = '../2_live_bees/2_final_masks/2024_06_10_h02b01.png'
        #mask_fp = np.random.choice(mask_fps)
    mask_fn = mask_fp.split('/')[-1]
    seg_fp = '../2_live_bees/2_card_segs/' + mask_fn
    mask = cv2.imread(mask_fp, cv2.IMREAD_GRAYSCALE)
    

    sy, sx = np.array(np.where(mask > 0.5)).min(axis=1)
    ey, ex = np.array(np.where(mask > 0.5)).max(axis=1)
    
    mask = mask / mask.max()
    orig_shape = mask.shape
    seg = cv2.imread(seg_fp)

    _id = mask_fn.split('.')[0]
    metadata_fp = f'../2_live_bees/3_card_mask_matches_metadata/{_id}.json'

    if os.path.exists(metadata_fp) and not DEBUG and not FIX_PROBLEM_WINGS:
        continue
    
    ref = cv2.imread(ref_fp, cv2.IMREAD_GRAYSCALE)
    ref = ref / ref.max()

    is_facing_up = is_wing_facing_up(seg, mask)
    if mask_fn in ['2024_06_27_h01b21.png','2024_07_02_h13b29.png', '2024_06_07_h02bee19.png']:
        is_facing_up = True
        
    if not is_facing_up:
        ref = np.flipud(ref)
    
    mask = mask[::scale_factor,::scale_factor]
    ref = ref[::scale_factor,::scale_factor]
    ref[:,ref.shape[1]//2:] = ref[:,ref.shape[1]//2:]*2 # scale last half higher

    mask_center = get_center(mask)
    
    # initial guess
    next_moving_mask = np.zeros(mask.shape)
    next_moving_mask = place_in_img(next_moving_mask, ref, mask_center, ref_scale = 2)
    #plt.imshow(next_moving_mask)
    

    choices = ['walk', 'rotate', 'scale']


    pos = np.array(mask_center)
    angle = 0
    scale = 2
    
    
    score = iou(mask, next_moving_mask)
    scores = [score]
    res = []
    metadata = {}
    
    for i in range(n_iter):
        next_pos = deepcopy(pos)
        next_angle = deepcopy(angle)
        next_scale = deepcopy(scale)
        
        #print(f'Iter {i}')
        move = np.random.choice(choices)
        if move == 'walk':
            x_walk = np.random.randint(-max_pos_step,max_pos_step)
            y_walk = np.random.randint(-max_pos_step,max_pos_step)
            next_pos = pos + np.array([y_walk, x_walk])
            #print('  move: walk ', y_walk, x_walk)
        elif move == 'rotate':
            angle_step = np.random.uniform(-max_angle_step,max_angle_step)
            next_angle = angle + angle_step
            #print('  move: rotate ', angle_step)
        elif move == 'scale':
            scale_step = np.random.uniform(-max_scale_step,max_scale_step)
            next_scale = scale + scale_step
            #print('  move: scale ', scale_step)
    
        if next_scale < 1.5:
            #print('hit wall, scale reset to 1.5', next_scale)
            next_scale = 1.5
        if next_scale > 10:
            #print('hit wall, scale reset to 10', next_scale)
            next_scale = 10
            
        if next_angle > 45:
            #print('hit wall, angle reset to 45', next_angle)
            next_angle = 45 
        if next_angle < -45:
            #print('hit wall, angle reset to -45', next_angle)
            next_angle = -45
    
    
        next_moving_mask = np.zeros(mask.shape)
        next_moving_mask = place_in_img(next_moving_mask, ref, next_pos, next_scale, next_angle)
    
        next_score = iou(mask, next_moving_mask)
    
        
        # choose whether to accept the next moving mask or not - some kind of pseudo metropolis hasting algo
        acceptance_ratio = next_score/score
        uniform_test = np.random.uniform()
        #print('Acceptance Ratio', acceptance_ratio, ', Score: ', score)
        if acceptance_ratio > 1:
            score = next_score
            angle = next_angle
            pos = next_pos
            scale = next_scale
            #if i % 10 == 0:
            if DEBUG:
                plt.figure()
                plt.imshow(mask + next_moving_mask)
                plt.title(f'Accepted Iter {i} score:{score}')
            
            #print('Accepted Move', move, pos, angle, scale, score)
            scores += [score]
            res += [[pos, scale, angle, score]]
        else:
            #plt.title(f'Rejected Iter {i} score:{score}')
            #print('Rejected Move', move, pos, angle, scale, score)
            continue


    next_moving_mask = np.zeros(mask.shape)
    next_moving_mask = place_in_img(next_moving_mask, ref, pos, scale, angle)
    
    out_fp = '../2_live_bees/3_card_mask_matches/' + mask_fn
    out_img = cv2.resize(mask+next_moving_mask, orig_shape[::-1])*75
    cv2.imwrite(out_fp, out_img)

    
    
    metadata['registration_ref'] = ref_fp
    metadata['registration_pos'] = list(map(int,list(pos*scale_factor)))
    metadata['registration_angle'] = angle
    metadata['registration_scale'] = scale
    metadata['registration_score'] = score
    metadata['is_facing_up?'] = is_facing_up
    metadata['wing_side'] = 'right' if is_facing_up else 'left'
    metadata['crop_y'] = [int(sy), int(ey)]
    metadata['crop_x'] = [int(sx), int(ex)]
    
    with open(metadata_fp, 'w') as f:
        json.dump(metadata, f)
    
    if DEBUG:
        break

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [06:09<00:00, 123.33s/it]


In [10]:
problem_wings = ['2024_06_18_h11b15.png', '2024_07_02_h23b11.png', '2024_07_16_h32b18.png']




In [5]:
# Make crops
fps = glob('../2_live_bees/2_card_segs/*')
for fp in tqdm(fps):
    fn = fp.split('/')[-1]
    wing_name = fn.split('.')[0]

    metadata_fp = '../2_live_bees/3_card_mask_matches_metadata/' + wing_name + '.json'
    with open(metadata_fp,'r') as f:
        metadata = json.load(f)
    metadata_fp = '../2_live_bees/6_fine_registered_wings_metadata/' + wing_name + '.json'
    if not os.path.exists(metadata_fp):
        continue
    with open(metadata_fp,'r') as f:
        metadata.update(json.load(f))

    img = cv2.imread(fp)

    sy,ey = metadata['crop_y']
    sx,ex = metadata['crop_x']
    cropped = img[sy:ey, sx:ex]
    if not metadata['is_facing_up?']:
        cropped = np.flipud(cropped)
    
    cropped_fp = '../2_live_bees/6_cropped_and_flipped/' + fn
    cv2.imwrite(cropped_fp, cropped)


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1193/1193 [01:31<00:00, 13.02it/s]
