Skip to content

Commit

Permalink
feat/fix: Update loss function for new model output shape
Browse files Browse the repository at this point in the history
* The model output tensor shape changed with the introduction of the
AnchorBox layer. This updates the loss function accordingly.
* A new argument `n_neg_min` was introduced. It allows to optionally
specify a minimum number of negatives to enter the loss computation in
batches where there are very few, or even none at all, positives.
* Fix: Fixed a bug where the loss function would divide by zero if
there are no positives in a batch. The function is now robust to this
case.
  • Loading branch information
pierluigiferrari committed Apr 6, 2017
1 parent 0ab5a34 commit d533214
Showing 1 changed file with 46 additions and 43 deletions.
89 changes: 46 additions & 43 deletions keras_ssd_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,33 @@

class SSD_Loss:
'''
The SSD loss. This implementation has an important difference to the loss
function used in https://arxiv.org/abs/1512.02325: The paper regresses to
`(cx, cy, w, h)`, i.e. to the box center x and y coordinates and the box width
and height, while this implementation regresses to `(xmin, xmax, ymin, ymax)`,
i.e. to the horizontal and vertical min and max box coordinates. This is relevant
for the normalization performed in `smooth_L1_loss()`. If it weren't for this
normalization, the format of the four box coordinates wouldn't matter for this
loss function as long as it would be consistent between `y_true` and `y_pred`.
The SSD loss, see https://arxiv.org/abs/1512.02325.
'''

def __init__(self,
loc_norm,
neg_pos_ratio=3,
n_neg_min=0,
alpha=1.0):
'''
Arguments:
loc_norm (array): A Numpy array with shape `(batch_size, #boxes, 4)`,
where the last dimension contains the default box widths and heights
in the format `(width, width, height, height)`. This is used for
normalization in `smooth_L1_loss`.
neg_pos_ratio (int): The maximum number of negative (i.e. background)
neg_pos_ratio (int, optional): The maximum number of negative (i.e. background)
ground truth boxes to include in the loss computation. There are no
actual background ground truth boxes of course, but `y_true`
contains default boxes labeled with the background class. Since
the number of background boxes in `y_true` will ususally exceed
the number of positive boxes by far, it is necessary to balance
their influence on the loss. Defaults to 3 following the paper.
alpha (float): A factor to weight the localization loss in the
n_neg_min (int, optional): The minimum number of negative ground truth boxes to
enter the loss computation *per batch*. This argument can be used to make
sure that the model learns from a minimum number of negatives in batches
in which there are very few, or even none at all, positive ground truth
boxes. It defaults to 0 and if used, it should be set to a value that
stands in reasonable proportion to the batch size used for training.
alpha (float, optional): A factor to weight the localization loss in the
computation of the total loss. Defaults to 1.0 following the paper.
'''
self.loc_norm = loc_norm
self.neg_pos_ratio = tf.constant(neg_pos_ratio)
self.n_neg_min = tf.constant(n_neg_min)
self.alpha = tf.constant(alpha)

def smooth_L1_loss(self, y_true, y_pred):
Expand All @@ -57,11 +52,8 @@ def smooth_L1_loss(self, y_true, y_pred):
References:
https://arxiv.org/abs/1504.08083
'''
# In order to normalize the localization loss, we perform element-wise division by the default box widths and heights.
# Deviations in xmin and xmax are divided by their respective default box widths, deviations in ymin and ymax are divided
# by their respective default box heights.
absolute_loss = tf.abs(y_true - y_pred) / self.loc_norm
square_loss = 0.5 * (y_true - y_pred)**2 / self.loc_norm
absolute_loss = tf.abs(y_true - y_pred)
square_loss = 0.5 * (y_true - y_pred)**2
l1_loss = tf.where(tf.less(absolute_loss, 1.0), square_loss, absolute_loss - 0.5)
return tf.reduce_sum(l1_loss, axis=-1)

Expand Down Expand Up @@ -91,38 +83,41 @@ def compute_loss(self, y_true, y_pred):
Compute the loss of the SSD model prediction against the ground truth.
Arguments:
y_true (array): A Numpy array of shape `(batch_size, #boxes, #classes + 4)`,
y_true (array): A Numpy array of shape `(batch_size, #boxes, #classes + 8)`,
where `#boxes` is the total number of boxes that the model predicts
per image. Be careful to make sure that the index of each given
box in `y_true` is the same as the index for the corresponding
box in `y_pred`. The last dimension must contain
`[classes 1-hot encoded, 4 box coordinates]` in this order,
including the background class.
box in `y_pred`. The last axis must have length `#classes + 8` and contain
`[classes one-hot encoded, 4 ground truth box coordinates, 4 arbitrary entries]`
in this order, including the background class. The last four entries of the
last axis are not used by this function and therefore their contents are
irrelevant, they only exist so that `y_true` has the same shape as `y_pred`,
where the last four entries of the last axis contain the anchor box
coordinates, which are needed during inference. Important: Boxes that
you want the cost function to ignore need to have a one-hot
class vector of all zeros.
y_pred (Keras tensor): The model prediction. The shape is identical
to that of `y_true`.
Returns:
A scalar, the total multitask loss for classification and localization.
'''
batch_size = tf.shape(y_pred)[0] # tf.int32
n_boxes = tf.shape(y_pred)[1] # tf.int32
depth = tf.shape(y_pred)[2] # tf.int32
batch_size = tf.shape(y_pred)[0] # Output dtype: tf.int32
n_boxes = tf.shape(y_pred)[1] # Output dtype: tf.int32, note that `n_boxes` in this context denotes the total number of boxes per image, not the number of boxes per cell

