In [6]:
import tensorflow as tf
from keras.losses import Loss

@tf.keras.utils.register_keras_serializable(package="Custom", name="ChamferDistanceMasked")
class ChamferDistanceMasked(Loss):
    def __init__(self, padding_val=-1.0, name="chamfer_distance_masked"):
        super().__init__(name=name)
        self.padding_val = padding_val

    def call(self, y_true, y_pred):
        """
        y_true: (B, N, D)
        y_pred: (B, M, D)
        """
        padding_val = self.padding_val
        epsilon = 1e-6  # used for both masking and numerical stability

        # Create masks: (B, N) and (B, M)
        mask_true = tf.reduce_any(tf.abs(y_true - padding_val) > epsilon, axis=-1)  # (B, N)
        mask_pred = tf.reduce_any(tf.abs(y_pred - padding_val) > epsilon, axis=-1)  # (B, M)

        # Compute pairwise squared distances: (B, N, M)
        y_true_exp = tf.expand_dims(y_true, axis=2)  # (B, N, 1, D)
        y_pred_exp = tf.expand_dims(y_pred, axis=1)  # (B, 1, M, D)
        diff = y_true_exp - y_pred_exp
        dist = tf.reduce_sum(tf.square(diff), axis=-1)  # (B, N, M)

        # Minimum distances
        min_true_to_pred = tf.reduce_min(dist, axis=2)  # (B, N)
        min_pred_to_true = tf.reduce_min(dist, axis=1)  # (B, M)

        # Mask invalid entries
        valid_min_true_to_pred = tf.where(mask_true, min_true_to_pred, tf.zeros_like(min_true_to_pred))
        valid_min_pred_to_true = tf.where(mask_pred, min_pred_to_true, tf.zeros_like(min_pred_to_true))

        # Count valid points
        count_true = tf.reduce_sum(tf.cast(mask_true, tf.float32))
        count_pred = tf.reduce_sum(tf.cast(mask_pred, tf.float32))

        # Mean distances with guard against divide-by-zero
        mean_true_to_pred = tf.where(
            count_true > 0, tf.reduce_sum(valid_min_true_to_pred) / count_true, 0.0
        )
        mean_pred_to_true = tf.where(
            count_pred > 0, tf.reduce_sum(valid_min_pred_to_true) / count_pred, 0.0
        )

        return mean_true_to_pred + mean_pred_to_true


In [7]:
y_true = tf.constant([[[-1.0, -1.0],[0.0, 0.0], [1.0, 1.0]]])  # (1, 3, 2)
y_pred = tf.constant([[[1.0, 1.1], [0.0, 0.1], [-1.0, -1.0]]])  # (1, 3, 2)
loss_fn = ChamferDistanceMasked(padding_val=-1.0)
loss = loss_fn(y_true, y_pred)
print("Chamfer loss:", loss.numpy())

Chamfer loss: 0.020000005
