<a href="https://colab.research.google.com/github/bnsreenu/python_for_microscopists/blob/master/331_fine_tune_SAM_mito.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### 1 Import relevant libraries and define relevant paths

In [None]:
# import relevant libraries
import torch # for deep learning
import numpy as np # for data manipulation
from sklearn.metrics import accuracy_score, precision_score, recall_score # for calculation of metrics

### 2 Define evaluation metrics which have no built-in functions

In [None]:
# define dice score
def dice_score(y_pred, y_true):
    '''
    Calculate the dice score between predicted masks and ground truth masks.

    Args:
        y_pred (torch.Tensor): Predicted segmentation masks.
        y_true (torch.Tensor): Ground truth segmentation masks.

    Returns:
        dice (float): Dice score.
    '''

    # calculate intersection and total area (separate)
    intersection = (y_pred * y_true).sum()
    total_area = y_pred.sum() + y_true.sum()

    # use formula
    dice = (2 * intersection) / (total_area)

    return dice

In [None]:
# define iou
def iou_score(y_pred, y_true):
    '''
    Calculate the Intersection over Union (IoU) score between predicted masks and ground truth masks.

    Args:
        y_pred (torch.Tensor): Predicted segmentation masks.
        y_true (torch.Tensor): Ground truth segmentation masks.

    Returns:
        iou (float): IoU score.
    '''

    # calculate intersection and union
    intersection = (y_pred * y_true).sum()
    union = y_pred.sum() + y_true.sum() - intersection

    # use formula
    iou = intersection / union

    return iou

In [None]:
# calculate metrics for given thresholds
def calculate_metrics(y_true, y_prob, thresholds):
    '''
    Calculate evaluation metrics for predicted segmentation masks compared to ground truth masks at different thresholds.

    Args:
        y_true (np.array): Ground truth segmentation masks (binary array).
        y_prob (np.array): Predicted probabilities of segmentation masks.
        thresholds (list): List of thresholds for binary conversion of probabilities.

    Returns:
        metrics (dict): Dictionary containing evaluation metrics calculated for each threshold.
                        Keys represent thresholds, and values are dictionaries containing metrics.
                        Metrics include accuracy, precision, recall, dice score, and Intersection over Union (IoU).
    '''
    metrics = {}
    for threshold in thresholds:
        y_pred = (y_prob > threshold).astype(np.uint8)
        accuracy = accuracy_score(y_true.flatten(), y_pred.flatten())
        precision = precision_score(y_true.flatten(), y_pred.flatten(), zero_division=0)
        recall = recall_score(y_true.flatten(), y_pred.flatten())
        dice = dice_score(y_pred.flatten(), y_true.flatten())
        iou = iou_score(y_pred.flatten(), y_true.flatten())
        metrics[f'threshold_{threshold}'] = {'accuracy': accuracy,
                                             'precision': precision,
                                             'recall': recall,
                                             'dice': dice,
                                             'iou': iou}
    return metrics

In [None]:
# sigmoid function
def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))