## Turns bounding boxes into segmented images, using the sam model segment predictions. 
##### (Warning: Just a proof of concept. In order to be able to execute this notebook the open fishery dataset is required) 

# Setup

In [1]:
import os, json, cv2, torch, numpy as np, pandas as pd, matplotlib.pyplot as plt
from tqdm.notebook import tqdm

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
def process_df(df, count_per_label=20):
    # only include rows that are fish labels
    df = df[df['label_l2'].isin(['YFT', 'ALB', 'OTH', 'BILL', 'DOL', 'BET', 'SKJ'])].copy()
    df['img_cat'] = df['img_id'].str[:3].map({'94a': 'A', '94b': 'B', '94c': 'C', '94d': 'D', '94e': 'E', '94f': 'F'})

    # Group by 'label_l2', sample rows from each group
    df = df.groupby('label_l2').apply(
        lambda x: x.sample(count_per_label, random_state=42)
    ).reset_index(drop=True)
    
    return df

In [4]:
base_path = os.getcwd()
weights_dir_path = os.path.join(base_path, 'weights')
weights_path = os.path.join(weights_dir_path, 'sam_vit_h_4b8939.pth')

raw_ds = 'fishnet_v100'
data_path =  os.path.join(base_path, 'raw_data', raw_ds)
image_data_path = data_path


df = pd.read_csv(os.path.join(data_path, '_annotations.csv'))

df_train = pd.read_csv(os.path.join(image_data_path, '_df_train.csv')) #process_df(df[df['train']].copy(), count_per_label=1000)
bboxes_dict_train = {'bboxes': {f'{filename}.jpg': {vals['label_l2'].values[0]: [np.array(bbox) for bbox in vals[['x_min', 'y_min', 'x_max', 'y_max']].values]} for filename, vals in df_train.groupby('img_id')}}
labeled_output_path_train =  os.path.join(base_path, 'output_labeled', f'{raw_ds}_01_labeled', 'train')
helper_output_path_train =  os.path.join(base_path, 'output_labeled_extra', f'{raw_ds}_01_labeled_extra', 'train')

df_val = pd.read_csv(os.path.join(image_data_path, '_df_val.csv')) #process_df(df[df['val']].copy(), count_per_label=300)
bboxes_dict_val = {'bboxes': {f'{filename}.jpg': {vals['label_l2'].values[0]: [np.array(bbox) for bbox in vals[['x_min', 'y_min', 'x_max', 'y_max']].values]} for filename, vals in df_val.groupby('img_id')}}
labeled_output_path_val =  os.path.join(base_path, 'output_labeled', f'{raw_ds}_01_labeled', 'val')
helper_output_path_val =  os.path.join(base_path, 'output_labeled_extra', f'{raw_ds}_01_labeled_extra', 'val')


model_type = 'vit_h'
points_per_batch = 16 # change based on available gpu memory and model

#raw_ds = 'fishery_simple_v1'
#bboxes_dict = {'bboxes': {filename: {vals['class'].values[0]: [np.array(bbox) for bbox in vals[['xmin', 'ymin', 'xmax', 'ymax']].values]} for filename, vals in df.groupby('filename

In [5]:
df_train.shape, df_val.shape

((7000, 16), (2100, 16))

#### install remaining dependencies, if needed

In [6]:
try:
    import supervision, segment_anything
    deps_install_needed = False
except:
    deps_install_needed = True

if deps_install_needed:
    # sam python package
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
    # other packages
    %pip install -q jupyter_bbox_widget roboflow dataclasses-json supervision

    import supervision as sv
    from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

import supervision as sv
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam = sam_model_registry[model_type](checkpoint=weights_path).to(device=device)

#### Download sam weights, if needed

In [7]:
weights_download_needed = not os.path.isfile(weights_path)

if weights_download_needed:        
    
    %mkdir -p {weights_dir_path}
    !wget -P {weights_dir_path} 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'


if not os.path.isfile(weights_path): raise Exception('sam weights were not found')

## bboxes to segments and cuts

# Util

In [8]:
def iou(bbox1, bbox2):
    assert bbox1[0] < bbox1[2]
    assert bbox1[1] < bbox1[3]
    assert bbox2[0] < bbox2[2]
    assert bbox2[1] < bbox2[3]

    x_left = max(bbox1[0], bbox2[0])
    y_top = max(bbox1[1], bbox2[1])
    x_right = min(bbox1[2], bbox2[2])
    y_bottom = min(bbox1[3], bbox2[3])

    if x_right < x_left or y_bottom < y_top:
        return 0.0

    intersection_area = (x_right - x_left) * (y_bottom - y_top)

    bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
    bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])

    iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area)
    assert iou >= 0.0
    assert iou <= 1.0
    return iou

