In [1]:
import os
import torch
from segment_anything import sam_model_registry
import cv2
from segment_anything import SamAutomaticMaskGenerator
import supervision as sv
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle
import numpy as np

In [2]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"
CHECKPOINT_PATH = "sam_vit_h_4b8939.pth"
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH)
sam.to(device=DEVICE)



Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-31): 32 x Block(
        (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1280, out_features=3840, bias=True)
          (proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=1280, out_features=5120, bias=True)
          (lin2): Linear(in_features=5120, out_features=1280, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(1280, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d

In [3]:
print(DEVICE)

cuda:0


In [4]:
def get_mask(result_dicts):
    result_dicts = sorted(result_dicts, key=lambda x: x['area'], reverse=True)
    mask = np.zeros((result_dicts[0]['segmentation'].shape[0], result_dicts[0]['segmentation'].shape[1])) == 1
    for result_dict in result_dicts[1:]:
        mask = mask | result_dict['segmentation']
    return mask

In [5]:
IMG_FOLDER = "patches_512"
# OUT_FOLDER = "annotated"
OUT_FOLDER_RESULT_DICTS = "result_dicts"
OUT_FOLDER_MASKS = "binary_masks"

for foldername in os.listdir(IMG_FOLDER)[4:5]:
    print(f"Currently on {foldername}")
    files = os.listdir(os.path.join(IMG_FOLDER, foldername))
    num_patches = len(files)

    # out_folder_path = os.path.join(OUT_FOLDER, foldername)
    # os.makedirs(out_folder_path, exist_ok=True)

    out_folder_mask_path = os.path.join(OUT_FOLDER_MASKS, foldername)
    os.makedirs(out_folder_mask_path, exist_ok=True)

    # out_folder_res_dict_path = os.path.join(OUT_FOLDER_RESULT_DICTS, foldername)
    # os.makedirs(out_folder_res_dict_path, exist_ok=True)


    # Continue where I left off
    for patch_number in tqdm(range(0, num_patches)):
        mask_generator = SamAutomaticMaskGenerator(sam)
        IMAGE_PATH = os.path.join(IMG_FOLDER, foldername) + "/patch_" + str(patch_number) + ".png"
        image_bgr = cv2.imread(IMAGE_PATH)
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        result = mask_generator.generate(image_rgb)
        mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
        detections = sv.Detections.from_sam(result)
        annotated_image = mask_annotator.annotate(image_bgr, detections)
        # OUT_PATH = os.path.join(out_folder_path, f"patch_{patch_number}.png")
        # cv2.imwrite(OUT_PATH, annotated_image)

        mask = get_mask(result)
        OUT_PATH_MASK = os.path.join(out_folder_mask_path, f"patch_{patch_number}.png")
        plt.imsave(OUT_PATH_MASK, mask)

        # OUT_PATH_RESULT_DICT = os.path.join(out_folder_res_dict_path, f"patch_{patch_number}.pkl")
        # with open(OUT_PATH_RESULT_DICT, 'wb') as f:
        #     pickle.dump(result, f)

Currently on Smooth Muscle-NA-2-dapi-20a-telC-CEPRNB-40x-MaxIP - Stitched_patches


100%|██████████| 588/588 [10:19:12<00:00, 63.18s/it] 
