Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VNNI dialect operands extended with memref type #181

Merged
merged 2 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions include/TPP/Dialect/VNNI/BufferizableOpInterfaceImpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VNNI_BUFFERIZABLEOPINTERFACEIMPL_H
#define MLIR_DIALECT_VNNI_BUFFERIZABLEOPINTERFACEIMPL_H

namespace mlir {
class DialectRegistry;
} // namespace mlir

namespace mlir {
namespace vnni {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace vnni
} // namespace mlir

#endif // MLIR_DIALECT_VNNI_BUFFERIZABLEOPINTERFACEIMPL_H
6 changes: 4 additions & 2 deletions include/TPP/Dialect/VNNI/VNNIOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

include "mlir/IR/OpBase.td"

def VNNIOperand : AnyTypeOf<[BF16Tensor, BF16MemRef]>;

def VNNI_Dialect : Dialect {
let name = "vnni";
let cppNamespace = "::mlir::vnni";
Expand All @@ -37,8 +39,8 @@ def VNNI_MatmulOp : Op<VNNI_Dialect, "matmul",
```
}];

let arguments = (ins BF16Tensor:$matrixA, BF16Tensor:$matrixB, BF16Tensor:$matrixC);
let results = (outs BF16Tensor:$dest);
let arguments = (ins VNNIOperand:$matrixA, VNNIOperand:$matrixB, VNNIOperand:$matrixC);
let results = (outs VNNIOperand:$dest);
let assemblyFormat = "`ins` `(` $matrixA `:` type($matrixA) `,` $matrixB `:` type($matrixB) `)` `out` `(` $matrixC `:` type($dest) `)` attr-dict";
}

Expand Down
92 changes: 92 additions & 0 deletions lib/TPP/Dialect/VNNI/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TPP/Dialect/VNNI/BufferizableOpInterfaceImpl.h"
#include "TPP/Dialect/VNNI/VNNIDialect.h"
#include "TPP/Dialect/VNNI/VNNIOps.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/IR/Operation.h"

using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::vnni;

namespace mlir {
namespace vnni {
namespace {

struct MatmulLayoutInterface
: public BufferizableOpInterface::ExternalModel<MatmulLayoutInterface,
vnni::MatmulOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return opOperand.getOperandNumber() == 0 ||
opOperand.getOperandNumber() == 1;
}

bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return opOperand.getOperandNumber() == 2;
}

bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}

SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
if (opOperand.getOperandNumber() < 1)
return {};
return {op->getResult(0)};
}

BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::Equivalent;
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
vnni::MatmulOp matmulOp = cast<vnni::MatmulOp>(op);

FailureOr<Value> maybeDestBuffer =
getBuffer(rewriter, matmulOp.getMatrixC(), options);
if (failed(maybeDestBuffer))
return failure();
Value destBuffer = *maybeDestBuffer;

FailureOr<Value> maybeSrcBufferA =
getBuffer(rewriter, matmulOp.getMatrixA(), options);
if (failed(maybeSrcBufferA))
return failure();
Value srcBufferA = *maybeSrcBufferA;

FailureOr<Value> maybeSrcBufferB =
getBuffer(rewriter, matmulOp.getMatrixB(), options);
if (failed(maybeSrcBufferB))
return failure();
Value srcBufferB = *maybeSrcBufferB;

rewriter.create<vnni::MatmulOp>(op->getLoc(), destBuffer.getType(),
srcBufferA, srcBufferB, destBuffer);
replaceOpWithBufferizedValues(rewriter, op, destBuffer);
return success();
}
};

} // namespace
} // namespace vnni
} // namespace mlir

void mlir::vnni::registerBufferizableOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, vnni::VNNIDialect *dialect) {
MatmulOp::attachInterface<vnni::MatmulLayoutInterface>(*ctx);
});
}
1 change: 1 addition & 0 deletions lib/TPP/Dialect/VNNI/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_dialect_library(TPPVNNIDialect
# Ops and dialect
BufferizableOpInterfaceImpl.cpp
VNNIDialect.cpp
VNNIOps.cpp

Expand Down
15 changes: 15 additions & 0 deletions test/BF16/vnni-bufferization.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: tpp-opt %s -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs function-boundary-type-conversion=identity-layout-map" | FileCheck %s

// CHECK-LABEL: @myfunc(
// CHECK: %[[ARG0:.+]]: memref<2x2x2xbf16>,
// CHECK: %[[ARG1:.+]]: memref<2x2xbf16>,
// CHECK: %[[ARG2:.+]]: memref<4x2xbf16>) -> memref<4x2xbf16> {
func.func @myfunc(%arg0: tensor<2x2x2xbf16>,
%arg1: tensor<2x2xbf16>,
%arg2: tensor<4x2xbf16>) -> tensor<4x2xbf16> {
// CHECK: %[[ALLOC:.+]] = memref.alloc() {alignment = 128 : i64} : memref<4x2xbf16>
// CHECK: %[[RET:.+]] = vnni.matmul ins(%[[ARG0]] : memref<2x2x2xbf16>, %[[ARG1]] : memref<2x2xbf16>) out(%[[ALLOC]] : memref<4x2xbf16>)
%vnni_result = vnni.matmul ins(%arg0: tensor<2x2x2xbf16>, %arg1: tensor<2x2xbf16>) out(%arg2: tensor<4x2xbf16>)
// CHECK: return %[[ALLOC]]
return %vnni_result:tensor<4x2xbf16>
}
10 changes: 10 additions & 0 deletions test/BF16/vnni-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,13 @@ func.func @myfunc(%arg0: tensor<2x2x2xbf16>,
%vnni_result = vnni.matmul ins(%arg0: tensor<2x2x2xbf16>, %arg1: tensor<2x2xbf16>) out(%arg2: tensor<4x2xbf16>)
return %vnni_result:tensor<4x2xbf16>
}


// CHECK-LABEL: @myfunc2
func.func @myfunc2(%arg0: memref<2x2x2xbf16>,
%arg1: memref<2x2xbf16>,
%arg2: memref<4x2xbf16>){
// CHECK: vnni.matmul
vnni.matmul ins(%arg0: memref<2x2x2xbf16>, %arg1: memref<2x2xbf16>) out(%arg2: memref<4x2xbf16>)
return
}
2 changes: 2 additions & 0 deletions tpp-opt/tpp-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "TPP/Dialect/LinalgX/LinalgXDialect.h"
#include "TPP/Dialect/LinalgX/TransformOps/LinalgXTransformOps.h"
#include "TPP/Dialect/Tpp/TppDialect.h"
#include "TPP/Dialect/VNNI/BufferizableOpInterfaceImpl.h"
#include "TPP/Dialect/VNNI/VNNIDialect.h"
#include "TPP/Dialect/Xsmm/XsmmDialect.h"
#include "TPP/Passes.h"
Expand All @@ -42,6 +43,7 @@ int main(int argc, char **argv) {
mlir::linalgx::registerTransformDialectExtension(registry);
mlir::linalgx::registerBufferizableOpInterfaceExternalModels(registry);
mlir::check::registerBufferizableOpInterfaceExternalModels(registry);
mlir::vnni::registerBufferizableOpInterfaceExternalModels(registry);
// Add the following to include *all* MLIR Core dialects, or selectively
// include what you need like above. You only need to register dialects that
// will be *parsed* by the tool, not the one generated
Expand Down