<a target="_blank" href="https://colab.research.google.com/github/younggon2/Education-ComputerVision-SAM/blob/master/SAM_MedSAM_tutorial.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Environment Set-up

In [None]:
!pip install git+https://github.com/bowang-lab/MedSAM.git

# 1.SAM을 이용한 레이블링 (직접 Box 입력)

In [None]:
# Download SAM model & demo.py

!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
!wget https://raw.githubusercontent.com/bowang-lab/MedSAM/main/utils/demo.py

# Download MedSAM model
model_id = "1UAmWL88roYR7wKlnApw5Bcuzf2iQgk6_&confirm=t"
!gdown $model_id

## Model Load

In [None]:
# Load sam model
import cv2
import os
import matplotlib.pyplot as plt
from google.colab import output
from demo import BboxPromptDemo
from segment_anything import sam_model_registry

output.enable_custom_widget_manager()
SAM_CKPT_PATH = "sam_vit_b_01ec64.pth"
device = "cuda"
sam_model = sam_model_registry['vit_b'](checkpoint=SAM_CKPT_PATH)
sam_model = sam_model.to(device)
sam_model.eval()

## Download demo image

In [None]:
# download demo image
url = "https://cdn.pixabay.com/photo/2018/10/01/09/21/pets-3715733_640.jpg"
# curl 요청
os.system("curl " + url + " > test.jpg")

In [None]:
# 데이터 확인
import cv2
import matplotlib.pyplot as plt

img = cv2.imread("test.jpg") # 이미지 불러오기
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR이미지 > RGB변환

plt.imshow(img)
plt.show()

## Inference

In [None]:
%matplotlib widget

img = 'test.jpg'
bbox_prompt_demo = BboxPromptDemo(sam_model)
bbox_prompt_demo.show(img)

In [None]:
# Mask 확인
%matplotlib inline

mask = cv2.imread('segs.png') # 이미지 불러오기
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
plt.imshow(mask,'gray')
plt.show()

# 2.SAM vs. MedSAM
Medical image segmentation

In [None]:
# download demo medical image
!wget https://github.com/younggon2/Education-ComputerVision-SAM/raw/main/data/ct.png
!wget https://github.com/younggon2/Education-ComputerVision-SAM/raw/main/data/pathology.png

In [None]:
# 데이터 확인
import cv2
import matplotlib.pyplot as plt

ct_img = cv2.imread("ct.png") # CT 이미지 불러오기
ct_img = cv2.cvtColor(ct_img, cv2.COLOR_BGR2RGB) # BGR이미지 > RGB변환

pathology_img = cv2.imread("pathology.png") # Pathology 이미지 불러오기
pathology_img = cv2.cvtColor(pathology_img, cv2.COLOR_BGR2RGB) # BGR이미지 > RGB변환

plt.subplot(1,2,1)
plt.imshow(ct_img)
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(pathology_img)
plt.axis('off')
plt.show()

### Model load

In [18]:
# SAM model
SAM_CKPT_PATH = "sam_vit_b_01ec64.pth"
device = "cuda"
sam_model = sam_model_registry['vit_b'](checkpoint=SAM_CKPT_PATH)
sam_model = sam_model.to(device)
sam_model.eval()

# MedSAM model
MedSAM_CKPT_PATH = "medsam_vit_b.pth"
device = "cuda"
medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH)
medsam_model = medsam_model.to(device)

### SAM inference

In [None]:
# SAM inference
%matplotlib widget

ct_img = 'ct.png'
bbox_prompt_demo = BboxPromptDemo(sam_model)
bbox_prompt_demo.show(ct_img)

# pathology_img = 'pathology.png'
# bbox_prompt_demo = BboxPromptDemo(sam_model)
# bbox_prompt_demo.show(pathology_img)

In [None]:
# Mask 확인
%matplotlib inline

mask = cv2.imread('segs.png') # 이미지 불러오기
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
plt.imshow(mask,'gray')
plt.show()

### MedSAM inference

In [None]:
# MedSAM inference
%matplotlib widget

ct_img = 'ct.png'
bbox_prompt_demo = BboxPromptDemo(medsam_model)
bbox_prompt_demo.show(ct_img)

