From 338868e45afcbd9b40d1a0bb387f676f1ac811cc Mon Sep 17 00:00:00 2001 From: Biant Shkololli Date: Thu, 19 Oct 2023 20:01:53 +0200 Subject: [PATCH 1/2] Rewrite _postprocess_keypoints_multi_class to support TFLite and TFJS conversion --- .../center_net_meta_arch.py | 107 +++++++----------- 1 file changed, 42 insertions(+), 65 deletions(-) diff --git a/research/object_detection/meta_architectures/center_net_meta_arch.py b/research/object_detection/meta_architectures/center_net_meta_arch.py index 66a87431d04..8605f710042 100644 --- a/research/object_detection/meta_architectures/center_net_meta_arch.py +++ b/research/object_detection/meta_architectures/center_net_meta_arch.py @@ -4433,10 +4433,8 @@ def _postprocess_keypoints_multi_class(self, prediction_dict, classes, This is the most general keypoint postprocessing function which supports multiple keypoint tasks (e.g. human and dog keypoints) and multiple object - detection classes. Note that it is the most expensive postprocessing logics - and is currently not tf.lite/tf.js compatible. See - _postprocess_keypoints_single_class if you plan to export the model in more - portable format. + detection classes. Note that it is a more expensive postprocessing logic + compared to _postprocess_keypoints_single_class. Args: prediction_dict: a dictionary holding predicted tensors, returned from the @@ -4460,19 +4458,14 @@ def _postprocess_keypoints_multi_class(self, prediction_dict, classes, keypoint_scores: a [batch_size, max_detections, num_total_keypoints] float32 tensor with keypoint scores. """ - total_num_keypoints = sum(len(kp_dict.keypoint_indices) for kp_dict - in self._kp_params_dict.values()) - batch_size, max_detections = _get_shape(classes, 2) - kpt_coords_for_example_list = [] - kpt_scores_for_example_list = [] + kpt_coords_combined = [] + kpt_scores_combined = [] + batch_size, _ = _get_shape(classes, 2) + for ex_ind in range(batch_size): - # The tensors that host the keypoint coordinates and scores for all - # instances and all keypoints. They will be updated by scatter_nd_add for - # each keypoint tasks. - kpt_coords_for_example_all_det = tf.zeros( - [max_detections, total_num_keypoints, 2]) - kpt_scores_for_example_all_det = tf.zeros( - [max_detections, total_num_keypoints]) + kpt_coords_for_example = [] + kpt_scores_for_example = [] + for task_name, kp_params in self._kp_params_dict.items(): keypoint_heatmap = prediction_dict[ get_keypoint_name(task_name, KEYPOINT_HEATMAP)][-1] @@ -4480,70 +4473,54 @@ def _postprocess_keypoints_multi_class(self, prediction_dict, classes, get_keypoint_name(task_name, KEYPOINT_OFFSET)][-1] keypoint_regression = prediction_dict[ get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1] - instance_inds = self._get_instance_indices( - classes, num_detections, ex_ind, kp_params.class_id) - - # Gather the feature map locations corresponding to the object class. - y_indices_for_kpt_class = tf.gather(y_indices, instance_inds, axis=1) - x_indices_for_kpt_class = tf.gather(x_indices, instance_inds, axis=1) - if boxes is None: - boxes_for_kpt_class = None - else: - boxes_for_kpt_class = tf.gather(boxes, instance_inds, axis=1) - # Postprocess keypoints and scores for class and single image. Shapes - # are [1, num_instances_i, num_keypoints_i, 2] and - # [1, num_instances_i, num_keypoints_i], respectively. Note that - # num_instances_i and num_keypoints_i refers to the number of - # instances and keypoints for class i, respectively. + # Postprocess keypoints and scores for class and single image. + # Shapes are [1, max_detections, num_keypoints, 2] and + # [1, max_detections, num_keypoints], respectively. (kpt_coords_for_class, kpt_scores_for_class, _) = ( self._postprocess_keypoints_for_class_and_image( keypoint_heatmap, keypoint_offsets, keypoint_regression, classes, - y_indices_for_kpt_class, - x_indices_for_kpt_class, - boxes_for_kpt_class, + y_indices, + x_indices, + boxes, ex_ind, kp_params, )) - # Prepare the indices for scatter_nd. The resulting combined_inds has - # the shape of [num_instances_i * num_keypoints_i, 2], where the first - # column corresponds to the instance IDs and the second column - # corresponds to the keypoint IDs. - kpt_inds = tf.constant(kp_params.keypoint_indices, dtype=tf.int32) - kpt_inds = tf.expand_dims(kpt_inds, axis=0) - instance_inds_expand = tf.expand_dims(instance_inds, axis=-1) - kpt_inds_expand = kpt_inds * tf.ones_like(instance_inds_expand) - instance_inds_expand = instance_inds_expand * tf.ones_like(kpt_inds) - combined_inds = tf.stack( - [instance_inds_expand, kpt_inds_expand], axis=2) - combined_inds = tf.reshape(combined_inds, [-1, 2]) - - # Reshape the keypoint coordinates/scores to [num_instances_i * - # num_keypoints_i, 2]/[num_instances_i * num_keypoints_i] to be used - # by scatter_nd_add. - kpt_coords_for_class = tf.reshape(kpt_coords_for_class, [-1, 2]) - kpt_scores_for_class = tf.reshape(kpt_scores_for_class, [-1]) - kpt_coords_for_example_all_det = tf.tensor_scatter_nd_add( - kpt_coords_for_example_all_det, - combined_inds, kpt_coords_for_class) - kpt_scores_for_example_all_det = tf.tensor_scatter_nd_add( - kpt_scores_for_example_all_det, - combined_inds, kpt_scores_for_class) - - kpt_coords_for_example_list.append( - tf.expand_dims(kpt_coords_for_example_all_det, axis=0)) - kpt_scores_for_example_list.append( - tf.expand_dims(kpt_scores_for_example_all_det, axis=0)) + # Set all keypoint coordinates and scores to zeros except for those + # whose class corresponds to the task in the current iteration. + mask_for_class = classes[ex_ind] == kp_params.class_id + mask_scores_for_class = mask_for_class[..., tf.newaxis] + mask_coords_for_class = mask_scores_for_class[..., tf.newaxis] + kpt_coords_for_class = tf2.where(mask_coords_for_class, + kpt_coords_for_class, + tf.zeros_like( + kpt_coords_for_class)) + kpt_scores_for_class = tf2.where(mask_scores_for_class, + kpt_scores_for_class, + tf.zeros_like( + kpt_scores_for_class)) + + kpt_coords_for_example.append(kpt_coords_for_class) + kpt_scores_for_example.append(kpt_scores_for_class) + + # Concatenate keypoints and scores from all classes in the example. + # Shapes are [1, max_detections, num_total_keypoints, 2] and + # [1, max_detections, num_total_keypoints], respectively. + kpt_coords_for_example = tf.concat(kpt_coords_for_example, axis=2) + kpt_scores_for_example = tf.concat(kpt_scores_for_example, axis=2) + + kpt_coords_combined.append(kpt_coords_for_example) + kpt_scores_combined.append(kpt_scores_for_example) # Concatenate all keypoints and scores from all examples in the batch. # Shapes are [batch_size, max_detections, num_total_keypoints, 2] and # [batch_size, max_detections, num_total_keypoints], respectively. - keypoints = tf.concat(kpt_coords_for_example_list, axis=0) - keypoint_scores = tf.concat(kpt_scores_for_example_list, axis=0) + keypoints = tf.concat(kpt_coords_combined, axis=0) + keypoint_scores = tf.concat(kpt_scores_combined, axis=0) return keypoints, keypoint_scores From d9a4816b48396897f0286e6404b8f5e7764d197e Mon Sep 17 00:00:00 2001 From: Biant Shkololli Date: Thu, 19 Oct 2023 20:32:46 +0200 Subject: [PATCH 2/2] Add test for _postprocess_keypoints_multi_class --- .../center_net_meta_arch_tf2_test.py | 161 ++++++++++++++++++ 1 file changed, 161 insertions(+) diff --git a/research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py b/research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py index 02d38d12678..b76d457c261 100644 --- a/research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py +++ b/research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py @@ -2529,6 +2529,167 @@ def graph_fn(): self.assertAllClose(detections['detection_scores'][0][:num_detections], [0.675]) + @parameterized.parameters( + { + 'candidate_ranking_mode': 'min_distance', + 'argmax_postprocessing': False + }, + { + 'candidate_ranking_mode': 'score_distance_ratio', + 'argmax_postprocessing': True + }) + def test_postprocess_multi_class(self, candidate_ranking_mode, + argmax_postprocessing): + """Test the postprocess function for multiple classes.""" + feature_extractor = DummyFeatureExtractor( + channel_means=(1.0, 2.0, 3.0), + channel_stds=(10., 20., 30.), + bgr_ordering=False, + num_feature_outputs=2, + stride=4) + image_resizer_fn = functools.partial( + preprocessor.resize_to_range, + min_dimension=128, + max_dimension=128, + pad_to_max_dimesnion=True) + + kp_params_1 = cnma.KeypointEstimationParams( + task_name='kpt_task_1', + class_id=0, + keypoint_indices=[0, 1], + keypoint_std_dev=[0.00001] * 2, + classification_loss=losses.WeightedSigmoidClassificationLoss(), + localization_loss=losses.L1LocalizationLoss(), + keypoint_candidate_score_threshold=0.1, + candidate_ranking_mode=candidate_ranking_mode, + argmax_postprocessing=argmax_postprocessing) + kp_params_2 = cnma.KeypointEstimationParams( + task_name='kpt_task_2', + class_id=1, + keypoint_indices=[2, 3, 4], + keypoint_std_dev=[0.00001] * 3, + classification_loss=losses.WeightedSigmoidClassificationLoss(), + localization_loss=losses.L1LocalizationLoss(), + keypoint_candidate_score_threshold=0.1, + candidate_ranking_mode=candidate_ranking_mode, + argmax_postprocessing=argmax_postprocessing) + model = cnma.CenterNetMetaArch( + is_training=True, + add_summaries=False, + num_classes=2, + feature_extractor=feature_extractor, + image_resizer_fn=image_resizer_fn, + object_center_params=get_fake_center_params(), + object_detection_params=get_fake_od_params(), + keypoint_params_dict={ + 'kpt_task_1': kp_params_1, + 'kpt_task_2': kp_params_2, + }) + max_detection = model._center_params.max_box_predictions + kp_params_dict = model._kp_params_dict + num_keypoints_task_1 = len(kp_params_dict['kpt_task_1'].keypoint_indices) + num_keypoints_task_2 = len(kp_params_dict['kpt_task_2'].keypoint_indices) + num_keypoints = num_keypoints_task_1 + num_keypoints_task_2 + + class_center = np.zeros((1, 32, 32, 2), dtype=np.float32) + height_width = np.zeros((1, 32, 32, 2), dtype=np.float32) + offset = np.zeros((1, 32, 32, 2), dtype=np.float32) + + class_probs = np.zeros(2) + class_probs[0] = _logit(0.75) + class_probs[1] = _logit(0.75) + class_center[0, 16, 16] = class_probs + height_width[0, 16, 16] = [5, 10] + offset[0, 16, 16] = [.25, .5] + + keypoint_heatmaps_task_1 = np.ones( + (1, 32, 32, num_keypoints_task_1), dtype=np.float32) * _logit(0.01) + keypoint_offsets_task_1 = np.zeros( + (1, 32, 32, num_keypoints_task_1 * 2), dtype=np.float32) + keypoint_regression_task_1 = np.random.randn(1, 32, 32, + num_keypoints_task_1 * 2) + + keypoint_regression_task_1[0, 16, 16] = [ + -1., -1., + -1., 1.] + keypoint_heatmaps_task_1[0, 14, 14, 0] = _logit(0.9) + keypoint_heatmaps_task_1[0, 14, 18, 1] = _logit(0.05) # Note the low score. + + keypoint_heatmaps_task_2 = np.ones( + (1, 32, 32, num_keypoints_task_2), dtype=np.float32) * _logit(0.01) + keypoint_offsets_task_2 = np.zeros( + (1, 32, 32, num_keypoints_task_2 * 2), dtype=np.float32) + keypoint_regression_task_2 = np.random.randn(1, 32, 32, + num_keypoints_task_2 * 2) + + keypoint_regression_task_2[0, 16, 16] = [ + -1., -1., + -1., 1., + 1, -1] + keypoint_heatmaps_task_2[0, 14, 14, 0] = _logit(0.9) + keypoint_heatmaps_task_2[0, 14, 18, 1] = _logit(0.9) + keypoint_heatmaps_task_2[0, 14, 18, 2] = _logit(0.05) # Note the low score. + + class_center = tf.constant(class_center) + height_width = tf.constant(height_width) + offset = tf.constant(offset) + keypoint_heatmaps_task_1 = tf.constant( + keypoint_heatmaps_task_1, dtype=tf.float32) + keypoint_offsets_task_1 = tf.constant( + keypoint_offsets_task_1, dtype=tf.float32) + keypoint_regression_task_1 = tf.constant( + keypoint_regression_task_1, dtype=tf.float32) + keypoint_heatmaps_task_2 = tf.constant( + keypoint_heatmaps_task_2, dtype=tf.float32) + keypoint_offsets_task_2 = tf.constant( + keypoint_offsets_task_2, dtype=tf.float32) + keypoint_regression_task_2 = tf.constant( + keypoint_regression_task_2, dtype=tf.float32) + + prediction_dict = { + cnma.OBJECT_CENTER: [class_center], + cnma.BOX_SCALE: [height_width], + cnma.BOX_OFFSET: [offset], + cnma.get_keypoint_name(kp_params_1.task_name, cnma.KEYPOINT_HEATMAP): + [keypoint_heatmaps_task_1], + cnma.get_keypoint_name(kp_params_1.task_name, cnma.KEYPOINT_OFFSET): + [keypoint_offsets_task_1], + cnma.get_keypoint_name(kp_params_1.task_name, cnma.KEYPOINT_REGRESSION): + [keypoint_regression_task_1], + cnma.get_keypoint_name(kp_params_2.task_name, cnma.KEYPOINT_HEATMAP): + [keypoint_heatmaps_task_2], + cnma.get_keypoint_name(kp_params_2.task_name, cnma.KEYPOINT_OFFSET): + [keypoint_offsets_task_2], + cnma.get_keypoint_name(kp_params_2.task_name, cnma.KEYPOINT_REGRESSION): + [keypoint_regression_task_2] + } + + def graph_fn(): + detections = model.postprocess(prediction_dict, + tf.constant([[128, 128, 3]])) + return detections + + detections = self.execute_cpu(graph_fn, []) + + self.assertAllClose(detections['detection_boxes'][0, 0], + np.array([55, 46, 75, 86]) / 128.0) + self.assertAllClose(detections['detection_scores'][0], + [.75, .75, .5, .5, .5]) + + self.assertAllEqual(detections['detection_classes'][0], [0, 1, 0, 1, 0]) + self.assertEqual(detections['num_detections'], [5]) + self.assertAllEqual([1, max_detection, num_keypoints, 2], + detections['detection_keypoints'].shape) + self.assertAllClose( + [[0.4375, 0.4375], [0.46875, 0.53125], [0, 0], [0, 0], [0, 0]], + detections['detection_keypoints'][0, 0, :, :]) + self.assertAllClose( + [[0, 0], [0, 0], [0.4375, 0.4375], [0.4375, 0.5625], + [0.53125, 0.46875]], + detections['detection_keypoints'][0, 1, :, :]) + self.assertAllEqual([1, max_detection, num_keypoints], + detections['detection_keypoint_scores'].shape) + @parameterized.parameters( { 'candidate_ranking_mode': 'min_distance',