# HuggingFace🤗 SAM Fine-tuning

[HF SAM 공식문서](https://huggingface.co/docs/transformers/model_doc/sam)

## Prep

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from statistics import mean
from PIL import Image

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.functional import threshold, normalize

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

### Transformers SAM modules

In [None]:
import transformers
hf_list = dir(transformers)
for module in hf_list:
    if "Sam" in module:
        print(module)

### 데이터셋 불러오기
유방암 초음파 진단 Segmentation dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset("nielsr/breast-cancer", split="train")

print(dataset['image'][:5])
print(dataset['label'][:5])

### 데이터 시각화

In [None]:
example = dataset[0]
example["image"]

In [109]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))  

def show_boxes_on_image(raw_image, boxes):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    for box in boxes:
        show_box(box, plt.gca())
    plt.axis('on')
    plt.show()

def show_points_on_image(raw_image, input_points, input_labels=None):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    input_points = np.array(input_points)
    if input_labels is None:
        labels = np.ones_like(input_points[:, 0])
    else:
        labels = np.array(input_labels)
    show_points(input_points, labels, plt.gca())
    plt.axis('on')
    plt.show()

def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    input_points = np.array(input_points)
    if input_labels is None:
        labels = np.ones_like(input_points[:, 0])
    else:
        labels = np.array(input_labels)
    show_points(input_points, labels, plt.gca())
    for box in boxes:
        show_box(box, plt.gca())
    plt.axis('on')
    plt.show()


def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    input_points = np.array(input_points)
    if input_labels is None:
        labels = np.ones_like(input_points[:, 0])
    else:
        labels = np.array(input_labels)
    show_points(input_points, labels, plt.gca())
    for box in boxes:
        show_box(box, plt.gca())
    plt.axis('on')
    plt.show()


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)


def show_masks_on_image(raw_image, masks, scores):
    masks = masks[0]
    if len(masks.shape) == 4:
        masks = masks.squeeze()
    if scores.shape[0] == 1:
        scores = scores.squeeze()

    nb_predictions = scores.shape[-1]
    fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))

    for i, (mask, score) in enumerate(zip(masks, scores)):
        mask = mask.cpu().detach()
        axes[i].imshow(np.array(raw_image))
        show_mask(mask, axes[i])
        axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
        axes[i].axis("off")
    plt.show()

In [None]:
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i in range(5):
    axes[0][i].imshow(np.array(dataset["image"][i]), cmap="gray")
    axes[1][i].imshow(np.array(dataset["image"][i]), cmap="gray")
    ground_truth_seg = np.array(dataset["label"][i])
    show_mask(ground_truth_seg, axes[1][i])
    axes[0][i].title.set_text(f"Image {i}")
    axes[0][i].axis("off")
plt.tight_layout()
plt.show()

### 프롬프트 생성
- 프롬프트 1: bounding box
    - 마스크를 둘러 싼 Bbox
- 프롬프트 2: point
    - 마스크 내 임의의 점

In [8]:
def get_bounding_box(ground_truth_map):
    # Segmentation mask -> Bbox 좌표(xyxy) 변환
    y_indices, x_indices = np.where(ground_truth_map > 0)
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    H, W = ground_truth_map.shape
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))
    bbox = [x_min, y_min, x_max, y_max]

    return bbox

def get_point_prompt(ground_truth_map):
    # Segmentation mask에 속한 점을 추출 (ground_truth_map > 0인 좌표 중 하나)
    y_indices, x_indices = np.where(ground_truth_map > 0)

    if len(x_indices) == 0 or len(y_indices) == 0:
        return None  # 만약 객체가 없으면 None 반환

    # 점 하나를 랜덤으로 선택
    idx = np.random.randint(0, len(x_indices))
    point = (x_indices[idx], y_indices[idx])  # (x, y) 좌표로 반환

    return point

In [None]:
_mask = get_bounding_box(np.array(dataset["label"][0]))
_mask

In [None]:
_point = get_point_prompt(np.array(dataset["label"][0]))
_point

### Dataset & DataLoader 클래스 선언