In [9]:
def convert_widget_box(widget_box):
    return np.array([
        widget_box['x'], 
        widget_box['y'], 
        widget_box['x'] + widget_box['width'], 
        widget_box['y'] + widget_box['height']
    ])   

In [10]:
def best_detection_for_bbox(mask_predictor, image_bgr, image_bbox):
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    mask_predictor.set_image(image_rgb)
    masks, scores, logits = mask_predictor.predict(box=image_bbox, multimask_output=True)
    detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=masks), mask=masks)
    return detections[detections.area == np.max(detections.area)]

In [11]:
def get_best_detection(sam_detections, label_detection):
    best_detection, best_iou = None, 0
    for sam_detection in sam_detections:
        iou_value = iou(label_detection.xyxy[0], sam_detection.xyxy[0])
        if iou_value > best_iou:
            best_iou = iou_value
            best_detection = sam_detection
    return best_detection

In [12]:
def convert_detection(image_bgr, detection):
    x1,y1,x2,y2 = detection.xyxy[0]
    seg_mask = detection.mask[0]
    segment_bbox = image_bgr[y1:y2,x1:x2]
    segment_cut = np.where(seg_mask[:,:,None], image_bgr, 0)[y1:y2,x1:x2]
    return seg_mask, segment_bbox, segment_cut

In [13]:
def plot_detection(image_bgr, detection):
    seg_mask, segment_bbox, segment_cut = convert_detection(image_bgr, detection)
    source_image = sv.BoxAnnotator(color=sv.Color.red()).annotate(scene=image_bgr.copy(), detections=detection, skip_label=True)
    segmented_image =  sv.MaskAnnotator(color=sv.Color.red()).annotate(scene=image_bgr.copy(), detections=detection)
    sv.plot_images_grid(
        images=[segmented_image, seg_mask, segment_bbox, segment_cut],
        grid_size=(2, 2),
        titles=['image segmented', 'mask', 'segment bbox', 'segment_cut']
    )

In [14]:
def xyxy_to_detection(xyxy):
    return sv.Detections(xyxy=xyxy[None,:])

def sam_result_to_detections(sam_result):
    return [sv.Detections(xyxy=sv.mask_to_xyxy(masks=result['segmentation'][None,:,:]), mask=result['segmentation'][None,:]) for result in sam_result]

In [15]:
# TODO simplify
def save_segments(image_bgr, detection, image_name, label, output_path, helper_output_path, segment_nr, label_detection=None):
    output_dir = os.path.join(output_path, label)
    os.makedirs(output_dir, exist_ok=True)
    
    seg_mask, segment_bbox, segment_cut = convert_detection(image_bgr, detection)

    out_name = f"{image_name.split('.')[0]}_{segment_nr}"
    cv2.imwrite(os.path.join(output_dir, f'{out_name}_CUT.jpg'), segment_cut)
    
    if label == 'NONE': return

    helper_output_dir = os.path.join(helper_output_path, image_name.split('.')[0], label, str(segment_nr))
    os.makedirs(helper_output_dir, exist_ok=True)
    
    if label_detection: label_bbox_image = sv.BoxAnnotator(color=sv.Color.red()).annotate(scene=image_bgr.copy(), detections=label_detection, skip_label=True)
    segmented_image =  sv.MaskAnnotator(color=sv.Color.red()).annotate(scene=image_bgr.copy(), detections=detection)
    sam_bbox_image = sv.BoxAnnotator(color=sv.Color.red()).annotate(scene=image_bgr.copy(), detections=detection, skip_label=True)

    if label_detection: cv2.imwrite(os.path.join(helper_output_dir, f'{out_name}_LABEL_BBOX.jpg'), label_bbox_image)
    cv2.imwrite(os.path.join(helper_output_dir, f'{out_name}_SAM_BBOX.jpg'), sam_bbox_image)
    cv2.imwrite(os.path.join(helper_output_dir, f'{out_name}_SEGEMENTED.jpg'), segmented_image)
    cv2.imwrite(os.path.join(helper_output_dir, f'{out_name}_MASK.jpg'), seg_mask.astype('uint8') * 255)
    cv2.imwrite(os.path.join(helper_output_dir, f'{out_name}_BBOX.jpg'), segment_bbox)
    cv2.imwrite(os.path.join(helper_output_dir, f'{out_name}_CUT.jpg'), segment_cut)

