Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions test/test_dynamic_shape_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def forward(self, x):
)
class TestDynamicShapeModels(unittest.TestCase):

@unittest.skip("Regresssion")
def test_forward_pass_dynamic_input_correctness(self):
losses = []
for _ in range(2):
Expand All @@ -67,7 +66,6 @@ def test_forward_pass_dynamic_input_correctness(self):
np.testing.assert_allclose(losses[0], losses[1], rtol=1e-2, atol=1e-2)
print('Test passed.')

@unittest.skip("Regresssion")
def test_forward_pass_dynamic_input_compile_once(self):
met.clear_metrics()
num_compilation_recorded = False
Expand Down
6 changes: 0 additions & 6 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

class TestDynamicShapes(test_utils.XlaTestCase):

@unittest.skip("Regression")
def test_simple_expand(self):
size1 = 5
size2 = 2
Expand All @@ -27,7 +26,6 @@ def test_simple_expand(self):
t6_cpu = t6.cpu()
self.assertEqual(t6_cpu.shape[0], 2)

@unittest.skip("Regression")
def test_simple_expand_on_2d_tensor(self):
size1 = 5
size2 = 2
Expand Down Expand Up @@ -62,7 +60,6 @@ def test_simple_expand_on_2d_tensor(self):
# the python dispatcher.
self.assertGreater(met.counter_value("xla::size_clone"), 0)

@unittest.skip("Regression")
def test_simple_expand_add_dimension(self):
size1 = 5
size2 = 2
Expand All @@ -88,7 +85,6 @@ def test_wrap(self):
a3 = a2.shape[0] + 3 # tests wrap
self.assertIsInstance(a3, torch.SymInt)

@unittest.skip("Regression")
def test_sizeAdd(self):
size1 = 5
size2 = 2
Expand All @@ -109,7 +105,6 @@ def test_sizeAdd(self):
t4 = t3.expand(dyn_size)
self.assertEqual(t4.size(0), 3)

@unittest.skip("Regression")
def test_sizeSub(self):
size1 = 5
size2 = 2
Expand Down Expand Up @@ -170,7 +165,6 @@ def test_nonzero_cast(self):
t2 = torch.nonzero(t1.int()).float()
xm.mark_step()

@unittest.skip("Regression")
def test_expand_symint_correctness(self):
dev = xm.xla_device()
size1 = 5
Expand Down
1 change: 0 additions & 1 deletion test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,6 @@ def test_scatter_add_bool(self):
xla_b.scatter_add_(0, xla_index, xla_a)
self.assertEqual(b, xla_b)

@unittest.skip("DS Regressions")
def test_squeeze_nonzero(self):

def test_fn(a):
Expand Down
48 changes: 48 additions & 0 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,52 @@ c10::SymNode XLASymNodeImpl::sym_max(const c10::SymNode& other) {
<< " has not been implemented.";
}

c10::SymNode XLASymNodeImpl::sym_or(const c10::SymNode& other) {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
}

c10::SymNode XLASymNodeImpl::sym_and(const c10::SymNode& other) {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
}

c10::SymNode XLASymNodeImpl::sym_not() {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
}
// NB: self is ignored here, only the arguments are used
c10::SymNode XLASymNodeImpl::is_contiguous(at::ArrayRef<c10::SymNode> sizes,
at::ArrayRef<c10::SymNode> strides) {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
}
c10::SymNode XLASymNodeImpl::is_channels_last_contiguous_2d(
at::ArrayRef<c10::SymNode> sizes, at::ArrayRef<c10::SymNode> strides) {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
}
c10::SymNode XLASymNodeImpl::is_channels_last_contiguous_3d(
at::ArrayRef<c10::SymNode> sizes, at::ArrayRef<c10::SymNode> strides) {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
}
c10::SymNode XLASymNodeImpl::is_channels_last_strides_2d(
at::ArrayRef<c10::SymNode> sizes, at::ArrayRef<c10::SymNode> strides) {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
}
c10::SymNode XLASymNodeImpl::is_channels_last_strides_3d(
at::ArrayRef<c10::SymNode> sizes, at::ArrayRef<c10::SymNode> strides) {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
}
c10::SymNode XLASymNodeImpl::is_non_overlapping_and_dense(
at::ArrayRef<c10::SymNode> sizes, at::ArrayRef<c10::SymNode> strides) {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
}

c10::SymNode XLASymNodeImpl::clone() {
TORCH_LAZY_FN_COUNTER("xla::size_");
return c10::make_intrusive<XLASymNodeImpl>(node());
Expand Down Expand Up @@ -796,6 +842,8 @@ bool XLASymNodeImpl::bool_() {
return dn->getDynamicValue() != 0;
}

bool XLASymNodeImpl::has_hint() { return true; }

std::string XLASymNodeImpl::str() {
return "<=" + std::to_string(DimCast(node().get())->getStaticValue());
}
Expand Down
22 changes: 22 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,27 @@ class TORCH_API XLASymNodeImpl : public c10::SymNodeImpl {
c10::SymNode neg() override;
c10::SymNode sym_min(const c10::SymNode& other) override;
c10::SymNode sym_max(const c10::SymNode& other) override;
c10::SymNode sym_or(const c10::SymNode& other) override;
c10::SymNode sym_and(const c10::SymNode& other) override;
c10::SymNode sym_not() override;
// NB: self is ignored here, only the arguments are used
c10::SymNode is_contiguous(at::ArrayRef<c10::SymNode> sizes,
at::ArrayRef<c10::SymNode> strides) override;
c10::SymNode is_channels_last_contiguous_2d(
at::ArrayRef<c10::SymNode> sizes,
at::ArrayRef<c10::SymNode> strides) override;
c10::SymNode is_channels_last_contiguous_3d(
at::ArrayRef<c10::SymNode> sizes,
at::ArrayRef<c10::SymNode> strides) override;
c10::SymNode is_channels_last_strides_2d(
at::ArrayRef<c10::SymNode> sizes,
at::ArrayRef<c10::SymNode> strides) override;
c10::SymNode is_channels_last_strides_3d(
at::ArrayRef<c10::SymNode> sizes,
at::ArrayRef<c10::SymNode> strides) override;
c10::SymNode is_non_overlapping_and_dense(
at::ArrayRef<c10::SymNode> sizes,
at::ArrayRef<c10::SymNode> strides) override;
c10::SymNode clone() override;
c10::SymNode sym_float() override;
c10::SymNode wrap_int(int64_t num) override;
Expand All @@ -60,6 +81,7 @@ class TORCH_API XLASymNodeImpl : public c10::SymNodeImpl {
bool guard_bool(const char* file, int64_t line) override;
int64_t int_() override;
bool bool_() override;
bool has_hint() override;
std::string str() override;

torch::lazy::NodePtr node() { return node_; }
Expand Down