Skip to content

Commit

Permalink
Integrate LLVM at llvm/llvm-project@27cf6ba1d7bc
Browse files Browse the repository at this point in the history
Updates LLVM usage to match
[27cf6ba1d7bc](llvm/llvm-project@27cf6ba1d7bc)

PiperOrigin-RevId: 530473001
  • Loading branch information
tensorflower-gardener committed May 9, 2023
1 parent 9f57eb7 commit c3bb0e5
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 2 deletions.
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -11187,6 +11187,8 @@ underlying graph, and executes each of the partitioned subgraphs as a function.

// Returns the callee of this operation.
CallInterfaceCallable getCallableForCallee() { return getFAttr(); }
// Sets the callee from the callable.
void setCalleeFromCallable(CallInterfaceCallable callee);

// returns the callee of this operation.
func::FuncOp func() {
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@ def TF_LegacyCallOp : TF_Op<"LegacyCall",

// Returns the callee of this operation.
CallInterfaceCallable getCallableForCallee() { return getFAttr(); }
// Sets the callee from the callable
void setCalleeFromCallable(::mlir::CallInterfaceCallable callee);

// Returns the resolved callee function of this operation.
// Prefer passing in SymbolTableCollection to reduce lookup costs by
Expand Down Expand Up @@ -570,6 +572,8 @@ underlying graph, and executes each of the partitioned subgraphs as a function.

// Returns the callee of this operation.
CallInterfaceCallable getCallableForCallee() { return getFAttr(); }
// Sets the callee from the callable
void setCalleeFromCallable(::mlir::CallInterfaceCallable callee);

// Returns the resolved callee function of this operation.
// Prefer passing in SymbolTableCollection to reduce lookup costs by
Expand Down Expand Up @@ -1009,6 +1013,8 @@ def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall",

// Returns the callee of this operation.
CallInterfaceCallable getCallableForCallee() { return getFAttr(); }
// Sets the callee from the callable.
void setCalleeFromCallable(CallInterfaceCallable callee);

// Returns the resolved callee function of this operation.
// Prefer passing in SymbolTableCollection to reduce lookup costs by
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3214,6 +3214,16 @@ LogicalResult LegacyCallOp::verifySymbolUses(
return success();
}

void LegacyCallOp::setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
// Direct call.
if (SymbolRefAttr fAttr = getFAttr()) {
SymbolRefAttr calleeAttr = callee.get<SymbolRefAttr>();
return setFAttr(cast<FlatSymbolRefAttr>(calleeAttr));
}
// Indirect call, callee Value is the first operand.
return setOperand(0, callee.get<Value>());
}

//===----------------------------------------------------------------------===//
// LogOp
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 26 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ limitations under the License.
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Parser/Parser.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
Expand Down Expand Up @@ -563,6 +564,31 @@ LogicalResult TPUPartitionedCallOp::verifySymbolUses(
return VerifyPartitionedCall(*this, symbolTable);
}

template <typename CallOpClass>
static void SetPartitionCalleeFromCallable(CallOpClass op,
mlir::CallInterfaceCallable callee) {
// Direct call.
if (SymbolRefAttr fAttr = op.getFAttr()) {
SymbolRefAttr calleeAttr = callee.get<SymbolRefAttr>();
return op.setFAttr(cast<FlatSymbolRefAttr>(calleeAttr));
}
// Indirect call, callee Value is the first operand.
return op.setOperand(0, callee.get<Value>());
}

void PartitionedCallOp::setCalleeFromCallable(
mlir::CallInterfaceCallable callee) {
return SetPartitionCalleeFromCallable(*this, callee);
}
void StatefulPartitionedCallOp::setCalleeFromCallable(
CallInterfaceCallable callee) {
return SetPartitionCalleeFromCallable(*this, callee);
}
void TPUPartitionedCallOp::setCalleeFromCallable(
mlir::CallInterfaceCallable callee) {
return SetPartitionCalleeFromCallable(*this, callee);
}

//===----------------------------------------------------------------------===//
// PowOp
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ bool TFRType::classof(Type type) {
// Custom op methods
//===----------------------------------------------------------------------===//

void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
// Direct call.
if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
auto symRef = callee.get<SymbolRefAttr>();
return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
}
// Indirect call, callee Value is the first operand.
return setOperand(0, callee.get<Value>());
}

LogicalResult ConstantTensorOp::verify() {
ConstantTensorOp op = *this;
auto input_type = op.getArg().getType();
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/tfr/ir/tfr_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def TFR_CallOp : TFR_Op<"call", [CallOpInterface]> {

// Return the callee of this operation.
CallInterfaceCallable getCallableForCallee() { return getCalleeAttr(); }
// Sets the callee from the callable
void setCalleeFromCallable(CallInterfaceCallable callee);
}];

