In [1]:
import os
import numpy as np
from torch.utils.data import Dataset
from PIL import Image

class SAMDataset(Dataset):
    def __init__(self, img_dir, mask_dir, processor):
        self.processor = processor

        # get mask file path list

        self.img_dir = img_dir
        self.mask_dir = mask_dir
        
        self.mask_path_list = os.listdir(mask_dir)
        

    def get_bounding_box(self, ground_truth_map):
        # get bounding box from mask
        y_indices, x_indices = np.where(ground_truth_map > 0)
        x_min, x_max = np.min(x_indices), np.max(x_indices)
        y_min, y_max = np.min(y_indices), np.max(y_indices)
        # add perturbation to bounding box coordinates
        H, W = ground_truth_map.shape
        x_min = max(0, x_min - np.random.randint(0, 20))
        x_max = min(W, x_max + np.random.randint(0, 20))
        y_min = max(0, y_min - np.random.randint(0, 20))
        y_max = min(H, y_max + np.random.randint(0, 20))
        bbox = [x_min, y_min, x_max, y_max]
        return bbox
    
    def __len__(self):
        return len(self.mask_path_list)
    
    def __getitem__(self, idx):
        # item = self.dataset[idx]
        mask_path = os.path.join(self.mask_dir,self.mask_path_list[idx])
        mask = Image.open(mask_path)
        mask = mask.resize((256,256))
        mask = np.array(mask)
        mask[mask==2] =  1

        ground_truth_mask = mask
        img_path = os.path.join(self.img_dir, self.mask_path_list[idx].replace('_mask',''))
        image = Image.open(img_path)
        # image = item["image"]
        # ground_truth_mask = np.array(item["label"])
    
        # get bounding box prompt
        prompt = self.get_bounding_box(ground_truth_mask)
        
        # prepare image and prompt for the model
        inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")
        
        # remove batch dimension which the processor adds by default
        inputs = {k:v.squeeze(0) for k,v in inputs.items()}
        
        # add ground truth segmentation
        inputs["ground_truth_mask"] = ground_truth_mask
        
        return inputs

    
    
    

In [2]:
from transformers import SamProcessor

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
train_dataset = SAMDataset(img_dir='../data/data_crop1024_shift512/train_images', mask_dir='../data/data_crop1024_shift512/1024-train-mask-mult', processor=processor)

In [4]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

In [5]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

pixel_values torch.Size([4, 3, 1024, 1024])
original_sizes torch.Size([4, 2])
reshaped_input_sizes torch.Size([4, 2])
input_boxes torch.Size([4, 1, 4])
ground_truth_mask torch.Size([4, 256, 256])


In [6]:
batch["ground_truth_mask"].shape

torch.Size([4, 256, 256])

In [7]:
from transformers import SamModel 

model = SamModel.from_pretrained("facebook/sam-vit-base")

# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

In [8]:
from torch.optim import Adam
import monai

# Note: Hyperparameter tuning could improve performance here
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)

seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [9]:
from tqdm import tqdm
from statistics import mean
import torch
from torch.nn.functional import threshold, normalize

num_epochs = 100

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
      # forward pass
      outputs = model(pixel_values=batch["pixel_values"].to(device),
                      input_boxes=batch["input_boxes"].to(device),
                      multimask_output=False)

      # compute loss
      predicted_masks = outputs.pred_masks.squeeze(1)
      ground_truth_masks = batch["ground_truth_mask"].float().to(device)
      loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

      # backward pass (compute gradients of parameters w.r.t. loss)
      optimizer.zero_grad()
      loss.backward()

      # optimize
      optimizer.step()
      epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')

  0%|                                                   | 0/451 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.00 GiB (GPU 0; 10.92 GiB total capacity; 6.89 GiB already allocated; 2.85 GiB free; 6.99 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF