@@ -16,19 +16,19 @@ class TestPreserveDtypeLayer(unittest.TestCase):
1616 def test_preserve_original_dtype (self ):
1717 """Test that the layer preserves original dtype when target_dtype is None."""
1818 layer = PreserveDtypeLayer ()
19-
19+
2020 # Test with string input
2121 string_input = tf .constant (["hello" , "world" ])
2222 output = layer (string_input )
2323 self .assertEqual (output .dtype , tf .string )
2424 np .testing .assert_array_equal (output .numpy (), string_input .numpy ())
25-
25+
2626 # Test with int input
2727 int_input = tf .constant ([1 , 2 , 3 ])
2828 output = layer (int_input )
2929 self .assertEqual (output .dtype , tf .int32 )
3030 np .testing .assert_array_equal (output .numpy (), int_input .numpy ())
31-
31+
3232 # Test with float input
3333 float_input = tf .constant ([1.5 , 2.7 , 3.9 ])
3434 output = layer (float_input )
@@ -43,48 +43,32 @@ def test_cast_to_target_dtype(self):
4343 output = layer (float_input )
4444 self .assertEqual (output .dtype , tf .int32 )
4545 np .testing .assert_array_equal (output .numpy (), [1 , 2 , 3 ])
46-
46+
4747 # Test casting int to float
4848 layer = PreserveDtypeLayer (target_dtype = tf .float32 )
4949 int_input = tf .constant ([1 , 2 , 3 ])
5050 output = layer (int_input )
5151 self .assertEqual (output .dtype , tf .float32 )
5252 np .testing .assert_array_equal (output .numpy (), [1.0 , 2.0 , 3.0 ])
53-
53+
5454 # Test casting to float64
5555 layer = PreserveDtypeLayer (target_dtype = tf .float64 )
5656 float_input = tf .constant ([1.5 , 2.7 , 3.9 ])
5757 output = layer (float_input )
5858 self .assertEqual (output .dtype , tf .float64 )
5959 np .testing .assert_array_almost_equal (output .numpy (), [1.5 , 2.7 , 3.9 ])
6060
61- def test_string_to_other_types (self ):
62- """Test casting string to other types."""
63- string_input = tf .constant (["1" , "2" , "3" ])
64-
65- # String to int
66- layer = PreserveDtypeLayer (target_dtype = tf .int32 )
67- output = layer (string_input )
68- self .assertEqual (output .dtype , tf .int32 )
69- np .testing .assert_array_equal (output .numpy (), [1 , 2 , 3 ])
70-
71- # String to float
72- layer = PreserveDtypeLayer (target_dtype = tf .float32 )
73- output = layer (string_input )
74- self .assertEqual (output .dtype , tf .float32 )
75- np .testing .assert_array_equal (output .numpy (), [1.0 , 2.0 , 3.0 ])
76-
7761 def test_batch_processing (self ):
7862 """Test that the layer works correctly with batched inputs."""
7963 layer = PreserveDtypeLayer ()
80-
64+
8165 # Test with 2D input
8266 batch_input = tf .constant ([[1 , 2 ], [3 , 4 ], [5 , 6 ]])
8367 output = layer (batch_input )
8468 self .assertEqual (output .dtype , tf .int32 )
8569 self .assertEqual (output .shape , (3 , 2 ))
8670 np .testing .assert_array_equal (output .numpy (), batch_input .numpy ())
87-
71+
8872 # Test with 3D input
8973 layer = PreserveDtypeLayer (target_dtype = tf .float32 )
9074 batch_input = tf .constant ([[[1 , 2 ], [3 , 4 ]], [[5 , 6 ], [7 , 8 ]]])
@@ -95,18 +79,18 @@ def test_batch_processing(self):
9579 def test_serialization (self ):
9680 """Test that the layer can be serialized and deserialized."""
9781 layer = PreserveDtypeLayer (target_dtype = tf .int32 , name = "test_layer" )
98-
82+
9983 # Serialize
10084 config = layer .get_config ()
101-
85+
10286 # Deserialize
10387 new_layer = PreserveDtypeLayer .from_config (config )
104-
88+
10589 # Test that they behave the same
10690 input_tensor = tf .constant ([1.5 , 2.7 , 3.9 ])
10791 original_output = layer (input_tensor )
10892 new_output = new_layer (input_tensor )
109-
93+
11094 self .assertEqual (original_output .dtype , new_output .dtype )
11195 np .testing .assert_array_equal (original_output .numpy (), new_output .numpy ())
11296 self .assertEqual (layer .name , new_layer .name )
@@ -115,20 +99,22 @@ def test_serialization(self):
11599 def test_model_integration (self ):
116100 """Test that the layer works correctly within a Keras model."""
117101 layer = PreserveDtypeLayer (target_dtype = tf .float32 )
118-
102+
119103 # Create a simple model
120104 inputs = tf .keras .Input (shape = (3 ,), dtype = tf .int32 )
121105 outputs = layer (inputs )
122106 model = tf .keras .Model (inputs = inputs , outputs = outputs )
123-
107+
124108 # Test the model
125109 test_input = tf .constant ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
126110 output = model (test_input )
127-
111+
128112 self .assertEqual (output .dtype , tf .float32 )
129113 self .assertEqual (output .shape , (2 , 3 ))
130- np .testing .assert_array_equal (output .numpy (), [[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ]])
114+ np .testing .assert_array_equal (
115+ output .numpy (), [[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ]]
116+ )
131117
132118
133119if __name__ == "__main__" :
134- unittest .main ()
120+ unittest .main ()
0 commit comments