Skip to content

Commit

Permalink
[spirv] Enable vectorizing tensor.extract into vector.gather (iree-or…
Browse files Browse the repository at this point in the history
…g#13626)

This also adds a check in `SPIRVVectorizePass` to make sure that we
don't have remaining linalg ops after vectorization to avoid suprises.
  • Loading branch information
antiagainst authored and nhasabni committed Aug 24, 2023
1 parent e154adf commit 5c43522
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
26 changes: 21 additions & 5 deletions compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,16 @@ std::optional<SmallVector<int64_t>> getNativeVectorShape(

/// Add patterns to vectorize any supported Linalg ops.
void populateVectorizationPatterns(RewritePatternSet &patterns) {
IREE::LinalgExt::LinalgTransformationFilter f;
IREE::LinalgExt::LinalgVectorizationOptions vectorizationOptions;
IREE::LinalgExt::LinalgTransformationFilter filter;
IREE::LinalgExt::LinalgVectorizationOptions options;
// Enable vectorizing tensor.extract in Linalg ops.
options.vectorizeGatherAccesses = true;
VectorizationPatterns<linalg::FillOp, linalg::GenericOp>::insert(
patterns, vectorizationOptions, f);
patterns, options, filter);
linalg::populateConvolutionVectorizationPatterns(patterns);
patterns.add<LinalgVectorizationPattern>(
patterns.getContext(), vectorizationOptions,
f.addOpFilter<linalg::ContractionOpInterface>());
patterns.getContext(), options,
filter.addOpFilter<linalg::ContractionOpInterface>());
}

/// Adds patterns to unroll vector ops to SPIR-V native vector size.
Expand Down Expand Up @@ -324,6 +326,20 @@ class SPIRVVectorizePass : public SPIRVVectorizeBase<SPIRVVectorizePass> {
llvm::dbgs() << "\n\n";
});

{
auto result = funcOp.walk([&](linalg::LinalgOp op) {
// linalg.generic ops for copy are fine to not vectorize; they will be
// handled in later steps.
if (isa<linalg::YieldOp>(op.getBlock()->begin())) {
return WalkResult::advance();
}
// Other ones should error out.
op.emitOpError("should not remain after vectorization");
return WalkResult::interrupt();
});
if (result.wasInterrupted()) return signalPassFailure();
}

// Special peephole optimizations to clean up IR before further processing.
{
RewritePatternSet patterns(context);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,31 @@
// RUN: iree-opt --split-input-file --iree-spirv-vectorize -canonicalize %s | FileCheck %s

func.func @tensor_extract(%arg0: tensor<6x4xf32>, %arg1: tensor<6xi32>, %data: tensor<1x2x512xf32>, %init: tensor<6x4xf32>, %i : index, %j: index) -> tensor<6x4xf32> {
%c0 = arith.constant 0 : index
%f0 = arith.constant 0.000000e+00 : f32
%generic = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} ins(%arg0, %arg1 : tensor<6x4xf32>, tensor<6xi32>) outs(%init : tensor<6x4xf32>) {
^bb0(%in: f32, %in_2: i32, %out: f32):
%0 = linalg.index 1 : index
%1 = affine.apply affine_map<(d0, d1)[s0] -> (d0 + d1 + s0 * 32)>(%0, %j)[%i]
%2 = arith.addf %out, %f0 : f32
%3 = arith.index_cast %in_2 : i32 to index
%extracted = tensor.extract %data[%c0, %3, %1] : tensor<1x2x512xf32>
%4 = arith.addf %2, %in : f32
%5 = arith.addf %4, %extracted : f32
linalg.yield %5 : f32
} -> tensor<6x4xf32>
return %generic : tensor<6x4xf32>
}

// CHECK-LABEL: func.func @tensor_extract
// CHECK-NOT: linalg.generic
// CHECK-COUNT-24: tensor.extract {{.+}} : tensor<1x2x512xf32>

// -----

func.func @vector_gather(%arg0: memref<16x1082x1922xi8>, %index_vec: vector<16xindex>) -> vector<16xi8> {
%c0 = arith.constant 0 : index
%mask = arith.constant dense<true> : vector<16xi1>
Expand Down Expand Up @@ -34,4 +60,3 @@ func.func @vector_gather(%arg0: memref<16x1082x1922xi8>, %index_vec: vector<16xi

// CHECK: vector.insert_strided_slice %[[INSERT3]], %[[INIT]] {offsets = [0], strides = [1]} : vector<4xi8> into vector<16xi8>
// CHECK-12: vector.load %[[ARG0]][%[[C0]], %[[C0]], %{{.*}}] : memref<16x1082x1922xi8>, vector<1xi8>

0 comments on commit 5c43522

Please sign in to comment.