# SAM methods

In [16]:
def store_best_segment_for_labels(mask_predictor, bboxes_dict, image_data_dir, output_dir, extra_output_dir):
    for image_name in tqdm(bboxes_dict['bboxes']):
        image_bgr = cv2.imread(os.path.join(image_data_dir, image_name))
        for label, bboxes in bboxes_dict['bboxes'][image_name].items():
            for segment_nr, bbox in enumerate(bboxes):
                best_detection = best_detection_for_bbox(mask_predictor, image_bgr, bbox)
                save_segments(image_bgr, best_detection, image_name, label, output_dir, extra_output_dir, segment_nr+1, label_detection=xyxy_to_detection(bbox))

In [17]:
def store_all_segments(mask_generator, image_names, image_data_path, output_dir):
    for image_name in tqdm(image_names):
        image_bgr = cv2.imread(os.path.join(image_data_path, image_name))
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
       
        sam_result = mask_generator.generate(image_rgb)
        for segment_nr, detection in enumerate(sam_result_to_detections(sam_result)):
            save_segments(image_bgr, detection, image_name, 'NONE', output_dir, '', segment_nr+1)

In [18]:
def store_all_segments_exclude_labels(mask_generator, bboxes_dict, image_data_path, output_dir):
    for image_name in tqdm(bboxes_dict['bboxes']):
        image_bgr = cv2.imread(os.path.join(image_data_path, image_name))
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        
        sam_result = mask_generator.generate(image_rgb)
        sam_detections = sam_result_to_detections(sam_result)
        for bboxes in bboxes_dict['bboxes'][image_name].values():
            for bbox in bboxes:
                best_detection = get_best_detection(sam_detections, xyxy_to_detection(bbox))
                if best_detection: sam_detections.remove(best_detection)
                
        for segment_nr, sam_detection in enumerate(sam_detections):
            save_segments(image_bgr, sam_detection, image_name, 'NONE', output_dir, '', segment_nr+1)

In [19]:
def store_all_segments_seperate_labels(mask_generator, bboxes_dict, image_data_path, output_dir, helper_output_dir, row_nr=0):
    
    for image_name in tqdm(bboxes_dict['bboxes']):
        image_bgr = cv2.imread(os.path.join(image_data_path, image_name))
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        
        sam_result = mask_generator.generate(image_rgb)
        sam_detections = sam_result_to_detections(sam_result)
        segment_nr = 1
        for label, bboxes in bboxes_dict['bboxes'][image_name].items():
            for bbox in bboxes:
                best_detection = get_best_detection(sam_detections, xyxy_to_detection(bbox))
                if best_detection:
                    sam_detections.remove(best_detection)
                    save_segments(image_bgr, best_detection, image_name, label, os.path.join(output_dir, 'NONE', image_name), helper_output_dir, segment_nr)
                    segment_nr += 1
                
        for sam_detection in sam_detections:
            save_segments(image_bgr, sam_detection, image_name, 'NONE', os.path.join(output_dir, 'NONE', image_name), '', segment_nr)
            segment_nr += 1
            
        with open('latest_row.txt', 'w') as f:
            f.write(str(row_nr))
            row_nr += 1

# Execution 
### Warning: This might overwrite existing files 

In [None]:
break

In [None]:
#mask_predictor = SamPredictor(sam)
#store_best_segment_for_labels(mask_predictor, bboxes_dict_train, image_data_path, labeled_output_path_train, helper_output_path_train)
#store_best_segment_for_labels(mask_predictor, bboxes_dict_val, image_data_path, labeled_output_path_val, helper_output_path_val)

In [25]:
mask_generator = SamAutomaticMaskGenerator(sam, points_per_batch=32)

In [None]:
# TRAIN
# row_nr = 0
# bboxes_dict_train_row_nr = {'bboxes': {k: v for k, v in list(bboxes_dict_train['bboxes'].items())[row_nr:]}}
# len(bboxes_dict_train['bboxes']), len(bboxes_dict_train_row_nr['bboxes'])
#store_all_segments_seperate_labels(mask_generator, bboxes_dict_train_row_nr, image_data_path, labeled_output_path_train, os.path.join(helper_output_path_train, '_best_segments'), row_nr=row_nr)

In [29]:
# VAL
store_all_segments_seperate_labels(mask_generator, bboxes_dict_val_row_nr, image_data_path, labeled_output_path_val, os.path.join(helper_output_path_val, '_best_segments'))

  0%|          | 0/1891 [00:00<?, ?it/s]