let assemblyFormat = [{
Expand Down
17 changes: 17 additions & 0 deletions tensorflow/compiler/xla/python/ifrt/ir/ifrt_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ mlir::LogicalResult DisassembleOp::verify() {
mlir::CallInterfaceCallable CallOp::getCallableForCallee() {
return (*this)->getAttrOfType<mlir::SymbolRefAttr>("callee");
}
void CallOp::setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
// Direct call
if ((*this)->getAttrOfType<mlir::SymbolRefAttr>("callee")) {
(*this)->setAttr("callee", callee.get<mlir::SymbolRefAttr>());
}
// Indirect call, callee Value is the first operand.
return setOperand(0, callee.get<mlir::Value>());
}

mlir::Operation::operand_range CallOp::getArgOperands() { return getInputs(); }

Expand Down Expand Up @@ -305,6 +313,15 @@ mlir::LogicalResult CallOp::verify() {
mlir::CallInterfaceCallable CallLoadedExecutableOp::getCallableForCallee() {
return (*this)->getAttrOfType<mlir::SymbolRefAttr>("callee");
}
void CallLoadedExecutableOp::setCalleeFromCallable(
mlir::CallInterfaceCallable callee) {
// Direct call
if ((*this)->getAttrOfType<mlir::SymbolRefAttr>("callee")) {
(*this)->setAttr("callee", callee.get<mlir::SymbolRefAttr>());
}
// Indirect call, callee Value is the first operand.
return setOperand(0, callee.get<mlir::Value>());
}

mlir::Operation::operand_range CallLoadedExecutableOp::getArgOperands() {
return getInputs();
Expand Down
4 changes: 2 additions & 2 deletions third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
LLVM_COMMIT = "14f0776550b5a49e1c42f49a00213f7f3fa047bf"
LLVM_SHA256 = "939e616f9ff00f5e06f471668c5dd1cbf75d2da05665424eada7338d627e336b"
LLVM_COMMIT = "27cf6ba1d7bc623a5dca5c0ae82af98d0cdfc390"
LLVM_SHA256 = "6b0dd4838b73035659414f167a3a4f252277734642245a33016ccc532f2720cc"

tf_http_archive(
name = name,
Expand Down
41 changes: 41 additions & 0 deletions third_party/triton/cl530389221.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
==== triton/include/triton/Dialect/Triton/IR/TritonOps.td#7 - /google/src/cloud/peiming/mlir_a2ab6a5e2b8d4e10ce29b24db7d6ae18c9acbec1_1683576894/triton/include/triton/Dialect/Triton/IR/TritonOps.td ====
# action=edit type=text
--- triton/include/triton/Dialect/Triton/IR/TritonOps.td 2023-04-24 23:33:26.000000000 -0700
+++ triton/include/triton/Dialect/Triton/IR/TritonOps.td 2023-05-08 13:18:46.000000000 -0700
@@ -591,6 +591,8 @@
CallInterfaceCallable getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
+ /// Set the callee from the callable
+ void setCalleeFromCallable(CallInterfaceCallable);
}];

let assemblyFormat = [{
==== triton/lib/Dialect/Triton/IR/Ops.cpp#8 - /google/src/cloud/peiming/mlir_a2ab6a5e2b8d4e10ce29b24db7d6ae18c9acbec1_1683576894/triton/lib/Dialect/Triton/IR/Ops.cpp ====
# action=edit type=text
--- triton/lib/Dialect/Triton/IR/Ops.cpp 2023-05-02 04:45:24.000000000 -0700
+++ triton/lib/Dialect/Triton/IR/Ops.cpp 2023-05-08 13:48:19.000000000 -0700
@@ -4,6 +4,7 @@
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/OperationSupport.h"
+#include "mlir/Interfaces/CallInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"

@@ -683,6 +684,15 @@

return success();
}
+
+void triton::CallOp::setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
+ // Direct call
+ if ((*this)->getAttrOfType<mlir::SymbolRefAttr>("callee")) {
+ (*this)->setAttr("callee", callee.get<mlir::SymbolRefAttr>());
+ }
+ // Indirect call, callee Value is the first operand.
+ return setOperand(0, callee.get<mlir::Value>());
+}

// -- ReturnOp --
LogicalResult triton::ReturnOp::verify() {
1 change: 1 addition & 0 deletions third_party/triton/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ def repo():
patch_file = [
"//third_party/triton:cl526173620.patch",
"//third_party/triton:cl528701873.patch",
"//third_party/triton:cl530389221.patch",
],
)

0 comments on commit c3bb0e5

Please sign in to comment.