<a href="https://colab.research.google.com/github/bnsreenu/python_for_microscopists/blob/master/331_fine_tune_SAM_mito.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### 1 Import relevant libraries

In [None]:
# import relevant libraries
import numpy as np # for data manipulation
import torch # for PyTorch framework

### 2 Define prompt function for new images, i.e., inference

When predicting new images, masks are not given and therefore the bounding box prompt does not work. For prediction a grid is created specifying potential spatial points where the retinal vessels could be to improve the performance of the model.

In [None]:
# function to create input points (prompt) for new images
def generate_input_points(image_size:int=1024, grid_size:int=100, batch_size:int=1, point_batch_size:int=1):
    '''
    Generates a tensor of 2D spatial points arranged in a grid, suitable for model input.

    Args:
        image_size (int): Size of the square image over which to distribute points.
        grid_size (int): Number of points along one dimension of the grid, resulting in grid_size x grid_size points.
        batch_size (int): Number of images being processed in a single batch.
        point_batch_size (int): Number of point sets for each image.

    Returns:
        input_points (torch.FloatTensor): Input tensor in shape (batch_size, point_batch_size, num_points_per_image, 2): last dimension represents x and y coordinates of each point.
    '''

    # generate evenly spaced grid points
    x = np.linspace(0, image_size-1, grid_size)
    y = np.linspace(0, image_size-1, grid_size)
    # create grid
    xv, yv = np.meshgrid(x, y)

    # combine x and y coordinates
    input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv.tolist(), yv.tolist())]

    # get required tensor shape (batch_size, point_batch_size, num_points_per_image, 2)
    # last dimension of 2 represents x and y coordinates of each point
    input_points = torch.tensor(input_points, dtype=torch.float).view(batch_size, point_batch_size, grid_size*grid_size, 2)

    return input_points