14
14
'test_pow_xla_float64' : 0.0045 ,
15
15
}
16
16
17
- DISABLED_TORCH_TESTS = {
17
+ DISABLED_TORCH_TESTS_ANY = {
18
18
# test_torch.py
19
19
# TestDevicePrecision
20
20
'test_sum_cpu_device_mismatch' , # doesn't raise
23
23
'test_min_max_binary_op_nan' , # XLA min/max ignores Nans.
24
24
'test_copy_noncontig' ,
25
25
'test_copy_broadcast' ,
26
- 'test_digamma' , # Precision issue at the first assert, then NAN handling (both on TPU)
27
26
28
27
# TestTensorDeviceOps
29
28
'test_block_diag_scipy' , #FIXME: RuntimeError: Error while lowering: f32[1,6]{1,0} xla::unselect, dim=1, start=2, end=2, stride=0
30
- 'test_cumprod_xla' , # FIXME: TPU X64Rewriter doesn't support reduce-window
31
- 'test_cumprod_neg_dim_xla' , # FIXME: TPU X64Rewriter doesn't support reduce-window
32
29
'test_mean_64bit_indexing_xla' , # protobuf limit exceeded
33
- 'test_pow_inplace_xla' , # (TPU) 0.0032 vs 0.001
34
- 'test_pow_inplace_3_xla' , # (TPU) 0.0028 vs 0.001
35
- 'test_pow_3_xla' , # (TPU) 0.0028 vs 0.001
36
- 'test_pow_-2_xla' , # (TPU) 0.391 vs 0.001
37
- 'test_topk_neg_dim_sort_xla' , # (TPU) unimplemented HLO for X64
38
- 'test_topk_dim_sort_xla' , # (TPU) unimplemented HLO for X64
39
- 'test_topk_dim_desc_sort_xla' , # (TPU) unimplemented HLO for X64
40
- 'test_sort_xla' , # (TPU) unimplemented HLO for X64
41
- 'test_sort_neg_dim_xla' , # (TPU) unimplemented HLO for X64
42
- 'test_sort_neg_dim_descending_xla' , # (TPU) unimplemented HLO for X64
43
- 'test_sort_dim_xla' , # (TPU) unimplemented HLO for X64
44
- 'test_sort_dim_descending_xla' , # (TPU) unimplemented HLO for X64
45
- 'test_kthvalue_xla' , # (TPU) unimplemented HLO for X64
46
- 'test_kthvalue_neg_dim_xla' , # (TPU) unimplemented HLO for X64
47
- 'test_kthvalue_dim_xla' , # (TPU) unimplemented HLO for X64
48
- 'test_eig_with_eigvec_xla_float64' , # Precision: tensor(1.1798, dtype=torch.float64) not less than or equal to 0.001
49
30
50
31
# TestTorchDeviceType
51
32
'test_addmm_sizes' , # FIXME: very slow compile
147
128
'test_stft' , # librosa (?!?) missing
148
129
'test_tensor_shape_empty' , # LLVM OOM in CI
149
130
'test_cholesky_inverse' , # precision (1e-6)
150
- 'test_cholesky_solve_batched_broadcasting' , # (TPU) 0.0039 vs 0.001
151
- 'test_cholesky_solve_batched_many_batches' , # (TPU) 0.36 vs 0.001
152
- 'test_cholesky_solve_batched' , # (TPU) precision (1e-5)
153
- 'test_cholesky_solve' , # (TPU) precision (1e-5)
154
- 'test_lu_solve_batched' , # (TPU) precision (1e-6)
155
- 'test_lu_solve' , # (TPU) precision (1e-7)
156
- 'test_solve_batched' , # (TPU) precision (1e-6)
157
- 'test_solve' , # (TPU) precison (1e-7)
158
- 'test_triangular_solve_batched' , # (TPU) precision (1e-6)
159
- 'test_triangular_solve_batched_many_batches' , # (TPU) 1.02 vs 0.001
160
- 'test_triangular_solve' , # (TPU) precision (1e-7)
161
- 'test_triangular_solve_batched_broadcasting' , # (TPU) 1.5 vs 0.001
131
+ 'test_cholesky_solve_batched' , # precision (2e-12)
132
+ 'test_cholesky_solve' , # precision(1e-12)
133
+ 'test_lu_solve_batched' , # precision(1e-12)
134
+ 'test_lu_solve' , # precision(1e-12)
135
+ 'test_solve_batched' , # precision(1e-12)
136
+ 'test_solve' , # precision(1e-12)
137
+ 'test_triangular_solve_batched' , # precision(3e-12)
138
+ 'test_triangular_solve' , # precision (4e-12)
162
139
'test_scalar_check' , # runtime error
163
140
'test_argminmax_large_axis' , # OOM, and the test is grepping "memory" in the exception message
164
141
'test_trapz' , # precision (1e-5), test use np.allClose
165
- 'test_random_from_to_xla_int32' , # precision, TPU does not have real F64
166
142
'test_randn_xla_float32' , # xla doesn't support manual_seed, as_stride
167
143
'test_randn_xla_float64' , # xla doesn't support manual_seed, as_stride
168
144
'test_rand_xla_float32' , # xla doesn't support manual_seed, as_stride
232
208
'test_complex_scalar_mult_tensor_promotion' , # complex support
233
209
}
234
210
211
+ DISABLED_TORCH_TESTS_TPU = DISABLED_TORCH_TESTS_ANY | {
212
+ # test_torch.py
213
+ # TestDevicePrecision
214
+ 'test_digamma' , # Precision issue at the first assert, then NAN handling (both on TPU)
215
+
216
+ #TestTensorDeviceOps
217
+ 'test_pow_inplace_xla' , # (TPU) 0.0032 vs 0.001
218
+ 'test_pow_inplace_3_xla' , # (TPU) 0.0028 vs 0.001
219
+ 'test_pow_3_xla' , # (TPU) 0.0028 vs 0.001
220
+ 'test_pow_-2_xla' , # (TPU) 0.391 vs 0.001
221
+ 'test_topk_dim_sort_xla' , # (TPU) unimplemented HLO for X64
222
+ 'test_topk_dim_desc_sort_xla' , # (TPU) unimplemented HLO for X64
223
+ 'test_sort_xla' , # (TPU) unimplemented HLO for X64
224
+ 'test_sort_neg_dim_xla' , # (TPU) unimplemented HLO for X64
225
+ 'test_sort_neg_dim_descending_xla' , # (TPU) unimplemented HLO for X64
226
+ 'test_sort_dim_xla' , # (TPU) unimplemented HLO for X64
227
+ 'test_sort_dim_descending_xla' , # (TPU) unimplemented HLO for X64
228
+ 'test_kthvalue_xla' , # (TPU) unimplemented HLO for X64
229
+ 'test_kthvalue_neg_dim_xla' , # (TPU) unimplemented HLO for X64
230
+ 'test_kthvalue_dim_xla' , # (TPU) unimplemented HLO for X64
231
+ 'test_eig_with_eigvec_xla_float64' , # Precision: tensor(1.1798, dtype=torch.float64) not less than or equal to 0.001
232
+ 'test_cumprod_xla' , # FIXME: TPU X64Rewriter doesn't support reduce-window
233
+ 'test_cumprod_neg_dim_xla' , # FIXME: TPU X64Rewriter doesn't support reduce-window
234
+ 'test_topk_neg_dim_sort_xla' , # (TPU) unimplemented HLO for X64
235
+
236
+ #TestTorchDeviceType
237
+ 'test_cholesky_solve_batched_broadcasting' , # (TPU) 0.0039 vs 0.001
238
+ 'test_cholesky_solve_batched_many_batches' , # (TPU) 0.36 vs 0.001
239
+ 'test_triangular_solve_batched_many_batches' , # (TPU) 1.02 vs 0.001
240
+ 'test_triangular_solve_batched_broadcasting' , # (TPU) 1.5 vs 0.001
241
+ 'test_random_from_to_xla_int32' , # precision, TPU does not have real F64
242
+ }
243
+
244
+ DISABLED_TORCH_TESTS_CPU = DISABLED_TORCH_TESTS_ANY
245
+ DISABLED_TORCH_TESTS_GPU = DISABLED_TORCH_TESTS_ANY
246
+
247
+ DISABLED_TORCH_TESTS = {
248
+ 'TPU' : DISABLED_TORCH_TESTS_TPU ,
249
+ 'CPU' : DISABLED_TORCH_TESTS_CPU ,
250
+ 'GPU' : DISABLED_TORCH_TESTS_GPU ,
251
+ }
252
+
235
253
236
254
class XLATestBase (DeviceTypeTestBase ):
237
255
device_type = 'xla'
@@ -251,13 +269,16 @@ def _alt_lookup(d, keys, defval):
251
269
@classmethod
252
270
def instantiate_test (cls , name , test ):
253
271
test_name = name + '_' + cls .device_type
272
+ real_device_type = xm .xla_device_hw (str (xm .xla_device ()))
273
+ assert real_device_type in DISABLED_TORCH_TESTS , 'Unsupported device type:' + real_device_type
274
+ disabled_torch_tests = DISABLED_TORCH_TESTS [real_device_type ]
254
275
255
276
@wraps (test )
256
277
def disallowed_test (self , test = test ):
257
278
raise unittest .SkipTest ('skipped on XLA' )
258
279
return test (self , cls .device_type )
259
280
260
- if test_name in DISABLED_TORCH_TESTS or test .__name__ in DISABLED_TORCH_TESTS :
281
+ if test_name in disabled_torch_tests or test .__name__ in disabled_torch_tests :
261
282
assert not hasattr (
262
283
cls , test_name ), 'Redefinition of test {0}' .format (test_name )
263
284
setattr (cls , test_name , disallowed_test )
@@ -284,7 +305,7 @@ def skipped_test(self, *args, reason=reason, **kwargs):
284
305
cls , dtype_test_name ), 'Redefinition of test {0}' .format (
285
306
dtype_test_name )
286
307
setattr (cls , dtype_test_name , skipped_test )
287
- elif dtype_test_name in DISABLED_TORCH_TESTS :
308
+ elif dtype_test_name in disabled_torch_tests :
288
309
setattr (cls , dtype_test_name , disallowed_test )
289
310
else :
290
311
xla_dtypes .append (dtype )
0 commit comments