In [None]:
! pip install kaggle &> /dev/null
! pip install torch torchvision &> /dev/null
! pip install opencv-python pycocotools matplotlib onnxruntime onnx &> /dev/null
! pip install git+https://github.com/facebookresearch/segment-anything.git &> /dev/null
! wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth &> /dev/null

In [2]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import cv2

In [4]:
MASK_DIR = "_data/combined/train/leaf_instances"
RGB_DIR = "_data/combined/train/images"

In [6]:
model_type = 'vit_b'
checkpoint = 'sam_vit_b_01ec64.pth'
device = 'cuda:0'
num_classes = 2

In [8]:

import torch
import torchvision
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
# Custom SAM with ResNet50 classifier
class CustomSAM(torch.nn.Module):
    def __init__(self, sam, mask_generator, num_classes):
        super().__init__()
        self.sam = sam
        self.mask_generator = mask_generator
        self.classifier = torchvision.models.resnet50(pretrained=False)
        self.classifier.fc = torch.nn.Linear(self.classifier.fc.in_features, num_classes)

    def forward(self, image):
        # Generate masks
        masks = self.mask_generator.generate(image)
        
        # Get image embeddings from SAM
        with torch.no_grad():
            image_embeddings = self.sam.image_encoder(image)
        
        # Use classifier on image embeddings
        class_output = self.classifier(image_embeddings)
        
        return masks, class_output

In [None]:


sam = sam_model_registry[model_type](checkpoint=checkpoint)
# Create mask generator
mask_generator = SamAutomaticMaskGenerator(sam)
custom_sam = CustomSAM(sam, mask_generator, num_classes)

In [57]:
from torch.utils.data import Dataset
# Preprocess the images
from collections import defaultdict

import torch

from segment_anything.utils.transforms import ResizeLongestSide
import os

transform = ResizeLongestSide(sam.image_encoder.img_size)

class LeafInstanceDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.files = os.listdir(os.path.join(path, "images"))
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.path, "images", self.files[idx])
        mask_path =  os.path.join(self.path, "leaf_instances", self.files[idx])
        mask_im = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = np.array(mask_im)

        unique_categories = np.unique(mask)
        unique_categories = unique_categories[unique_categories > 0]  # Exclude background (assumed to be 0)
        
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        input_image = transform.apply_image(image)
        input_image_torch = torch.as_tensor(input_image, device=device)
        transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]

        input_image = sam_model.preprocess(transformed_image)
        original_image_size = image.shape[:2]
        input_size = tuple(transformed_image.shape[-2:])
        
        data = {}
        data["image"] = input_image
        data["original_image_size"] = original_image_size
        data["input_size"] = input_size
        data["bboxes"] = []
        data["masks"] = []
        data["bboxes_transformed"]

        for category_id in unique_categories:
            y, x = np.nonzero(mask)
            x_min = np.min(x)
            y_min = np.min(y)
            x_max = np.max(x)
            y_max = np.max(y)
            bboxes = np.array([x_min, y_min, x_max, y_max])
            mask = (mask == category_id).squeeze()
            data["bboxes"].append(bboxes)
            data["bboxes_transformed"].append(transform.apply_boxes(bboxes, original_image_size))
            data["masks"].append(mask)
            data["masks_transformed"].append(transform.apply_image(mask, original_image_size))
            
        return data

In [44]:
# Set up the optimizer, hyperparameter tuning will improve performance here
lr = 1e-4
wd = 0
optimizer = torch.optim.Adam(sam.mask_decoder.parameters(), lr=lr, weight_decay=wd)

In [None]:
from torchvision.ops import box_iou

def custom_loss(generated_masks, class_output, true_masks, true_bboxes, true_classes):
    # Class loss
    class_loss = torch.nn.functional.cross_entropy(class_output, true_classes)
    
    # Mask and bbox loss
    mask_loss = 0
    iou_threshold = 0.5
    
    for gen_mask in generated_masks:
        gen_bbox = torch.tensor(gen_mask['bbox'])
        ious = box_iou(gen_bbox.unsqueeze(0), true_bboxes)
        best_iou, best_idx = ious.max(dim=1)
        
        if best_iou > iou_threshold:
            # Calculate mask IoU
            gen_mask_tensor = torch.tensor(gen_mask['segmentation'])
            true_mask_tensor = true_masks[best_idx]
            mask_iou = (gen_mask_tensor & true_mask_tensor).sum() / (gen_mask_tensor | true_mask_tensor).sum()
            mask_loss += 1 - mask_iou

    mask_loss = mask_loss / len(generated_masks) if generated_masks else torch.tensor(0.)
    
    total_loss = class_loss + mask_loss
    return total_loss

In [None]:
from statistics import mean

from tqdm import tqdm
from torch.nn.functional import threshold, normalize

from torch.utils.data import DataLoader

num_epochs = 100
losses = []
data_loader = DataLoader(LeafInstanceDataset("_data/combined/train"))

for epoch in range(num_epochs):
    epoch_losses = []
    p_bar = tqdm(data_loader)
    for data in p_bar:
        optimizer.zero_grad()
        
        image = data["image"]
        true_masks = torch.stack(data["masks"])
        true_bboxes = torch.tensor(data["bboxes_transformed"])
        true_classes = torch.tensor([your_class_mapping[mask.sum() > 0] for mask in data["masks"]])
        
        generated_masks, class_output = custom_sam(image)
        loss = custom_loss(generated_masks, class_output, true_masks, true_bboxes, true_classes)
        loss.backward()
        optimizer.step()
    losses.append(epoch_losses)
    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')

Loss: 73.42185974121094:   0%|          | 39/10010 [00:19<1:17:59,  2.13it/s]

In [None]:
mean_losses = [mean(x) for x in losses]
mean_losses

plt.plot(list(range(len(mean_losses))), mean_losses)
plt.title('Mean epoch loss')
plt.xlabel('Epoch Number')
plt.ylabel('Loss')

plt.show()

In [None]:
# Set up predictors for both tuned and original models
from segment_anything import sam_model_registry, SamPredictor
predictor_tuned = SamPredictor(sam_model)

In [None]:
import random
image_file = random.sample(os.listdir("_data/combined/test"))
image_file = os.path.join("_data/combined/test", image_file)
image = cv2.imread(image_file)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

In [None]:
predictor_tuned.set_image(image)

masks_tuned, _, _ = predictor_tuned.predict(
    point_coords=None,
    box=None,
    multimask_output=True,
)

In [None]:
# Helper functions provided in https://github.com/facebookresearch/segment-anything/blob/9e8f1309c94f1128a6e5c047a10fdcb02fc8d651/notebooks/predictor_example.ipynb
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

In [None]:
axs[0].imshow(image)
show_mask(masks_tuned, axs[0])
show_box(input_bbox, axs[0])
axs[0].set_title('Mask with Tuned Model', fontsize=26)
axs[0].axis('off')
