Skip to content

Commit

Permalink
Adds a proto profile summary formatter to the TFLite benchmark.
Browse files Browse the repository at this point in the history
Adds a Python script to convert benchmark profile protos to a JSON consumable by the model-explorer.

PiperOrigin-RevId: 628143920
  • Loading branch information
tensorflower-gardener committed May 22, 2024
1 parent 98a4c09 commit f10d05f
Show file tree
Hide file tree
Showing 37 changed files with 1,114 additions and 160 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ package_group(
packages = [
"//tensorflow/...",
"//tensorflow_text/...",
"//waymo/ml/compiler/frontend/kernels/...",
"//waymo/onboard/ml/...",
],
)
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/gather_nd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class GatherNdOp : public OpKernel {

Tensor out;
OP_REQUIRES_OK(
c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out));
c, functor::DoGatherNd<Device, T, Index, /*kDropBadIndices=*/false>(
c, params, indices, &out));
c->set_output(0, out);
}
};
Expand Down
7 changes: 6 additions & 1 deletion tensorflow/core/kernels/gather_nd_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ struct GatherNdSlice {
typename TTypes<T>::Matrix Tout);
};

template <typename Device, typename T, typename Index>
template <typename Device, typename T, typename Index,
bool kDropBadIndices = false>
Status DoGatherNd(OpKernelContext* c, const Tensor& params,
const Tensor& indices, Tensor* out) {
if (!TensorShapeUtils::IsVectorOrHigher(params.shape())) {
Expand Down Expand Up @@ -151,6 +152,10 @@ Status DoGatherNd(OpKernelContext* c, const Tensor& params,
indices_nd);
}

if constexpr (kDropBadIndices) {
return absl::OkStatus();
}

// bad_i will only return >= 0 on CPUs right now.
if (bad_i >= 0) {
auto shape = indices.shape();
Expand Down
63 changes: 42 additions & 21 deletions tensorflow/core/kernels/scatter_nd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ class IndexFlattener {
namespace {

template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op>
scatter_nd_op::UpdateOp Op, bool kDropBadIndices>
Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape,
Tensor* out, bool allocate) {
Expand Down Expand Up @@ -925,7 +925,11 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
for (int i = 0; i < IXDIM; ++i) { \
output_shape_prefix[i] = shape.dim_size(i); \
} \
functor::ScatterNdFunctor<Device, T, Index, Op, IXDIM> functor; \
constexpr bool kShallDropBadIndices = \
kDropBadIndices || std::is_same<Device, GPUDevice>::value; \
functor::ScatterNdFunctor<Device, T, Index, Op, IXDIM, \
kShallDropBadIndices> \
functor; \
bad_i = \
functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \
output_matrix, indices_flat, updates_flat, output_matrix); \
Expand All @@ -947,6 +951,9 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
slice_dim);
}
}
if constexpr (kDropBadIndices) {
return absl::OkStatus();
}
if (bad_i >= 0) {
auto slice_shape = indices.shape();
slice_shape.RemoveLastDims(1);
Expand All @@ -970,7 +977,8 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
// back to GPU. This is useful because the CPU implementation is deterministic
// and the GPU implementation is not. Tensor inputs to this function must be on
// the GPU.
template <typename T, typename Index, scatter_nd_op::UpdateOp Op>
template <typename T, typename Index, scatter_nd_op::UpdateOp Op,
bool kDropBadIndices>
Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape,
Tensor* out, bool allocate) {
Expand Down Expand Up @@ -1015,7 +1023,7 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
}

TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
TF_RETURN_IF_ERROR(DoScatterNd<CPUDevice, T, Index, Op>(
TF_RETURN_IF_ERROR(DoScatterNd<CPUDevice, T, Index, Op, kDropBadIndices>(
c, host_indices, host_updates, shape, &host_out, /*allocate=*/false));

// Copy 'host_out' to device.
Expand All @@ -1033,44 +1041,57 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
} // namespace

template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op>
scatter_nd_op::UpdateOp Op, bool kDropBadIndices>
Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape, Tensor* out,
bool allocate) {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (std::is_same<Device, GPUDevice>::value &&
tensorflow::OpDeterminismRequired() && !DisableScatterOpDeterminism()) {
return DoScatterNdOnCpu<T, Index, Op>(c, indices, updates, shape, out,
allocate);
return DoScatterNdOnCpu<T, Index, Op, kDropBadIndices>(
c, indices, updates, shape, out, allocate);
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// Run on the CPU for integer types, since the GPU implementation uses
// atomics, which are not supported for all integer types.
if constexpr (std::is_same<Device, GPUDevice>::value &&
std::is_integral<T>::value) {
return DoScatterNdOnCpu<T, Index, Op>(c, indices, updates, shape, out,
allocate);
return DoScatterNdOnCpu<T, Index, Op, kDropBadIndices>(
c, indices, updates, shape, out, allocate);
} else {
return DoScatterNdImpl<Device, T, Index, Op>(c, indices, updates, shape,
out, allocate);
return DoScatterNdImpl<Device, T, Index, Op, kDropBadIndices>(
c, indices, updates, shape, out, allocate);
}
}
} // namespace functor

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
template <> \
Index ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>::operator()( \
const GPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput); \
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>;
#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
template <> \
Index \
ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, /*kDropBadIndices=*/true>:: \
operator()(const GPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput); \
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, \
/*kDropBadIndices=*/true>; \
template <> \
Index ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, \
/*kDropBadIndices=*/false>:: \
operator()(const GPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput); \
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, \
/*kDropBadIndices=*/false>;

