33#include < gmock/gmock.h>
44#include < gtest/gtest.h>
55
6+ #include < random>
67#include < set>
8+ #include < string>
79#include < unordered_map>
810#include < vector>
911
1012#include " absl/status/status.h"
1113#include " absl/types/span.h"
14+ #include " torch_xla/csrc/runtime/debug_macros.h"
1215#include " tsl/platform/errors.h"
1316#include " tsl/platform/protobuf.h"
1417#include " tsl/platform/status_matchers.h"
@@ -46,7 +49,7 @@ absl::StatusOr<MessageType> ParseTextProto(const std::string& text_proto) {
4649 return parsed_proto;
4750}
4851
49- TEST (XlaUtilrest , CreateModule) {
52+ TEST (XlaUtilTest , CreateModule) {
5053 TF_ASSERT_OK_AND_ASSIGN (
5154 xla::HloModuleProto hlo_module_proto,
5255 ParseTextProto<xla::HloModuleProto>(
@@ -102,7 +105,7 @@ TEST(XlaUtilrest, CreateModule) {
102105 EXPECT_EQ ((*got)->computation_count (), 1 );
103106}
104107
105- TEST (XlaUtilrest , XlaToHlo) {
108+ TEST (XlaUtilTest , XlaToHlo) {
106109 xla::Shape input_shape =
107110 xla::ShapeUtil::MakeShape (xla::PrimitiveType::F32, {2 , 2 });
108111 xla::XlaBuilder builder (" AddComputation" );
@@ -116,6 +119,150 @@ TEST(XlaUtilrest, XlaToHlo) {
116119 HasSubstr (" ROOT %add.3" ))));
117120}
118121
122+ TEST (XlaUtilTest, TestDeterministicModuleProtoSerializationEmptyProto) {
123+ xla::HloModuleProto empty_proto;
124+ auto result =
125+ ::ConsumeValue (GetDeterministicSerializedModuleProto(empty_proto));
126+ // Verify that the result is an empty string
127+ EXPECT_TRUE (result.empty ());
128+ }
129+
130+ TEST (XlaUtilTest, TestDeterministicModuleProtoSerialization) {
131+ // Create a test HLO module with a known structure
132+ TF_ASSERT_OK_AND_ASSIGN (
133+ xla::HloModuleProto hlo_module_proto,
134+ ParseTextProto<xla::HloModuleProto>(
135+ R"pb(
136+ name: "myname"
137+ id: 9
138+ entry_computation_name: "MyCustomName.9"
139+ entry_computation_id: 9
140+ computations {
141+ id: 9
142+ name: "MyCustomName.9"
143+ instructions: {
144+ name: "p0.1"
145+ id: 1
146+ opcode: "parameter"
147+ shape: {
148+ element_type: S64
149+ layout { tail_padding_alignment_in_elements: 1 }
150+ }
151+ metadata {
152+ op_type: "xla__device_data"
153+ op_name: "xla__device_data"
154+ source_file: "/ansible/pytorch/xla/small_test.py"
155+ source_line: 14
156+ stack_frame_id: 1
157+ }
158+ }
159+ instructions: {
160+ name: "p1.2"
161+ id: 2
162+ opcode: "parameter"
163+ parameter_number: 1
164+ shape: {
165+ element_type: S64
166+ layout { tail_padding_alignment_in_elements: 1 }
167+ }
168+ metadata {
169+ op_type: "xla__device_data"
170+ op_name: "xla__device_data"
171+ source_file: "/ansible/pytorch/xla/small_test.py"
172+ source_line: 13
173+ stack_frame_id: 2
174+ }
175+ }
176+ instructions: {
177+ name: "call.7"
178+ id: 7
179+ opcode: "call"
180+ shape: {
181+ element_type: S64
182+ layout { tail_padding_alignment_in_elements: 1 }
183+ }
184+ metadata {
185+ op_type: "xla___op_some_op"
186+ op_name: "xla___op_some_op"
187+ source_file: "/ansible/pytorch/xla/torch_xla/core/xla_op_registry.py"
188+ source_line: 44
189+ stack_frame_id: 4
190+ }
191+ called_computation_ids: 3
192+ operand_ids: 2
193+ operand_ids: 1
194+ }
195+ instructions: {
196+ name: "tuple.8"
197+ id: 8
198+ opcode: "tuple"
199+ shape: {
200+ element_type: TUPLE
201+ tuple_shapes {
202+ element_type: S64
203+ layout { tail_padding_alignment_in_elements: 1 }
204+ }
205+ }
206+ operand_ids: 7
207+ }
208+ root_id: 8
209+ }
210+ host_program_shape: {
211+ parameters {
212+ element_type: S64
213+ layout { tail_padding_alignment_in_elements: 1 }
214+ }
215+ parameters {
216+ element_type: S64
217+ layout { tail_padding_alignment_in_elements: 1 }
218+ }
219+ result {
220+ element_type: TUPLE
221+ tuple_shapes {
222+ element_type: S64
223+ layout { tail_padding_alignment_in_elements: 1 }
224+ }
225+ }
226+ parameter_names: "p0"
227+ parameter_names: "p1"
228+ }
229+ )pb" ));
230+
231+ // Define a set of dummy fixed key-value pairs for frontend attributes.
232+ std::vector<std::pair<std::string, std::string>> attr_pairs = {
233+ {" key1" , " value1" },
234+ {" key2" , " value2" },
235+ {" key3" , " value3" },
236+ {" key4" , " value4" }};
237+
238+ auto shuffle_and_hash = [&attr_pairs](xla::HloModuleProto hlo_module_proto) {
239+ // Create a random number generator for shuffling.
240+ std::random_device random_device;
241+ std::mt19937 random_generator (random_device ());
242+
243+ for (auto & computation : *hlo_module_proto.mutable_computations ()) {
244+ for (auto & instruction : *computation.mutable_instructions ()) {
245+ std::shuffle (attr_pairs.begin (), attr_pairs.end (), random_generator);
246+ auto * frontend_attrs = instruction.mutable_frontend_attributes ();
247+ // Add the dummy shuffled pairs to the frontend attributes.
248+ for (const auto & pair : attr_pairs) {
249+ (*frontend_attrs->mutable_map ())[pair.first ] = pair.second ;
250+ }
251+ }
252+ }
253+ std::string serialized_proto =
254+ ::ConsumeValue (GetDeterministicSerializedModuleProto(hlo_module_proto));
255+ return torch::lazy::Hash (serialized_proto);
256+ };
257+
258+ // Compute hashes with different random orderings of attributes
259+ torch::lazy::hash_t hash1 = shuffle_and_hash (hlo_module_proto);
260+ torch::lazy::hash_t hash2 = shuffle_and_hash (hlo_module_proto);
261+ // Verify that different orderings produce the same hash
262+ ASSERT_EQ (hash1, hash2)
263+ << " Hashes should match regardless of the frontend attribute ordering" ;
264+ }
265+
119266} // namespace util
120267} // namespace runtime
121268} // namespace torch_xla
0 commit comments