Skip to content

Commit 0b49efa

Browse files
JackCaoGdlibenzi
authored andcommitted
Add device specific test disable list
1 parent c7f1703 commit 0b49efa

File tree

2 files changed

+69
-35
lines changed

2 files changed

+69
-35
lines changed

test/pytorch_test_base.py

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
'test_pow_xla_float64': 0.0045,
1515
}
1616

17-
DISABLED_TORCH_TESTS = {
17+
DISABLED_TORCH_TESTS_ANY = {
1818
# test_torch.py
1919
# TestDevicePrecision
2020
'test_sum_cpu_device_mismatch', # doesn't raise
@@ -23,29 +23,10 @@
2323
'test_min_max_binary_op_nan', # XLA min/max ignores Nans.
2424
'test_copy_noncontig',
2525
'test_copy_broadcast',
26-
'test_digamma', # Precision issue at the first assert, then NAN handling (both on TPU)
2726

2827
# TestTensorDeviceOps
2928
'test_block_diag_scipy', #FIXME: RuntimeError: Error while lowering: f32[1,6]{1,0} xla::unselect, dim=1, start=2, end=2, stride=0
30-
'test_cumprod_xla', # FIXME: TPU X64Rewriter doesn't support reduce-window
31-
'test_cumprod_neg_dim_xla', # FIXME: TPU X64Rewriter doesn't support reduce-window
3229
'test_mean_64bit_indexing_xla', # protobuf limit exceeded
33-
'test_pow_inplace_xla', # (TPU) 0.0032 vs 0.001
34-
'test_pow_inplace_3_xla', # (TPU) 0.0028 vs 0.001
35-
'test_pow_3_xla', # (TPU) 0.0028 vs 0.001
36-
'test_pow_-2_xla', # (TPU) 0.391 vs 0.001
37-
'test_topk_neg_dim_sort_xla', # (TPU) unimplemented HLO for X64
38-
'test_topk_dim_sort_xla', # (TPU) unimplemented HLO for X64
39-
'test_topk_dim_desc_sort_xla', # (TPU) unimplemented HLO for X64
40-
'test_sort_xla', # (TPU) unimplemented HLO for X64
41-
'test_sort_neg_dim_xla', # (TPU) unimplemented HLO for X64
42-
'test_sort_neg_dim_descending_xla', # (TPU) unimplemented HLO for X64
43-
'test_sort_dim_xla', # (TPU) unimplemented HLO for X64
44-
'test_sort_dim_descending_xla', # (TPU) unimplemented HLO for X64
45-
'test_kthvalue_xla', # (TPU) unimplemented HLO for X64
46-
'test_kthvalue_neg_dim_xla', # (TPU) unimplemented HLO for X64
47-
'test_kthvalue_dim_xla', # (TPU) unimplemented HLO for X64
48-
'test_eig_with_eigvec_xla_float64', # Precision: tensor(1.1798, dtype=torch.float64) not less than or equal to 0.001
4930

5031
# TestTorchDeviceType
5132
'test_addmm_sizes', # FIXME: very slow compile
@@ -147,22 +128,17 @@
147128
'test_stft', # librosa (?!?) missing
148129
'test_tensor_shape_empty', # LLVM OOM in CI
149130
'test_cholesky_inverse', # precision (1e-6)
150-
'test_cholesky_solve_batched_broadcasting', # (TPU) 0.0039 vs 0.001
151-
'test_cholesky_solve_batched_many_batches', # (TPU) 0.36 vs 0.001
152-
'test_cholesky_solve_batched', # (TPU) precision (1e-5)
153-
'test_cholesky_solve', # (TPU) precision (1e-5)
154-
'test_lu_solve_batched', # (TPU) precision (1e-6)
155-
'test_lu_solve', # (TPU) precision (1e-7)
156-
'test_solve_batched', # (TPU) precision (1e-6)
157-
'test_solve', # (TPU) precison (1e-7)
158-
'test_triangular_solve_batched', # (TPU) precision (1e-6)
159-
'test_triangular_solve_batched_many_batches', # (TPU) 1.02 vs 0.001
160-
'test_triangular_solve', # (TPU) precision (1e-7)
161-
'test_triangular_solve_batched_broadcasting', # (TPU) 1.5 vs 0.001
131+
'test_cholesky_solve_batched', # precision (2e-12)
132+
'test_cholesky_solve', # precision(1e-12)
133+
'test_lu_solve_batched', # precision(1e-12)
134+
'test_lu_solve', # precision(1e-12)
135+
'test_solve_batched', # precision(1e-12)
136+
'test_solve', # precision(1e-12)
137+
'test_triangular_solve_batched', # precision(3e-12)
138+
'test_triangular_solve', # precision (4e-12)
162139
'test_scalar_check', # runtime error
163140
'test_argminmax_large_axis', # OOM, and the test is grepping "memory" in the exception message
164141
'test_trapz', # precision (1e-5), test use np.allClose
165-
'test_random_from_to_xla_int32', # precision, TPU does not have real F64
166142
'test_randn_xla_float32', # xla doesn't support manual_seed, as_stride
167143
'test_randn_xla_float64', # xla doesn't support manual_seed, as_stride
168144
'test_rand_xla_float32', # xla doesn't support manual_seed, as_stride
@@ -232,6 +208,48 @@
232208
'test_complex_scalar_mult_tensor_promotion', # complex support
233209
}
234210

211+
DISABLED_TORCH_TESTS_TPU = DISABLED_TORCH_TESTS_ANY | {
212+
# test_torch.py
213+
# TestDevicePrecision
214+
'test_digamma', # Precision issue at the first assert, then NAN handling (both on TPU)
215+
216+
#TestTensorDeviceOps
217+
'test_pow_inplace_xla', # (TPU) 0.0032 vs 0.001
218+
'test_pow_inplace_3_xla', # (TPU) 0.0028 vs 0.001
219+
'test_pow_3_xla', # (TPU) 0.0028 vs 0.001
220+
'test_pow_-2_xla', # (TPU) 0.391 vs 0.001
221+
'test_topk_dim_sort_xla', # (TPU) unimplemented HLO for X64
222+
'test_topk_dim_desc_sort_xla', # (TPU) unimplemented HLO for X64
223+
'test_sort_xla', # (TPU) unimplemented HLO for X64
224+
'test_sort_neg_dim_xla', # (TPU) unimplemented HLO for X64
225+
'test_sort_neg_dim_descending_xla', # (TPU) unimplemented HLO for X64
226+
'test_sort_dim_xla', # (TPU) unimplemented HLO for X64
227+
'test_sort_dim_descending_xla', # (TPU) unimplemented HLO for X64
228+
'test_kthvalue_xla', # (TPU) unimplemented HLO for X64
229+
'test_kthvalue_neg_dim_xla', # (TPU) unimplemented HLO for X64
230+
'test_kthvalue_dim_xla', # (TPU) unimplemented HLO for X64
231+
'test_eig_with_eigvec_xla_float64', # Precision: tensor(1.1798, dtype=torch.float64) not less than or equal to 0.001
232+
'test_cumprod_xla', # FIXME: TPU X64Rewriter doesn't support reduce-window
233+
'test_cumprod_neg_dim_xla', # FIXME: TPU X64Rewriter doesn't support reduce-window
234+
'test_topk_neg_dim_sort_xla', # (TPU) unimplemented HLO for X64
235+
236+
#TestTorchDeviceType
237+
'test_cholesky_solve_batched_broadcasting', # (TPU) 0.0039 vs 0.001
238+
'test_cholesky_solve_batched_many_batches', # (TPU) 0.36 vs 0.001
239+
'test_triangular_solve_batched_many_batches', # (TPU) 1.02 vs 0.001
240+
'test_triangular_solve_batched_broadcasting', # (TPU) 1.5 vs 0.001
241+
'test_random_from_to_xla_int32', # precision, TPU does not have real F64
242+
}
243+
244+
DISABLED_TORCH_TESTS_CPU = DISABLED_TORCH_TESTS_ANY
245+
DISABLED_TORCH_TESTS_GPU = DISABLED_TORCH_TESTS_ANY
246+
247+
DISABLED_TORCH_TESTS = {
248+
'TPU': DISABLED_TORCH_TESTS_TPU,
249+
'CPU': DISABLED_TORCH_TESTS_CPU,
250+
'GPU': DISABLED_TORCH_TESTS_GPU,
251+
}
252+
235253

236254
class XLATestBase(DeviceTypeTestBase):
237255
device_type = 'xla'
@@ -251,13 +269,16 @@ def _alt_lookup(d, keys, defval):
251269
@classmethod
252270
def instantiate_test(cls, name, test):
253271
test_name = name + '_' + cls.device_type
272+
real_device_type = xm.xla_device_hw(str(xm.xla_device()))
273+
assert real_device_type in DISABLED_TORCH_TESTS, 'Unsupported device type:' + real_device_type
274+
disabled_torch_tests = DISABLED_TORCH_TESTS[real_device_type]
254275

255276
@wraps(test)
256277
def disallowed_test(self, test=test):
257278
raise unittest.SkipTest('skipped on XLA')
258279
return test(self, cls.device_type)
259280

260-
if test_name in DISABLED_TORCH_TESTS or test.__name__ in DISABLED_TORCH_TESTS:
281+
if test_name in disabled_torch_tests or test.__name__ in disabled_torch_tests:
261282
assert not hasattr(
262283
cls, test_name), 'Redefinition of test {0}'.format(test_name)
263284
setattr(cls, test_name, disallowed_test)
@@ -284,7 +305,7 @@ def skipped_test(self, *args, reason=reason, **kwargs):
284305
cls, dtype_test_name), 'Redefinition of test {0}'.format(
285306
dtype_test_name)
286307
setattr(cls, dtype_test_name, skipped_test)
287-
elif dtype_test_name in DISABLED_TORCH_TESTS:
308+
elif dtype_test_name in disabled_torch_tests:
288309
setattr(cls, dtype_test_name, disallowed_test)
289310
else:
290311
xla_dtypes.append(dtype)

torch_xla/core/xla_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,19 @@ def xla_real_devices(devices):
162162
return real_devices
163163

164164

165+
def xla_device_hw(device):
166+
"""Returns the hardware type of the given device.
167+
168+
Args:
169+
device (string): The xla device that will be mapped to the real device.
170+
171+
Returns:
172+
A string representation of the hardware type of the given device.
173+
"""
174+
real_device = xla_real_devices([device])[0]
175+
return real_device.split(':')[0]
176+
177+
165178
def xla_replication_devices(local_devices):
166179
real_devices = xla_real_devices(local_devices)
167180
device_types = set()

0 commit comments

Comments
 (0)