In [11]:
class SAMDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        """
        이미지와 프롬프트를 입력받아 마스크를 생성
        여기에선 lable(마스크)을 활용하여 두 가지 프롬프트(bbox, point)를 생성
        x: image, prompt1(bounding box), prompt2(point)
        y: segmentation mask
        """
        item = self.dataset[idx]
        image = item["image"].convert("RGB")
        ground_truth_mask = np.array(item["label"])

        prompt_bbox = get_bounding_box(ground_truth_mask)
        prompt_point = get_point_prompt(ground_truth_mask)

        # x: image, prompt
        inputs = self.processor(image, input_boxes=[[prompt_bbox]], input_points=[[prompt_point]], return_tensors="pt")

        # Processor을 통과하면 자동으로 batch 차원이 추가되므로 하나 제거
        inputs = {k:v.squeeze(0) for k, v in inputs.items()}

        # y: segmentation mask
        inputs["ground_truth_mask"] = ground_truth_mask

        return inputs

In [12]:
from transformers import SamProcessor

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [None]:
sam_dataset = SAMDataset(dataset=dataset, processor=processor)
sam_dataset[0]

In [None]:
example = sam_dataset[0]
for k, v in example.items():
    print(k, v.shape)

In [94]:
train_size = int(len(sam_dataset) * 0.8)
test_size = len(sam_dataset) - train_size
train_dataset, test_dataset = random_split(sam_dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [None]:
batch = next(iter(train_loader))
for k, v in batch.items():
    print(k ,v.shape)

## 학습

### 모델 불러오기

[공식 문서](https://github.com/huggingface/transformers/blob/v4.45.1/src/transformers/models/sam/modeling_sam.py#L1173)

In [17]:
from transformers import SamModel 

model = SamModel.from_pretrained("facebook/sam-vit-base")

# 이미지 Encoder가 아닌 Mask prediction head(Decoder)만 학습
for name, param in model.named_parameters():
    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
        param.requires_grad_(False)

### Inputs
- input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
    - Optional input points for the prompt encoder. The padding of the point is automatically done by the processor. `point_batch_size` refers to the number of masks that we want the model to predict per point. The model will output `point_batch_size` times 3 masks in total.
    
    
- input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
    - Optional input labels for the prompt encoder. The padding of the labels is automatically done by the processor, or can be fed by the user.


- input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
    - Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the processor. users can also pass manually the input boxes.


- input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
    - Optional input masks for the prompt encoder.

In [18]:
model.to(device)
dummy_image = torch.randn(4, 3, 1024, 1024).to(device)
dummy_boxes = torch.zeros(4, 1, 4).to(device)
dummy_points = torch.zeros(4, 1, 1, 2).to(device)

batch = {"pixel_values": dummy_image, "input_boxes": dummy_boxes, "input_points": dummy_points}

### Outputs
- iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`):
    The iou scores of the predicted masks.
- pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`):
    The predicted low resolutions masks. Needs to be post-processed by the processor
- vision_hidden_states  (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
    Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
    one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

    Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.

- vision_attentions  (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
    Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
    sequence_length)`.

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
    heads.

- mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
    Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
    sequence_length)`.

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
    heads.

In [None]:
cpu_device = torch.device("cpu")
model = model.to(cpu_device)
output = model(pixel_values=batch["pixel_values"].to(cpu_device), # 이미지
               input_boxes=batch["input_boxes"].to(cpu_device),   # bbox prompt
#                input_points=batch["input_points"].to(device), # point prompt
               multimask_outputs=True, # 논문 상 마스크를 3개씩 예측하나, False일 경우 하나만 예측
               return_dict=True)

output.keys()

In [None]:
output

In [None]:
print(output['iou_scores'])
print(f"Best mask channel: {output['iou_scores'].argmax()}")

In [None]:
output['pred_masks'].shape

In [None]:
output['iou_scores'].shape

### Pre-trained model로 Zero-shot prediction 수행

In [None]:
image = dataset[0]["image"]
image

In [50]:
gt_mask = np.array(dataset[0]["label"])
bbox_prompt = get_bounding_box(gt_mask)
point_prompt = get_point_prompt(gt_mask)

inputs = inputs = processor(image, input_boxes=[[bbox_prompt]], input_points=[[point_prompt]], return_tensors="pt").to(cpu_device)

In [None]:
print(bbox_prompt, point_prompt)
show_points_and_boxes_on_image(image, [bbox_prompt], [point_prompt])

In [None]:
%%time
model.to(cpu_device)
model.eval()

with torch.no_grad():
    outputs = model(**inputs, multimask_outputs=True)

In [59]:
masks = processor.image_processor.post_process_masks(
    outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores

In [None]:
ig, axes = plt.subplots()

axes.imshow(np.array(image))
show_mask(gt_mask, axes)
axes.title.set_text(f"Ground truth mask")
axes.axis("off")

In [None]:
show_masks_on_image(image, masks, scores)

### Fine-tuning

In [69]:
def dice_loss(pred, target, smooth = 1.):

    intersection = (pred * target).sum(dim=2).sum(dim=1)
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=1) + target.sum(dim=2).sum(dim=1) + smooth)))
    
    return loss.mean()

