In [None]:
! pip install kaggle &> /dev/null
! pip install torch torchvision &> /dev/null
! pip install opencv-python pycocotools matplotlib onnxruntime onnx &> /dev/null
! pip install git+https://github.com/facebookresearch/segment-anything.git &> /dev/null
! wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth &> /dev/null

In [2]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import cv2

In [4]:
MASK_DIR = "_data/combined/train/leaf_instances"
RGB_DIR = "_data/combined/train/images"

In [6]:
model_type = 'vit_b'
checkpoint = 'sam_vit_b_01ec64.pth'
device = 'cuda:0'

In [8]:
from segment_anything import SamPredictor, sam_model_registry
sam_model = sam_model_registry[model_type]()
sam_model.to(device)
sam_model.train();

In [57]:
from torch.utils.data import Dataset
# Preprocess the images
from collections import defaultdict

import torch

from segment_anything.utils.transforms import ResizeLongestSide
import os

transform = ResizeLongestSide(sam_model.image_encoder.img_size)

class LeafInstanceDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.files = os.listdir(os.path.join(path, "images"))
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.path, "images", self.files[idx])
        mask_path =  os.path.join(self.path, "leaf_instances", self.files[idx])
        mask_im = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = np.array(mask_im)

        unique_categories = np.unique(mask)
        unique_categories = unique_categories[unique_categories > 0]  # Exclude background (assumed to be 0)
        
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        input_image = transform.apply_image(image)
        input_image_torch = torch.as_tensor(input_image, device=device)
        transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]

        input_image = sam_model.preprocess(transformed_image)
        original_image_size = image.shape[:2]
        input_size = tuple(transformed_image.shape[-2:])
        
        data = [{}] * len(unique_categories)
        
        for category_index, category_id in enumerate(unique_categories):
            y, x = np.nonzero(mask)
            x_min = np.min(x)
            y_min = np.min(y)
            x_max = np.max(x)
            y_max = np.max(y)
            data[category_index]["bbox"] = np.array([x_min, y_min, x_max, y_max])
            data[category_index]["bbox_transformed"] = transform.apply_boxes(np.array([x_min, y_min, x_max, y_max]), original_image_size)
            data[category_index]["mask"] = (mask == category_id).squeeze()
            data[category_index]["image"] = input_image.squeeze()
            data[category_index]["input_size"] = input_size
            data[category_index]['original_image_size'] = original_image_size
            
        return data

In [44]:
# Set up the optimizer, hyperparameter tuning will improve performance here
lr = 1e-4
wd = 0
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=lr, weight_decay=wd)

#loss_fn = torch.nn.MSELoss()
loss_fn = torch.nn.BCELoss()
#keys = list(bbox_coords.keys())

In [None]:
from statistics import mean

from tqdm import tqdm
from torch.nn.functional import threshold, normalize

from torch.utils.data import DataLoader

num_epochs = 100
losses = []
data_loader = DataLoader(LeafInstanceDataset("_data/combined/train"))

for epoch in range(num_epochs):
  epoch_losses = []
  # Just train on the first 20 examples
  for data in data_loader:
    for i in range(len(data)):
        input_image = data[i]['image'].to(device)
        input_size = data[i]['input_size']
        original_image_size = data[i]['original_image_size']

        # No grad here as we don't want to optimise the encoders
        with torch.no_grad():
          image_embedding = sam_model.image_encoder(input_image)

          box = data[i]["bbox_transformed"]
          box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
          box_torch = box_torch[None, :]

          sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
              points=None,
              boxes=box_torch,
              masks=None,
          )
        low_res_masks, iou_predictions = sam_model.mask_decoder(
          image_embeddings=image_embedding,
          image_pe=sam_model.prompt_encoder.get_dense_pe(),
          sparse_prompt_embeddings=sparse_embeddings,
          dense_prompt_embeddings=dense_embeddings,
          multimask_output=False,
        )

        upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)
        binary_mask = normalize(threshold(upscaled_masks, 0.0, 0))

        data[i]["mask"] = data[i]["mask"].squeeze()
        gt_mask_resized = torch.from_numpy(np.resize(data[i]["mask"], (1, 1, data[i]["mask"].shape[0], data[i]["mask"].shape[1]))).to(device)
        gt_binary_mask = torch.as_tensor(gt_mask_resized > 0, dtype=torch.float32)

        loss = loss_fn(binary_mask, gt_binary_mask)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
  losses.append(epoch_losses)
  print(f'EPOCH: {epoch}')
  print(f'Mean loss: {mean(epoch_losses)}')

In [10]:
from statistics import mean

from tqdm import tqdm
from torch.nn.functional import threshold, normalize

from torch.utils.data import DataLoader

num_epochs = 100
losses = []
data_loader = DataLoader(LeafInstanceDataset("_data/combined"))

for epoch in range(num_epochs):
  epoch_losses = []
  # Just train on the first 20 examples
  for k in keys:
    input_image = transformed_data[k]['image'].to(device)
    input_size = transformed_data[k]['input_size']
    original_image_size = transformed_data[k]['original_image_size']

    # No grad here as we don't want to optimise the encoders
    with torch.no_grad():
      image_embedding = sam_model.image_encoder(input_image)

      prompt_box = bbox_coords[k]
      box = transform.apply_boxes(prompt_box, original_image_size)
      box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
      box_torch = box_torch[None, :]

      sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
          points=None,
          boxes=box_torch,
          masks=None,
      )
    low_res_masks, iou_predictions = sam_model.mask_decoder(
      image_embeddings=image_embedding,
      image_pe=sam_model.prompt_encoder.get_dense_pe(),
      sparse_prompt_embeddings=sparse_embeddings,
      dense_prompt_embeddings=dense_embeddings,
      multimask_output=False,
    )

    upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)
    binary_mask = normalize(threshold(upscaled_masks, 0.0, 0))

    gt_mask_resized = torch.from_numpy(np.resize(ground_truth_masks[k], (1, 1, ground_truth_masks[k].shape[0], ground_truth_masks[k].shape[1]))).to(device)
    gt_binary_mask = torch.as_tensor(gt_mask_resized > 0, dtype=torch.float32)

    loss = loss_fn(binary_mask, gt_binary_mask)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch_losses.append(loss.item())
  losses.append(epoch_losses)
  print(f'EPOCH: {epoch}')
  print(f'Mean loss: {mean(epoch_losses)}')

SyntaxError: '(' was never closed (852363337.py, line 10)