<a href="https://colab.research.google.com/github/bnsreenu/python_for_microscopists/blob/master/331_fine_tune_SAM_mito.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### 1 Import relevant libraries

In [None]:
# import relevant libraries
from PIL import Image # read and save images
import numpy as np # for data manipulation
import torch # for PyTorch framework
import torchvision.transforms.functional as TF # for tensor transformations
from torch.utils.data import Dataset, DataLoader # for creating custom datasets and data loaders

### 2 Define custom SAM dataset class

In [None]:
class SAMDataset(Dataset):
  def __init__(self, images:list, masks:list, processor, transforms=None):
    '''
    Initializes SAM dataset.

    Args:
        images (list): List of PIL images.
        masks (list): Corresponding list of masks for images.
        processor: Processing function that prepares input for the SAM model.
        transforms (callable, optional): A function that takes in a PIL image and mask and returns a transformed version (e.g., data augmentation)
    '''
    self.images = images
    self.masks = masks
    self.processor = processor
    self.transforms = transforms

  def __len__(self):
    '''
    Returns the number of items in dataset.
    '''
    return len(self.images)

  def __getitem__(self, idx):
    '''
    Retrieves an image and its corresponding mask from the dataset.

    Args:
        idx: The index of the item to retrieve.

    Returns:
        inputs (dict): Dictionary containing image, mask, and bounding box prompt.
    '''
    # load image and mask
    image = self.images[idx]
    mask = self.masks[idx]

    # convert to numpy arrays
    image = np.array(image)
    mask = np.array(mask)

    # apply defined transformations, if any
    if self.transforms is not None:
        image, mask = self.transforms(image, mask*255)
        mask = mask/255

    # get indices of points within the mask
    y_indices, x_indices = np.where(mask > 0)

    # generate 1000 random indices
    random_indices = np.random.choice(len(x_indices), size=1750, replace=False)

    # generate 1750 input points as prompts for the model training
    random_points_x = x_indices[random_indices]
    random_points_y = y_indices[random_indices]
    prompt = np.column_stack((random_points_x, random_points_y))

    # convert images and masks to PyTorch tensors
    image_tensor = TF.to_tensor(image)
    mask_tensor = TF.to_tensor(mask)

    # prepare image and prompt for the model using SAM processor
    inputs = self.processor(image, input_points=[[prompt]], return_tensors="pt")

    # remove batch dimension which the processor adds by default
    inputs = {k:v.squeeze(0) for k,v in inputs.items()}

    # include ground truth mask as part of the inputs
    inputs['ground_truth_mask'] = mask_tensor

    return inputs