In [70]:
from torch.optim import AdamW

optimizer = AdamW(model.mask_decoder.parameters(), lr=1e-4, weight_decay=0)

epochs = 20

In [110]:
def choose_best_mask(output):
    # IOU 점수 중 최대값을 가진 인덱스를 찾습니다. squeeze()로 차원 축소 후 argmax 실행
    best_channel_indices = output['iou_scores'].squeeze(1).argmax(dim=-1)

    # 결과 마스크를 선택
    batch_size, _, num_masks, height, width = output['pred_masks'].shape
    best_masks = torch.zeros((batch_size, height, width)).float()

    # 각 배치에 대해 최고의 마스크를 선택
    for i in range(batch_size):
        best_masks[i] = torch.sigmoid(output['pred_masks'][i, 0, best_channel_indices[i], :, :])

    return best_masks

In [71]:
def train_one_epoch(model, train_dataloader, optimizer, dice_loss, prompt_bbox=True):
    model.to(device)
    model.train()
    batch_loss = []
    for batch in tqdm(train_dataloader):
        if prompt_bbox:
            outputs = model(pixel_values=batch["pixel_values"].to(device), input_boxes=batch["input_boxes"].to(device), multimask_outputs=False)
        else: 
            outputs = model(pixel_values=batch["pixel_values"].to(device), input_points=batch["input_points"].to(device), multimask_outputs=False)
        pred_mask = choose_best_mask(outputs).to(device)
        gt_mask = batch["ground_truth_mask"].float().to(device)
        loss = dice_loss(pred_mask, gt_mask)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        batch_loss.append(loss.item())
    return sum(batch_loss) / len(batch_loss)

In [None]:
loss_history = []
for epoch in range(epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, dice_loss, prompt_bbox=False)
    print(f"Epoch {epoch}, train loss: {train_loss}")
    loss_history.append(train_loss)

In [77]:
torch.save(model.state_dict(), 'SAM_model.pth')
torch.save(optimizer.state_dict(), 'optimizer.pth')

### Inference

In [None]:
model.load_state_dict(torch.load('SAM_model.pth'))
optimizer.load_state_dict(torch.load('optimizer.pth'))

In [None]:
idx = 55
image = dataset[idx]["image"]
image

In [113]:
gt_mask = np.array(dataset[idx]["label"])
bbox_prompt = get_bounding_box(gt_mask)
point_prompt = get_point_prompt(gt_mask)

inputs = inputs = processor(image, input_boxes=[[bbox_prompt]], input_points=[[point_prompt]], return_tensors="pt").to(cpu_device)

In [None]:
print(bbox_prompt, point_prompt)
show_points_and_boxes_on_image(image, [bbox_prompt], [point_prompt])

In [None]:
%%time
model.to(cpu_device)
model.eval()

with torch.no_grad():
    outputs = model(**inputs, multimask_outputs=True)

In [116]:
masks = processor.image_processor.post_process_masks(
    outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores

In [None]:
ig, axes = plt.subplots()

axes.imshow(np.array(image))
show_mask(gt_mask, axes)
axes.title.set_text(f"Ground truth mask")
axes.axis("off")

In [None]:
show_masks_on_image(image, masks, scores)