Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aten/src/ATen/NestedTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
int64_t size_custom(int64_t d) const override {
return this->size(d);
}
c10::SymInt sym_size_custom(int64_t d) const override {
return c10::SymInt{this->size(d)};
}
IntArrayRef sizes_custom() const override;
c10::SymIntArrayRef sym_sizes_custom() const override;
c10::SymIntArrayRef sym_sizes() const override;
Expand Down
6 changes: 1 addition & 5 deletions aten/src/ATen/core/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,7 @@ class TORCH_API TensorBase {
}

c10::SymInt sym_size(int64_t dim) const {
const auto sizes = this->sym_sizes();
const auto ndim = static_cast<int64_t>(sizes.size());
// false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];

return impl_->sym_size(dim);
}

int64_t size(int64_t dim) const {
Expand Down
20 changes: 20 additions & 0 deletions c10/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return sizes_and_strides_.size_at_unchecked(d).as_int_unchecked();
}

c10::SymInt sym_size(int64_t d) const {
if (C10_UNLIKELY(
sizes_strides_policy_ >=
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) {
return sym_size_custom(d);
}
d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
const auto sizes = this->sym_sizes();
return sizes[d];
}

/**
* Return the stride of a tensor at some dimension, wrapping the dimension
* if necessary.
Expand Down Expand Up @@ -697,6 +708,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
return sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds)
}

virtual c10::SymInt sym_size_custom(int64_t d) const {
// TODO: We could add support to Python dispatch here.
// TODO: We could call into aten::size.int instead of
// sym_sizes_custom()[d] and enable use of the dispatcher.
d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
return sym_sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds)
}

virtual IntArrayRef sizes_custom() const;
virtual Device device_custom() const;
virtual Layout layout_custom() const;
Expand Down
2 changes: 2 additions & 0 deletions test/cpp/lazy/test_lazy_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ TEST(LazyDynamicOpsTest, NarrowCopy) {
}

TEST(LazyDynamicOpsTest, NarrowCopyViaSymSizes) {
FLAGS_ltc_enable_symbolic_shapes = true;
auto xc = torch::rand({10});
auto x = xc.to(kLazy);
const size_t Y_DIM = 3;
Expand All @@ -105,6 +106,7 @@ TEST(LazyDynamicOpsTest, NarrowCopyViaSymSizes) {
ASSERT_EQ(z.sizes()[0], xc.sizes()[0]); // note, xc not zc
// shape inference assumes narrow_copy can copy the whole tensor
AllClose(z.cpu(), zc);
FLAGS_ltc_enable_symbolic_shapes = false;
}

TEST_F(LazyOpsTest, TestScalarTensor) {
Expand Down
128 changes: 68 additions & 60 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def add_func(op):

@register_meta([aten.add.Tensor, aten.sub.Tensor])
def binary_meta(a, b):
return a.new_empty(a.sym_size())
return a.new_empty(a.shape)


@register_meta(aten.cat.default)
Expand All @@ -53,7 +53,7 @@ def cat_meta(tensors, dim=0):
@register_meta([aten.narrow_copy.SymInt])
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
shape = []
for i, x in enumerate(a.sym_size()):
for i, x in enumerate(a.shape):
if i == dim:
shape.append(length)
else:
Expand Down Expand Up @@ -165,6 +165,14 @@ def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
self = args[0]
return self.sym_shape

# some calls can be redirected to `sym_size` rather than
# `sym_sizes`. `sym_size` uses `dim` to canonicalize an index
# so we need to implement both `sym_size` and `dim` for python
# tensors
if func_overload == torch.ops.aten.dim.default:
self = args[0]
return len(self.sym_shape)

if func_overload == torch.ops.aten.new_empty.default:
self = args[0]
shape = args[1]
Expand All @@ -174,7 +182,7 @@ def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):


def create_symbolic_tensor(name, arg, shape_env):
sym_shapes = tuple([shape_env.create_symint(f"{name}_{idx}", val) for idx, val in enumerate(arg.sym_size())])
sym_shapes = tuple([shape_env.create_symint(f"{name}_{idx}", val) for idx, val in enumerate(arg.size())])
sym_strides = tuple([shape_env.create_symint(f"{name}_{idx}_stride", val) for idx, val in enumerate(arg.stride())])
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device)

Expand All @@ -188,22 +196,22 @@ class TestPySymInt(TestCase):
def test_roundtrip(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
self.assertTrue(not isinstance(x.sym_size(0), PySymInt))
self.assertTrue(isinstance(x.sym_size(0), CPP_SYMINT_CLASS))
self.assertTrue(not isinstance(x.shape[0], PySymInt))
self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS))

self.assertEqual(int(x.sym_size(0)), 5)
self.assertEqual(int(x.sym_size(1)), 4)
self.assertEqual(int(x.sym_size(2)), 3)
self.assertEqual(int(x.shape[0]), 5)
self.assertEqual(int(x.shape[1]), 4)
self.assertEqual(int(x.shape[2]), 3)

self.assertEqual(int(x.sym_size()[0]), 5)
self.assertEqual(int(x.sym_size()[1]), 4)
self.assertTrue(isinstance(x.sym_size()[1], CPP_SYMINT_CLASS))
self.assertEqual(int(x.sym_size()[2]), 3)
self.assertEqual(int(x.size()[0]), 5)
self.assertEqual(int(x.size()[1]), 4)
self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS))
self.assertEqual(int(x.size()[2]), 3)

self.assertEqual(int(x.sym_size(0)), 5)
self.assertEqual(int(x.sym_size(1)), 4)
self.assertEqual(int(x.sym_size(2)), 3)
self.assertTrue(isinstance(x.sym_size(2), CPP_SYMINT_CLASS))
self.assertEqual(int(x.size(0)), 5)
self.assertEqual(int(x.size(1)), 4)
self.assertEqual(int(x.size(2)), 3)
self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS))

@skipIfNoSympy
def test_binary(self):
Expand All @@ -212,33 +220,33 @@ def test_binary(self):
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)

z = x + y
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
self.assertEqual(int(z.shape[0]), 5)
self.assertEqual(int(z.shape[1]), 4)
self.assertEqual(int(z.shape[2]), 3)

# broadcasting
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
z = x + y
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
self.assertEqual(int(z.shape[0]), 5)
self.assertEqual(int(z.shape[1]), 4)
self.assertEqual(int(z.shape[2]), 3)

@skipIfNoSympy
def test_symint_args(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
LAST_DIM = 2
z = x.narrow_copy(LAST_DIM, 0, y.sym_size(LAST_DIM))
self.assertEqual(int(z.sym_size(2)), int(y.sym_size(2)))
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
self.assertEqual(int(z.shape[2]), int(y.shape[2]))

# arithmetic expr with two symints
z = x.narrow_copy(LAST_DIM, 0, x.sym_size(LAST_DIM) - y.sym_size(LAST_DIM))
self.assertEqual(int(z.sym_size(2)), 2)
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
self.assertEqual(int(z.shape[2]), 2)

# arithmetic expr with a symint and python int
z = x.narrow_copy(LAST_DIM, 0, x.sym_size(LAST_DIM) - 1)
self.assertEqual(int(z.sym_size(2)), 2)
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
self.assertEqual(int(z.shape[2]), 2)

@skipIfNoSympy
def test_symint_vargs(self):
Expand All @@ -247,67 +255,67 @@ def test_symint_vargs(self):
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)

# varargs
z = y.expand(x.sym_size(0), y.sym_size(1), x.sym_size(2))
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
self.assertEqual(int(z.shape[0]), 5)
self.assertEqual(int(z.shape[1]), 4)
self.assertEqual(int(z.shape[2]), 3)

# shape list
z = y.expand((x.sym_size(0), y.sym_size(1), x.sym_size(2)))
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
self.assertEqual(int(z.shape[0]), 5)
self.assertEqual(int(z.shape[1]), 4)
self.assertEqual(int(z.shape[2]), 3)

# mixed python symints and ints
z = y.expand(x.sym_size(0), y.sym_size(1), 3)
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
z = y.expand(x.shape[0], y.shape[1], 3)
self.assertEqual(int(z.shape[0]), 5)
self.assertEqual(int(z.shape[1]), 4)
self.assertEqual(int(z.shape[2]), 3)

# mixed python symints and ints in a list
z = y.expand((x.sym_size(0), y.sym_size(1), 3))
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
z = y.expand((x.shape[0], y.shape[1], 3))
self.assertEqual(int(z.shape[0]), 5)
self.assertEqual(int(z.shape[1]), 4)
self.assertEqual(int(z.shape[2]), 3)

# mixed python symints and ints
z = y.expand(5, y.sym_size(1), x.sym_size(2))
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
z = y.expand(5, y.shape[1], x.shape[2])
self.assertEqual(int(z.shape[0]), 5)
self.assertEqual(int(z.shape[1]), 4)
self.assertEqual(int(z.shape[2]), 3)

# mixed python ints and symints in a list
z = y.expand((5, y.sym_size(1), x.sym_size(2)))
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
z = y.expand((5, y.shape[1], x.shape[2]))
self.assertEqual(int(z.shape[0]), 5)
self.assertEqual(int(z.shape[1]), 4)
self.assertEqual(int(z.shape[2]), 3)

@skipIfNoSympy
def test_size_expressions(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
expand_x = x.expand(x.sym_size(0), x.sym_size(0))
if expand_x.sym_size(0) > 3:
expand_x = x.expand(x.shape[0], x.shape[0])
if expand_x.shape[0] > 3:
result = expand_x + expand_x
else:
result = expand_x + expand_x

gt_op = shape_env.guards[0][0]
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
self.assertTrue(str(x.sym_size(0)), str(gt_op.args[0]))
self.assertTrue(str(expand_x.sym_size(1)), str(x.sym_size(0)))
self.assertTrue(str(expand_x.sym_size(1)), str(result.sym_size(0)))
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))

@skipIfNoSympy
def test_aten_ops(self):

shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
torch.ops.aten.narrow_copy.SymInt(x, 0, 0, x.sym_size(0))
torch.ops.aten.narrow_copy.SymInt(x, 0, 0, x.shape[0])

shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
torch.ops.aten.expand.SymInt(x, [x.sym_size(0), x.sym_size(1), x.sym_size(2)])
torch.ops.aten.expand.SymInt(x, [x.shape[0], x.shape[1], x.shape[2]])

def test_fx_trace_intlist(self):
class CustomModule(torch.nn.Module):
Expand All @@ -327,7 +335,7 @@ def test_meta_symint(self):
shape_env = ShapeEnv()
a0 = shape_env.create_symint("a0", 2)
r = torch.empty(a0, device='meta')
self.assertIsInstance(r.sym_size(0), CPP_SYMINT_CLASS)
self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)


if __name__ == '__main__':
Expand Down
8 changes: 4 additions & 4 deletions test/test_python_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,7 @@ def __new__(cls, data, wrapper):
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
if func.overloadpacket == torch.ops.aten.size:
if func.overloadpacket == torch.ops.aten.sym_size:
return (5, 3)
return NotImplemented

Expand All @@ -1807,13 +1807,13 @@ def __new__(cls, data, wrapper):
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
if func.overloadpacket == torch.ops.aten.size:
if func.overloadpacket == torch.ops.aten.sym_size:
return None
return NotImplemented

err_msg = "no implementation found for 'torch.ops.aten.size'"
err_msg = "no implementation found for 'torch.ops.aten.sym_size'"
e = SizesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
with self.assertRaisesRegex(RuntimeError, err_msg):
e.size()

e = SizesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
Expand Down
46 changes: 2 additions & 44 deletions tools/autograd/templates/python_variable_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,43 +95,6 @@ static PyObject * THPVariable_apply_(PyObject* self, PyObject* arg)
END_HANDLE_TH_ERRORS
}

// TODO: FIXME This should be super temprorary until we fix the XLA issue.
static PyObject * THPVariable_sym_size(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"sym_size(int64_t dim)",
"sym_size()",
"sym_size(Dimname dim)",
});
auto& self_ = THPVariable_Unpack(self);
ParsedArgs<3> parsed_args;
auto r = parser.parse(self, args, kwargs, parsed_args);

if(r.has_torch_function()){
return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
}
if (r.idx == 0) {
if (jit::tracer::isTracing()) {
// will error out if a tensor has symints
return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0)));
} else {
return torch::toPyObject(self_.sym_size(r.toInt64(0)));
}
} else if (r.idx == 1) {
return THPSize_NewFromSymSizes(self_);
}
else if (r.idx == 2) {
if (jit::tracer::isTracing()) {
TORCH_INTERNAL_ASSERT(false, "NYI: Named tensors w/ JIT");
}
return wrap(self_.size(r.dimname(0)));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}


static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
Expand All @@ -152,14 +115,10 @@ static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwa
// will error out if a tensor has symints
return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0)));
} else {
return wrap(self_.size(r.toInt64(0)));
//return torch::toPyObject(self_.sym_size(r.toInt64(0)));
return torch::toPyObject(self_.sym_size(r.toInt64(0)));
}
} else if (r.idx == 1) {
// we can't do the normal wrapping here because IntArrayRef maps to both
// torch.Size and tuple in python.
return THPSize_New(self_);
//return THPSize_NewFromSymSizes(self_);
return THPSize_NewFromSymSizes(self_);
}
else if (r.idx == 2) {
if (jit::tracer::isTracing()) {
Expand Down Expand Up @@ -1322,7 +1281,6 @@ PyMethodDef variable_methods[] = {
{"set_", castPyCFunctionWithKeywords(THPVariable_set_), METH_VARARGS | METH_KEYWORDS, NULL},
{"short", castPyCFunctionWithKeywords(THPVariable_short), METH_VARARGS | METH_KEYWORDS, NULL},
{"size", castPyCFunctionWithKeywords(THPVariable_size), METH_VARARGS | METH_KEYWORDS, NULL},
{"sym_size", castPyCFunctionWithKeywords(THPVariable_sym_size), METH_VARARGS | METH_KEYWORDS, NULL},
{"_storage", THPVariable_storage, METH_NOARGS, NULL},
{"storage_offset", THPVariable_storage_offset, METH_NOARGS, NULL},
{"stride", castPyCFunctionWithKeywords(THPVariable_stride), METH_VARARGS | METH_KEYWORDS, NULL},
Expand Down
Loading