diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index a99e6e3a50c1..366f67a3be44 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -3227,16 +3227,28 @@ static inferSqueezeGeometry(const Tensor &tensor, std::bitset d namespace { // Named type instead of a pair/tuple so that we can be sure to // construct the vectors in place and get NRVO. +template struct InferUnsqueezeGeometryResult { - DimVector sizes; - DimVector strides; - InferUnsqueezeGeometryResult(IntArrayRef tensor_sizes, IntArrayRef tensor_strides) + SmallVectorsizes; + SmallVector strides; + InferUnsqueezeGeometryResult(ArrayRef tensor_sizes, ArrayRef tensor_strides) : sizes(tensor_sizes.begin(), tensor_sizes.end()) , strides(tensor_strides.begin(), tensor_strides.end()) {} }; -InferUnsqueezeGeometryResult + +InferUnsqueezeGeometryResult +inferUnsqueezeGeometry_symint(const Tensor& tensor, int64_t dim) { + InferUnsqueezeGeometryResult result(tensor.sym_sizes(), tensor.sym_strides()); + c10::SymInt new_stride = dim >= tensor.dim() ? 1 : result.sizes[dim] * result.strides[dim]; + result.sizes.insert(result.sizes.begin() + dim, 1); + result.strides.insert(result.strides.begin() + dim, new_stride); + + return result; +} + +InferUnsqueezeGeometryResult inferUnsqueezeGeometry(const Tensor& tensor, int64_t dim) { - InferUnsqueezeGeometryResult result(tensor.sizes(), tensor.strides()); + InferUnsqueezeGeometryResult result(tensor.sizes(), tensor.strides()); int64_t new_stride = dim >= tensor.dim() ? 1 : result.sizes[dim] * result.strides[dim]; result.sizes.insert(result.sizes.begin() + dim, 1); result.strides.insert(result.strides.begin() + dim, new_stride); @@ -3377,8 +3389,8 @@ Tensor _unsafe_view(const Tensor& self, IntArrayRef size) { Tensor unsqueeze(const Tensor& self, int64_t dim) { dim = maybe_wrap_dim(dim, self.dim() + 1); - auto g = inferUnsqueezeGeometry(self, dim); - return self.as_strided(g.sizes, g.strides); + auto g = inferUnsqueezeGeometry_symint(self, dim); + return self.as_strided_symint(g.sizes, g.strides); } Tensor unsqueeze_sparse(Tensor const &self, int64_t dim) { diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 40075eb24e04..079d2211a6e8 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -805,6 +805,19 @@ def forward(self, input): ep = torch.export.export(MyNumpyModel(), args=(torch.randn(1000),)) self.assertTrue(isinstance(ep, torch.export.ExportedProgram)) + def test_unsqueeze_copy(self): + shape_env = ShapeEnv() + t1 = torch.ones(2, 2, 768) + with FakeTensorMode(shape_env=shape_env) as fake_mode: + t = fake_mode.from_tensor( + t1, + symbolic_context=StatelessSymbolicContext( + dynamic_sizes=[DimDynamic.DYNAMIC, DimDynamic.STATIC, DimDynamic.STATIC], + ) + ) + + self.assertEqual(t.shape[0], torch.ops.aten.unsqueeze_copy(t, 1).shape[0]) + def test_alias_call(self): fwAD = torch.autograd.forward_ad