diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events.cc b/tensorflow/core/profiler/convert/trace_viewer/trace_events.cc index 504c54b3c8b0da..3e6b2556f3fe40 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_events.cc +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events.cc @@ -131,6 +131,16 @@ void MaybeAddEventUniqueId(std::vector& events) { } // namespace +int GetLevelForDuration(uint64_t duration_ps) { + int i = 0; + for (; i < NumLevels(); ++i) { + if (duration_ps > kLayerResolutions[i]) { + return i; + } + } + return i; +} + std::vector MergeEventTracks( const std::vector& event_tracks) { std::vector events; diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events.h b/tensorflow/core/profiler/convert/trace_viewer/trace_events.h index a94961f6b3311b..494fba4c25c93c 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_events.h +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events.h @@ -73,6 +73,9 @@ absl::Status ReadFileTraceMetadata(std::string& filepath, Trace* trace); std::vector> GetEventsByLevel( const Trace& trace, std::vector& events); +// Returns the level that an event with `duration_ps` would go into. +int GetLevelForDuration(uint64_t duration_ps); + struct EventFactory { TraceEvent* Create() { events.push_back(std::make_unique());