# 1: Compute the losses for class and box predictions for each default box
# 1: Compute the losses for class and box predictions for every box

classification_loss = tf.to_float(self.log_loss(y_true[:,:,:-4], y_pred[:,:,:-4])) # Tensor of shape (batch_size, n_boxes)
localization_loss = tf.to_float(self.smooth_L1_loss(y_true[:,:,-4:], y_pred[:,:,-4:])) # Tensor of shape (batch_size, n_boxes)
classification_loss = tf.to_float(self.log_loss(y_true[:,:,:-8], y_pred[:,:,:-8])) # Output shape: (batch_size, n_boxes)
localization_loss = tf.to_float(self.smooth_L1_loss(y_true[:,:,-8:-4], y_pred[:,:,-8:-4])) # Output shape: (batch_size, n_boxes)

# 2: Compute the classification losses for the positive and negative targets

# Count the number of positive (classes [1:]) and negative (class 0) boxes in y_true across the whole batch
n_boxes_batch = batch_size * n_boxes # tf.int32
n_negative = tf.to_int32(tf.reduce_sum(y_true[:,:,0]))
n_positive = n_boxes_batch - n_negative

# Create masks for the positive and negative ground truth classes
negatives = y_true[:,:,0] # Tensor of shape (batch_size, n_boxes)
positives = 1 - negatives # Tensor of shape (batch_size, n_boxes)
positives = tf.to_float(tf.reduce_max(y_true[:,:,1:-8], axis=-1)) # Tensor of shape (batch_size, n_boxes)

# Count the number of positive boxes (classes 1 to n) in y_true across the whole batch
n_positive = tf.reduce_sum(positives)

# Now mask all negative boxes and sum up the losses for the positive boxes PER batch item
# (Keras loss functions must output one scalar loss value PER batch item, rather than just
Expand All @@ -134,17 +129,25 @@ def compute_loss(self, y_true, y_pred):
# First, compute the classification loss for all negative boxes
neg_class_loss_all = classification_loss * negatives # Tensor of shape (batch_size, n_boxes)
n_neg_losses = tf.count_nonzero(neg_class_loss_all, dtype=tf.int32) # The number of non-zero loss entries in `neg_class_loss_all`
# What's the point of `n_neg_losses`? For the next step, which will be to compute which negative boxes enter the classification
# loss, we don't just want to know how many negative ground truth boxes there are, but for how many of those there actually is
# a positive (i.e. non-zero) loss. This is necessary because `tf.nn.top-k()` in the function below will pick the top k boxes with
# the highest losses no matter what, even if it receives a vector where all losses are zero. In the unlikely event that all negative
# classification losses ARE actually zero though, this behavior might lead to `tf.nn.top-k()` returning the indices of positive
# boxes, leading to an incorrect negative classification loss computation, and hence an incorrect overall loss computation.
# We therefore need to make sure that `n_negative_keep`, which assumes the role of the `k` argument in `tf.nn.top-k()`,
# is at most the number of negative boxes for which there is a positive classification loss.

# Compute the number of negative examples we want to account for in the loss
# We'll keep at most `self.neg_pos_ratio` times the number of positives in `y_true`, but at least `self.n_neg_min` (unless `n_neg_loses` is smaller)
n_negative_keep = tf.minimum(tf.maximum(self.neg_pos_ratio * tf.to_int32(n_positive), self.n_neg_min), n_neg_losses)

# In the unlikely case when either (1) there are no negative ground truth boxes at all
# or (2) the classification loss for all negative boxes is zero, return zero as the `neg_class_loss`
def f1():
return tf.zeros([batch_size])
# Otherwise compute the negative loss
def f2():
# Compute the number of negative examples we want to account for in the loss
# We'll keep at most `self.neg_pos_ratio` times the number of positives in `y_true`
n_negative_keep = tf.to_int32(tf.minimum(self.neg_pos_ratio * n_positive, n_neg_losses))

# Now we'll identify the top-k (where k == `n_negative_keep`) boxes with the highest confidence loss that
# belong to the background class in the ground truth data. Note that this doesn't necessarily mean that the model
# predicted the wrong class for those boxes, it just means that the loss for those boxes is the highest.
Expand All @@ -162,7 +165,7 @@ def f2():

neg_class_loss = tf.cond(tf.equal(n_neg_losses, tf.constant(0)), f1, f2)

class_loss = pos_class_loss + neg_class_loss
class_loss = pos_class_loss + neg_class_loss # Tensor of shape (batch_size,)

# 3: Compute the localization loss for the positive targets
# We don't penalize localization loss for negative predicted boxes (obviously: there are no ground truth boxes they would correspond to)
Expand All @@ -171,6 +174,6 @@ def f2():

# 4: Compute the total loss

total_loss = (class_loss + self.alpha * loc_loss) / tf.to_float(n_positive)
total_loss = (class_loss + self.alpha * loc_loss) / tf.maximum(1.0, n_positive) # In case `n_positive == 0`

return total_loss

0 comments on commit d533214

Please sign in to comment.