-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a converter from dcn_slack_analysis.proto to GViz DataTable format.
PiperOrigin-RevId: 559508558
- Loading branch information
1 parent
f3e3719
commit 4bfdd0a
Showing
7 changed files
with
325 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
86 changes: 86 additions & 0 deletions
86
plugin/tensorboard_plugin_profile/convert/dcn_collective_stats_proto_to_gviz.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""For conversion of Dcn Collective Stats page protos to GViz DataTables. | ||
Usage: | ||
gviz_data_tables = generate_all_chart_tables(dcn_slack_analysis) | ||
""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import gviz_api | ||
|
||
from tensorboard_plugin_profile.protobuf import dcn_slack_analysis_pb2 | ||
|
||
|
||
def get_dcn_collective_stats_table_args(dcn_slack_analysis): | ||
"""Creates a gviz DataTable object from DcnSlackAnalysis proto. | ||
Args: | ||
dcn_slack_analysis: dcn_slack_analysis_pb2.DcnSlackAnalysis. | ||
Returns: | ||
Returns a gviz_api.DataTable | ||
""" | ||
|
||
table_description = [ | ||
("dcnCollectiveName", "string", "Dcn Collective Name"), | ||
("recvOpName", "string", "Recv Op Name"), | ||
("sendOpName", "string", "Send Op Name"), | ||
("slackTime", "number", "Slack Time (ms)"), | ||
("observedDuration", "number", "Observed Duration (ms)"), | ||
("stallDuration", "number", "Stall Duration (ms)"), | ||
("occurrences", "number", "Occurrences"), | ||
] | ||
|
||
data = [] | ||
for slack in dcn_slack_analysis.dcn_slack_summary: | ||
row = [ | ||
slack.rendezvous, | ||
slack.recv_op_name, | ||
slack.send_op_name, | ||
slack.slack_us / 1000, | ||
slack.observed_duration_us / 1000, | ||
slack.stall_duration_us / 1000, | ||
slack.occurrences, | ||
] | ||
data.append(row) | ||
|
||
return (table_description, data, []) | ||
|
||
|
||
def generate_dcn_collective_stats_table(dcn_slack_analysis): | ||
(table_description, data, custom_properties) = ( | ||
get_dcn_collective_stats_table_args(dcn_slack_analysis) | ||
) | ||
return gviz_api.DataTable(table_description, data, custom_properties) | ||
|
||
|
||
def generate_all_chart_tables(dcn_slack_analysis): | ||
"""Converts a DcnSlackAnalysis proto to gviz DataTables.""" | ||
return [ | ||
generate_dcn_collective_stats_table(dcn_slack_analysis), | ||
] | ||
|
||
|
||
def to_json(raw_data): | ||
"""Converts a serialized DcnCollectiveAnalysis string to json.""" | ||
dcn_slack_analysis = dcn_slack_analysis_pb2.DcnSlackAnalysis() | ||
dcn_slack_analysis.ParseFromString(raw_data) | ||
all_chart_tables = generate_all_chart_tables(dcn_slack_analysis) | ||
json_join = ",".join(x.ToJSon() if x else "{}" for x in all_chart_tables) | ||
return "[" + json_join + "]" |
139 changes: 139 additions & 0 deletions
139
plugin/tensorboard_plugin_profile/convert/dcn_collective_stats_proto_to_gviz_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Tests for dcn_collective_stats_proto_to_gviz.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import csv | ||
import enum | ||
import io | ||
|
||
import gviz_api | ||
import tensorflow as tf | ||
|
||
from tensorboard_plugin_profile.convert import dcn_collective_stats_proto_to_gviz | ||
from tensorboard_plugin_profile.protobuf import dcn_slack_analysis_pb2 | ||
|
||
|
||
class StrEnum(str, enum.Enum): | ||
pass | ||
|
||
|
||
class MockValues(StrEnum): | ||
DCN_COLLECTIVE_NAME = "collective-1" | ||
RECV_OP_NAME = "recv-done" | ||
SEND_OP_NAME = "send" | ||
SLACK_US = 2 | ||
OBSERVED_DURATION_US = 12 | ||
STALL_DURATION_MS = 5 | ||
OCCURRENCES = 6 | ||
|
||
|
||
class ProtoToGvizTest(tf.test.TestCase): | ||
|
||
def create_empty_dcn_slack_analysis(self): | ||
return dcn_slack_analysis_pb2.DcnSlackAnalysis() | ||
|
||
def create_mock_dcn_slack_summary(self): | ||
dcn_slack_summary = dcn_slack_analysis_pb2.DcnSlackSummary( | ||
rendezvous=MockValues.DCN_COLLECTIVE_NAME, | ||
recv_op_name=MockValues.RECV_OP_NAME, | ||
send_op_name=MockValues.SEND_OP_NAME, | ||
slack_us=int(MockValues.SLACK_US) * 1000, | ||
observed_duration_us=int(MockValues.OBSERVED_DURATION_US) * 1000, | ||
stall_duration_us=int(MockValues.STALL_DURATION_MS) * 1000, | ||
occurrences=int(MockValues.OCCURRENCES), | ||
) | ||
return dcn_slack_summary | ||
|
||
def create_mock_dcn_slack_analysis(self): | ||
dcn_slack_analysis = dcn_slack_analysis_pb2.DcnSlackAnalysis() | ||
for _ in range(0, 3): | ||
dcn_slack_analysis.dcn_slack_summary.append( | ||
self.create_mock_dcn_slack_summary() | ||
) | ||
return dcn_slack_analysis | ||
|
||
def test_dcn_collective_stats_empty(self): | ||
dcn_slack_analysis = self.create_empty_dcn_slack_analysis() | ||
data_table = ( | ||
dcn_collective_stats_proto_to_gviz.generate_dcn_collective_stats_table( | ||
dcn_slack_analysis | ||
) | ||
) | ||
|
||
self.assertEqual(0, data_table.NumberOfRows()) | ||
self.assertLen(data_table.columns, 7) | ||
|
||
def test_dcn_collective_stats_table(self): | ||
dcn_slack_analysis = self.create_mock_dcn_slack_analysis() | ||
(table_description, data, custom_properties) = ( | ||
dcn_collective_stats_proto_to_gviz.get_dcn_collective_stats_table_args( | ||
dcn_slack_analysis | ||
) | ||
) | ||
data_table = gviz_api.DataTable(table_description, data, custom_properties) | ||
|
||
self.assertLen(data, 3) | ||
self.assertEqual(3, data_table.NumberOfRows()) | ||
self.assertLen(table_description, 7) | ||
self.assertLen(data_table.columns, 7) | ||
|
||
csv_file = io.StringIO(data_table.ToCsv()) | ||
reader = csv.reader(csv_file) | ||
|
||
expected = [ | ||
MockValues.DCN_COLLECTIVE_NAME, | ||
MockValues.RECV_OP_NAME, | ||
MockValues.SEND_OP_NAME, | ||
MockValues.SLACK_US, | ||
MockValues.OBSERVED_DURATION_US, | ||
MockValues.STALL_DURATION_MS, | ||
MockValues.OCCURRENCES, | ||
] | ||
|
||
for rr, row_values in enumerate(reader): | ||
if rr == 0: | ||
# DataTable columns match schema defined in table_description. | ||
for cc, column_header in enumerate(row_values): | ||
self.assertEqual(table_description[cc][2], column_header) | ||
else: | ||
for cc, cell_str in enumerate(row_values): | ||
raw_value = data[rr - 1][cc] | ||
value_type = table_description[cc][1] | ||
|
||
# Only number and strings are used in the DataTable schema. | ||
self.assertIn(value_type, ["number", "string"]) | ||
|
||
# Encode in similar fashion as DataTable.ToCsv(). | ||
expected_value = gviz_api.DataTable.CoerceValue(raw_value, value_type) | ||
self.assertNotIsInstance(expected_value, tuple) | ||
self.assertEqual(expected_value, raw_value) | ||
|
||
# Check against expected values we have set in our mock table. | ||
if value_type == "string": | ||
self.assertEqual(expected[cc], cell_str) | ||
else: | ||
if expected[cc] == MockValues.OCCURRENCES: | ||
self.assertEqual(str(int(expected[cc])), cell_str) | ||
else: | ||
self.assertEqual(str(float(expected[cc])), cell_str) | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
67 changes: 67 additions & 0 deletions
67
plugin/tensorboard_plugin_profile/protobuf/dcn_slack_analysis.proto
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
syntax = "proto3"; | ||
|
||
package tensorboard_plugin_profile; | ||
|
||
message DcnSlack { | ||
string rendezvous = 1; | ||
// Xprof observed send start time. | ||
uint64 send_start_time_us = 2; | ||
// Xprof observed recv_done end time. | ||
uint64 recv_done_end_time_us = 3; | ||
|
||
// Slack is defined as the time the collective has to send and recv data | ||
// without stalling the tpu. The effect of the network and other overlapping | ||
// collectives are removed from the collective of interest. | ||
// | ||
// | ||
// HOST 1 : | ||
// |--------|SEND1|-------|SEND1.DONE|-------|RECV1|------|RECV1.DONE|------- | ||
// HOST 2: | ||
// |------|SEND2|-------|SEND2.DONE|-------|RECV2|------|RECV2.DONE |----- | ||
// | ||
// Slack is computed as | ||
// RECV2.DONE.StartTime - SEND2.StartTime - (Overlapping Communication) | ||
// In this case, Overlapping communication is the duration of SEND2, | ||
// SEND2.DONE and RECV2. In cases where other collectives are interspaced | ||
// between this collective, Overlapping duration would include their durations | ||
// as well. Host 1 is ignored while computing the slack, as we assume that the | ||
// similar ops are executing each core. This also prevents clock drifts to | ||
// effect the analysis. | ||
uint64 slack_us = 4; | ||
|
||
uint64 bytes_transmitted_over_network = 5; | ||
|
||
// Duration the collective stalled the TPU. | ||
uint64 stall_duration_us = 6; | ||
|
||
// Recv op name | ||
string recv_op_name = 7; | ||
|
||
// Send op name | ||
string send_op_name = 8; | ||
} | ||
|
||
message DcnSlackSummary { | ||
// Rendezvous name for the collective. | ||
string rendezvous = 1; | ||
// Slack Time in Microseconds, | ||
uint64 slack_us = 2; | ||
// Number of occurrences in the sampled duration. | ||
uint64 occurrences = 3; | ||
// Bytes transmitted over the network. | ||
uint64 bytes_transmitted_over_network = 4; | ||
// Duration the collective stalled the TPU. | ||
uint64 stall_duration_us = 5; | ||
// Observed duration. | ||
uint64 observed_duration_us = 6; | ||
// Recv op name. | ||
string recv_op_name = 7; | ||
|
||
// Send op name. | ||
string send_op_name = 8; | ||
} | ||
|
||
message DcnSlackAnalysis { | ||
repeated DcnSlack dcn_slack = 1; | ||
repeated DcnSlackSummary dcn_slack_summary = 2; | ||
} |