#define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \
DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/scatter_nd_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace functor {

// Functor used by ScatterOp to do the computations.
template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp op, int IXDIM>
scatter_nd_op::UpdateOp op, int IXDIM, bool kDropBadIndices>
struct ScatterNdFunctor {
// Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index.
Index operator()(
Expand All @@ -63,7 +63,7 @@ struct ScatterNdFunctor {
// right type (T) and shape. This tensor will not be zeroed out
// before the scatter is executed.
template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op>
scatter_nd_op::UpdateOp Op, bool kDropBadIndices = false>
Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape, Tensor* out,
bool allocate);
Expand Down
52 changes: 32 additions & 20 deletions tensorflow/core/kernels/scatter_nd_op_cpu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::MAX> {
namespace functor {

// Implementation of update functor for CPU.
template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM>
struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM,
bool kDropBadIndices>
struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM, kDropBadIndices> {
Index operator()(
const CPUDevice& d, const Index slice_size,
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
Expand Down Expand Up @@ -136,33 +137,44 @@ struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
i += ix_d * batch_strides[dim];
}
if (TF_PREDICT_FALSE(out_of_bounds)) {
if constexpr (kDropBadIndices) {
continue;
}
error_loc = loc;
break;
} else {
auto input_chip = Toutput.template chip<0>(i);
auto output_chip = input_chip;
auto update_chip = Tupdates.template chip<0>(loc);
update_executor::UpdateExecutor<
CPUDevice, decltype(input_chip), decltype(update_chip),
decltype(output_chip), OP>::Execute(d, input_chip, update_chip,
output_chip);
}
auto input_chip = Toutput.template chip<0>(i);
auto output_chip = input_chip;
auto update_chip = Tupdates.template chip<0>(loc);
update_executor::UpdateExecutor<
CPUDevice, decltype(input_chip), decltype(update_chip),
decltype(output_chip), OP>::Execute(d, input_chip, update_chip,
output_chip);
}

return error_loc;
}
};

#define REGISTER_SCATTER_ND_FULL(T, Index, op) \
template Index \
ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \
const CPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \
output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput)
#define REGISTER_SCATTER_ND_FULL(T, Index, op) \
template Index ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM, \
/*kDropBadIndices=*/false>:: \
operator()(const CPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \
output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput); \
template Index ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM, \
/*kDropBadIndices=*/true>:: \
operator()(const CPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \
output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput)

#define REGISTER_SCATTER_ND_INDEX(type, op) \
REGISTER_SCATTER_ND_FULL(type, int32, op); \
Expand Down
10 changes: 6 additions & 4 deletions tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ __global__ void ScatterNdOpKernel(
namespace functor {

// Functor used by ScatterOp to do the computations.
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> {
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM,
bool kDropBadIndices>
struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, kDropBadIndices> {
Index operator()(
const GPUDevice& d, const Index slice_size,
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
Expand Down Expand Up @@ -164,8 +165,9 @@ struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> {

} // namespace functor

#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
template struct functor::ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>;
#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
template struct functor::ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, \
/*kDropBadIndices=*/true>;

#define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \
DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/lite/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,13 @@ if(TFLITE_KERNEL_TEST)
add_subdirectory(${TFLITE_SOURCE_DIR}/kernels)
endif()

# Add the generated headers directory. Required for maintaining the
# tensorflow/lite directory structure for generated headers.
set(TFLITE_GENERATED_HEADERS_DIR ${CMAKE_BINARY_DIR}/tensorflow/lite)

# Add the profiling proto directory.
add_subdirectory(${TFLITE_SOURCE_DIR}/profiling/proto)

# The benchmark tool.
add_subdirectory(${TFLITE_SOURCE_DIR}/tools/benchmark)

Expand Down
3 changes: 3 additions & 0 deletions tensorflow/lite/profiling/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ cc_library(
copts = common_copts,
deps = [
"//tensorflow/core/util:stats_calculator_portable",
"//tensorflow/lite/profiling/proto:profiling_info_cc_proto",
"//tensorflow/lite/tools:logging",
],
)

Expand All @@ -202,6 +204,7 @@ cc_test(
srcs = ["profile_summary_formatter_test.cc"],
deps = [
":profile_summary_formatter",
"//tensorflow/lite/profiling/proto:profiling_info_cc_proto",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/lite/profiling/profile_summarizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ void ProfileSummarizer::ProcessProfiles(
if (delegate_internal_total_us > 0) {
delegate_stats_calculator_->UpdateRunTotalUs(delegate_internal_total_us);
}

SetSubgraphNameMap(interpreter);
}

tensorflow::StatsCalculator* ProfileSummarizer::GetStatsCalculator(
Expand Down
19 changes: 15 additions & 4 deletions tensorflow/lite/profiling/profile_summarizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ class ProfileSummarizer {
// Returns a string detailing the accumulated runtime stats in the format of
// summary_formatter_.
std::string GetOutputString() {
return summary_formatter_->GetOutputString(stats_calculator_map_,
*delegate_stats_calculator_);
return summary_formatter_->GetOutputString(
stats_calculator_map_, *delegate_stats_calculator_, subgraph_name_map_);
}

std::string GetShortSummary() {
return summary_formatter_->GetShortSummary(stats_calculator_map_,
*delegate_stats_calculator_);
return summary_formatter_->GetShortSummary(
stats_calculator_map_, *delegate_stats_calculator_, subgraph_name_map_);
}

tensorflow::StatsCalculator* GetStatsCalculator(uint32_t subgraph_index);
Expand All @@ -73,6 +73,17 @@ class ProfileSummarizer {

// Summary formatter for customized output formats.
std::shared_ptr<ProfileSummaryFormatter> summary_formatter_;

std::map<uint32_t, std::string> subgraph_name_map_;

void SetSubgraphNameMap(const tflite::Interpreter& interpreter) {
subgraph_name_map_.clear();
for (int subgraph_index = 0; subgraph_index < interpreter.subgraphs_size();
++subgraph_index) {
subgraph_name_map_[subgraph_index] =
interpreter.subgraph(subgraph_index)->GetName();
}
}
};

} // namespace profiling
Expand Down

0 comments on commit f10d05f

Please sign in to comment.