Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,18 @@ shift $(($OPTIND - 1))
export TRIM_GRAPH_SIZE=$MAX_GRAPH_SIZE
export TRIM_GRAPH_CHECK_FREQUENCY=$GRAPH_CHECK_FREQUENCY
export XLA_TEST_DIR=$CDIR
export PYTORCH_TEST_WITH_SLOW=1

if [ "$LOGFILE" != "" ]; then
python3 "$CDIR/../../test/test_torch.py" "$@" -v TestTorchDeviceTypeXLA 2>&1 | tee $LOGFILE
python3 "$CDIR/../../test/test_indexing.py" "$@" -v TestIndexingXLA 2>&1 | tee $LOGFILE
python3 "$CDIR/../../test/test_indexing.py" "$@" -v NumpyTestsXLA 2>&1 | tee $LOGFILE
python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY 2>&1 | tee $LOGFILE
python3 "$CDIR/test_mp_replication.py" "$@" 2>&1 | tee $LOGFILE
python3 "$CDIR/../../test/test_torch.py" "$@" -v TestTorchDeviceTypeXLA 2>&1 | tee $LOGFILE
else
python3 "$CDIR/../../test/test_torch.py" "$@" -v TestTorchDeviceTypeXLA
python3 "$CDIR/../../test/test_indexing.py" "$@" -v TestIndexingXLA
python3 "$CDIR/../../test/test_indexing.py" "$@" -v NumpyTestsXLA
python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
python3 "$CDIR/test_mp_replication.py" "$@"
python3 "$CDIR/../../test/test_torch.py" "$@" -v TestTorchDeviceTypeXLA
fi
57 changes: 56 additions & 1 deletion test/torch_test_meta.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
allowed_torch_tests = {
## test_torch.py
'test_addcdiv',
'test_addcmul',
'test_diagonal',
Expand Down Expand Up @@ -78,5 +79,59 @@
'test_int_tensor_pow_neg_ints',
'test_long_tensor_pow_floats',
'test_var_mean_some_dims',
'test_clamp'
'test_clamp',
## test_indexing.py
'test_single_int',
'test_multiple_int',
'test_none',
'test_step',
'test_step_assignment',
#'test_bool_indices',
'test_bool_indices_accumulate',
'test_multiple_bool_indices',
#'test_byte_mask',
#'test_byte_mask_accumulate',
#'test_multiple_byte_mask',
'test_byte_mask2d',
'test_int_indices',
'test_int_indices2d',
'test_int_indices_broadcast',
'test_empty_index',
#'test_empty_ndim_index',
'test_empty_ndim_index_bool',
#'test_empty_slice',
#'test_index_getitem_copy_bools_slices',
'test_index_setitem_bools_slices',
'test_index_scalar_with_bool_mask',
#'test_setitem_expansion_error',
#'test_getitem_scalars',
'test_setitem_scalars',
'test_basic_advanced_combined',
'test_int_assignment',
#'test_byte_tensor_assignment',
'test_variable_slicing',
'test_ellipsis_tensor',
'test_invalid_index',
'test_out_of_bound_index',
'test_zero_dim_index',
'test_index_no_floats',
'test_none_index',
#'test_empty_tuple_index',
#'test_empty_fancy_index',
#'test_ellipsis_index',
'test_single_int_index',
'test_single_bool_index',
#'test_boolean_shape_mismatch',
'test_boolean_indexing_onedim',
#'test_boolean_assignment_value_mismatch',
'test_boolean_indexing_twodim',
#'test_boolean_indexing_weirdness',
#'test_boolean_indexing_weirdness_tensors',
'test_boolean_indexing_alldims',
'test_boolean_list_indexing',
'test_everything_returns_views',
#'test_broaderrors_indexing',
#'test_trivial_fancy_out_of_bounds',
#'test_index_is_larger',
#'test_broadcast_subspace',
}
17 changes: 13 additions & 4 deletions torch_patches/X10-device_test.diff
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
diff --git a/test/common_device_type.py b/test/common_device_type.py
index dda0b8d791..785c30fed7 100644
index 945882b0d7..e8232dee92 100644
--- a/test/common_device_type.py
+++ b/test/common_device_type.py
@@ -156,10 +156,72 @@ class CUDATestBase(DeviceTypeTestBase):
cls.no_magma = not torch.cuda.has_magma
@@ -209,10 +209,81 @@ class CUDATestBase(DeviceTypeTestBase):
cls.primary_device = 'cuda:{0}'.format(torch.cuda.current_device())


+import torch_xla
+assert torch_xla # Silences Flake (unused import)
+import torch_xla.core.xla_model as xm
+
+# Acquires XLA test metadata
+import os
Expand Down Expand Up @@ -66,6 +66,15 @@ index dda0b8d791..785c30fed7 100644
+ test.dtypes[cls.device_type] = xla_dtypes
+ super().instantiate_test(test)
+
+ @classmethod
+ def get_primary_device(cls):
+ return cls.primary_device
+
+ @classmethod
+ def setUpClass(cls):
+ # Sets the primary test device to the xla_device (CPU or TPU)
+ cls.primary_device = str(xm.xla_device())
+
+
# Adds available device-type-specific test base classes
device_type_test_bases.append(CPUTestBase)
Expand Down