Skip to content

Commit

Permalink
SymInt-ify unsqueeze_copy
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi committed May 13, 2024
1 parent 96bdb7a commit 523a4d0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
26 changes: 19 additions & 7 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3227,16 +3227,28 @@ static inferSqueezeGeometry(const Tensor &tensor, std::bitset<dim_bitset_size> d
namespace {
// Named type instead of a pair/tuple so that we can be sure to
// construct the vectors in place and get NRVO.
template <typename T>
struct InferUnsqueezeGeometryResult {
DimVector sizes;
DimVector strides;
InferUnsqueezeGeometryResult(IntArrayRef tensor_sizes, IntArrayRef tensor_strides)
SmallVector<T, kDimVectorStaticSize>sizes;
SmallVector<T, kDimVectorStaticSize> strides;
InferUnsqueezeGeometryResult(ArrayRef<T> tensor_sizes, ArrayRef<T> tensor_strides)
: sizes(tensor_sizes.begin(), tensor_sizes.end())
, strides(tensor_strides.begin(), tensor_strides.end()) {}
};
InferUnsqueezeGeometryResult

InferUnsqueezeGeometryResult<c10::SymInt>
inferUnsqueezeGeometry_symint(const Tensor& tensor, int64_t dim) {
InferUnsqueezeGeometryResult<c10::SymInt> result(tensor.sym_sizes(), tensor.sym_strides());
c10::SymInt new_stride = dim >= tensor.dim() ? 1 : result.sizes[dim] * result.strides[dim];
result.sizes.insert(result.sizes.begin() + dim, 1);
result.strides.insert(result.strides.begin() + dim, new_stride);

return result;
}

InferUnsqueezeGeometryResult<int64_t>
inferUnsqueezeGeometry(const Tensor& tensor, int64_t dim) {
InferUnsqueezeGeometryResult result(tensor.sizes(), tensor.strides());
InferUnsqueezeGeometryResult<int64_t> result(tensor.sizes(), tensor.strides());
int64_t new_stride = dim >= tensor.dim() ? 1 : result.sizes[dim] * result.strides[dim];
result.sizes.insert(result.sizes.begin() + dim, 1);
result.strides.insert(result.strides.begin() + dim, new_stride);
Expand Down Expand Up @@ -3377,8 +3389,8 @@ Tensor _unsafe_view(const Tensor& self, IntArrayRef size) {

Tensor unsqueeze(const Tensor& self, int64_t dim) {
dim = maybe_wrap_dim(dim, self.dim() + 1);
auto g = inferUnsqueezeGeometry(self, dim);
return self.as_strided(g.sizes, g.strides);
auto g = inferUnsqueezeGeometry_symint(self, dim);
return self.as_strided_symint(g.sizes, g.strides);
}

Tensor unsqueeze_sparse(Tensor const &self, int64_t dim) {
Expand Down
13 changes: 13 additions & 0 deletions test/test_fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,19 @@ def forward(self, input):
ep = torch.export.export(MyNumpyModel(), args=(torch.randn(1000),))
self.assertTrue(isinstance(ep, torch.export.ExportedProgram))

def test_unsqueeze_copy(self):
shape_env = ShapeEnv()
t1 = torch.ones(2, 2, 768)
with FakeTensorMode(shape_env=shape_env) as fake_mode:
t = fake_mode.from_tensor(
t1,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=[DimDynamic.DYNAMIC, DimDynamic.STATIC, DimDynamic.STATIC],
)
)

self.assertEqual(t.shape[0], torch.ops.aten.unsqueeze_copy(t, 1).shape[0])

def test_alias_call(self):
fwAD = torch.autograd.forward_ad

Expand Down

0 comments on commit 523a4d0

Please sign in to comment.