In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

In [2]:
# function from segment anything. Used to plot masks
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    polygons = []
    color = []
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.35)))

# Mask generation

In [3]:
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

# device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to()

mask_generator = SamAutomaticMaskGenerator(sam)

In [4]:
mask_generator_2 = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=8,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.9,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=300000, # clean up small holes
)

In [5]:
# files = ['Cannabis_seeds/image288.jpg', 'Cannabis_seeds/image821.jpg']

In [6]:
brown = np.array([93.91798457472527, 74.24539142925545, 44.09861938494327])

In [7]:
def closest(colors):
    colors = np.array(colors)
    distances = np.sqrt(np.sum((colors-brown)**2,axis=1))
    index_of_smallest = np.where(distances==np.amin(distances))
    return index_of_smallest[0][0]

In [8]:
def bm1(mask1, mask2):
    mask1_area = np.count_nonzero( mask1 )
    mask2_area = np.count_nonzero( mask2 )
    intersection = np.count_nonzero( np.logical_and( mask1, mask2 ) )
    iou = intersection/(mask1_area+mask2_area-intersection)
    return intersection

In [None]:
import os

directory = 'tests'

for entry in os.listdir(directory):
    f = os.path.join(directory, entry)
    print(entry)

    image = cv2.imread(f)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    masks2 = mask_generator_2.generate(image)
    print("masked")
    
    areas_list = [masks2[i]['area'] for i in range(len(masks2))]
    
    big_areas = [i for i,v in enumerate(areas_list) if v > 40000]
    print("found big areas")
    
    list_of_colors = []

    for i in big_areas:
        mask_numbers = masks2[i]['segmentation']*1
        mask_3d = np.stack((mask_numbers,mask_numbers,mask_numbers),axis=2).astype('uint8')
        masked_arr = image*mask_3d

        channel0 = masked_arr[:,:,0]
        channel1 = masked_arr[:,:,1]
        channel2 = masked_arr[:,:,2]

        average0 = channel0[np.nonzero(channel0)].mean()
        average1 = channel1[np.nonzero(channel1)].mean()
        average2 = channel2[np.nonzero(channel2)].mean()

        average_loop = np.array([average0, average1, average2])

        list_of_colors.append(list(average_loop))
    print("computed average colors")

    closest_color = closest(list_of_colors)
    seed_mask_number = big_areas[closest_color]
    print("got seed number")
    
    for j in big_areas:
        if j == seed_mask_number:
            pass
        intersection_size = bm1(masks2[seed_mask_number]['segmentation'], masks2[j]['segmentation'])
        
        if abs(intersection_size - masks2[j]['area']) < 500:
            seed_mask_number = j
    
    mask_numbers = masks2[seed_mask_number]['segmentation']*1
    mask_3d = np.stack((mask_numbers,mask_numbers,mask_numbers),axis=2).astype('uint8')
    masked_arr = image*mask_3d
    print("got masked array")
    
    # Mask input image with binary mask
    result = cv2.bitwise_and(image, masked_arr)
    # Color background white
    #result[masked_arr[:,:,3]==0] = 255 # Optional
    result[np.where((masked_arr==[0,0,0]).all(axis=2))] = [255,255,255]
    
    filename, extension = os.path.splitext(entry)

    # cv2.imshow('image', image)
    # cv2.imshow('mask', mask)
    # cv2.imshow('result', result)
    # cv2.waitKey()
    #plt.imshow(cv2.cvtColor(result, cv2.BGR2RGB))
    plt.imshow(result)
    # as opencv loads in BGR format by default, we want to show it in RGB.
    # plt.show()
    # plt.draw()
    plt.savefig('outputs/'+str(filename)+'.png')

image236.jpg