# pathology_img = 'pathology.png'
# bbox_prompt_demo = BboxPromptDemo(medsam_model)
# bbox_prompt_demo.show(pathology_img)

In [None]:
# Mask 확인
%matplotlib inline

mask = cv2.imread('segs.png') # 이미지 불러오기
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
plt.imshow(mask,'gray')
plt.show()

# 3.MedSAM을 이용한 반자동 레이블링 (Box 자동 입력)

In [None]:
from skimage import io, transform
import numpy as np
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
import os
import torch
import torch.nn.functional as F
import glob
from segment_anything import sam_model_registry, SamPredictor

## Model load

In [None]:
# MedSAM load

MedSAM_CKPT_PATH = "medsam_vit_b.pth"
device = "cuda"
medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH)
medsam_model = medsam_model.to(device)
mask_predictor = SamPredictor(medsam_model)

## Dataset load

In [None]:
# Dataset load
!wget https://github.com/younggon2/Education-ComputerVision-SAM/raw/main/data/sample_lung_data.zip
!unzip sample_lung_data.zip

In [None]:
# box 좌표가 담겨있는 txt 파일 확인
with open(f'sample_lung_data/box/resize_CHNCXR_0001_0.txt') as file :
    for line in file:
        values = line.strip().split()
        print('Box 좌표 :',[int(float(value)*256) for value in values[1:]])

In [None]:
# 하나의 리스트에 모든 이미지 데이터셋 경로 할당
img_pathes = sorted(glob.glob('sample_lung_data/img/*.png'))

## Main inference

In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax):
    # box = [x_min, y_min, x_max, y_max]
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

def modify_coords(bbox_list):
    boxes = []
    for bbox in bbox_list:
        x_center, y_center, w, h = bbox
        x_min = int(x_center - w / 2)
        y_min = int(y_center - h / 2)
        x_max = int(x_center + w / 2)
        y_max = int(y_center + h / 2)
        boxes.append([x_min, y_min, x_max, y_max])
    return boxes

In [None]:
# Main
# MedSAM multi-class

for i, path in enumerate (tqdm(img_pathes)):
    name = path.split('/')[-1].split('.')[0]

    # 이미지 불러오기
    image = cv2.imread(path)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    mask_predictor.set_image(image)

    H, W, _ = image.shape

    bbox_list = []
    classes = []
    # box 좌표가 담겨있는 txt 파일 불러오기
    with open(f'sample_lung_data/box/{name}.txt') as file :
        for line in file:
            values = line.strip().split()
            bbox_list.append([int(float(value)*H) for value in values[1:]])
            classes.append([int(value) for value in values[0]])
    bboxes = modify_coords(bbox_list)
    input_boxes = torch.tensor(bboxes, device=mask_predictor.device)

    # box 좌표를 모델에 입력
    transformed_boxes = mask_predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])

    # 예측
    masks, _, _ = mask_predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )

    # mask, box 확인
    plt.figure(figsize=(10, 5))
    plt.subplot(1,2,1)
    plt.title('Segmented Image')
    plt.imshow(image)
    for mask in masks:
        show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
    for box in input_boxes:
        show_box(box.cpu().numpy(), plt.gca())
    plt.axis('off')
    # plt.show()

    # Save box image
#     plt.savefig(f"sample_lung_data/mask/{name}_seg.png", bbox_inches='tight', pad_inches=0)

    # 예측한 mask들을 하나의 마스크로 결합
    combined_mask = np.zeros((H, W), dtype=np.uint8)
    for idx, mask in enumerate(masks):
        class_ = classes[idx][0]
        mask_np = mask.cpu().numpy()[0]
        mask_binary = (mask_np > 0.5).astype(np.uint8) * (class_+1)  # 이진화
        combined_mask = np.maximum(combined_mask, mask_binary)  # 최대값으로 결합
    plt.subplot(1,2,2)
    plt.title('Mask Image')
    plt.imshow(combined_mask)
    plt.show()

    # Save mask
    # io.imsave(f"sample_lung_data/mask/{name}.png", combined_mask)