Skip to content

[mlir][memref] Verify out-of-bounds access for memref.subview #133086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 31, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 44 additions & 89 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
@@ -1859,11 +1859,11 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
]> {
let summary = "memref subview operation";
let description = [{
The "subview" operation converts a memref type to another memref type
which represents a reduced-size view of the original memref as specified by
the operation's offsets, sizes and strides arguments.
The `subview` operation converts a memref type to a memref type which
represents a reduced-size view of the original memref as specified by the
operation's offsets, sizes and strides arguments.

The SubView operation supports the following arguments:
The `subview` operation supports the following arguments:

* source: the "base" memref on which to create a "view" memref.
* offsets: memref-rank number of offsets into the "base" memref at which to
@@ -1876,118 +1876,73 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
The representation based on offsets, sizes and strides support a
partially-static specification via attributes specified through the
`static_offsets`, `static_sizes` and `static_strides` arguments. A special
sentinel value ShapedType::kDynamic encodes that the corresponding entry has
a dynamic value.
sentinel value `ShapedType::kDynamic` encodes that the corresponding entry
has a dynamic value.

A subview operation may additionally reduce the rank of the resulting view
by removing dimensions that are statically known to be of size 1.
A `subview` operation may additionally reduce the rank of the resulting
view by removing dimensions that are statically known to be of size 1.

In the absence of rank reductions, the resulting memref type is computed
as follows:
```
result_sizes[i] = size_operands[i]
result_strides[i] = src_strides[i] * stride_operands[i]
result_offset = src_offset + dot_product(offset_operands, src_strides)
```

The offset, size and stride operands must be in-bounds with respect to the
source memref. When possible, the static operation verifier will detect
out-of-bounds subviews. Subviews that cannot be confirmed to be in-bounds
or out-of-bounds based on compile-time information are valid. However,
performing an out-of-bounds subview at runtime is undefined behavior.

Example 1:

```mlir
%0 = memref.alloc() : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>

// Create a sub-view of "base" memref '%0' with offset arguments '%c0',
// dynamic sizes for each dimension, and stride arguments '%c1'.
%1 = memref.subview %0[%c0, %c0][%size0, %size1][%c1, %c1]
: memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to
memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + d1 + s0)>>
// Subview of static memref with strided layout at static offsets, sizes
// and strides.
%1 = memref.subview %0[4, 2][8, 2][3, 2]
: memref<64x4xf32, strided<[7, 9], offset: 91>> to
memref<8x2xf32, strided<[21, 18], offset: 137>>
```

Example 2:

```mlir
%0 = memref.alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>

// Create a sub-view of "base" memref '%0' with dynamic offsets, sizes,
// Subview of static memref with identity layout at dynamic offsets, sizes
// and strides.
// Note that dynamic offsets are represented by the linearized dynamic
// offset symbol 's0' in the subview memref layout map, and that the
// dynamic strides operands, after being applied to the base memref
// strides in each dimension, are represented in the view memref layout
// map as symbols 's1', 's2' and 's3'.
%1 = memref.subview %0[%i, %j, %k][%size0, %size1, %size2][%x, %y, %z]
: memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
memref<?x?x?xf32,
affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
%1 = memref.subview %0[%off0, %off1][%sz0, %sz1][%str0, %str1]
: memref<64x4xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
```

Example 3:

```mlir
%0 = memref.alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>

// Subview with constant offsets, sizes and strides.
%1 = memref.subview %0[0, 2, 0][4, 4, 4][1, 1, 1]
: memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
memref<4x4x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)>>
// Subview of dynamic memref with strided layout at dynamic offsets and
// strides, but static sizes.
%1 = memref.subview %0[%off0, %off1][4, 4][%str0, %str1]
: memref<?x?xf32, strided<[?, ?], offset: ?>> to
memref<4x4xf32, strided<[?, ?], offset: ?>>
```

Example 4:

```mlir
%0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>

// Subview with constant size, but dynamic offsets and
// strides. The resulting memref has a static shape, but if the
// base memref has an affine map to describe the layout, the result
// memref also uses an affine map to describe the layout. The
// strides of the result memref is computed as follows:
//
// Let #map1 represents the layout of the base memref, and #map2
// represents the layout of the result memref. A #mapsubview can be
// constructed to map an index from the result memref to the base
// memref (note that the description below uses more convenient
// naming for symbols, while in affine maps, symbols are
// represented as unsigned numbers that identify that symbol in the
// given affine map.
//
// #mapsubview = (d0, d1)[o0, o1, t0, t1] -> (d0 * t0 + o0, d1 * t1 + o1)
//
// where, o0, o1, ... are offsets, and t0, t1, ... are strides. Then,
//
// #map2 = #map1.compose(#mapsubview)
//
// If the layout map is represented as
//
// #map1 = (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)
//
// then,
//
// #map2 = (d0, d1)[s0, s1, s2, o0, o1, t0, t1] ->
// (d0 * s1 * t0 + d1 * s2 * t1 + o0 * s1 + o1 * s2 + s0)
//
// Representing this canonically
//
// #map2 = (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0)
//
// where, r0 = o0 * s1 + o1 * s2 + s0, r1 = s1 * t0, r2 = s2 * t1.
%1 = memref.subview %0[%i, %j][4, 4][%x, %y] :
: memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>> to
memref<4x4xf32, affine_map<(d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0)>>

// Note that the subview op does not guarantee that the result
// memref is "inbounds" w.r.t to base memref. It is upto the client
// to ensure that the subview is accessed in a manner that is
// in-bounds.
// Rank-reducing subviews.
%1 = memref.subview %0[0, 0, 0][1, 16, 4][1, 1, 1]
: memref<8x16x4xf32> to memref<16x4xf32>
%3 = memref.subview %2[3, 4, 2][1, 6, 3][1, 1, 1]
: memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>>
```

Example 5:

```mlir
// Rank-reducing subview.
%1 = memref.subview %0[0, 0, 0][1, 16, 4][1, 1, 1] :
memref<8x16x4xf32> to memref<16x4xf32>

// Original layout:
// (d0, d1, d2) -> (64 * d0 + 16 * d1 + d2)
// Subviewed layout:
// (d0, d1, d2) -> (64 * (d0 + 3) + 4 * (d1 + 4) + d2 + 2) = (64 * d0 + 4 * d1 + d2 + 210)
// After rank reducing:
// (d0, d1) -> (4 * d0 + d1 + 210)
%3 = memref.subview %2[3, 4, 2][1, 6, 3][1, 1, 1] :
memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>>
// Identity subview. The subview is the full source memref.
%1 = memref.subview %0[0, 0, 0] [8, 16, 4] [1, 1, 1]
: memref<8x16x4xf32> to memref<8x16x4xf32>
```

}];

let arguments = (ins AnyMemRef:$source,
17 changes: 7 additions & 10 deletions mlir/include/mlir/Interfaces/ViewLikeInterface.h
Original file line number Diff line number Diff line change
@@ -76,8 +76,7 @@ SliceBoundsVerificationResult verifyInBoundsSlice(
/// returns the new result type of the op, based on the new offsets, sizes and
/// strides. `CastOpFunc` is used to generate a cast op if the result type of
/// the op has changed.
template <typename OpType, typename ResultTypeFn, typename CastOpFunc,
bool CheckInBounds = false>
template <typename OpType, typename ResultTypeFn, typename CastOpFunc>
class OpWithOffsetSizesAndStridesConstantArgumentFolder final
: public OpRewritePattern<OpType> {
public:
@@ -95,14 +94,12 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
failed(foldDynamicIndexList(mixedStrides)))
return failure();

if (CheckInBounds) {
// Pattern does not apply if the produced op would not verify.
SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
cast<ShapedType>(op.getSource().getType()).getShape(), mixedOffsets,
mixedSizes, mixedStrides);
if (!sliceResult.isValid)
return failure();
}
// Pattern does not apply if the produced op would not verify.
SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
cast<ShapedType>(op.getSource().getType()).getShape(), mixedOffsets,
mixedSizes, mixedStrides);
if (!sliceResult.isValid)
return failure();

// Compute the new result type.
auto resultType =
13 changes: 12 additions & 1 deletion mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
@@ -2977,6 +2977,9 @@ static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
LogicalResult SubViewOp::verify() {
MemRefType baseType = getSourceType();
MemRefType subViewType = getType();
ArrayRef<int64_t> staticOffsets = getStaticOffsets();
ArrayRef<int64_t> staticSizes = getStaticSizes();
ArrayRef<int64_t> staticStrides = getStaticStrides();

// The base memref and the view memref should be in the same memory space.
if (baseType.getMemorySpace() != subViewType.getMemorySpace())
@@ -2991,7 +2994,7 @@ LogicalResult SubViewOp::verify() {
// Compute the expected result type, assuming that there are no rank
// reductions.
MemRefType expectedType = SubViewOp::inferResultType(
baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
baseType, staticOffsets, staticSizes, staticStrides);

// Verify all properties of a shaped type: rank, element type and dimension
// sizes. This takes into account potential rank reductions.
@@ -3025,6 +3028,14 @@ LogicalResult SubViewOp::verify() {
return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
*this, expectedType);

// Verify that offsets, sizes, strides do not run out-of-bounds with respect
// to the base memref.
SliceBoundsVerificationResult boundsResult =
verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
staticStrides, /*generateErrorMessage=*/true);
if (!boundsResult.isValid)
return getOperation()->emitError(boundsResult.errorMessage);

return success();
}

8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
@@ -2617,10 +2617,10 @@ struct SliceCanonicalizer {

void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
ExtractSliceOp, SliceReturnTypeCanonicalizer,
SliceCanonicalizer, /*CheckInBounds=*/true>,
ExtractSliceOpCastFolder>(context);
results.add<
OpWithOffsetSizesAndStridesConstantArgumentFolder<
ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
ExtractSliceOpCastFolder>(context);
}

//
Original file line number Diff line number Diff line change
@@ -192,7 +192,7 @@ func.func @subview_const_stride(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>

// CHECK-LABEL: func @subview_const_stride_and_offset(
// CHECK-SAME: %[[MEM:.*]]: memref<{{.*}}>
func.func @subview_const_stride_and_offset(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>) -> memref<62x3xf32, strided<[4, 1], offset: 8>> {
func.func @subview_const_stride_and_offset(%0 : memref<64x8xf32, strided<[8, 1], offset: 0>>) -> memref<62x3xf32, strided<[8, 1], offset: 2>> {
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]

@@ -201,21 +201,21 @@ func.func @subview_const_stride_and_offset(%0 : memref<64x4xf32, strided<[4, 1],
// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[CST_OFF:.*]] = llvm.mlir.constant(8 : index) : i64
// CHECK: %[[CST_OFF:.*]] = llvm.mlir.constant(2 : index) : i64
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[CST_OFF]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[CST_SIZE0:.*]] = llvm.mlir.constant(62 : index) : i64
// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[CST_SIZE0]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[CST_STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64
// CHECK: %[[CST_STRIDE0:.*]] = llvm.mlir.constant(8 : index) : i64
// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[CST_STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[CST_SIZE1:.*]] = llvm.mlir.constant(3 : index) : i64
// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST_SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[CST_STRIDE1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[CST_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>

%1 = memref.subview %0[0, 8][62, 3][1, 1] :
memref<64x4xf32, strided<[4, 1], offset: 0>>
to memref<62x3xf32, strided<[4, 1], offset: 8>>
return %1 : memref<62x3xf32, strided<[4, 1], offset: 8>>
%1 = memref.subview %0[0, 2][62, 3][1, 1] :
memref<64x8xf32, strided<[8, 1], offset: 0>>
to memref<62x3xf32, strided<[8, 1], offset: 2>>
return %1 : memref<62x3xf32, strided<[8, 1], offset: 2>>
}

// -----
@@ -238,7 +238,7 @@ func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, strided<[4, 1], of
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index
// CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
// CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG1]], %[[STRIDE0]] overflow<nsw> : i64
// CHECK: %[[BASE_OFF:.*]] = llvm.mlir.constant(8 : index) : i64
// CHECK: %[[BASE_OFF:.*]] = llvm.mlir.constant(2 : index) : i64
// CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[BASE_OFF]] : i64
// CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index
// CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64
@@ -253,7 +253,7 @@ func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, strided<[4, 1], of
// CHECK: %[[CST_STRIDE1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[CST_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>

%1 = memref.subview %0[%arg1, 8][62, %arg2][%arg0, 1] :
%1 = memref.subview %0[%arg1, 2][62, %arg2][%arg0, 1] :
memref<64x4xf32, strided<[4, 1], offset: 0>>
to memref<62x?xf32, strided<[?, 1], offset: ?>>
return %1 : memref<62x?xf32, strided<[?, 1], offset: ?>>
28 changes: 14 additions & 14 deletions mlir/test/Dialect/Linalg/promote.mlir
Original file line number Diff line number Diff line change
@@ -287,18 +287,18 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0, d1) -> (d0, d1)>

// CHECK-LABEL: func.func @linalg_generic_update_all_function_inputs_outputs(
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x4xf32, 1>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<3x4xf32, 1>) -> memref<3x4xf32, 1> {
func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf32, 1>, %arg1: memref<3x4xf32, 1>) -> memref<3x4xf32, 1> {
// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3x4xf32, 1>
// CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]][0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>
// CHECK: %[[VAL_4:.*]] = memref.subview %[[VAL_1]][0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>
// CHECK: %[[VAL_5:.*]] = memref.subview %[[VAL_2]][0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>

%alloc = memref.alloc() {alignment = 64 : i64} : memref<3x4xf32, 1>
%subview = memref.subview %arg0[0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>
%subview_0 = memref.subview %arg1[0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>
%subview_1 = memref.subview %alloc[0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>
// CHECK-SAME: %[[VAL_0:.*]]: memref<8x4xf32, 1>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<8x4xf32, 1>) -> memref<8x4xf32, 1> {
func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<8x4xf32, 1>, %arg1: memref<8x4xf32, 1>) -> memref<8x4xf32, 1> {
// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x4xf32, 1>
// CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]][0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>
// CHECK: %[[VAL_4:.*]] = memref.subview %[[VAL_1]][0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>
// CHECK: %[[VAL_5:.*]] = memref.subview %[[VAL_2]][0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>

%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4xf32, 1>
%subview = memref.subview %arg0[0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>
%subview_0 = memref.subview %arg1[0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>
%subview_1 = memref.subview %alloc[0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1>

// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_7:.*]] = arith.constant 4 : index
@@ -376,10 +376,10 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf
// CHECK: memref.dealloc %[[VAL_22]] : memref<48xi8, #gpu.address_space<workgroup>>
// CHECK: memref.dealloc %[[VAL_41]] : memref<48xi8, #gpu.address_space<workgroup>>
// CHECK: memref.dealloc %[[VAL_60]] : memref<48xi8, #gpu.address_space<workgroup>>
// CHECK: return %[[VAL_2]] : memref<3x4xf32, 1>
// CHECK: return %[[VAL_2]] : memref<8x4xf32, 1>
// CHECK: }

return %alloc : memref<3x4xf32, 1>
return %alloc : memref<8x4xf32, 1>
}


Loading
Oops, something went wrong.
Loading
Oops, something went wrong.