<a href="https://colab.research.google.com/github/bnsreenu/python_for_microscopists/blob/master/331_fine_tune_SAM_mito.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### 1 Import relevant libraries

In [None]:
# import relevant libraries
import numpy as np # for data manipulation
from PIL import Image # read and save images
import torch # for PyTorch framework
import torchvision.transforms.functional as TF # for tensor transformations and normalization
from torch.utils.data import Dataset # for creating custom datasets and data loaders

#### 2 Define U-Net dataset class and create DataLoader

In [None]:
# create custom dataset class for PyTorch
class UNETDataset(Dataset):
    def __init__(self, images:list, masks:list, transforms=None):
        '''
        Initializes images and masks.

        Args:
            images (list): List of images.
            masks (list): List of masks.
            transforms (torchvision.transforms): Optional: torchvision.transforms instance containing desired transformations.
        '''
        self.images = images
        self.masks = masks
        self.transforms = transforms

    def __len__(self):
        '''
        Returns the number of items in the dataset.
        '''
        return len(self.images)

    def __getitem__(self, idx):
        '''
        Retrieves an image and its corresponding mask from the dataset.

        Args:
            idx: The index of the item to retrieve.

        Returns:
            image: Image tensor.
            mask: Mask tensor.
        '''
        image = self.images[idx]
        mask = self.masks[idx]

        # apply transformations if defined
        if self.transforms is not None:
            image, mask = self.transforms(image, mask*255)
            mask = mask/255

        # convert image and mask to PyTorch tensors
        image = TF.to_tensor(image)
        mask = torch.tensor(mask, dtype=torch.long)
        # normalize image
        image = TF.normalize(image, mean=[0.0], std=[1.0])

        return image, mask