In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
from patchify import patchify  #Only to handle large images
import random
from scipy import ndimage
from datasets import Dataset as ds
from torch.utils.data import Dataset, DataLoader
from transformers import SamModel
import torch
from torch.optim import Adam
import monai
from tqdm import tqdm
import torch.nn.functional as F
from torch.nn.functional import threshold, normalize
import cv2 as cv
import csv
import time

In [None]:
# The following code is based on the following github project:
# Bhattiprolu, S. (2023). Fine-tune Segment Anything Model (SAM) - 
# Mitochondria Segmentation [Jupyter notebook]. 
# GitHub. https://github.com/bnsreenu/python_for_microscopists/blob/master/331_fine_tune_SAM_mito.ipynb

# Portions of this code were generated with the assistance of ChatGPT (OpenAI, 2025) and subsequently modified by the author.
# OpenAI. (2025). ChatGPT (May 2025 version) [Large language model]. https://chat.openai.com

trainDirImage = "/content/BrG_training_data/all_im/"
#trainDirImage2 = "D:/Thesis/datasets/DRIVE/training/images/"
trainDirMask = "/content/BrG_training_data/all_mask_cup/"
#trainDirMask2 = "D:/Thesis/datasets/DRIVE/training/1st_manual/"

files1 = sorted([f for f in os.listdir(trainDirImage) if os.path.isfile(os.path.join(trainDirImage, f))])
files2 = sorted([f for f in os.listdir(trainDirMask) if os.path.isfile(os.path.join(trainDirMask, f))])

large_images = []
large_masks = []

h = 511
w = 503

for i in range(len( os.listdir(trainDirImage))):
    if files1[i].endswith('.png'):
        img = cv.imread(cv.samples.findFile(trainDirImage + files1[i]))
        img = cv.cvtColor(img,cv.COLOR_BGR2RGB)
        R,G,B = cv.split(img)
        img = np.asarray(G)
        large_images += [img]


for i in  range(len(os.listdir(trainDirMask))):
    if files2[i].endswith('.png'):
        img = cv.imread(cv.samples.findFile(trainDirMask + files2[i]))
        img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
        img = np.asarray(img)
        large_masks += [img]

#Desired patch size for smaller images and step size.
patch_size = 250
step = 250

all_img_patches = []
for img in range(len(large_images)):
    large_image = np.asarray(large_images[img])
    patches_img = patchify(large_image, (patch_size, patch_size), step=step)
    for i in range(patches_img.shape[0]):
        for j in range(patches_img.shape[1]):

            single_patch_img = patches_img[i,j,:,:]
            all_img_patches.append(single_patch_img)

images = np.array(all_img_patches)

#Let us do the same for masks
all_mask_patches = []
for img in range(len(large_masks)):
    large_mask = np.asarray(large_masks[img])
    patches_mask = patchify(large_mask, (patch_size, patch_size), step=step)  #Step=256 for 256 patches means no overlap

    for i in range(patches_mask.shape[0]):
        for j in range(patches_mask.shape[1]):

            single_patch_mask = patches_mask[i,j,:,:]
            single_patch_mask = (single_patch_mask / 255.).astype(np.uint8)
            all_mask_patches.append(single_patch_mask)

masks = np.array(all_mask_patches)

# Create a list to store the indices of non-empty masks
valid_indices = [i for i, mask in enumerate(masks) if mask.max() != 0]
# Filter the image and mask arrays to keep only the non-empty pairs
filtered_images = images[valid_indices]
filtered_masks = masks[valid_indices]
print("Image shape:", filtered_images.shape)  # e.g., (num_frames, height, width, num_channels)
print("Mask shape:", filtered_masks.shape)

# Convert the NumPy arrays to Pillow images and store them in a dictionary
dataset_dict = {
    "image": [Image.fromarray(img) for img in filtered_images],
    "label": [Image.fromarray(mask) for mask in filtered_masks],
}

# Create the dataset using the datasets.Dataset class
dataset = ds.from_dict(dataset_dict)

