In [None]:
import tensorflow as tf
def dice_coefficient(spartial_axis = (1, 2), ignore_empty=False, smooth = 1e-6):
    ''' Compute dice coefficient
        Dice coefficient = (2 * TP) / (2 * TP + FP + FN)
        
        Parameters
        ----------
        spartial_axis : Tuple
            Define sample axis. For example where y_true = [batch, depth, height, width, channel], (2, 3) for per slice dice coefficient, (1, 2, 3) for 3D volume dice coefficient
        ignore_empty  : bool
            Exclude cases where both masks are empty from calculation
        smooth        : float
            Smoothing constant
        
        Returns
        -------
        function
            Dice coefficient function that can be used. Example: model.compile(..., metrics[dice_coefficient( ( 2, 3), True, 1e-6 )])
    '''
    def dice_coefficient(y_true, y_pred):
        y_pred = tf.math.round(y_pred) # Round y_pred to nearest integer. You can change to thresholding
        tp = tf.math.reduce_sum(y_true * y_pred, axis=spartial_axis) # calculate True Positive
        fn = tf.math.reduce_sum(y_true * (1 - y_pred), axis=spartial_axis) # calculate False Negative
        fp = tf.math.reduce_sum((1 - y_true) * y_pred, axis=spartial_axis) # calculate False Positive 
        
        # Do not add smooth constant here. If you add smooth constant here you cannot count non-zero mask
        numerator = 2 * tp 
        denominator = 2 * tp + fn + fp
        
        if ignore_empty:
            # Since denominator is essentially the sum of y_true and y_pred, if boths are empty, the sum will be 0. So we can easily count non_empty mask
            non_empty = tf.math.count_nonzero(denominator, dtype=tf.float32)
            # Now we only add smooth constant to the denominator because if we also add to the numerator, it will produce 1 where both masks are empty. We want to eliminate these cases. 
            # Smooth constant is also added to non_empty in case of zero division
            return tf.math.reduce_sum(numerator / (denominator + smooth)) / (non_empty + smooth)
        
        # Normal way of computing dice coefficient where both masks are empty will be rewarded with dice of 1.
        return tf.math.reduce_mean( (numerator + smooth) / (denominator + smooth) )
    
    return dice_coefficient