Skip to content

Commit

Permalink
Small fix to avoid crashes.
Browse files Browse the repository at this point in the history
Reverts e5c11f3

PiperOrigin-RevId: 621211025
  • Loading branch information
tensorflower-gardener committed Apr 2, 2024
1 parent 5bda2ca commit 487f956
Show file tree
Hide file tree
Showing 49 changed files with 655 additions and 522 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -177,24 +177,25 @@ FailureOr<std::string> QuantizationMethodToTextProto(const Method& method) {
// TODO: b/307620778 - Support more advanced selective quantization methods.
LogicalResult ApplyQuantizationSpec(const QuantizationSpec& spec,
ModuleOp module_op) {
func::FuncOp main_func = FindMainFuncOp(module_op);
if (!main_func) return failure();

const Method& quantization_method = spec.method();

FailureOr<std::string> quantization_method_txtpb =
QuantizationMethodToTextProto(quantization_method);
if (failed(quantization_method_txtpb)) return failure();

const FunctionNameMatcher matcher(spec.matcher().function_name());
for (auto xla_call_module_op : main_func.getOps<TF::XlaCallModuleOp>()) {
if (!matcher.Match(xla_call_module_op)) continue;

// Set the text representation of `Method` to matched `TF::XlaCallModuleOp`.
xla_call_module_op->setAttr(
kQuantizationMethodAttr,
StringAttr::get(module_op.getContext(),
std::move(*quantization_method_txtpb)));
// Iterate over all XlaCallModuleOp in all FuncOps.
for (auto func : module_op.getOps<func::FuncOp>()) {
for (auto xla_call_module_op : func.getOps<TF::XlaCallModuleOp>()) {
if (!matcher.Match(xla_call_module_op)) continue;

// Set the text representation of `Method` to matched
// `TF::XlaCallModuleOp`.
xla_call_module_op->setAttr(
kQuantizationMethodAttr,
StringAttr::get(module_op.getContext(),
std::move(*quantization_method_txtpb)));
}
}
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,35 @@ func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY: _original_entry_function
// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY-NOT: _quantization_method
// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY: _tfl_quant_trait = "fully_quantizable"

// -----

// RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs="quantization-specs=static-range-ptq-to-all" \
// RUN: -split-input-file | FileCheck %s --check-prefix=STATIC-RANGE-PTQ-TO-ALL

// STATIC-RANGE-PTQ-TO-ALL-LABEL: @some_func
func.func @some_func(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> {
%0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32>
%1 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32>
return %1 : tensor<1x1x64xf32>
}
// Tests that XlaCallModuleOp in non-main function has attributes set correctly.

// STATIC-RANGE-PTQ-TO-ALL: %[[CONST:.+]] = stablehlo.constant dense<2.000000e+00>
// STATIC-RANGE-PTQ-TO-ALL: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%arg0, %[[CONST]])

// Check that the `_quantization_method` attribute contains the quantization
// method in textproto format, enabling static-range PTQ.
// STATIC-RANGE-PTQ-TO-ALL-SAME: _entry_function = @composite_dot_general_fn_1
// STATIC-RANGE-PTQ-TO-ALL-SAME: _original_entry_function
// STATIC-RANGE-PTQ-TO-ALL-SAME: _quantization_method = "static_range_ptq { }"
// STATIC-RANGE-PTQ-TO-ALL-SAME: _tfl_quant_trait = "fully_quantizable"

// STATIC-RANGE-PTQ-TO-ALL: return %[[XLA_CALL_MODULE:.+]] : tensor<1x1x64xf32>
// STATIC-RANGE-PTQ-TO-ALL: }

// STATIC-RANGE-PTQ-TO-ALL-LABEL: private @composite_dot_general_fn_1
// STATIC-RANGE-PTQ-TO-ALL-SAME: tf_quant.composite_function
// STATIC-RANGE-PTQ-TO-ALL: %[[DOT_GENERAL:.+]] = stablehlo.dot_general %arg0, %arg1
// STATIC-RANGE-PTQ-TO-ALL: return %[[DOT_GENERAL:.+]] : tensor<1x1x64xf32>
// STATIC-RANGE-PTQ-TO-ALL: }
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ class HloProtoBufferWrapper {
// to obtain the buffer allocation index ourselves.
if (heap_simulator_traces[i].events().empty()) continue;
int logical_buffer_id = heap_simulator_traces[i].events(0).buffer_id();
if (!id_to_logical_buffer_.contains(logical_buffer_id)) continue;
auto* logical_buffer = id_to_logical_buffer_[logical_buffer_id].get();
auto buffer_allocation_index = logical_buffer->buffer_allocation.index();
id_to_buffer_allocation_[buffer_allocation_index]
Expand Down
47 changes: 0 additions & 47 deletions third_party/triton/cl609333259.patch

This file was deleted.

1 change: 0 additions & 1 deletion third_party/triton/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,5 @@ def repo():
"//third_party/triton:cl617812302.patch",
"//third_party/triton:cl619146327.patch",
"//third_party/triton:cl619443019.patch",
"//third_party/triton:cl609333259.patch",
],
)
47 changes: 0 additions & 47 deletions third_party/xla/third_party/triton/cl609333259.patch

This file was deleted.

1 change: 0 additions & 1 deletion third_party/xla/third_party/triton/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,5 @@ def repo():
"//third_party/triton:cl617812302.patch",
"//third_party/triton:cl619146327.patch",
"//third_party/triton:cl619443019.patch",
"//third_party/triton:cl609333259.patch",
],
)
5 changes: 4 additions & 1 deletion third_party/xla/xla/ffi/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,12 @@ struct ArgDecoding<BufferBase> {

auto* buf = reinterpret_cast<XLA_FFI_Buffer*>(arg);

size_t size_bytes = primitive_util::ByteWidth(PrimitiveType(buf->dtype));
for (int64_t i = 0; i < buf->rank; ++i) size_bytes *= buf->dims[i];

BufferBase buffer;
buffer.dtype = PrimitiveType(buf->dtype);
buffer.data = se::DeviceMemoryBase(buf->data);
buffer.data = se::DeviceMemoryBase(buf->data, size_bytes);
buffer.dimensions = absl::MakeConstSpan(buf->dims, buf->rank);
return buffer;
}
Expand Down
67 changes: 61 additions & 6 deletions third_party/xla/xla/python/ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,11 @@ cc_library(
],
compatible_with = get_compatible_with_portable(),
deps = [
":device_proto_cc",
":dtype_proto_cc",
":serdes",
":types_proto_cc",
":shape_proto_cc",
":sharding_proto_cc",
"//xla:status",
"//xla:statusor",
"//xla:util",
Expand All @@ -87,6 +90,7 @@ cc_library(
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
Expand Down Expand Up @@ -160,7 +164,7 @@ xla_cc_test(
srcs = ["shape_test.cc"],
deps = [
":ifrt",
":types_proto_cc",
":shape_proto_cc",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:status_matchers",
Expand Down Expand Up @@ -366,8 +370,11 @@ cc_library(
":ifrt",
":serdes",
":sharding_proto_cc",
"//xla:statusor",
":sharding_serdes_proto_cc",
"//xla:util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@llvm-project//llvm:Support",
"@local_tsl//tsl/platform:statusor",
],
Expand All @@ -385,17 +392,65 @@ xla_cc_test(
":sharding_test_util",
"@com_google_absl//absl/functional:bind_front",
"@com_google_googletest//:gtest_main",
"@llvm-project//llvm:Support",
"@local_tsl//tsl/platform:statusor",
],
)

tf_proto_library(
name = "types_proto",
srcs = ["types.proto"],
name = "device_proto",
srcs = ["device.proto"],
)

xla_cc_test(
name = "device_test",
size = "small",
srcs = ["device_test.cc"],
deps = [
":device_proto_cc",
":ifrt",
":sharding_test_util",
"@com_google_absl//absl/status:statusor",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:statusor",
],
)

tf_proto_library(
name = "dtype_proto",
srcs = ["dtype.proto"],
)

xla_cc_test(
name = "dtype_test",
size = "small",
srcs = ["dtype_test.cc"],
deps = [
":dtype_proto_cc",
":ifrt",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
],
)

tf_proto_library(
name = "shape_proto",
srcs = ["shape.proto"],
)

tf_proto_library(
name = "sharding_proto",
srcs = ["sharding.proto"],
protodeps = [":types_proto"],
protodeps = [":serdes_proto"],
)

tf_proto_library(
name = "sharding_serdes_proto",
srcs = ["sharding_serdes.proto"],
protodeps = [
":device_proto",
":dtype_proto",
":shape_proto",
],
)
17 changes: 16 additions & 1 deletion third_party/xla/xla/python/ifrt/device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@ limitations under the License.
#include "xla/python/ifrt/device.h"

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "xla/python/ifrt/types.pb.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "xla/python/ifrt/device.pb.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace ifrt {
Expand Down Expand Up @@ -52,6 +57,16 @@ DeviceListProto DeviceList::ToProto() const {
return proto;
}

std::string DeviceList::DebugString() const {
return absl::StrCat("[",
absl::StrJoin(devices(), ",",
[](std::string* out, Device* device) {
absl::StrAppend(out,
device->DebugString());
}),
"]");
}

std::vector<int> GetDeviceIds(DeviceList device_list) {
std::vector<int> ids;
ids.reserve(device_list.devices().size());
Expand Down
6 changes: 5 additions & 1 deletion third_party/xla/xla/python/ifrt/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@ limitations under the License.
#define XLA_PYTHON_IFRT_DEVICE_H_

#include <memory>
#include <string>
#include <type_traits>
#include <variant>
#include <vector>

#include "absl/container/inlined_vector.h"
#include "absl/functional/function_ref.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/python/ifrt/types.pb.h"
#include "xla/python/ifrt/device.pb.h"

namespace xla {
namespace ifrt {
Expand Down Expand Up @@ -89,6 +91,8 @@ class DeviceList {
auto end() const { return state().devices.end(); }
auto cend() const { return state().devices.cend(); }

std::string DebugString() const;

private:
// Internal state that may be shared across `DeviceList` instances.
struct State {
Expand Down

0 comments on commit 487f956

Please sign in to comment.