Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/scripts/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ function run_torch_xla_cpp_tests() {
"test_aten_xla_tensor_2"
"test_aten_xla_tensor_3"
"test_aten_xla_tensor_4"
"pjrt_computation_client_test"
"ifrt_computation_client_test")
"pjrt_computation_client_test")
# Disable IFRT test as it currently crashes
#"ifrt_computation_client_test")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it crash? leave a note?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Saw it in the PR comment, //test/cpp:test_xla_sharding is the failed test right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that's the thing I don't know, it crashed right away for any XLA graph execution. I decided to not digging too much giving no one is using ifrt right now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ifrt_computation_client_test is the only test that we have for ifrt and it is crashing. It also crashed on any ifrt execution.

test_names2=("test_aten_xla_tensor_5"
"test_aten_xla_tensor_6"
"test_ir"
Expand Down
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ new_local_repository(
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.

xla_hash = '32ebd694c4d0442e241d76324ff1a721831366b4'
xla_hash = 'eef7ee50d0980848436f0b4f402cec8c5bf86f21'

http_archive(
name = "xla",
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_date = '20240913'
_date = '20241015'
_libtpu_version = f'0.1.dev{_date}'
_libtpu_storage_path = f'https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}+nightly-py3-none-any.whl'
_jax_version = f'0.4.33.dev{_date}'
_jax_version = f'0.4.35.dev{_date}'


def _get_build_mode():
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ ptxla_cc_test(
"//torch_xla/csrc:aten_cuda_functions",
"@com_google_googletest//:gtest_main",
"@xla//xla:xla_data_proto_cc",
"@tsl//tsl/profiler/utils:session_manager",
"@xla//xla/tsl/profiler/utils:session_manager",
],
)

Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ ptxla_cc_library(
"@xla//xla/client/lib:slicing",
"@xla//xla/client/lib:sorting",
"@xla//xla/client/lib:svd",
"@xla//xla/hlo/pass:hlo_pass_pipeline",
"@xla//xla/stream_executor:dnn",
"@tsl//tsl/platform:errors",
"@tsl//tsl/profiler/lib:traceme",
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@
#include "torch_xla/csrc/xla_sharding_util.h"
#include "tsl/platform/env.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/python/profiler/internal/traceme_wrapper.h"
#include "xla/service/hlo_parser.h"

namespace torch_xla {
namespace {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "torch_xla/csrc/shape_helper.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/matrix.h"
#include "xla/client/lib/qr.h"
#include "xla/hlo/builder/lib/qr.h"
#include "xla/shape_util.h"
#include "xla/util.h"

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/embedding_bag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#include "torch_xla/csrc/xla_lower_util.h"
#include "tsl/platform/stacktrace.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/loops.h"
#include "xla/client/lib/slicing.h"
#include "xla/hlo/builder/lib/loops.h"
#include "xla/shape_util.h"

namespace torch_xla {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
#include "torch_xla/csrc/torch_util.h"
#include "torch_xla/csrc/xla_lower_util.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/logdet.h"
#include "xla/client/lib/math.h"
#include "xla/client/lib/matrix.h"
#include "xla/client/lib/slicing.h"
#include "xla/hlo/builder/lib/logdet.h"
#include "xla/shape_util.h"

namespace torch_xla {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/ops_lower_fn.cpp
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
#include "torch_xla/csrc/reduction.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/xla_lower_util.h"
#include "xla/client/lib/logdet.h"
#include "xla/client/lib/math.h"
#include "xla/client/lib/matrix.h"
#include "xla/hlo/builder/lib/logdet.h"

namespace torch_xla {
torch_xla::XlaOpVector Abs::Lower(LoweringContext* loctx) const {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "torch_xla/csrc/pooling.h"
#include "torch_xla/csrc/reduction.h"
#include "torch_xla/csrc/xla_lower_util.h"
#include "xla/client/lib/logdet.h"
#include "xla/hlo/builder/lib/logdet.h"
#include "xla/shape_util.h"

namespace torch_xla {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/qr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "torch_xla/csrc/lowering_context.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/matrix.h"
#include "xla/client/lib/qr.h"
#include "xla/hlo/builder/lib/qr.h"

namespace torch_xla {
namespace {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/randperm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "tsl/platform/stacktrace.h"
#include "xla/client/lib/loops.h"
#include "xla/hlo/builder/lib/loops.h"
#include "xla/shape_util.h"

namespace torch_xla {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
#include "torch_xla/csrc/xla_lower_util.h"
#include "xla/client/lib/arithmetic.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/loops.h"
#include "xla/client/lib/pooling.h"
#include "xla/client/lib/slicing.h"
#include "xla/hlo/builder/lib/loops.h"

namespace torch_xla {
namespace {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/prng.h"
#include "xla/hlo/builder/lib/prng.h"

namespace torch_xla {
namespace {
Expand Down
6 changes: 1 addition & 5 deletions torch_xla/csrc/runtime/ifrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -560,11 +560,7 @@ IfrtComputationClient::ExecuteReplicated(
counter.Wait();
}

xla::ExecuteOptions execute_options;
execute_options.untuple_result = options.explode_tuple;
execute_options.strict_shape_checking = true;
// TODO(yeounoh) currently only support single-slice execution
execute_options.multi_slice_config = nullptr;
xla::ifrt::ExecuteOptions execute_options;

TF_VLOG(5) << "ExecuteReplicated acquiring IFRT device lock for "
<< spmd_device_str;
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "xla/python/ifrt/hlo/hlo_program.h"
#include "xla/python/pjrt_ifrt/pjrt_array.h"
#include "xla/python/pjrt_ifrt/pjrt_client.h"
#include "xla/python/pjrt_ifrt/pjrt_dtype.h"
#include "xla/python/pjrt_ifrt/xla_compiler.h"
#include "xla/shape.h"

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,7 @@ void PjRtComputationClient::RegisterCustomCall(const std::string& fn_name,
args.function_name = fn_name.c_str();
args.function_name_size = fn_name.size();
args.api_version = 0;
args.custom_call_function = function_ptr;
args.handler_execute = function_ptr;
PJRT_Error* error =
reinterpret_cast<const PJRT_Gpu_Custom_Call*>(next)->custom_call(&args);
if (error) {
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/csrc/runtime/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ ComputationClient* GetComputationClient() {

std::unique_ptr<ComputationClient> client;

static bool use_ifrt = sys_util::GetEnvBool("XLA_USE_IFRT", false);
// Disable IFRT right now as it currently crashes.
// static bool use_ifrt = sys_util::GetEnvBool("XLA_USE_IFRT", false);
static bool use_ifrt = false;
if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") {
if (use_ifrt) {
client = std::make_unique<IfrtComputationClient>();
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
#include "xla/client/lib/arithmetic.h"
#include "xla/client/lib/comparators.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/loops.h"
#include "xla/client/lib/math.h"
#include "xla/client/lib/slicing.h"
#include "xla/hlo/builder/lib/loops.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/dnn.h"
#include "xla/util.h"
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_op_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/tensor_util.h"
#include "xla/client/lib/logdet.h"
#include "xla/client/lib/math.h"
#include "xla/client/lib/matrix.h"
#include "xla/client/lib/pooling.h"
#include "xla/hlo/builder/lib/logdet.h"
#include "xla/primitive_util.h"
#include "xla/shape_util.h"

Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
#include "tsl/profiler/lib/traceme.h"
#include "xla/execution_options_util.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/pass/hlo_pass_pipeline.h"
#include "xla/protobuf_util.h"
#include "xla/service/hlo_parser.h"
#include "xla/service/hlo_pass_pipeline.h"
#include "xla/service/hlo_verifier.h"
#include "xla/service/sharding_propagation.h"
#include "xla/service/spmd/spmd_partitioner.h"
Expand Down
Loading