2222
2323from tensorflow .python .framework import constant_op
2424from tensorflow .python .ops import array_ops
25+ from tensorflow .python .ops import gradient_checker
2526from tensorflow .python .ops import linalg_ops
2627from tensorflow .python .ops import math_ops
2728from tensorflow .python .ops import random_ops
@@ -140,11 +141,11 @@ def Test(self):
140141 x_reshape = np .reshape (x_np , (- 1 , x_np .shape [- 2 ], x_np .shape [- 1 ]))
141142 for i in range (new_first_dim ):
142143 if full_matrices_ :
143- np_q_reshape [i ,:, :], _ = \
144- np . linalg . qr ( x_reshape [i ,:, :], mode = "complete" )
144+ np_q_reshape [i , :, :], _ = np . linalg . qr (
145+ x_reshape [i , :, :], mode = "complete" )
145146 else :
146- np_q_reshape [i ,:, :], _ = \
147- np . linalg . qr ( x_reshape [i ,:, :], mode = "reduced" )
147+ np_q_reshape [i , :, :], _ = np . linalg . qr (
148+ x_reshape [i , :, :], mode = "reduced" )
148149 np_q = np .reshape (np_q_reshape , q_dims )
149150 CompareOrthogonal (self , np_q , q_tf_val , min (shape_ [- 2 :]))
150151 CheckApproximation (self , x_np , q_tf_val , r_tf_val )
@@ -153,6 +154,46 @@ def Test(self):
153154 return Test
154155
155156
157+ class QrGradOpTest (test .TestCase ):
158+ pass
159+
160+
161+ def _GetQrGradOpTest (dtype_ , shape_ , full_matrices_ ):
162+
163+ def Test (self ):
164+ np .random .seed (42 )
165+ a = np .random .uniform (low = - 1.0 , high = 1.0 , size = shape_ ).astype (dtype_ )
166+ if dtype_ in [np .complex64 , np .complex128 ]:
167+ a += 1j * np .random .uniform (
168+ low = - 1.0 , high = 1.0 , size = shape_ ).astype (dtype_ )
169+ # Optimal stepsize for central difference is O(epsilon^{1/3}).
170+ epsilon = np .finfo (dtype_ ).eps
171+ delta = 0.1 * epsilon ** (1.0 / 3.0 )
172+ if dtype_ in [np .float32 , np .complex64 ]:
173+ tol = 3e-2
174+ else :
175+ tol = 1e-6
176+ with self .test_session (use_gpu = True ):
177+ tf_a = constant_op .constant (a )
178+ tf_b = linalg_ops .qr (tf_a , full_matrices = full_matrices_ )
179+ for b in tf_b :
180+ x_init = np .random .uniform (
181+ low = - 1.0 , high = 1.0 , size = shape_ ).astype (dtype_ )
182+ if dtype_ in [np .complex64 , np .complex128 ]:
183+ x_init += 1j * np .random .uniform (
184+ low = - 1.0 , high = 1.0 , size = shape_ ).astype (dtype_ )
185+ theoretical , numerical = gradient_checker .compute_gradient (
186+ tf_a ,
187+ tf_a .get_shape ().as_list (),
188+ b ,
189+ b .get_shape ().as_list (),
190+ x_init_value = x_init ,
191+ delta = delta )
192+ self .assertAllClose (theoretical , numerical , atol = tol , rtol = tol )
193+
194+ return Test
195+
196+
156197if __name__ == "__main__" :
157198 for dtype in np .float32 , np .float64 , np .complex64 , np .complex128 :
158199 for rows in 1 , 2 , 5 , 10 , 32 , 100 :
@@ -168,4 +209,21 @@ def Test(self):
168209 _AddTest (QrOpTest , "Qr" , name ,
169210 _GetQrOpTest (dtype , shape , full_matrices ,
170211 use_static_shape ))
212+
213+ # TODO(pfau): Get working with complex types.
214+ # TODO(pfau): Get working with full_matrices when rows != cols
215+ # TODO(pfau): Get working when rows < cols
216+ # TODO(pfau): Get working with shapeholders (dynamic shapes)
217+ for full_matrices in False , True :
218+ for dtype in np .float32 , np .float64 :
219+ for rows in 1 , 2 , 5 , 10 :
220+ for cols in 1 , 2 , 5 , 10 :
221+ if rows == cols or (not full_matrices and rows > cols ):
222+ for batch_dims in [(), (3 ,)] + [(3 , 2 )] * (max (rows , cols ) < 10 ):
223+ shape = batch_dims + (rows , cols )
224+ name = "%s_%s_full_%s" % (dtype .__name__ ,
225+ "_" .join (map (str , shape )),
226+ full_matrices )
227+ _AddTest (QrGradOpTest , "QrGrad" , name ,
228+ _GetQrGradOpTest (dtype , shape , full_matrices ))
171229 test .main ()
0 commit comments