Skip to content

Commit

Permalink
XRT branch cherry-pick 07/10 (#5293)
Browse files Browse the repository at this point in the history
* Skip calling as_strided in empty_strided_symint if the input has dynamic dimensions. (#5239)

* Skip calling as_strided in empty_strided_symint.

* only return empty_symint conditionally.

* add a comment

* Add XRT nightly builds (#5261)

* Add XRT nightly builds

* remove space

* Add ToString method for both PjrtData and PjrtShardedData (#5265)

* Add ToString method for both PjrtData and PjrtShardedData

* on cpu same config will become replicated, dont't check actual op sharding type

* fix xrt tostring

* Update Sharded graph HLO dumping (#5266)

* Disable Bazel remote cache for forked PR (#5259)

* disable bazel remote cache if gcloud key is empty

* remove remote cache from setup.py

* experiment with debug msg

* fix flag

* add more logs

* skip remote chache if credential file is empty

* add comment

* add logs

* add check in test and coverage script

* fix condition in coverage test

* advance branch pr

* allow remote cache if gloud file isn't specified explicitly

* remove dummy comment

* Suppress debug symbols in OpenXLA code (#5269)

* [SPMD] Sharding n-d tensor on (n+1)-d Mesh (#5268)

* Make TPU detection more robust (#5271)

* Clean bazel stuff on distutils clean. (#5274)

* Clean bazel stuff on distutils clean

* Fix python formatting

* fix conflict

* Fix the error when export_torch_model is given a non-tensor (#5277)

However the generated StableHLO graph still hardcodes the
non-tensor value. this is not correct, will fix later.

* Dsiable test_simple_model_with_different_input_shape since it is curretnly broken by pytorch (#5282)

* Always do build_ext in python setup.py develop (#5273)

Bazel should figure out that _XLAC.so is current
or not, and trigger rebuild if any cpp files changed.

* Remove or improve several hardcoded TPU test conditions (#5272)

* Remove or improve several hardcoded TPU test conditions

* Fix test condition

* Add `runtime.host_index` (#5283)

* Make it an error if calling sizes() on a dynamic tensor. (#4998)

* Err if calling sizes() on dynamic tensor

* try to set has_symbolic_sizes_strides_

* resolve merge conflict

* enable CONTINUE_ON_ERROR

* fixed the python test test_SizeEq_should_not_compile_for_identical_symints

* fix test_index_types

* set CONTINUE_ON_ERROR to true

* remove some unwanted code.

* add a print

* directly set has_symbolic_sizes_strides_ = true

* make some fixes.

* fix empty_strided_symint

* ran linter

* change error type in the test.

* fix comments

* ran linter

* Fix the error where mark_step does not materalize tensors on SPMD:0 (#5281)

* Fix the error where mark_step does not materalize tensors on SPMD:0

* typo

* fix test_non_tensor_scalar

* Disable torch._dynamo.config.automatic_dynamic_shapes (#5285)

* Set torch._dynamo.config.automatic_dynamic_shapes to False

* Enable DynamoInferenceBasicTest.test_simple_model_with_different_input_shape

* [Traceable Collecive] Hide token for all_gather (#5232)

Summary:
This pull request does the following:
1. It hides token for all_gather.
2. It folds the out-of-place all_gather into the regular all_gather.
3. It fixes an issue with the last all_reduce_in_place PR where it forgot to set the token.

Test Plan:
PJRT_DEVICE=TPU python test/test_mp_all_gather.py

* Lower squeeze.dims (#5286)

* avoid copy proto in PrepareOutputShardingPropagation (#5287)

* Revert "Suppress debug symbols in OpenXLA code (#5269)"

This reverts commit 3967d7b.

* Revert "fix conflict"

This reverts commit e91ad3a.

---------

Co-authored-by: iefgnoix <isaacwxf23@gmail.com>
Co-authored-by: Will Cromar <wcromar@google.com>
Co-authored-by: Siyuan Liu <lsiyuan@google.com>
Co-authored-by: stgpetrovic <stgpetrovic@gmail.com>
Co-authored-by: Mohit Khatwani <118776932+khatwanimohit@users.noreply.github.com>
Co-authored-by: qihqi <hanq@google.com>
Co-authored-by: Wonjoo Lee <wonjoo@google.com>
Co-authored-by: Jiewen Tan <jwtan@google.com>
Co-authored-by: Baole Ai <baoleai01@gmail.com>
  • Loading branch information
10 people committed Jul 12, 2023
1 parent 3a8db63 commit 4b1742e
Show file tree
Hide file tree
Showing 35 changed files with 612 additions and 176 deletions.
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ supported:
- sqrt
- squeeze_copy
- squeeze_copy.dim
- squeeze_copy.dims
- stack
- std
- std.correction
Expand Down
12 changes: 12 additions & 0 deletions infra/tpu-pytorch-releases/artifacts.auto.tfvars
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ nightly_builds = [
}
]

# TODO: Remove this after the 2.1 release
xrt_nightly_builds = [
{
accelerator = "tpu"
python_version = "3.10"
},
{
accelerator = "cuda"
cuda_version = "12.0"
},
]

# Built on push to specific tag.
versioned_builds = [
{
Expand Down
70 changes: 70 additions & 0 deletions infra/tpu-pytorch-releases/artifacts_builds.tf
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@ variable "nightly_builds" {
default = []
}

// TODO: Remove this after the 2.1 release
variable "xrt_nightly_builds" {
type = list(
object({
accelerator = string
cuda_version = optional(string, "11.8")
python_version = optional(string, "3.8")
arch = optional(string, "amd64")
})
)

default = []
}

variable "versioned_builds" {
type = list(
object({
Expand All @@ -39,6 +53,15 @@ locals {
) => b
}

// TODO: Remove this after the 2.1 release
xrt_nightly_builds_dict = {
for b in var.xrt_nightly_builds :
format("%s_%s",
b.python_version,
b.accelerator == "tpu" ? "tpuvm" : format("cuda_%s", b.cuda_version)
) => b
}

versioned_builds_dict = {
for b in var.versioned_builds :
format("r%s_%s_%s",
Expand Down Expand Up @@ -95,6 +118,53 @@ module "nightly_builds" {
docker_repo_url = module.docker_registry.url
}

// TODO: Remove this after the 2.1 release
module "xrt_nightly_builds" {
source = "../terraform_modules/xla_docker_build"
for_each = local.xrt_nightly_builds_dict

ansible_vars = merge(each.value, {
package_version = var.nightly_package_version
nightly_release = true
pytorch_git_rev = "main"
xla_git_rev = "$COMMIT_SHA"
})

trigger_on_schedule = { schedule = "0 0 * * *", branch = "xrt" }

trigger_name = "nightly-xrt-${replace(each.key, "/[_.]/", "-")}"
image_name = "xla"
image_tags = [
"nightly_xrt_${each.key}",
# Append _YYYYMMDD suffix to nightly image name.
"nightly_xrt_${each.key}_$(date +%Y%m%d)",
]

description = join(" ", [
"Builds nightly xla:nightly_${each.key}' ${
each.value.accelerator == "tpu"
? "TPU"
: format("CUDA %s", each.value.cuda_version)
} docker image and corresponding wheels for PyTorch/XLA.",
"Trigger managed by Terraform setup in",
"infra/tpu-pytorch-releases/artifacts_builds.tf."
])

wheels_dest = "${module.releases_storage_bucket.url}/wheels/xrt/${
each.value.accelerator == "tpu"
? "tpuvm"
: "cuda/${each.value.cuda_version}"
}"
wheels_srcs = ["/dist/*.whl"]
build_args = {
python_version = each.value.python_version
}

scheduler_account_email = module.scheduler_account.email
worker_pool_id = module.worker_pool.id
docker_repo_url = module.docker_registry.url
}

module "versioned_builds" {
source = "../terraform_modules/xla_docker_build"
for_each = local.versioned_builds_dict
Expand Down
13 changes: 11 additions & 2 deletions scripts/run_bazel_coverage.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,20 @@
export XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0"
export XRT_WORKERS="localservice:0;grpc://localhost:40934"
export XLA_EXPERIMENTAL="nonzero:masked_select"
bazel coverage --config=remote_cache --remote_default_exec_properties=cache-silo-key=cache-silo-coverage //...

BAZEL_REMOTE_CACHE_CONFIG="--config=remote_cache --remote_default_exec_properties=cache-silo-key=cache-silo-coverage"
if [ ! -z "$GCLOUD_SERVICE_KEY_FILE" ]; then
file_size=$(stat -c%s "$GCLOUD_SERVICE_KEY_FILE")
if [ "$file_size" -le 1 ]; then
BAZEL_REMOTE_CACHE_CONFIG=""
fi
fi

bazel coverage $BAZEL_REMOTE_CACHE_CONFIG //...
cp "$(bazel info output_path)/_coverage/_coverage_report.dat" /tmp/cov_xrt.dat

export PJRT_DEVICE="CPU"
bazel coverage --config=remote_cache --remote_default_exec_properties=cache-silo-key=cache-silo-coverage //test/...
bazel coverage $BAZEL_REMOTE_CACHE_CONFIG //test/...
cp "$(bazel info output_path)/_coverage/_coverage_report.dat" /tmp/cov_pjrt.dat

# requires `apt-get install lcov`
Expand Down
28 changes: 22 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@
# CXX_ABI=""
# value for cxx_abi flag; if empty, it is infered from `torch._C`.
#

from __future__ import print_function

from setuptools import setup, find_packages, distutils, Extension, command
from setuptools.command import develop
from torch.utils.cpp_extension import BuildExtension
import posixpath
import contextlib
Expand Down Expand Up @@ -164,6 +164,9 @@ def maybe_bundle_libtpu(base_dir):

class Clean(distutils.command.clean.clean):

def bazel_clean_(self):
self.spawn(['bazel', 'clean', '--expunge'])

def run(self):
import glob
import re
Expand All @@ -184,6 +187,8 @@ def run(self):
except OSError:
shutil.rmtree(filename, ignore_errors=True)

self.execute(self.bazel_clean_, (), msg="Cleaning bazel outputs")

# It's an old-style class in Python 2.7...
distutils.command.clean.clean.run(self)

Expand Down Expand Up @@ -257,12 +262,15 @@ def bazel_build(self, ext):
bazel_argv.append('--config=disable_xrt')

# Remote cache authentication.
if _check_env_flag('BAZEL_REMOTE_CACHE'):
bazel_argv.append('--config=remote_cache')

if GCLOUD_KEY_FILE:
bazel_argv.append('--google_credentials=%s' % GCLOUD_KEY_FILE)
if not _check_env_flag('BAZEL_REMOTE_CACHE'):
# Temporary workaround to allow PRs from forked repo to run CI. See details at (#5259).
# TODO: Remove the check once self-hosted GHA workers are avaialble to CPU/GPU CI.
gclout_key_file_size = os.path.getsize(GCLOUD_KEY_FILE)
if gclout_key_file_size > 1:
bazel_argv.append('--google_credentials=%s' % GCLOUD_KEY_FILE)
bazel_argv.append('--config=remote_cache')
else:
if _check_env_flag('BAZEL_REMOTE_CACHE'):
bazel_argv.append('--config=remote_cache')
if CACHE_SILO_NAME:
bazel_argv.append('--remote_default_exec_properties=cache-silo-key=%s' %
Expand Down Expand Up @@ -298,6 +306,13 @@ def bazel_build(self, ext):
shutil.copyfile(ext_bazel_bin_path, ext_dest_path)


class Develop(develop.develop):

def run(self):
self.run_command("build_ext")
super().run()


setup(
name=os.environ.get('TORCH_XLA_PACKAGE_NAME', 'torch_xla'),
version=version,
Expand Down Expand Up @@ -329,4 +344,5 @@ def bazel_build(self, ext):
cmdclass={
'build_ext': BuildBazelExtension,
'clean': Clean,
'develop': Develop,
})
7 changes: 7 additions & 0 deletions test/cpp/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ if [[ -n "${CXX_ABI}" ]]; then
EXTRA_FLAGS="${EXTRA_FLAGS} --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=${CXX_ABI}"
fi

# Override BAZEL_REMOTE_CACHE if GLOUD_SERVICE_KEY_FILE is 1 byte
if [ ! -z "$GCLOUD_SERVICE_KEY_FILE" ]; then
file_size=$(stat -c%s "$GCLOUD_SERVICE_KEY_FILE")
if [ "$file_size" -le 1 ]; then
BAZEL_REMOTE_CACHE=0
fi
fi
# Handle remote builds and remote cache. Use a CI-private cache silo to avoid cache pollution.
if [[ "$BAZEL_REMOTE_CACHE" == "1" ]]; then
EXTRA_FLAGS="$EXTRA_FLAGS --config=remote_cache"
Expand Down
24 changes: 24 additions & 0 deletions test/cpp/test_aten_xla_tensor_4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,18 @@ TEST_F(AtenXlaTensorTest, TestSymSizes) {
});
}

TEST_F(AtenXlaTensorTest, TestGettingSizeOnDynamicTensor) {
// Make sure doing tensor.size() in c++ on dynamic tensor should fail.
ForEachDevice([&](const torch::Device& device) {
torch::Tensor b = torch::tensor({{0.0, 1.0}, {0.0, 0.0}},
torch::TensorOptions(torch::kFloat));
torch::Tensor xla_b = CopyToDevice(b, device);
torch::Tensor xla_nonzero = torch::nonzero(xla_b);
EXPECT_THROW(xla_nonzero.sizes(), std::runtime_error);
EXPECT_NO_THROW(xla_nonzero.sym_sizes());
});
}

TEST_F(AtenXlaTensorTest, TestMul) {
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
Expand Down Expand Up @@ -932,6 +944,18 @@ TEST_F(AtenXlaTensorTest, TestSqueezeOne) {
}
}

TEST_F(AtenXlaTensorTest, TestSqueezeMultipleDims) {
torch::Tensor input =
torch::rand({2, 1, 3, 1}, torch::TensorOptions(torch::kFloat));
std::vector<int64_t> dims = {1, 2, 3};
torch::Tensor output = torch::squeeze(input, dims);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::squeeze(xla_input, dims);
AllClose(output, xla_output);
});
}

TEST_F(AtenXlaTensorTest, TestSqueezeOneInPlace) {
int rank = 4;
for (int dim = -rank; dim < rank; ++dim) {
Expand Down
5 changes: 4 additions & 1 deletion test/ds/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,10 @@ def test_SizeEq_should_not_compile_for_identical_symints(self):
dyn_size = t2.shape[0]
self.assertEqual(dyn_size, dyn_size)
# Without the code change, met.metric_data('CompileTime')[0] returns 1.
self.assertIsNone(met.metric_data('CompileTime'))
# self.assertIsNone(met.metric_data('CompileTime'))
# TODO(ds): Uncomment the line above after we implement 0/1 specialization.
# The extra compilation comes from the call `set_sizes_and_strides` in XLATensorImpl::XLATensorImpl when we compare a SymInt with 0.
self.assertEqual(met.metric_data('CompileTime')[0], 1)


if __name__ == '__main__':
Expand Down
8 changes: 8 additions & 0 deletions test/pjrt/test_internal_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ def test_task_id(self, task_id, expected):

self.assertEqual(i, expected)

@parameterized.parameters(('0', 0), ('1', 1), ('15', 15))
def test_worker_id(self, worker_id, expected):
with mock.patch.object(
tpu, 'get_tpu_env', return_value={xenv.WORKER_ID: worker_id}):
i = tpu.worker_id()

self.assertEqual(i, expected)

@parameterized.named_parameters(
('v4',
textwrap.dedent("""
Expand Down
3 changes: 3 additions & 0 deletions test/pjrt/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def test_pjrt_default_device(self, env_vars, expect_using_pjrt):
else:
self.assertIsNone(xr.device_type())

def test_host_index(self):
self.assertEqual(xr.host_index(), 0)


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 4b1742e

Please sign in to comment.