# Fuctionality to turn images into segment images and store them

## setup

In [1]:
import os, sys, torch, math, gc, numpy as np, cv2
from tqdm.notebook import tqdm

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [2]:
IS_COLAB = 'google.colab' in sys.modules

if IS_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)

# paths
base_path = os.getcwd() if not IS_COLAB else '/content/drive/MyDrive/ba'
weights_dir_path = os.path.join(base_path, 'weights')
weights_path = os.path.join(weights_dir_path, 'sam_vit_h_4b8939.pth')
data_path =  os.path.join(base_path, 'raw_data', 'mattress_target')
output_path =  os.path.join(base_path, 'output_segmented', 'mattress_target')

# sam settings
model_type = 'vit_h'

points_per_batch = 16 # change based on available gpu memory and model

## download weights

In [3]:
weights_download_needed = not os.path.isfile(weights_path)

if weights_download_needed:        
    
    %mkdir -p {weights_dir_path}
    !wget -P {weights_dir_path} 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'


if not os.path.isfile(weights_path): raise Exception('sam weights were not found')

## install and import remaining dependencies

In [4]:
try:
    import supervision, segment_anything
    deps_install_needed = False
except:
    deps_install_needed = True

if deps_install_needed:
    # sam python package
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
    # other packages
    %pip install -q jupyter_bbox_widget roboflow dataclasses-json supervision

    import supervision as sv
    from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

import supervision as sv
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

##  methods to segment images and the segment bboxes and cuts

In [5]:
def gen_annotated_image(img_path, sam_masks):
    return sv.MaskAnnotator().annotate(
        scene=cv2.imread(img_path),
        detections= sv.Detections.from_sam(sam_result=sam_result)
    )

def gen_segment_images(img_bgr, sam_mask):
    x,y,w,h = [int(x) for x in sam_mask['bbox']]
    segment_bbox = img_bgr[y:y+h, x:x+w]
    segment_cut = np.where(sam_mask['segmentation'][:,:,None], img_bgr, 0)[y:y+h, x:x+w]
    return segment_bbox, segment_cut

def store_segments(sam_masks, img_bgr, img_name, output_path, silent=False):
    # create directories
    output_img_path = os.path.join(output_path, img_name.split('.')[0])
    output_img_bbox_path = os.path.join(output_img_path, 'segment_bbox')
    output_img_cut_path = os.path.join(output_img_path, 'segment_cuts')
    os.makedirs(output_img_bbox_path, exist_ok=True)
    os.makedirs(output_img_cut_path, exist_ok=True)
    
    # save bbox and cut segments
    num_digits = int(math.log10(len(sam_masks))) + 1
    for i, sam_mask in enumerate(sam_masks if silent else tqdm(sam_masks)):
        segment_bbox, segment_cut = gen_segment_images(img_bgr, sam_mask)
        cv2.imwrite(os.path.join(output_img_bbox_path, f'bbox_{i:0{num_digits}}.jpg'), segment_bbox)
        cv2.imwrite(os.path.join(output_img_cut_path, f'cut_{i:0{num_digits}}.jpg'), segment_cut)

def proccess_img(mask_generator, img_dir, img_name, output_path, silent=False):
    img_bgr = cv2.imread(os.path.join(img_dir, img_name))
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

    sam_masks = mask_generator.generate(img_rgb)
    sorted_sam_masks = sorted(sam_masks, key=lambda x: x['predicted_iou'], reverse=True)
    store_segments(sorted_sam_masks, img_bgr, img_name, output_path, silent=silent)

def proccess_img_dir(img_dir, model_type, weights_path, output_path, points_per_batch):
    sam = sam_model_registry[model_type](checkpoint=weights_path).to(device=device)
    mask_generator = SamAutomaticMaskGenerator(sam, points_per_batch=points_per_batch)
    for img_name in tqdm(os.listdir(img_dir)):
        if img_name.startswith('.'): continue
        proccess_img(mask_generator, img_dir, img_name, output_path, silent=True)

In [None]:
proccess_img_dir(data_path, model_type, weights_path, output_path, points_per_batch)

In [7]:
#sam = sam_model_registry[model_type](checkpoint=weights_path).to(device=device)
#mask_generator = SamAutomaticMaskGenerator(sam, points_per_batch=points_per_batch)
#proccess_img(mask_generator, data_path, 'mattress1.jpg', output_path)