img_num = random.randint(0, filtered_images.shape[0]-1)
example_image = dataset[img_num]["image"]
example_mask = dataset[img_num]["label"]

#Get bounding boxes from mask.
def get_bounding_box(ground_truth_map):
  # Define the size of your array
  array_size = 250

  # Define the size of your grid
  grid_size = 15

  # Generate the grid points
  x = np.linspace(0, array_size-1, grid_size)
  y = np.linspace(0, array_size-1, grid_size)

  # Generate a grid of coordinates
  xv, yv = np.meshgrid(x, y)

  # Convert the numpy arrays to lists
  xv_list = xv.tolist()
  yv_list = yv.tolist()

  # Combine the x and y coordinates into a list of list of lists
  input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)]

  input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)

  return input_points

class SAMDataset(Dataset):
  """
  This class is used to create a dataset that serves input images and masks.
  It takes a dataset and a processor as input and overrides the __len__ and __getitem__ methods of the Dataset class.
  """
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    image = item["image"]
    ground_truth_mask = np.array(item["label"])

    # get bounding box prompt
    prompt = get_bounding_box(ground_truth_mask)

    image = image.convert('RGB')
    # prepare image and prompt for the model
    inputs = self.processor(image, return_tensors="pt")

    inputs["input_points"] = prompt

    # remove batch dimension which the processor adds by default
    inputs = {k:v.squeeze(0) for k,v in inputs.items()}

    # add ground truth segmentation
    inputs["ground_truth_mask"] = ground_truth_mask

    return inputs

# Initialize the processor
from transformers import SamProcessor
processor = SamProcessor.from_pretrained("/content/sam-vit-base")

# Create an instance of the SAMDataset
train_dataset = SAMDataset(dataset=dataset, processor=processor)

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=False)

# Load the model
model = SamModel.from_pretrained("/content/sam-vit-base")

#  !!only do this if you have a trained model that you want to train further!!
model.load_state_dict(torch.load("/content/sam_b_cup_green_30.pth"))

# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)


# Initialize the optimizer and the loss function
#optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-6, weight_decay=0)
#Try DiceFocalLoss, FocalLoss, DiceCELoss
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

#Training loop
num_epochs = 10

device = "cuda" if torch.cuda.is_available() else "cpu"
print("divice: " + device)
model.to(device)

epoch_count = 31

start_time = time.perf_counter()
model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
      # forward pass
      outputs = model(pixel_values=batch["pixel_values"].to(device),
                      input_points=batch["input_points"].to(device),
                      multimask_output=False)

      # compute loss
      predicted_masks = outputs.pred_masks.squeeze(1)
      predicted_masks = F.interpolate(predicted_masks, size=(250, 250), mode='bilinear', align_corners=False)
      ground_truth_masks = batch["ground_truth_mask"].float().to(device)
      loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

      # backward pass (compute gradients of parameters w.r.t. loss)
      optimizer.zero_grad()
      loss.backward()

      # optimize
      optimizer.step()
      epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch_count}')
    print(f'Mean loss: {np.mean(epoch_losses)}')

    write_header = not os.path.exists("sam_green_cup.csv")

    with open("sam_green_cup.csv", mode="a", newline="") as file:
      writer = csv.writer(file)
      if write_header:
          writer.writerow(["time", "epoch", "loss"])
      writer.writerow([time.perf_counter()-start_time,epoch_count+1, epoch_losses])

    f1 = "/content/sam_b_cup_green_"+str(epoch_count)+".pth"
    f0 = "/content/sam_b_cup_green_"+str(epoch_count-1)+".pth"

    torch.save(model.state_dict(), f1)

    if os.path.exists(f0):
        os.remove(f0)
        print("file: "+ "sam_b_cup_green_"+str(epoch_count-1)+".pth"+ " updated!")
    else:
      print("file: "+ "sam_b_cup_green_"+str(epoch_count)+".pth"+" created!")

    epoch_count += 1
