diff --git a/.ci/docker/common/install_cache.sh b/.ci/docker/common/install_cache.sh index 6b48f11e438b..d1aa2ff48a20 100644 --- a/.ci/docker/common/install_cache.sh +++ b/.ci/docker/common/install_cache.sh @@ -36,14 +36,11 @@ if [ -n "$ROCM_VERSION" ]; then curl --retry 3 http://repo.radeon.com/misc/.sccache_amd/sccache -o /opt/cache/bin/sccache else ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') - case "$ID" in - ubuntu) - install_ubuntu - ;; - *) - install_binary - ;; - esac + # TODO: Install the pre-built binary from S3 as building from source + # https://github.com/pytorch/sccache has started failing mysteriously + # in which sccache server couldn't start with the following error: + # sccache: error: Invalid argument (os error 22) + install_binary fi chmod a+x /opt/cache/bin/sccache diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 2196c92fe99a..f669963ba0de 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -244,9 +244,9 @@ unittest-xml-reporting<=3.2.0,>=2.0.0 #Pinned versions: #test that import: -lintrunner==0.9.2 -#Description: all about linters -#Pinned versions: 0.9.2 +lintrunner==0.10.7 +#Description: all about linters! +#Pinned versions: 0.10.7 #test that import: rockset==1.0.3 diff --git a/.ci/pytorch/build-asan.sh b/.ci/pytorch/build-asan.sh index 91953c322f22..51f02f137f20 100755 --- a/.ci/pytorch/build-asan.sh +++ b/.ci/pytorch/build-asan.sh @@ -14,17 +14,13 @@ clang --version python tools/stats/export_test_times.py -# detect_leaks=0: Python is very leaky, so we need suppress it -# symbolize=1: Gives us much better errors when things go wrong -export ASAN_OPTIONS=detect_leaks=0:detect_stack_use_after_return=1:symbolize=1:detect_odr_violation=0 if [ -n "$(which conda)" ]; then export CMAKE_PREFIX_PATH=/opt/conda fi -# TODO: Make the ASAN flags a centralized env var and unify with USE_ASAN option CC="clang" CXX="clang++" LDSHARED="clang --shared" \ - CFLAGS="-fsanitize=address -fsanitize=undefined -fno-sanitize-recover=all -fsanitize-address-use-after-scope -shared-libasan" \ USE_ASAN=1 USE_CUDA=0 USE_MKLDNN=0 \ + UBSAN_FLAGS="-fno-sanitize-recover=all" \ python setup.py bdist_wheel pip_install_whl "$(echo dist/*.whl)" diff --git a/.ci/pytorch/build-tsan.sh b/.ci/pytorch/build-tsan.sh index e10edb310d81..9e532caca321 100755 --- a/.ci/pytorch/build-tsan.sh +++ b/.ci/pytorch/build-tsan.sh @@ -19,7 +19,6 @@ if [ -n "$(which conda)" ]; then fi CC="clang" CXX="clang++" LDSHARED="clang --shared" \ - CFLAGS="-fsanitize=thread" \ USE_TSAN=1 USE_CUDA=0 USE_MKLDNN=0 \ python setup.py bdist_wheel pip_install_whl "$(echo dist/*.whl)" diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index f8328bc3e828..a40f8973578f 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -149,6 +149,14 @@ function clone_pytorch_xla() { fi } +function install_matplotlib() { + pip_install matplotlib +} + +function install_tabulate() { + pip_install tabulate +} + function setup_torchdeploy_deps(){ conda install -y -n "py_${ANACONDA_PYTHON_VERSION}" "libpython-static=${ANACONDA_PYTHON_VERSION}" local CC diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 7e93da195078..75e514f827f7 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -126,7 +126,7 @@ fi # if you're not careful. Check this if you made some changes and the # ASAN test is not working if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then - export ASAN_OPTIONS=detect_leaks=0:symbolize=1:detect_stack_use_after_return=1:strict_init_order=true:detect_odr_violation=0 + export ASAN_OPTIONS=detect_leaks=0:symbolize=1:detect_stack_use_after_return=1:strict_init_order=true:detect_odr_violation=1:detect_container_overflow=0 export UBSAN_OPTIONS=print_stacktrace=1 export PYTORCH_TEST_WITH_ASAN=1 export PYTORCH_TEST_WITH_UBSAN=1 @@ -255,6 +255,30 @@ test_inductor() { python test/run_test.py --include inductor/test_torchinductor inductor/test_torchinductor_opinfo --verbose } +# "Global" flags for inductor benchmarking controlled by TEST_CONFIG +# For example 'dynamic_aot_eager_torchbench' TEST_CONFIG means we run +# the benchmark script with '--dynamic-shapes --backend aot_eager --device cuda' +# The matrix of test options is specified in .github/workflows/periodic.yml +# and .github/workflows/inductor.yml +DYNAMO_BENCHMARK_FLAGS=() + +if [[ "${TEST_CONFIG}" == *aot_eager* ]]; then + DYNAMO_BENCHMARK_FLAGS+=(--backend aot_eager) +elif [[ "${TEST_CONFIG}" == *inductor* ]]; then + DYNAMO_BENCHMARK_FLAGS+=(--inductor) +fi + +if [[ "${TEST_CONFIG}" == *dynamic* ]]; then + # TODO: make specialize_int = False default, then remove this + DYNAMO_BENCHMARK_FLAGS+=(--dynamic-shapes --unspecialize-int) +fi + +if [[ "${TEST_CONFIG}" == *cpu_accuracy* ]]; then + DYNAMO_BENCHMARK_FLAGS+=(--device cpu) +else + DYNAMO_BENCHMARK_FLAGS+=(--device cuda) +fi + test_single_dynamo_benchmark() { # Usage: test_single_dynamo_benchmark inductor_inference huggingface 0 --args-for-script @@ -277,143 +301,66 @@ test_single_dynamo_benchmark() { partition_flags=( --total-partitions 2 --partition-id "$shard_id" ) fi - # Feel free to remove --device cuda if you ever decide to need to - # test CPU as well in CI - python "benchmarks/dynamo/$suite.py" \ - --ci --accuracy --timing --explain \ - "$@" "${partition_flags[@]}" \ - --output "$TEST_REPORTS_DIR/${name}_${suite}.csv" - python benchmarks/dynamo/check_csv.py \ - -f "$TEST_REPORTS_DIR/${name}_${suite}.csv" -} - -test_aot_eager_benchmark() { - # Usage: test_dynamo_benchmark huggingface 0 - - local exit_status=0 - - # Check inference with --float32 - test_single_dynamo_benchmark "aot_eager_inference" "$@" --backend aot_eager --device cuda || exit_status=$? - - # Check training with --amp - test_single_dynamo_benchmark "aot_eager_training" "$@" --backend aot_eager --device cuda --training --amp || exit_status=$? - - if [[ $exit_status -ne 0 ]]; then - echo "Some benchmarks failed; scroll up for details" + if [[ "${TEST_CONFIG}" == *perf* ]]; then + # MKL_THREADING_LAYER=GNU to mitigate https://github.com/pytorch/pytorch/issues/37377 + MKL_THREADING_LAYER=GNU python benchmarks/dynamo/runner.py --suites="$suite" \ + --base-sha="$BASE_SHA" --output-dir="$TEST_REPORTS_DIR" "${partition_flags[@]}" \ + --no-graphs --no-update-archive --no-gh-comment "$@" + else + python "benchmarks/dynamo/$suite.py" \ + --ci --accuracy --timing --explain \ + "${DYNAMO_BENCHMARK_FLAGS[@]}" \ + "$@" "${partition_flags[@]}" \ + --output "$TEST_REPORTS_DIR/${name}_${suite}.csv" + python benchmarks/dynamo/check_csv.py \ + -f "$TEST_REPORTS_DIR/${name}_${suite}.csv" fi - return $exit_status } -test_inductor_benchmark() { +test_dynamo_benchmark() { # Usage: test_dynamo_benchmark huggingface 0 - local device="$1" + local suite="$1" + shift + local shard_id="$1" shift - if [[ $device == "cpu" ]]; then - # TODO: Add training and dynamic shape test - test_single_dynamo_benchmark "inductor_inference" "$@" --inductor --float32 --device cpu + if [[ "${TEST_CONFIG}" == *perf* ]]; then + # Performance test training only, for float32 and amp + test_single_dynamo_benchmark "amp" "$suite" "$shard_id" --training --dtypes=amp "$@" + test_single_dynamo_benchmark "float32" "$suite" "$shard_id" --training --dtypes=float32 "$@" else # Check inference with --float32 - test_single_dynamo_benchmark "inductor_inference" "$@" --inductor --device cuda + test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --float32 "$@" - # Check training with --amp - test_single_dynamo_benchmark "inductor_training" "$@" --inductor --training --amp --device cuda - - # Check inference with --dynamic-shapes - test_single_dynamo_benchmark "dynamic_inductor-inference" "$@" --inductor --dynamic-shapes --device cuda + if [[ "${TEST_CONFIG}" != *cpu_accuracy* ]]; then + # Check training with --amp + test_single_dynamo_benchmark "training" "$suite" "$shard_id" --training --amp "$@" + fi fi } -test_inductor_benchmark_perf() { - # Use test-reports directory under test folder will allow the CI to automatically pick up - # the test reports and upload them to S3. Need to use full path here otherwise the script - # will bark about file not found later on +test_inductor_torchbench_smoketest_perf() { TEST_REPORTS_DIR=$(pwd)/test/test-reports - PARTITION_FLAGS="" - if [[ -n "$NUM_TEST_SHARDS" && -n "$2" ]]; then - PARTITION_FLAGS="--total-partitions 2 --partition-id $2" - fi mkdir -p "$TEST_REPORTS_DIR" - # Check training with --amp - # Not checking accuracy for perf test for now - # shellcheck disable=SC2086 - if [[ "$1" == *smoketest* ]]; then - python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --float16 --training \ - --batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" --only hf_Bert \ - --output "$TEST_REPORTS_DIR"/inductor_training_$1.csv - # the reference speedup value is hardcoded in check_hf_bert_perf_csv.py - # this value needs to be actively maintained to make this check useful - python benchmarks/dynamo/check_hf_bert_perf_csv.py -f "$TEST_REPORTS_DIR"/inductor_training_$1.csv - - # Check memory compression ratio for a few models - for test in hf_Albert timm_efficientdet timm_vision_transformer; do - python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --amp --training \ - --disable-cudagraphs --batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" \ - --only $test --output "$TEST_REPORTS_DIR"/inductor_training_$1_$test.csv - cat "$TEST_REPORTS_DIR"/inductor_training_$1_$test.csv - python benchmarks/dynamo/check_memory_compression_ratio.py --actual \ - "$TEST_REPORTS_DIR"/inductor_training_$1_$test.csv \ - --expected benchmarks/dynamo/expected_ci_perf_inductor_torchbench.csv - done - else - python benchmarks/dynamo/$1.py --ci --training --performance --disable-cudagraphs\ - --device cuda --inductor --amp $PARTITION_FLAGS --output "$TEST_REPORTS_DIR"/inductor_training_$1.csv - fi -} - -# No sharding for the periodic job, we don't care if latency is bad -test_aot_eager_all() { - local exit_status=0 - PYTHONPATH=$(pwd)/torchbench test_aot_eager_benchmark torchbench "" "$@" || exit_status=$? - test_aot_eager_benchmark huggingface "" "$@" || exit_status=$? - test_aot_eager_benchmark timm_models "" "$@" || exit_status=$? - if [[ $exit_status -ne 0 ]]; then - echo "Some benchmarks failed; scroll up for details" - fi - return $exit_status -} - -test_inductor_huggingface() { - local device=$1 - shift - test_inductor_benchmark "$device" huggingface "" -} - -test_inductor_huggingface_perf() { - test_inductor_benchmark_perf huggingface -} - -test_inductor_timm_shard() { - if [[ -z "$NUM_TEST_SHARDS" ]]; then - echo "NUM_TEST_SHARDS must be defined to run a Python test shard" - exit 1 - fi - local device=$1 - shift - test_inductor_benchmark "$device" timm_models "$1" -} -test_inductor_timm_perf_shard() { - if [[ -z "$NUM_TEST_SHARDS" ]]; then - echo "NUM_TEST_SHARDS must be defined to run a Python test shard" - exit 1 - fi - test_inductor_benchmark_perf timm_models "$1" -} - -test_inductor_torchbench() { - local device=$1 - shift - PYTHONPATH=$(pwd)/torchbench test_inductor_benchmark "$device" torchbench "" -} - -test_inductor_torchbench_perf() { - PYTHONPATH=$(pwd)/torchbench test_inductor_benchmark_perf torchbench -} - -test_inductor_torchbench_smoketest_perf(){ - PYTHONPATH=$(pwd)/torchbench test_inductor_benchmark_perf smoketest + python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --float16 --training \ + --batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" --only hf_Bert \ + --output "$TEST_REPORTS_DIR/inductor_training_smoketest.csv" + # the reference speedup value is hardcoded in check_hf_bert_perf_csv.py + # this value needs to be actively maintained to make this check useful + python benchmarks/dynamo/check_hf_bert_perf_csv.py -f "$TEST_REPORTS_DIR/inductor_training_smoketest.csv" + + # Check memory compression ratio for a few models + for test in hf_Albert timm_efficientdet timm_vision_transformer; do + python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --amp --training \ + --disable-cudagraphs --batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" \ + --only $test --output "$TEST_REPORTS_DIR/inductor_training_smoketest_$test.csv" + cat "$TEST_REPORTS_DIR/inductor_training_smoketest_$test.csv" + python benchmarks/dynamo/check_memory_compression_ratio.py --actual \ + "$TEST_REPORTS_DIR/inductor_training_smoketest_$test.csv" \ + --expected benchmarks/dynamo/expected_ci_perf_inductor_torchbench.csv + done } test_python_gloo_with_tls() { @@ -842,6 +789,12 @@ test_executorch() { assert_git_not_dirty } +# TODO: Include this in the Docker image +if [[ "${TEST_CONFIG}" == *_perf* ]]; then + install_matplotlib + install_tabulate +fi + if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* || "${BUILD_ENVIRONMENT}" == *-tsan* ]]; then (cd test && python -c "import torch; print(torch.__config__.show())") (cd test && python -c "import torch; print(torch.__config__.parallel_info())") @@ -878,81 +831,24 @@ elif [[ "${TEST_CONFIG}" == *dynamo* && "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHAR elif [[ "${TEST_CONFIG}" == *dynamo* && "${SHARD_NUMBER}" == 2 && $NUM_TEST_SHARDS -gt 1 ]]; then install_torchvision test_dynamo_shard 2 -elif [[ "${TEST_CONFIG}" == *aot_eager_all* ]]; then - install_torchtext +elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then install_torchvision - checkout_install_torchbench install_huggingface - install_timm - if [[ "${TEST_CONFIG}" == *dynamic* ]]; then - # NB: This code path is currently dead because dynamic shapes takes - # too long to run unsharded - test_aot_eager_all --dynamic-shapes - else - test_aot_eager_all - fi -elif [[ "${TEST_CONFIG}" == *aot_eager_huggingface* ]]; then - install_torchvision - install_huggingface - if [[ "${TEST_CONFIG}" == *dynamic* ]]; then - test_aot_eager_benchmark huggingface "" --dynamic-shapes - else - test_aot_eager_benchmark huggingface "" - fi -elif [[ "${TEST_CONFIG}" == *aot_eager_timm* && $NUM_TEST_SHARDS -gt 1 ]]; then - install_torchvision - install_timm - id=$((SHARD_NUMBER-1)) - if [[ "${TEST_CONFIG}" == *dynamic* ]]; then - test_aot_eager_benchmark timm_models "$id" --dynamic-shapes - else - test_aot_eager_benchmark timm_models "$id" - fi -elif [[ "${TEST_CONFIG}" == *aot_eager_torchbench* ]]; then - install_torchtext - install_torchvision - checkout_install_torchbench - if [[ "${TEST_CONFIG}" == *dynamic* ]]; then - PYTHONPATH=$(pwd)/torchbench test_aot_eager_benchmark torchbench "" --dynamic-shapes - else - PYTHONPATH=$(pwd)/torchbench test_aot_eager_benchmark torchbench "" - fi -elif [[ "${TEST_CONFIG}" == *inductor_huggingface* ]]; then - install_torchvision - install_huggingface - if [[ "${TEST_CONFIG}" == *inductor_huggingface_perf* ]]; then - test_inductor_huggingface_perf - elif [[ "${TEST_CONFIG}" == *inductor_huggingface_cpu_accuracy* ]]; then - test_inductor_huggingface cpu - else - test_inductor_huggingface cuda - fi -elif [[ "${TEST_CONFIG}" == *inductor_timm* && $NUM_TEST_SHARDS -gt 1 ]]; then + test_dynamo_benchmark huggingface "" +elif [[ "${TEST_CONFIG}" == *timm* ]]; then install_torchvision install_timm id=$((SHARD_NUMBER-1)) - if [[ "${TEST_CONFIG}" == *inductor_timm_perf* && $NUM_TEST_SHARDS -gt 1 ]]; then - test_inductor_timm_perf_shard $id - elif [[ "${TEST_CONFIG}" == *inductor_timm_cpu_accuracy* && $NUM_TEST_SHARDS -gt 1 ]]; then - test_inductor_timm_shard cpu $id - else - test_inductor_timm_shard cuda $id - fi -elif [[ "${TEST_CONFIG}" == *inductor_torchbench* ]]; then + test_dynamo_benchmark timm_models "$id" +elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then install_torchtext install_torchvision - if [[ "${TEST_CONFIG}" == *inductor_torchbench_perf* ]]; then - checkout_install_torchbench - test_inductor_torchbench_perf - elif [[ "${TEST_CONFIG}" == *inductor_torchbench_cpu_accuracy* ]]; then - checkout_install_torchbench - test_inductor_torchbench cpu - elif [[ "${TEST_CONFIG}" == *inductor_torchbench_smoketest_perf* ]]; then + if [[ "${TEST_CONFIG}" == *inductor_torchbench_smoketest_perf* ]]; then checkout_install_torchbench hf_Bert hf_Albert timm_efficientdet timm_vision_transformer - test_inductor_torchbench_smoketest_perf + PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_smoketest_perf else checkout_install_torchbench - test_inductor_torchbench cuda + PYTHONPATH=$(pwd)/torchbench test_dynamo_benchmark torchbench "" fi elif [[ "${TEST_CONFIG}" == *inductor* && "${SHARD_NUMBER}" == 1 ]]; then install_torchvision diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index 671105bd7da5..03d1456a6659 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -beb4bb706b5e13009cb5d5586505c6d2896d184a +5850f370c03d941f97c7bd53f99a83abb0b9dd01 diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 0fc4f694e003..8df19ed021ae 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -f9963f6c2d34b9662f93e5518adb15949be05f65 +015ebcba441dbd5dd21dc02ef12af2c29791a7f0 diff --git a/.github/requirements-gha-cache.txt b/.github/requirements-gha-cache.txt index 822a6fdde457..796f8cb65a77 100644 --- a/.github/requirements-gha-cache.txt +++ b/.github/requirements-gha-cache.txt @@ -7,7 +7,7 @@ # .ci/docker/requirements-ci.txt boto3==1.19.12 jinja2==3.0.1 -lintrunner==0.9.2 +lintrunner==0.10.7 ninja==1.10.0.post1 nvidia-ml-py==11.525.84 pyyaml==6.0 diff --git a/.github/scripts/filter_test_configs.py b/.github/scripts/filter_test_configs.py index 9d99c0eef7b8..521db7be9bfa 100755 --- a/.github/scripts/filter_test_configs.py +++ b/.github/scripts/filter_test_configs.py @@ -6,9 +6,8 @@ import sys import warnings from typing import Any, Dict, List, Set -from urllib.request import urlopen +from urllib.request import Request, urlopen -import requests import yaml PREFIX = "test-config/" @@ -92,27 +91,24 @@ def get_labels(pr_number: int) -> Set[str]: Dynamical get the latest list of labels from the pull request """ # From https://docs.github.com/en/actions/learn-github-actions/environment-variables - PYTORCH_REPO = os.environ.get("GITHUB_REPOSITORY", "pytorch/pytorch") - PYTORCH_GITHUB_API = f"https://api.github.com/repos/{PYTORCH_REPO}" - GITHUB_TOKEN = os.environ["GITHUB_TOKEN"] + pytorch_repo = os.environ.get("GITHUB_REPOSITORY", "pytorch/pytorch") + pytorch_github_api = f"https://api.github.com/repos/{pytorch_repo}" + github_token = os.environ["GITHUB_TOKEN"] - REQUEST_HEADERS = { + headers = { "Accept": "application/vnd.github.v3+json", - "Authorization": "token " + GITHUB_TOKEN, + "Authorization": f"token {github_token}", } - - response = requests.get( - f"{PYTORCH_GITHUB_API}/issues/{pr_number}/labels", - headers=REQUEST_HEADERS, + json_response = download_json( + url=f"{pytorch_github_api}/issues/{pr_number}/labels", + headers=headers, ) - if response.status_code != requests.codes.ok: - warnings.warn( - f"Failed to get the labels for #{pr_number} (status code {response.status_code})" - ) + if not json_response: + warnings.warn(f"Failed to get the labels for #{pr_number}") return set() - return {label.get("name") for label in response.json() if label.get("name")} + return {label.get("name") for label in json_response if label.get("name")} def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, List[Any]]: @@ -208,7 +204,7 @@ def remove_disabled_jobs( # The result will be stored here filtered_test_matrix: Dict[str, List[Any]] = {"include": []} - for _, record in download_json(DISABLED_JOBS_URL).items(): + for _, record in download_json(url=DISABLED_JOBS_URL, headers={}).items(): ( author, _, @@ -286,10 +282,11 @@ def remove_disabled_jobs( return test_matrix -def download_json(url: str, num_retries: int = 3) -> Any: +def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any: for _ in range(num_retries): try: - content = urlopen(url, timeout=5).read().decode("utf-8") + req = Request(url=url, headers=headers) + content = urlopen(req, timeout=5).read().decode("utf-8") return json.loads(content) except Exception as e: warnings.warn(f"Could not download {url}: {e}") diff --git a/.github/scripts/test_filter_test_configs.py b/.github/scripts/test_filter_test_configs.py index 4bd91c13822c..f0fc7ba4da0a 100755 --- a/.github/scripts/test_filter_test_configs.py +++ b/.github/scripts/test_filter_test_configs.py @@ -2,10 +2,9 @@ import json import os -from typing import Any, Dict +from typing import Any from unittest import main, mock, TestCase -import requests import yaml from filter_test_configs import ( filter, @@ -16,7 +15,6 @@ SUPPORTED_PERIODICAL_MODES, VALID_TEST_CONFIG_LABELS, ) -from requests.models import Response MOCKED_DISABLED_JOBS = { @@ -85,19 +83,7 @@ "build (dynamo)", ], } - - -def mocked_gh_get_labels_failed(url: str, headers: Dict[str, str]) -> Response: - mocked_response = Response() - mocked_response.status_code = requests.codes.bad_request - return mocked_response - - -def mocked_gh_get_labels(url: str, headers: Dict[str, str]) -> Response: - mocked_response = Response() - mocked_response.status_code = requests.codes.ok - mocked_response._content = b'[{"name": "foo"}, {"name": "bar"}, {}, {"name": ""}]' - return mocked_response +MOCKED_LABELS = [{"name": "foo"}, {"name": "bar"}, {}, {"name": ""}] class TestConfigFilter(TestCase): @@ -106,15 +92,15 @@ def setUp(self) -> None: if os.getenv("GITHUB_OUTPUT"): del os.environ["GITHUB_OUTPUT"] - @mock.patch("filter_test_configs.requests.get", side_effect=mocked_gh_get_labels) - def test_get_labels(self, mocked_gh: Any) -> None: + @mock.patch("filter_test_configs.download_json") + def test_get_labels(self, mock_download_json: Any) -> None: + mock_download_json.return_value = MOCKED_LABELS labels = get_labels(pr_number=12345) self.assertSetEqual({"foo", "bar"}, labels) - @mock.patch( - "filter_test_configs.requests.get", side_effect=mocked_gh_get_labels_failed - ) - def test_get_labels_failed(self, mocked_gh: Any) -> None: + @mock.patch("filter_test_configs.download_json") + def test_get_labels_failed(self, mock_download_json: Any) -> None: + mock_download_json.return_value = {} labels = get_labels(pr_number=54321) self.assertFalse(labels) diff --git a/.github/workflows/_bazel-build-test.yml b/.github/workflows/_bazel-build-test.yml index 4f81c365b11d..3c78c5afedba 100644 --- a/.github/workflows/_bazel-build-test.yml +++ b/.github/workflows/_bazel-build-test.yml @@ -84,6 +84,10 @@ jobs: with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG + uses: pytorch/test-infra/.github/actions/setup-nvidia@main + if: ${{ inputs.cuda-version != 'cpu' }} + - name: Output disk space left run: | sudo df -H @@ -118,7 +122,9 @@ jobs: CUDA_VERSION: ${{ inputs.cuda-version }} run: | # detached container should get cleaned up by teardown_ec2_linux + # shellcheck disable=SC2086 container_name=$(docker run \ + ${GPU_FLAG:-} \ -e BUILD_ENVIRONMENT \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ @@ -136,7 +142,7 @@ jobs: -w /var/lib/jenkins/workspace \ "${DOCKER_IMAGE}" ) - docker exec -t "${container_name}" sh -c 'sudo chown -R jenkins . && sudo chown -R jenkins /dev && .ci/pytorch/build.sh' + docker exec -t "${container_name}" sh -c '.ci/pytorch/build.sh' - name: Test id: test @@ -174,7 +180,9 @@ jobs: # TODO: Stop building test binaries as part of the build phase # Make sure we copy test results from bazel-testlogs symlink to # a regular directory ./test/test-reports + # shellcheck disable=SC2086,SC2090 container_name=$(docker run \ + ${GPU_FLAG:-} \ -e BUILD_ENVIRONMENT \ -e GITHUB_ACTIONS \ -e GIT_DEFAULT_BRANCH="$GIT_DEFAULT_BRANCH" \ @@ -198,7 +206,7 @@ jobs: -w /var/lib/jenkins/workspace \ "${DOCKER_IMAGE}" ) - docker exec -t "${container_name}" sh -c 'sudo chown -R jenkins . && sudo chown -R jenkins /dev && .ci/pytorch/test.sh && cp -Lr ./bazel-testlogs ./test/test-reports' + docker exec -t "${container_name}" sh -c '.ci/pytorch/test.sh && cp -Lr ./bazel-testlogs ./test/test-reports' - name: Print remaining test logs shell: bash diff --git a/.github/workflows/_mac-build.yml b/.github/workflows/_mac-build.yml index f5f66ae5129b..5445f9bda794 100644 --- a/.github/workflows/_mac-build.yml +++ b/.github/workflows/_mac-build.yml @@ -44,6 +44,11 @@ on: An option JSON description of what test configs to run later on. This is moved here from the Linux test workflow so that we can apply filter logic using test-config labels earlier and skip unnecessary builds + sccache-use-gha: + required: false + type: boolean + default: false + description: If true, use the Github cache as the storage option for sccache instead of S3. outputs: test-matrix: @@ -71,6 +76,7 @@ jobs: AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} BUILD_ENVIRONMENT: ${{ inputs.build-environment }} + SCCACHE_USE_GHA: ${{ inputs.sccache-use-gha }} # this is placed here instead of the sccache step to appease actionlint outputs: build-outcome: ${{ steps.build.outcome }} steps: @@ -119,10 +125,17 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + if [[ "${SCCACHE_USE_GHA}" == "true" ]]; then + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v0.4.0-pre.7 --output /usr/local/bin/sccache + echo "ACTIONS_CACHE_URL=${ACTIONS_CACHE_URL}" >> "${GITHUB_ENV}" + echo "ACTIONS_RUNTIME_TOKEN=${ACTIONS_RUNTIME_TOKEN}" >> "${GITHUB_ENV}" + echo "SCCACHE_GHA_ENABLED=on" >> "${GITHUB_ENV}" + else + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" + echo "SCCACHE_S3_KEY_PREFIX=${GITHUB_WORKFLOW}" >> "${GITHUB_ENV}" + fi sudo chmod +x /usr/local/bin/sccache - echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - echo "SCCACHE_S3_KEY_PREFIX=${GITHUB_WORKFLOW}" >> "${GITHUB_ENV}" - name: Get workflow job id id: get-job-id diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index d53d90fca6e9..6493f0447cfe 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -13,12 +13,12 @@ concurrency: cancel-in-progress: true jobs: - linux-bionic-cuda11_7-py3_10-gcc7-inductor-build: - name: cuda11.7-py3.10-gcc7-sm80 + linux-bionic-cuda11_8-py3_10-gcc7-inductor-build: + name: cuda11.8-py3.10-gcc7-sm80 uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-bionic-cuda11.7-py3.10-gcc7-sm80 - docker-image-name: pytorch-linux-bionic-cuda11.7-cudnn8-py3-gcc7 + build-environment: linux-bionic-cuda11.8-py3.10-gcc7-sm80 + docker-image-name: pytorch-linux-bionic-cuda11.8-cudnn8-py3-gcc7 cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -28,12 +28,13 @@ jobs: { config: "inductor_torchbench_perf", shard: 1, num_shards: 1, runner: "linux.gcp.a100.large" }, ]} - linux-bionic-cuda11_7-py3_10-gcc7-inductor-test: - name: cuda11.7-py3.10-gcc7-sm80 + linux-bionic-cuda11_8-py3_10-gcc7-inductor-test: + name: cuda11.8-py3.10-gcc7-sm80 uses: ./.github/workflows/_linux-test.yml - needs: linux-bionic-cuda11_7-py3_10-gcc7-inductor-build + needs: linux-bionic-cuda11_8-py3_10-gcc7-inductor-build with: - build-environment: linux-bionic-cuda11.7-py3.10-gcc7-sm80 - docker-image: ${{ needs.linux-bionic-cuda11_7-py3_10-gcc7-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-bionic-cuda11_7-py3_10-gcc7-inductor-build.outputs.test-matrix }} + build-environment: linux-bionic-cuda11.8-py3.10-gcc7-sm80 + docker-image: ${{ needs.linux-bionic-cuda11_8-py3_10-gcc7-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-bionic-cuda11_8-py3_10-gcc7-inductor-build.outputs.test-matrix }} use-gha: anything-non-empty-to-use-gha + timeout-minutes: 1200 diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index a9c9945f33e9..711b0f86a424 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -28,6 +28,10 @@ jobs: { config: "inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_torchbench", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_huggingface_dynamic", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm_dynamic", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm_dynamic", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench_dynamic", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" }, ]} diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 29987ae5e31d..bf51c680265a 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -67,8 +67,10 @@ jobs: cuda-arch-list: '8.6' test-matrix: | { include: [ - { config: "aot_eager_all", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, - # These jobs run too slowly so they must be sharded, unfortunately + { config: "aot_eager_torchbench", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 704b05314293..9933fc4033d1 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -179,6 +179,7 @@ jobs: xcode-version: "13.3.1" runner-type: macos-12-xl build-generates-artifacts: true + sccache-use-gha: true test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 3, runner: "macos-12" }, @@ -210,6 +211,7 @@ jobs: xcode-version: "13.3.1" runner-type: macos-12-xl build-generates-artifacts: false + sccache-use-gha: true test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1, runner: "macos-12" }, diff --git a/.github/workflows/upload-contrib-stats.yml b/.github/workflows/upload-contrib-stats.yml index 5980a8c64740..0a895edcac69 100644 --- a/.github/workflows/upload-contrib-stats.yml +++ b/.github/workflows/upload-contrib-stats.yml @@ -5,10 +5,14 @@ on: # Choose a random time near midnight PST because it may be delayed if there are high loads - cron: 37 7 * * * +concurrency: + group: ${{ github.workflow }}-${{ github.event_name == 'schedule' || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + jobs: upload-contribution-stats: - runs-on: ubuntu-latest + runs-on: [self-hosted, linux.2xlarge] steps: - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@master @@ -22,6 +26,7 @@ jobs: env: ROCKSET_API_KEY: ${{ secrets.ROCKSET_API_KEY }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | - echo "Uploading external contribution stats for $(date -v-1d +%F)" - python3 -m tools.stats.upload_external_contrib_stats --startDate "$(date -v-1d +%F)" + echo "Uploading external contribution stats for" "$(date -d yesterday '+%Y-%m-%d')" + python3 -m tools.stats.upload_external_contrib_stats --startDate "$(date -d yesterday '+%Y-%m-%d')" \ No newline at end of file diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index 0f1a74a5d9e1..fb4bca8d64f7 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -2,7 +2,7 @@ name: Upload test stats on: workflow_run: - workflows: [pull, trunk, periodic, inductor, inductor-A100-perf] + workflows: [pull, trunk, periodic, inductor] types: - completed diff --git a/.github/workflows/upload-torch-dynamo-perf-stats.yml b/.github/workflows/upload-torch-dynamo-perf-stats.yml new file mode 100644 index 000000000000..6a1d3d8af74b --- /dev/null +++ b/.github/workflows/upload-torch-dynamo-perf-stats.yml @@ -0,0 +1,63 @@ +name: Upload torch dynamo performance stats + +on: + workflow_run: + workflows: [inductor-A100-perf] + types: + - completed + branches: + - master + - main + +jobs: + get-conclusion: + runs-on: ubuntu-latest + outputs: + conclusion: ${{ fromJson(steps.get-conclusion.outputs.data).conclusion }} + steps: + - name: Get workflow run conclusion + uses: octokit/request-action@v2.1.0 + id: get-conclusion + with: + route: GET /repos/${{ github.repository }}/actions/runs/${{ github.event.workflow_run.id }}/attempts/${{ github.event.workflow_run.run_attempt }} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + upload-perf-stats: + needs: get-conclusion + if: github.event.workflow_run.conclusion == 'success' || needs.get-conclusion.outputs.conclusion == 'success' || + github.event.workflow_run.conclusion == 'failure' || needs.get-conclusion.outputs.conclusion == 'failure' + runs-on: [self-hosted, linux.2xlarge] + name: Upload dynamo performance stats for ${{ github.event.workflow_run.id }}, attempt ${{ github.event.workflow_run.run_attempt }} + steps: + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@master + with: + submodules: false + fetch-depth: 1 + + - run: | + pip3 install requests==2.26 rockset==1.0.3 boto3==1.19.12 + + - name: Upload torch dynamo performance stats to S3 + id: upload-s3 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + WORKFLOW_ARTIFACTS_URL: ${{ github.event.workflow_run.artifacts_url }} + WORKFLOW_RUN_ID: ${{ github.event.workflow_run.id }} + WORKFLOW_RUN_ATTEMPT: ${{ github.event.workflow_run.run_attempt }} + REPO_FULLNAME: ${{ github.event.workflow_run.repository.full_name }} + run: | + # Upload perf test reports from GHA to S3, which can now be downloaded + # on HUD + python3 -m tools.stats.upload_artifacts --workflow-run-id "${WORKFLOW_RUN_ID}" --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" --repo "${REPO_FULLNAME}" + + - name: Upload torch dynamo performance stats to Rockset + if: steps.upload-s3.outcome && steps.upload-s3.outcome == 'success' + env: + ROCKSET_API_KEY: ${{ secrets.ROCKSET_API_KEY }} + WORKFLOW_RUN_ID: ${{ github.event.workflow_run.id }} + WORKFLOW_RUN_ATTEMPT: ${{ github.event.workflow_run.run_attempt }} + REPO_FULLNAME: ${{ github.event.workflow_run.repository.full_name }} + run: | + python3 -m tools.stats.upload_dynamo_perf_stats --workflow-run-id "${WORKFLOW_RUN_ID}" --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" --repo "${REPO_FULLNAME}" diff --git a/.gitignore b/.gitignore index 9f7128d495a9..8b13ab22b9bb 100644 --- a/.gitignore +++ b/.gitignore @@ -64,6 +64,7 @@ third_party/build/ tools/coverage_plugins_package/pip-wheel-metadata/ tools/shared/_utils_internal.py tools/fast_nvcc/wrap_nvcc.sh +tools/fast_nvcc/wrap_nvcc.bat tools/fast_nvcc/tmp/ torch.egg-info/ torch/_C/__init__.pyi diff --git a/.lintrunner.toml b/.lintrunner.toml index 9450cddbe0d4..a420579713d3 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1,3 +1,5 @@ +merge_base_with = "master" + [[linter]] code = 'FLAKE8' include_patterns = ['**/*.py'] diff --git a/CMakeLists.txt b/CMakeLists.txt index b9addcf005b3..b5c09ca05baf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -931,15 +931,6 @@ else() append_cxx_flag_if_supported("/wd4273" CMAKE_CXX_FLAGS) endif() -if(USE_ASAN) - string(APPEND CMAKE_CXX_FLAGS_DEBUG " -fsanitize=address -fsanitize=undefined") - string(APPEND CMAKE_LINKER_FLAGS_DEBUG " -fsanitize=address -fsanitize=undefined") -endif() - -if(USE_TSAN) - string(APPEND CMAKE_CXX_FLAGS_DEBUG " -fsanitize=thread") - string(APPEND CMAKE_LINKER_FLAGS_DEBUG " -fsanitize=thread") -endif() if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") include(CheckCSourceCompiles) diff --git a/aten/src/ATen/SparseTensorUtils.h b/aten/src/ATen/SparseTensorUtils.h index 31c061400e62..6e3984ab34c9 100644 --- a/aten/src/ATen/SparseTensorUtils.h +++ b/aten/src/ATen/SparseTensorUtils.h @@ -8,6 +8,7 @@ #include #else #include +#include #endif namespace at { @@ -119,5 +120,65 @@ TORCH_API Tensor flatten_indices_by_dims( // Find the CSR representation for a row `indices` from the COO format TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz); +template +class TensorGeometryHolder { + using geometry_holder_t = std::array; + + public: + explicit TensorGeometryHolder( + IntArrayRef sizes, + IntArrayRef strides, + TensorOptions options = {}) { + std::copy(sizes.begin(), sizes.end(), t_sizes.begin()); + std::copy(strides.begin(), strides.end(), t_strides.begin()); + } + + explicit TensorGeometryHolder(const Tensor& t) + : TensorGeometryHolder(t.sizes(), t.strides()) {} + + auto operator*() const { + return std::make_tuple(t_sizes, t_strides); + } + + private: + geometry_holder_t t_sizes; + geometry_holder_t t_strides; +}; + +template <> +class TensorGeometryHolder<0> { + using geometry_holder_t = Tensor; + + public: + explicit TensorGeometryHolder( + IntArrayRef sizes, + IntArrayRef strides, + TensorOptions options) { + const int64_t t_ndims = sizes.size(); + const auto cpu_options = TensorOptions(options).dtype(kLong).device(kCPU); + Tensor t_sizes_and_strides_cpu = at::empty({2, t_ndims}, cpu_options); + t_sizes_and_strides_cpu.select(0, 0).copy_(at::tensor(sizes, cpu_options)); + t_sizes_and_strides_cpu.select(0, 1).copy_( + at::tensor(strides, cpu_options)); + const Tensor t_sizes_and_strides = + t_sizes_and_strides_cpu.to(options.device()); + t_sizes = t_sizes_and_strides.select(0, 0); + t_strides = t_sizes_and_strides.select(0, 1); + } + + explicit TensorGeometryHolder(const Tensor& t) + : TensorGeometryHolder(t.sizes(), t.strides(), t.options()) {} + + auto operator*() const { + return std::make_tuple( + t_sizes.template data_ptr(), + t_strides.template data_ptr()); + } + + private: + geometry_holder_t t_sizes; + geometry_holder_t t_strides; +}; + } // namespace sparse } // namespace at diff --git a/aten/src/ATen/core/GeneratorForPrivateuseone.cpp b/aten/src/ATen/core/GeneratorForPrivateuseone.cpp new file mode 100644 index 000000000000..734ea90de029 --- /dev/null +++ b/aten/src/ATen/core/GeneratorForPrivateuseone.cpp @@ -0,0 +1,28 @@ +#include +#include + +namespace at { + +c10::optional& GetGeneratorPrivate() { + static c10::optional generator_privateuse1 = c10::nullopt; + return generator_privateuse1; +} + +std::mutex _generator_mutex_lock; +_GeneratorRegister::_GeneratorRegister(GeneratorFuncType func) { + _generator_mutex_lock.lock(); + TORCH_CHECK(!GetGeneratorPrivate().has_value(), + "Only can register a generator to the PrivateUse1 dispatch key once!"); + auto& m_generator = GetGeneratorPrivate(); + m_generator = func; + _generator_mutex_lock.unlock(); +} + +at::Generator GetGeneratorForPrivateuse1(c10::DeviceIndex device_index) { + TORCH_CHECK(GetGeneratorPrivate().has_value(), + "Please register a generator to the PrivateUse1 dispatch key, \ + using the REGISTER_GENERATOR_PRIVATEUSE1 macro."); + return GetGeneratorPrivate().value()(device_index); +} + +} diff --git a/aten/src/ATen/core/GeneratorForPrivateuseone.h b/aten/src/ATen/core/GeneratorForPrivateuseone.h new file mode 100644 index 000000000000..eb74484081cc --- /dev/null +++ b/aten/src/ATen/core/GeneratorForPrivateuseone.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +namespace at { + +using GeneratorFuncType = std::function; + +c10::optional& GetGeneratorPrivate(); + +class TORCH_API _GeneratorRegister{ +public: + _GeneratorRegister(GeneratorFuncType func); +}; + +TORCH_API at::Generator GetGeneratorForPrivateuse1(c10::DeviceIndex device_index); + +/** + * This is used to register Generator to PyTorch for `privateuse1` key. + * Usage: REGISTER_GENERATOR_PRIVATEUSE1(GeneratorForPrivateuse1) + * GeneratorForPrivateuse1 func must return a argument with type of at::Generator. + * class CustomGeneratorImpl : public c10::GeneratorImpl { + * CustomGeneratorImpl(DeviceIndex device_index = -1); + * ~CustomGeneratorImpl() override = default; + * ... + * } + * at::Generator MakeGeneratorForPrivateuse1(c10::DeviceIndex id) { + * return at::make_generator(id); + * } + * REGISTER_GENERATOR_PRIVATEUSE1(MakeGeneratorForPrivateuse1) + */ +#define REGISTER_GENERATOR_PRIVATEUSE1(GeneratorPrivate) \ + auto temp##GeneratorPrivate = at::_GeneratorRegister(GeneratorPrivate); + +} diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h index 9960845809c2..b5aaeb124080 100644 --- a/aten/src/ATen/cudnn/Descriptors.h +++ b/aten/src/ATen/cudnn/Descriptors.h @@ -135,7 +135,7 @@ class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor< // padding). If 't' is lower-dimensional than 'pad', the remaining // dimensions (on the right) are padded with ones. This doesn't // affect the underlying data layout. This is particularly useful for - // dealing with a pecularity of the CuDNN API, which is that broadcasting in CuDNN is + // dealing with a peculiarity of the CuDNN API, which is that broadcasting in CuDNN is // done in two steps: first, the client code is expected to pad out // (the dimensions) input tensors to be the same dimension as the // target broadcast, and then second, CuDNN takes of actually @@ -245,7 +245,7 @@ struct TORCH_CUDA_CPP_API DropoutDescriptor // NB: seed doesn't matter when dropout = 0, because no random number // initialization actually takes place when there is no dropout. // NB: Empirically, cudnnSetDropoutDescriptor is cheap when - // dropoot == 0 + // dropout == 0 AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, 0 /* dropout */, nullptr, 0 /* state_size */, 0 /* seed */)); } }; diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index daa3b6bd5739..ebc37e1f85b3 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -259,7 +259,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE2(std_mean, dim); OP_DECOMPOSE(swapaxes); OP_DECOMPOSE2(subtract, Tensor); - OP_DECOMPOSE(sum_to_size); + m.impl("sum_to_size", native::sum_to_size_symint); OP_DECOMPOSE(svd); OP_DECOMPOSE(swapdims); OP_DECOMPOSE(take_along_dim); @@ -281,7 +281,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE2(where, ScalarOther); OP_DECOMPOSE2(where, ScalarSelf); OP_DECOMPOSE(orgqr); - OP_DECOMPOSE2(unflatten, int); + m.impl("unflatten.int", native::unflatten_symint); m.impl("_convolution_double_backward", native::_convolution_double_backward); OP_DECOMPOSE(conv_transpose1d); OP_DECOMPOSE2(conv_transpose2d, input); diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 804b91705306..39998c357315 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1557,7 +1557,8 @@ inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const T r += s2[k] * m1[k][j]; } } else { - r *= beta; + // For beta == 0, the r's value will be ignored, especially for nan value. + r = beta == scalar_t(0) ? scalar_t(0) : beta * r; for (const auto k : c10::irange(ks)) { r += alpha * s2[k] * m1[k][j]; } @@ -1798,61 +1799,41 @@ Tensor& vdot_out(const Tensor& self, const Tensor& other, Tensor& result) { } bool should_fold(const Tensor& tensor1, const Tensor& tensor2) { - // We check that we can fold the larger tensor into a matrix and dispatch to mm or mv rather than - // to bmm. We want to make sure we can do so without incurring in any extra copy - const auto tensor1_larger = tensor1.dim() >= tensor2.dim(); - - // We order the tensors. t1 will be the larger tensor - // We can always transpose tensor2 as the dimensions are always >= 1 (precondition from matmul) - // and tensor1_larger iff tensor2.dim() > tensor1.dim(9 - const auto t1 = tensor1_larger ? MaybeOwned::borrowed(tensor1) - : MaybeOwned::owned(tensor2.mT()); - const int64_t dim_t1 = t1->dim(); - const auto dim_t2 = tensor1_larger ? tensor2.dim() - : tensor1.dim(); - - // Just fold for dim_t1 >= 3 and (dim_t2 == 1 || dim_t2 == 2) - if (!(dim_t1 >= 3 && dim_t2 <= 2)) { - return false; - } - - // In this case we *do* incur in an extra copy to avoid creating an unnecessary large tensor in the backward - // Suppose we don't fold here. Let t1.shape = [b, m, n] t2.shape = [n, k] like in a transformer - // t2 will be expanded to a tensor of shape [b, n, k] and then we do t1.bmm(t2_expanded) - // The issue appears in the backward. - // The output gradient g of this operation would have shape [b, m, k] - // The backward wrt. t2 of bmm would be given by t1.mH @ g, which has shape [b, n, k] - // Then, the backward of expand is simply `sum(0)`. As such, we are instantiating a tensor - // of shape [b, n, k] unnacessarily, which may cause a large memory footprint, and in the - // worst case, an OOM - bool t2_requires_grad = tensor1_larger ? tensor2.requires_grad() : tensor1.requires_grad(); - if (t2_requires_grad) { - return true; - } - - // Don't fold in this case, as we would have to call mm on the transposed tensor, the result - // would be contiguous, and then we would need to transpose it and call contiguous on it, thus - // having to copy the tensor - if (tensor1.dim() == 2) { - return false; - } - - // Can always fold if the tensor is empty - // This serves as a precondition for the code below - if (t1->numel() == 0) { - return true; - } - - // t1->view(-1, t1->size(-1)) does not copy only when the first n-1 dimensions are contiguous - // in the sense that t1_stride[i] = t1_stride[i+1]*t1_shape[i+1] - const auto t1_shape = t1->sizes(); - const auto t1_strides = t1->strides(); - for (auto i = int64_t{0}; i < dim_t1 - int64_t{2}; ++i) { - if (t1_strides[i] != t1_strides[i+1] * t1_shape[i+1]) { + const auto dim_tensor1 = tensor1.dim(); + const auto dim_tensor2 = tensor2.dim(); + if (dim_tensor1 >= 3 && (dim_tensor2 == 1 || dim_tensor2 == 2)) { + // Suppose we don't fold here. Let t1.shape = [b, m, n] t2.shape = [n, k] like in a transformer + // t2 will be expanded to a tensor of shape [b, n, k] and then we do t1.bmm(t2_expanded) + // The issue appears in the backward. + // The output gradient g of this operation would have shape [b, m, k] + // The backward wrt. t2 of bmm would be given by t1.mH @ g, which has shape [b, n, k] + // Then, the backward of expand is simply `sum(0)`. As such, we are instantiating a tensor + // of shape [b, n, k] unnacessarily, which may cause a large memory footprint, and in the + // worst case, an OOM + if (tensor2.requires_grad()) { + return true; + } + const auto t1_sizes_ptr = tensor1.sizes().cbegin(); + const auto t1_strides = tensor1.strides(); + if (dim_tensor1 == 3 && dim_tensor2 == 2 && + t1_strides.back() != 1 && + t1_strides.front() == t1_sizes_ptr[1] * t1_sizes_ptr[2]) { + // First dim is slowest moving, and then the following two dims are + // transposed. This can happen for example by permute(0, 2, 1). + // First 2 dims could be folded to use mm but would require permutation + // with actual data movement, which can be instead handled by BMM with each + // GEMM transposed. + // This can be generalized to a tensor with dim X + Y + Z where X, Y, and Z + // dims are contiguous, Y dims and Z dims are transposed, and X, Y, Z > 0. + // For example, this can happen by permute(0, 1, 5, 2, 3, 4), where X = 2, + // Y = 3, and Z = 1. return false; + } else { + return true; } + } else { + return false; } - return true; } /* @@ -1894,12 +1875,10 @@ Tensor _matmul_impl( : tensor1.unsqueeze(0).mm(tensor2).squeeze_(0); } else if (dim_tensor1 == 2 && dim_tensor2 == 2) { return has_out ? at::mm_out(out, tensor1, tensor2) : tensor1.mm(tensor2); - } else if (should_fold(tensor1, tensor2)) { + } else if (should_fold(tensor1, tensor2) || should_fold(tensor2, tensor1)) { // dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) || // dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2) - // and at least one of the following two conditions hold - // - the small tensor requires grad (see should_fold for the why) - // - we can fold the larger tensor t1 into a matrix as t1.view(-1, t1.size(-1)) without copying + // and some condition on the strides is fulfilled // optimization: use mm instead of bmm by folding the batch of the larger tensor // into its leading matrix dimension @@ -1925,38 +1904,41 @@ Tensor _matmul_impl( if (t2_is_matrix) { output_shape.push_back(t2->sizes()[1]); } - // This will almost always be a view. - // It may not be a view if t2->requires_grad(). See should_fold for an explanation const auto t1_folded = t1->reshape({folded_dim1, sizes_1.back()}); if (!has_out) { if (t2_is_matrix) { + // FIXME This path always does an unnecessary copy when transpose == true as the returned + // result from BLAS is already C-transposed const auto output = at::_unsafe_view(t1_folded.mm(*t2), output_shape); - // This copies if we perform a 2D @ 3D and the first tensor requires_grad - // See should_fold for why. - // If mm_out were differentiable, we could use it here, and pass a result with the - // correct strides to avoid this unnecessary copy. return transpose ? output.mT().contiguous() : output; } else { return at::_unsafe_view(t1_folded.mv(*t2), output_shape); } } else { - // See the !has_out branch for an explanation - TORCH_INTERNAL_ASSERT(!(transpose && t2_is_matrix)); - // Resize output into the correct shape - at::native::resize_output(out, output_shape); + const auto transpose_out = transpose && t2_is_matrix; + if (transpose_out) { + // Swap last two elements of output_shape + std::iter_swap(output_shape.end() - 2, output_shape.end() - 1); + at::native::resize_output(out, output_shape); + std::iter_swap(output_shape.end() - 2, output_shape.end() - 1); + } else { + at::native::resize_output(out, output_shape); + } + const auto out_ = transpose_out ? c10::MaybeOwned::owned(out.mT()) + : c10::MaybeOwned::borrowed(out); // We then reshape the output to the expected shape and call mm/mv // and transpose back if necessary - auto reshaped_out = t2_is_matrix ? out.reshape({folded_dim1, t2->sizes().back()}) - : out.reshape({folded_dim1}); + auto reshaped_out = t2_is_matrix ? out_->reshape({folded_dim1, t2->sizes().back()}) + : out_->reshape({folded_dim1}); if (t2_is_matrix) { at::mm_out(reshaped_out, t1_folded, *t2); } else { at::mv_out(reshaped_out, t1_folded, *t2); } if (!reshaped_out.is_alias_of(out)) { - out.copy_(reshaped_out); + out_->copy_(reshaped_out.view_as(*out_)); } return out; } @@ -1965,8 +1947,9 @@ Tensor _matmul_impl( // We track m1 vs m2 separately even though they must match for nicer error messages const int64_t n = dim_tensor1 > 1 ? tensor1.sizes().cend()[-2] : 1LL; const int64_t m1 = tensor1.sizes().back(); - auto batch_tensor1 = tensor1.sizes().slice(0, std::max(dim_tensor1 - 2, 0LL)); - const int64_t m2 = dim_tensor2 > 1 ? tensor2.sizes().cend()[-2] : tensor2.sizes().front(); + const IntArrayRef batch_tensor1(tensor1.sizes().data(), + std::max(dim_tensor1 - 2, 0LL)); + const int64_t m2 = dim_tensor2 > 1 ? tensor2.sizes().cend()[-2] : tensor2.sizes().back(); const int64_t p = dim_tensor2 > 1 ? tensor2.sizes().back() : 1LL; const IntArrayRef batch_tensor2(tensor2.sizes().data(), std::max(dim_tensor2 - 2, 0LL)); @@ -1983,33 +1966,21 @@ Tensor _matmul_impl( } auto output_shape = infer_size_dimvector(batch_tensor1, batch_tensor2); - const int64_t expand_batch_product = c10::multiply_integers(output_shape); - // flatten expanded batches const auto tensor1_expand_size = [&output_shape, n, m1]{ DimVector ret(output_shape); ret.append({n, m1}); return ret; }(); + const auto tensor2_expand_size = [&output_shape, m2, p]{ DimVector ret(output_shape); + ret.append({m2, p}); + return ret; }(); + + const int64_t expand_batch_product = c10::multiply_integers(output_shape); + + // flatten expanded batches const auto tensor1_expanded = tensor1.expand(tensor1_expand_size) .reshape({expand_batch_product, n, m1}); - // We need to treat the dim_tensor2 == 1 case separately as broadcasting would not convert - // a vector of shape (n,) into a batch of matrices of shape (*, n, 1) - auto vector_rhs = dim_tensor2 == 1; - const auto tensor2_expand_size = [&output_shape, m2, p, vector_rhs]{ - DimVector ret(output_shape); - if (vector_rhs) { - ret.push_back(m2); - } else { - ret.append({m2, p}); - } - return ret; - }(); - auto tensor2_expanded = tensor2.expand(tensor2_expand_size); - if (vector_rhs) { - tensor2_expanded = tensor2_expanded.reshape({expand_batch_product, m2}).unsqueeze(2); - } else { - tensor2_expanded = tensor2_expanded.reshape({expand_batch_product, m2, p}); - } - + const auto tensor2_expanded = tensor2.expand(tensor2_expand_size) + .reshape({expand_batch_product, m2, p}); if (dim_tensor1 > 1) { output_shape.push_back(n); } @@ -2018,18 +1989,11 @@ Tensor _matmul_impl( } if (!has_out) { - if (vector_rhs) { - return at::_unsafe_view(tensor1_expanded.bmm(tensor2_expanded).squeeze(-1), output_shape); - } else { - return at::_unsafe_view(tensor1_expanded.bmm(tensor2_expanded), output_shape); - } + return at::_unsafe_view(tensor1_expanded.bmm(tensor2_expanded), output_shape); } else { at::native::resize_output(out, output_shape); auto reshaped_out = out.reshape({expand_batch_product, n, p}); at::bmm_out(reshaped_out, tensor1_expanded, tensor2_expanded); - if (vector_rhs) { - reshaped_out = reshaped_out.squeeze(-1); - } if (!reshaped_out.is_alias_of(out)) { out.copy_(reshaped_out.view_as(out)); } diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp index 6b2b985bdd92..c2dbb997ad8b 100644 --- a/aten/src/ATen/native/RNN.cpp +++ b/aten/src/ATen/native/RNN.cpp @@ -1422,7 +1422,7 @@ std::tuple lstm( return std::make_tuple(std::move(output), std::move(hy), std::move(cy)); } #ifdef USE_MPS - if (_input.is_mps() && !bidirectional) { + if (_input.is_mps()) { std::tuple output = at::_lstm_mps(_input, hx, _params, has_biases, num_layers, dropout_p, train, bidirectional, batch_first); std::tuple return_values = std::make_tuple(std::get<0>(output), std::get<1>(output), std::get<2>(output)); diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 24ea40652e82..aaa5f8ff113e 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -5,7 +5,7 @@ // index(Tensor self, indices) -> Tensor // index_put_(Tensor self, indices, value, accumulate=false) // -// The index is a TensorList containg kLong, kBool or kByte tensors or nulls. Byte +// The index is a TensorList containing kLong, kBool or kByte tensors or nulls. Byte // tensors (boolean masks) are expanded to long tensors via nonzero(). Null // tensors signify that the dimension is not indexed. // @@ -1842,8 +1842,8 @@ Tensor masked_fill(const Tensor & self, const Tensor & mask, const Tensor & sour static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self, const Tensor & mask) { NoNamesGuard guard; - TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool, - "masked_select: expected BoolTensor or ByteTensor for mask"); + TORCH_CHECK(mask.scalar_type() == ScalarType::Bool, + "masked_select: expected BoolTensor for mask"); TORCH_CHECK(self.scalar_type() == result.scalar_type(), "masked_select(): self and result must have the same scalar type"); @@ -1851,11 +1851,6 @@ static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self, at::assert_no_overlap(result, self); at::assert_no_overlap(result, mask); - if (mask.dtype() == at::ScalarType::Byte) { - TORCH_WARN("masked_select received a mask with dtype torch.uint8, this behavior is now deprecated," \ - "please use a mask with dtype torch.bool instead."); - } - c10::MaybeOwned _mask, _self; std::tie(_mask, _self) = expand_outplace(mask, self); @@ -1880,7 +1875,7 @@ static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self, _self->is_contiguous() && _mask->is_contiguous(); if (use_serial_kernel) { auto iter = TensorIteratorConfig() - .set_check_mem_overlap(false) // result is intenionally zero-strided above + .set_check_mem_overlap(false) // result is intentionally zero-strided above .check_all_same_dtype(false) .resize_outputs(false) .add_output(result_strided) @@ -1899,12 +1894,12 @@ static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self, auto mask_long_data = mask_long.data_ptr(); auto mask_prefix_sum_data = mask_prefix_sum.data_ptr(); // TODO: Here can only use std::partial_sum for C++14, - // use std::exclusive_scan when PyTorch upgrades to C++17, which have better peformance. + // use std::exclusive_scan when PyTorch upgrades to C++17, which have better performance. // std::exclusive_scan(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data, 0); std::partial_sum(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data); auto iter = TensorIteratorConfig() - .set_check_mem_overlap(false) // result is intenionally zero-strided above + .set_check_mem_overlap(false) // result is intentionally zero-strided above .check_all_same_dtype(false) .resize_outputs(false) .add_output(result_strided) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 26b03289494c..80c5c509d9bb 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1118,8 +1118,8 @@ Tensor expand_as(const Tensor& self, const Tensor& other) { return self.expand_symint(other.sym_sizes()); } -Tensor sum_to_size(const Tensor& self, IntArrayRef size) { - TORCH_CHECK(is_expandable_to(size, self.sizes()), +Tensor sum_to_size_symint(const Tensor& self, SymIntArrayRef size) { + TORCH_CHECK(is_expandable_to(size, self.sym_sizes()), "size {", size, "} is not expandable to size {", self.sizes(), "}."); return sum_to(self, size); @@ -3397,7 +3397,7 @@ Tensor ravel(const Tensor& self) { static inline void handle_unflatten_exception(const std::runtime_error &e, const Tensor &self, int64_t dim, - IntArrayRef sizes, + SymIntArrayRef sizes, c10::optional names) { if (!strstr(e.what(), "is invalid for input of size")) { TORCH_CHECK(false, "unflatten got an unexpected error:\n", e.what()); @@ -3406,16 +3406,16 @@ static inline void handle_unflatten_exception(const std::runtime_error &e, if (self.has_names()) { TORCH_CHECK(false, "unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ", - dim, " (", self.names()[dim], ": ", self.size(dim), ") in Tensor", self.names()); + dim, " (", self.names()[dim], ": ", self.sym_size(dim), ") in Tensor", self.names()); } else { TORCH_CHECK(false, "unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ", - dim, " (", self.size(dim), ") in the input tensor"); + dim, " (", self.sym_size(dim), ") in the input tensor"); } } -Tensor unflatten_impl(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::optional names) { +Tensor unflatten_impl(const Tensor& self, int64_t dim, SymIntArrayRef sizes, c10::optional names) { dim = maybe_wrap_dim(dim, self.dim()); TORCH_CHECK(!sizes.empty(), "unflatten: sizes must be non-empty"); @@ -3424,9 +3424,9 @@ Tensor unflatten_impl(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::o TORCH_CHECK(names, "unflatten: input is a named tensor but no names were given for unflattened sizes"); } - DimVector inferred_size; + SymDimVector inferred_size; try { - inferred_size = at::infer_size_dv(sizes, self.size(dim)); + inferred_size = at::infer_size_dv(sizes, self.sym_size(dim)); } catch (const std::runtime_error& e) { // at::infer_size would throw std::runtime_error for invalid size, // catch the runtime_error and display the error message in a more user-friendly way @@ -3434,14 +3434,14 @@ Tensor unflatten_impl(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::o handle_unflatten_exception(e, self, dim, sizes, names); } - DimVector shape(self.sizes().begin(), self.sizes().end()); + SymDimVector shape(self.sym_sizes().begin(), self.sym_sizes().end()); shape.erase(shape.begin() + dim); shape.insert(shape.begin() + dim, inferred_size.begin(), inferred_size.end()); Tensor result; { NoNamesGuard guard; - result = self.view(shape); + result = self.view_symint(shape); } if (names) { @@ -3454,11 +3454,11 @@ Tensor unflatten_impl(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::o return result; } -Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes) { +Tensor unflatten_symint(const Tensor& self, int64_t dim, SymIntArrayRef sizes) { return native::unflatten_impl(self, dim, sizes, c10::nullopt); } -Tensor unflatten(const Tensor& self, Dimname dim, IntArrayRef sizes, DimnameList names) { +Tensor unflatten_dimname_symint(const Tensor& self, Dimname dim, SymIntArrayRef sizes, DimnameList names) { return native::unflatten_impl(self, dimname_to_position(self, dim), sizes, names); } diff --git a/aten/src/ATen/native/cuda/IndexKernel.cpp b/aten/src/ATen/native/cuda/IndexKernel.cpp index b5e92a197700..b0337b96fa80 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cpp +++ b/aten/src/ATen/native/cuda/IndexKernel.cpp @@ -24,8 +24,8 @@ namespace at::native { static Tensor & masked_select_out_cuda_impl(Tensor & result, const Tensor & self, const Tensor & mask) { NoNamesGuard guard; - TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool, - "masked_select: expected BoolTensor or ByteTensor for mask"); + TORCH_CHECK(mask.scalar_type() == ScalarType::Bool, + "masked_select: expected BoolTensor for mask"); TORCH_CHECK(self.scalar_type() == result.scalar_type(), "masked_select(): self and result must have the same scalar type"); diff --git a/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu b/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu index ead1ff6326ea..87f8c191a6dc 100644 --- a/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu +++ b/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu @@ -4,6 +4,7 @@ #include #include #include +#include namespace at::native { @@ -28,10 +29,10 @@ FUNCAPI INLINE bool MulOp::apply(bool a, bool b) { return a && b; } -struct LhsProjOp { +struct RhsProjOp { template static FUNCAPI scalar_t apply(scalar_t a, scalar_t b) { - return a; + return b; } }; @@ -68,11 +69,12 @@ template void binary_op_intersection_kernel( TensorIterator& iter, int64_t lhs_nnz_stride, - int64_t rhs_nnz_stride) { + int64_t rhs_nnz_stride, + const Tensor& argsort) { if (!iter.can_use_32bit_indexing()) { for (auto& sub_iter : iter.with_32bit_indexing()) { binary_op_intersection_kernel( - sub_iter, lhs_nnz_stride, rhs_nnz_stride); + sub_iter, lhs_nnz_stride, rhs_nnz_stride, argsort); } return; } @@ -82,7 +84,8 @@ void binary_op_intersection_kernel( const auto* RESTRICT ptr_lhs_select_idx_bytes = reinterpret_cast(iter.data_ptr(2)); const auto* RESTRICT ptr_rhs_values_bytes = reinterpret_cast(iter.data_ptr(3)); const auto* RESTRICT ptr_rhs_select_idx_bytes = reinterpret_cast(iter.data_ptr(4)); - const auto* RESTRICT ptr_match_bytes = reinterpret_cast(iter.data_ptr(5)); + const auto* RESTRICT ptr_intersction_counts_bytes = reinterpret_cast(iter.data_ptr(5)); + const auto* RESTRICT ptr_argsort = argsort.data_ptr(); auto offset_calc = make_offset_calculator<6>(iter); auto loop = [=] FUNCAPI (int i) { @@ -93,15 +96,22 @@ void binary_op_intersection_kernel( const auto lhs_nnz_idx = *reinterpret_cast(ptr_lhs_select_idx_bytes + offsets[2]); const auto* RESTRICT ptr_rhs_values = reinterpret_cast(ptr_rhs_values_bytes + offsets[3]); const auto rhs_nnz_idx = *reinterpret_cast(ptr_rhs_select_idx_bytes + offsets[4]); - const auto match = *reinterpret_cast(ptr_match_bytes + offsets[5]); - - if (match) { - *ptr_res_values = binary_op_t::apply( - *(ptr_lhs_values + lhs_nnz_idx * lhs_nnz_stride), - *(ptr_rhs_values + rhs_nnz_idx * rhs_nnz_stride)); - } else { - *ptr_res_values = 0; + const auto count = *reinterpret_cast(ptr_intersction_counts_bytes + offsets[5]); + + const auto* RESTRICT ptr_lhs_begin = ptr_lhs_values + lhs_nnz_idx * lhs_nnz_stride; + const auto* RESTRICT ptr_rhs_sorted_nnz_idx = ptr_argsort + rhs_nnz_idx; + + using accscalar_t = at::acc_type; + accscalar_t res_values = 0; + accscalar_t lhs_values = static_cast(*ptr_lhs_begin); + accscalar_t rhs_values; + index_t rhs_sorted_nnz_idx; + for (int64_t c = 0; c < count; ++c) { + rhs_sorted_nnz_idx = *ptr_rhs_sorted_nnz_idx++; + rhs_values = static_cast(*(ptr_rhs_values + rhs_sorted_nnz_idx * rhs_nnz_stride)); + res_values += binary_op_t::apply(lhs_values, rhs_values); } + *ptr_res_values = static_cast(res_values); }; launch_kernel(iter.numel(), loop); @@ -115,13 +125,14 @@ struct CUDAValueSelectionIntersectionKernel { const Tensor& lhs_select_idx, const Tensor& rhs_values, const Tensor& rhs_select_idx, - const c10::optional& match_mask = c10::nullopt) { + const Tensor& intersection_counts, + const Tensor& argsort) { auto iter = make_value_selection_intersection_iter( lhs_values, lhs_select_idx, rhs_values, rhs_select_idx, - match_mask); + intersection_counts); auto res_values = iter.tensor(0); // If res_values is empty, we can return it right away. @@ -136,11 +147,10 @@ struct CUDAValueSelectionIntersectionKernel { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, res_values.scalar_type(), "binary_op_intersection_cpu", [&] { - AT_DISPATCH_INDEX_TYPES(lhs_select_idx.scalar_type(), - "binary_op_intersection_cpu", [&] { - binary_op_intersection_kernel( - iter, lhs_nnz_stride, rhs_nnz_stride); - }); + // COO indices are only 64-bit for now. + using index_t = int64_t; + binary_op_intersection_kernel( + iter, lhs_nnz_stride, rhs_nnz_stride, argsort); }); return res_values; @@ -161,8 +171,8 @@ void sparse_mask_intersection_out_cuda_kernel( Tensor& result, const Tensor& x, const Tensor& y) { - using CUDAValueLhsProjKernel = CUDAValueSelectionIntersectionKernel; - _sparse_binary_op_intersection_kernel_out( + using CUDAValueRhsProjKernel = CUDAValueSelectionIntersectionKernel; + _sparse_binary_op_intersection_kernel_out( result, x, y, true ); } diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 4547512b4953..32084d4b4a30 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -404,6 +404,9 @@ void resize_tensor(Tensor* output) { // this is meant to suppress the availability warning on castTensor // we pass ScalarType instead of MPSDataType to handle MPSDataTypeBoolean's availability too MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType) { + if ([tensor dataType] == toType) { + return tensor; + } return [mpsGraph castTensor:tensor toType:toType name:@"castTensor"]; } diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 65ec86f533a1..2ae120e5abbc 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -233,7 +233,9 @@ void div_mode_template(const Tensor& self, const Tensor& other, void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output, std::string op_name) { if (alpha.toDouble() == 0.0) { - const_cast(output) = self.clone(); + if (!self.is_alias_of(output)) { // if inplace, no-op + const_cast(output) = self.clone(); + } return; } diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index 0af63e1a4a06..74d6bdec6948 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -189,8 +189,8 @@ void index_put_kernel_mps(TensorIterator& iter, IntArrayRef index_size, IntArray static Tensor & masked_select_out_mps_impl(Tensor & result, const Tensor & self, const Tensor & mask) { NoNamesGuard guard; - TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool, - "masked_select: expected BoolTensor or ByteTensor for mask"); + TORCH_CHECK(mask.scalar_type() == ScalarType::Bool, + "masked_select: expected BoolTensor for mask"); TORCH_CHECK(self.scalar_type() == result.scalar_type(), "masked_select(): self and result must have the same scalar type"); diff --git a/aten/src/ATen/native/mps/operations/PointwiseOps.mm b/aten/src/ATen/native/mps/operations/PointwiseOps.mm index 92109c64caf1..034786e3be3f 100644 --- a/aten/src/ATen/native/mps/operations/PointwiseOps.mm +++ b/aten/src/ATen/native/mps/operations/PointwiseOps.mm @@ -6,21 +6,21 @@ // scope the MPS's internal methods to not expose them to at::native namespace mps { -Tensor& addc_mul_div_out_mps(const Tensor& self, +void addc_mul_div_out_mps(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value_opt, // default value = 1.0 - Tensor& output, + const Tensor& output, const bool is_div, const string op_name) { if (value_opt.toDouble() == 0.0) { output.copy_(self); - return output; + return; } if(output.numel() == 0) { - return output; + return; } MPSStream* mpsStream = getCurrentMPSStream(); @@ -38,10 +38,11 @@ CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { + if (!cachedGraph) { cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { CachedGraph* newCachedGraph = nil; + ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), c10::promoteTypes(tensor1.scalar_type(), tensor2.scalar_type())); @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); @@ -53,22 +54,25 @@ // the tensor to be optionally multiplied by value_scalar MPSGraphTensor *multiplicandTensor = nil; + auto firstTensor = castMPSTensor(mpsGraph, newCachedGraph->firstTensor, common_dtype); + auto secondTensor = castMPSTensor(mpsGraph, newCachedGraph->secondTensor, common_dtype); if (is_div) { - multiplicandTensor = [mpsGraph divisionWithPrimaryTensor:newCachedGraph->firstTensor - secondaryTensor:newCachedGraph->secondTensor + multiplicandTensor = [mpsGraph divisionWithPrimaryTensor:firstTensor + secondaryTensor:secondTensor name:nil]; } else { - multiplicandTensor = [mpsGraph multiplicationWithPrimaryTensor:newCachedGraph->firstTensor - secondaryTensor:newCachedGraph->secondTensor + multiplicandTensor = [mpsGraph multiplicationWithPrimaryTensor:firstTensor + secondaryTensor:secondTensor name:nil]; } // the tensor to be added to input_tensor MPSGraphTensor *addendTensor = [mpsGraph multiplicationWithPrimaryTensor:multiplicandTensor - secondaryTensor:newCachedGraph->valueTensor + secondaryTensor:castMPSTensor(mpsGraph, newCachedGraph->valueTensor, common_dtype) name:nil]; - newCachedGraph->outputTensor = [mpsGraph additionWithPrimaryTensor:newCachedGraph->inputTensor - secondaryTensor:addendTensor - name:nil]; + auto outputTensor = [mpsGraph additionWithPrimaryTensor:castMPSTensor(mpsGraph, newCachedGraph->inputTensor, common_dtype) + secondaryTensor:addendTensor + name:nil]; + newCachedGraph->outputTensor = castMPSTensor(mpsGraph, outputTensor, output.scalar_type()); } return newCachedGraph; }); @@ -95,8 +99,6 @@ runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results); } - - return output; } } // namespace mps @@ -105,13 +107,13 @@ TORCH_IMPL_FUNC(addcmul_out_mps) (const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output) { - mps::addc_mul_div_out_mps(self, tensor1, tensor2, value, const_cast(output), false, "addcmul_out_mps"); + mps::addc_mul_div_out_mps(self, tensor1, tensor2, value, output, false, "addcmul_out_mps"); } TORCH_IMPL_FUNC(addcdiv_out_mps) (const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output) { - mps::addc_mul_div_out_mps(self, tensor1, tensor2, value, const_cast(output), true, "addcdiv_out_mps"); + mps::addc_mul_div_out_mps(self, tensor1, tensor2, value, output, true, "addcdiv_out_mps"); } } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/RnnOps.mm b/aten/src/ATen/native/mps/operations/RnnOps.mm index 9e59a6cf7021..fbf0a99d20a1 100644 --- a/aten/src/ATen/native/mps/operations/RnnOps.mm +++ b/aten/src/ATen/native/mps/operations/RnnOps.mm @@ -23,6 +23,85 @@ return output_dimensions; } +/** + * Accepts tensors in Pytorch API format and returns tensors in MPS API format + * @return tuple of tensors to use with MPS API in order: + * stateTensor, cellStateTensor, recurrentWeight, inputWeight, biasTensor + */ +static std::tuple + getMPSTensorsFromPytorchTensors(MPSGraph* mpsGraph, + MPSGraphTensor* stateTensor, MPSGraphTensor* cellStateTensor, + NSMutableArray *recurrentKernelWeightsList, + NSMutableArray *kernelWeightsList, + NSMutableArray *kernelBiasList, + NSMutableArray *recurrentBiasList, + bool has_biases, bool bidirectional, size_t layer_no) { + MPSGraphTensor* biasTensor_ = nil; + MPSGraphTensor* stateTensor_ = nil, *cellStateTensor_ = nil; + MPSGraphTensor* recurrentWeight_ = nil, *inputWeight_ = nil; + + if (bidirectional) { + stateTensor_ = [mpsGraph sliceTensor:stateTensor + dimension:0 + start:layer_no * 2 + length:2 + name:nil]; + // [2, N, H] -> [N, 2, H] + stateTensor_ = [mpsGraph transposeTensor:stateTensor_ dimension: 0 withDimension: 1 name:nil]; + // [N, 2, H] -> [N, 2 * H] + stateTensor_ = [mpsGraph flatten2DTensor:stateTensor_ axis:1 name:nil]; + cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor + dimension:0 + start:layer_no * 2 + length:2 + name:nil]; + cellStateTensor_ = [mpsGraph transposeTensor:cellStateTensor_ dimension: 0 withDimension: 1 name:nil]; + cellStateTensor_ = [mpsGraph flatten2DTensor:cellStateTensor_ axis:1 name:nil]; + + recurrentWeight_ = [mpsGraph + concatTensor: [mpsGraph expandDimsOfTensor: recurrentKernelWeightsList[layer_no * 2] axis: 0 name: nil] + withTensor: [mpsGraph expandDimsOfTensor: recurrentKernelWeightsList[layer_no * 2 + 1] axis: 0 name: nil] + dimension: 0 + name: nil + ]; + inputWeight_ = [mpsGraph + concatTensor: kernelWeightsList[layer_no * 2] + withTensor: kernelWeightsList[layer_no * 2 + 1] + dimension: 0 + name: nil + ]; + if (has_biases) { + auto biasTensorFwd_ = [mpsGraph additionWithPrimaryTensor:kernelBiasList[layer_no * 2] + secondaryTensor:recurrentBiasList[layer_no * 2] + name:nil]; + auto biasTensorBack_ = [mpsGraph additionWithPrimaryTensor:kernelBiasList[layer_no * 2 + 1] + secondaryTensor:recurrentBiasList[layer_no * 2 + 1] + name:nil]; + + biasTensor_ = [mpsGraph concatTensor:biasTensorFwd_ withTensor:biasTensorBack_ dimension:0 name:nil]; + } + } else { + stateTensor_ = [mpsGraph sliceTensor:stateTensor + dimension:0 + start:layer_no + length:1 + name:nil]; + cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor + dimension:0 + start:layer_no + length:1 + name:nil]; + recurrentWeight_ = recurrentKernelWeightsList[layer_no]; + inputWeight_ = kernelWeightsList[layer_no]; + if (has_biases) { + biasTensor_ = [mpsGraph additionWithPrimaryTensor:kernelBiasList[layer_no] + secondaryTensor:recurrentBiasList[layer_no] + name:nil]; + } + } + return std::make_tuple(stateTensor_, cellStateTensor_, recurrentWeight_, inputWeight_, biasTensor_); +} + std::tuple _lstm_mps(const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { using namespace mps; @@ -38,15 +117,17 @@ std::vector recurrent_kernel_weights; std::vector biases; std::vector recurrent_biases; - for (size_t i = 0; i < num_layers; i+=1) { + + const int64_t total_layers = num_layers * (bidirectional ? 2 : 1); + + for (const auto i : c10::irange(total_layers)) { + const int stride = (has_biases ? 4 : 2); + kernel_weights.push_back(params[i*stride]); + recurrent_kernel_weights.push_back(params[i*stride+1]); + if (has_biases) { - kernel_weights.push_back(params[i*4]); - recurrent_kernel_weights.push_back(params[i*4+1]); - biases.push_back(params[i*4+2]); - recurrent_biases.push_back(params[i*4+3]); - } else { - kernel_weights.push_back(params[i*2]); - recurrent_kernel_weights.push_back(params[i*2+1]); + biases.push_back(params[i*stride+2]); + recurrent_biases.push_back(params[i*stride+3]); } } @@ -65,7 +146,7 @@ MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "lstm_" + getTensorsStringKey({input, hx[0], hx[1]}) + getMPSTypeString(input.scalar_type()) + "_num_layers_" + std::to_string(num_layers); + string key = "lstm_" + getTensorsStringKey({input, hx[0], hx[1]}) + getMPSTypeString(input.scalar_type()) + "_num_layers_" + std::to_string(num_layers) + "_bidirectional_" + std::to_string(bidirectional) + "_has_biases_" + std::to_string(has_biases) + "_dropout_" + std::to_string(dropout_p) + "_batch_first_" + std::to_string(batch_first); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { @@ -81,7 +162,7 @@ NSMutableArray *recurrentBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()]; NSMutableArray *layersOutputsList = [[NSMutableArray alloc] initWithCapacity:num_layers]; - for (size_t i = 0; i < num_layers; i += 1) { + for (const auto i : c10::irange(total_layers)) { [kernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(kernel_weights[i]))]; [recurrentKernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_kernel_weights[i]))]; if(has_biases) { @@ -100,7 +181,7 @@ MPSGraphTensor* cellStateTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(hx[1])); std::vector inputTensors = {inputTensor, stateTensor, cellStateTensor,}; - if(batch_first) { + if (batch_first) { inputTensor = [mpsGraph transposeTensor:inputTensor dimension:0 withDimension:1 @@ -113,49 +194,61 @@ NSMutableArray* outputCellStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; NSMutableArray* outputZStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; NSMutableArray* outputCellStateFwdArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; - for(int i = 0; i < num_layers; i++) { - MPSGraphTensor* biasTensor = nil; - if(has_biases) { - biasTensor = [mpsGraph additionWithPrimaryTensor:kernelBiasList[i] - secondaryTensor:recurrentBiasList[i] - name:nil]; - } - MPSGraphTensor* stateTensor_ = [mpsGraph sliceTensor:stateTensor - dimension:0 - start:i - length:1 - name:nil]; - MPSGraphTensor* cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor - dimension:0 - start:i - length:1 - name:nil]; + for (int i = 0; i < num_layers; i++) { + auto tensorsData = getMPSTensorsFromPytorchTensors(mpsGraph, stateTensor, cellStateTensor, + recurrentKernelWeightsList, kernelWeightsList, + kernelBiasList, recurrentBiasList, has_biases, + bidirectional, i); + MPSGraphTensor* stateTensor_ = std::get<0>(tensorsData), *cellStateTensor_ = std::get<1>(tensorsData); + MPSGraphTensor* recurrentWeight_ = std::get<2>(tensorsData), *inputWeight_ = std::get<3>(tensorsData); + MPSGraphTensor* biasTensor_ = std::get<4>(tensorsData); + + outputs = [mpsGraph LSTMWithSourceTensor:inputTensor_ - recurrentWeight:recurrentKernelWeightsList[i] - inputWeight:kernelWeightsList[i] - bias:biasTensor + recurrentWeight:recurrentWeight_ + inputWeight:inputWeight_ + bias:biasTensor_ initState:stateTensor_ initCell:cellStateTensor_ descriptor:opDesc name:nil]; inputTensor_ = [outputs objectAtIndex:0]; - // no need to keep a final layer output copy as it is + // no need to keep the final layer output copy as it is // returned anyway and not used in backprop - if(i != num_layers - 1) { + if (i != num_layers - 1) { [layersOutputsList addObject:[mpsGraph expandDimsOfTensor:inputTensor_ axis:0 name:nil]]; } - if(dropout_p>0.0 && train && (i!=num_layers-1)) { + if (dropout_p>0.0 && train && (i!=num_layers-1)) { inputTensor_ = [mpsGraph dropoutTensor:inputTensor_ rate:dropout_p name:nil]; } - [outputStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:-1 length:1 name:nil]]; - [outputCellStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:-1 length:1 name:nil]]; + if (bidirectional) { + // [1, N, 2 * H] + auto stateLastT = [mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:-1 length:1 name:nil]; + auto stateFirstT = [mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:0 length:1 name:nil]; + // [1, N, H] ([1, N, 0:H]) + auto stateForward = [mpsGraph sliceTensor:stateLastT dimension: -1 start:0 length:hx[0].sizes()[2] name:nil]; + // [1, N, H] ([1, N, H:2H]) + auto stateBack = [mpsGraph sliceTensor:stateFirstT dimension: -1 start:hx[0].sizes()[2] length:hx[0].sizes()[2] name:nil]; + [outputStateArray addObject:stateForward]; + [outputStateArray addObject:stateBack]; + + auto cellStateLastT = [mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:-1 length:1 name:nil]; + auto cellStateFirstT = [mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:0 length:1 name:nil]; + auto cellStateForward = [mpsGraph sliceTensor:cellStateLastT dimension: -1 start:0 length:hx[1].sizes()[2] name:nil]; + auto cellStateBack = [mpsGraph sliceTensor:cellStateFirstT dimension: -1 start:hx[1].sizes()[2] length:hx[1].sizes()[2] name:nil]; + [outputCellStateArray addObject:cellStateForward]; + [outputCellStateArray addObject:cellStateBack]; + } else { + [outputStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:-1 length:1 name:nil]]; + [outputCellStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:-1 length:1 name:nil]]; + } [outputCellStateFwdArray addObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:1] axis:0 name:nil]]; @@ -205,21 +298,18 @@ NSMutableArray *biasList = cachedGraph->biasList_; NSMutableArray *recurrentBiasList = cachedGraph->recurrentBiasList_; - Placeholder kernelWeight, recurrentKernelWeight, bias, recurrentBias; - NSMutableDictionary *feeds = [[[NSMutableDictionary alloc] init] autorelease]; - for (size_t i = 0; i < num_layers; i+=1) { - kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]); - recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]); + for (const auto i : c10::irange(total_layers)) { + Placeholder kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]); + Placeholder recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]); [feeds setObject:kernelWeight.getMPSGraphTensorData() forKey:kernelWeight.getMPSGraphTensor()]; [feeds setObject:recurrentKernelWeight.getMPSGraphTensorData() forKey:recurrentKernelWeight.getMPSGraphTensor()]; - if(has_biases) { - bias = Placeholder([biasList objectAtIndex:i], biases[i]); - recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]); + if (has_biases) { + Placeholder bias = Placeholder([biasList objectAtIndex:i], biases[i]); + Placeholder recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]); [feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()]; [feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()]; } - } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensors_[0], input); Placeholder selfState = Placeholder(cachedGraph->inputTensors_[1], hx[0]); @@ -274,22 +364,22 @@ std::vector recurrent_kernel_weights; std::vector biases; std::vector recurrent_biases; - for (size_t i = 0; i < num_layers; i+=1) { - if(has_biases) { - kernel_weights.push_back(params[i*4]); - recurrent_kernel_weights.push_back(params[i*4+1]); - biases.push_back(params[i*4+2]); - recurrent_biases.push_back(params[i*4+3]); - } else { - kernel_weights.push_back(params[i*2]); - recurrent_kernel_weights.push_back(params[i*2+1]); - } + + const int64_t total_layers = num_layers * (bidirectional ? 2 : 1); + + for (const auto i : c10::irange(total_layers)) { + const int stride = (has_biases ? 4 : 2); + kernel_weights.push_back(params[i*stride]); + recurrent_kernel_weights.push_back(params[i*stride+1]); + if(has_biases) { + biases.push_back(params[i*stride + 2]); + recurrent_biases.push_back(params[i*stride + 3]); + } } struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} std::vector inputTensors_; - std::vector outputTensors_; NSMutableArray *kernelWeightsList_ = nil; NSMutableArray *recurrentKernelWeightsList_ = nil; NSMutableArray *biasList_ = nil; @@ -308,9 +398,9 @@ MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "lstm_backward_" + getTensorsStringKey({input, z_state, cell_state_fwd, grad_y, grad_cy, grad_hy})+ getMPSTypeString(input.scalar_type()) + "_num_layers_" + std::to_string(num_layers); + string key = "lstm_backward_" + getTensorsStringKey({input, z_state, cell_state_fwd, grad_y, grad_cy, grad_hy})+ getMPSTypeString(input.scalar_type()) + "_num_layers_" + std::to_string(num_layers) + "_bidirectional_" + std::to_string(bidirectional) + "_has_biases_" + std::to_string(has_biases) + "_batch_first_" + std::to_string(batch_first); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { + if (!cachedGraph) { MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @@ -323,10 +413,10 @@ NSMutableArray *kernelBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()]; NSMutableArray *recurrentBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()]; - for (size_t i = 0; i < num_layers; i += 1) { + for (const auto i : c10::irange(total_layers)) { [kernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(kernel_weights[i]))]; [recurrentKernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_kernel_weights[i]))]; - if(has_biases) { + if (has_biases) { [kernelBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(biases[i]))]; [recurrentBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_biases[i]))]; } @@ -377,6 +467,8 @@ NSMutableArray* gradStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; NSMutableArray* gradCellStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; + auto hidden_size = hx[0].sizes()[2]; + for (int i = num_layers - 1; i >= 0; i--) { MPSGraphTensor* zState = [mpsGraph sliceTensor:zStateTensor dimension:0 @@ -394,37 +486,47 @@ cellStateFwd = [mpsGraph squeezeTensor:cellStateFwd axis:0 name:nil]; - MPSGraphTensor* biasTensor = nil; - if(has_biases) { - biasTensor = [mpsGraph additionWithPrimaryTensor:kernelBiasList[i] - secondaryTensor:recurrentBiasList[i] - name:nil]; + auto tensorsData = getMPSTensorsFromPytorchTensors(mpsGraph, stateTensor, cellStateTensor, + recurrentKernelWeightsList, kernelWeightsList, + kernelBiasList, recurrentBiasList, has_biases, + bidirectional, i); + MPSGraphTensor* stateTensor_ = std::get<0>(tensorsData), *cellStateTensor_ = std::get<1>(tensorsData); + MPSGraphTensor* recurrentWeight_ = std::get<2>(tensorsData), *inputWeight_ = std::get<3>(tensorsData); + MPSGraphTensor* biasTensor_ = std::get<4>(tensorsData); + + MPSGraphTensor* gradientHyTensor_ = nil, *gradientCyTensor_ = nil; + if (bidirectional) { + gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor + dimension:0 + start:i * 2 + length:2 + name:nil]; + // [2, N, H] -> [N, 2, H] + gradientHyTensor_ = [mpsGraph transposeTensor:gradientHyTensor_ dimension: 0 withDimension: 1 name:nil]; + // [N, 2, H] -> [N, 2 * H] + gradientHyTensor_ = [mpsGraph flatten2DTensor:gradientHyTensor_ axis:1 name:nil]; + + + gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor + dimension:0 + start:i * 2 + length:2 + name:nil]; + gradientCyTensor_ = [mpsGraph transposeTensor:gradientCyTensor_ dimension: 0 withDimension: 1 name:nil]; + gradientCyTensor_ = [mpsGraph flatten2DTensor:gradientCyTensor_ axis:1 name:nil]; } else { - biasTensor = [mpsGraph constantWithScalar:0.0 - dataType:inputTensor.dataType]; - } + gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor + dimension:0 + start:i + length:1 + name:nil]; - MPSGraphTensor* stateTensor_ = [mpsGraph sliceTensor:stateTensor - dimension:0 - start:i - length:1 - name:nil]; - MPSGraphTensor* cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor - dimension:0 - start:i - length:1 - name:nil]; - MPSGraphTensor* gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor - dimension:0 - start:i - length:1 - name:nil]; - - MPSGraphTensor* gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor - dimension:0 - start:i - length:1 - name:nil]; + gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor + dimension:0 + start:i + length:1 + name:nil]; + } MPSGraphTensor* iterationInputTensor_ = nil; if (i == 0) { @@ -432,8 +534,9 @@ } else { iterationInputTensor_ = [mpsGraph sliceTensor:layersOutputsTensor dimension: 0 - // last element in layersOutputsTensor contains - // **inputs** for the last layer + // the last element in layersOutputsTensor + // contains **inputs** for the **last** layer + // and so on start: i - num_layers length: 1 name: nil]; @@ -443,14 +546,14 @@ } outputs = [mpsGraph LSTMGradientsWithSourceTensor: iterationInputTensor_ - recurrentWeight: recurrentKernelWeightsList[i] + recurrentWeight: recurrentWeight_ sourceGradient: gradientTensor_ zState: zState cellOutputFwd: cellStateFwd stateGradient: gradientHyTensor_ cellGradient: gradientCyTensor_ - inputWeight: kernelWeightsList[i] - bias: biasTensor + inputWeight: inputWeight_ + bias: biasTensor_ initState: stateTensor_ initCell: cellStateTensor_ mask: nil @@ -459,14 +562,103 @@ name: nil]; gradientTensor_ = [outputs objectAtIndex:0]; - [gradRecWeightsArray insertObject:[outputs objectAtIndex:1] atIndex:0]; - [gradWeightsArray insertObject:[outputs objectAtIndex:2] atIndex:0]; - [gradBiasArray insertObject: [outputs objectAtIndex:3] atIndex:0]; - [gradStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:4] axis:0 name:nil] atIndex:0]; - [gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:5] axis:0 name:nil] atIndex:0]; - } - std::vector outputTensors = {[outputs objectAtIndex:0],[outputs objectAtIndex:1],[outputs objectAtIndex:2],[outputs objectAtIndex:3], [outputs objectAtIndex:4], [outputs objectAtIndex:5]}; + if (bidirectional) { + int outputIter = 1; + auto gradRecWeightsBidirectional = [outputs objectAtIndex:outputIter++]; + auto gradRecWeightFwd = [mpsGraph sliceTensor:gradRecWeightsBidirectional + dimension: 0 + start: 0 + length: 1 + name: nil]; + gradRecWeightFwd = [mpsGraph squeezeTensor:gradRecWeightFwd axis:0 name: nil]; + auto gradRecWeightBack = [mpsGraph sliceTensor:gradRecWeightsBidirectional + dimension: 0 + start: 1 + length: 1 + name: nil]; + gradRecWeightBack = [mpsGraph squeezeTensor:gradRecWeightBack axis:0 name: nil]; + + // inverse order + [gradRecWeightsArray insertObject:gradRecWeightBack atIndex:0]; + [gradRecWeightsArray insertObject:gradRecWeightFwd atIndex:0]; + auto gradWeightsBidirectional = [outputs objectAtIndex:outputIter++]; + auto gradWeightFwd = [mpsGraph sliceTensor:gradWeightsBidirectional + dimension: 0 + start: 0 + length: hidden_size * 4 + name: nil]; + auto gradWeightBack = [mpsGraph sliceTensor:gradWeightsBidirectional + dimension: 0 + start: hidden_size * 4 + length: hidden_size * 4 + name: nil]; + + [gradWeightsArray insertObject:gradWeightBack atIndex:0]; + [gradWeightsArray insertObject:gradWeightFwd atIndex:0]; + + if (has_biases) { + // has shape [1, 1, 8H] vs [8H] as should be + // so, squeeze these two first dimensions + auto gradBiasBidirectional = [outputs objectAtIndex:outputIter++]; + gradBiasBidirectional = [mpsGraph squeezeTensor: gradBiasBidirectional + axes: @[@0, @1] + name: nil]; + auto gradBiasFwd = [mpsGraph sliceTensor:gradBiasBidirectional + dimension: 0 + start: 0 + length: hidden_size * 4 + name: nil]; + auto gradBiasBack = [mpsGraph sliceTensor:gradBiasBidirectional + dimension: 0 + start: hidden_size * 4 + length: hidden_size * 4 + name: nil]; + + [gradBiasArray insertObject: gradBiasBack atIndex:0]; + [gradBiasArray insertObject: gradBiasFwd atIndex:0]; + } + + auto gradStateBidirectional = [outputs objectAtIndex:outputIter++]; + auto gradStateFwd = [mpsGraph sliceTensor:gradStateBidirectional + dimension: 1 + start: 0 + length: hidden_size + name: nil]; + auto gradStateBack = [mpsGraph sliceTensor:gradStateBidirectional + dimension: 1 + start: hidden_size + length: hidden_size + name: nil]; + + [gradStateArray insertObject: [mpsGraph expandDimsOfTensor:gradStateBack axis:0 name:nil] atIndex:0]; + [gradStateArray insertObject: [mpsGraph expandDimsOfTensor:gradStateFwd axis:0 name:nil] atIndex:0]; + + auto gradCellStateBidirectional = [outputs objectAtIndex:outputIter++]; + auto gradCellStateFwd = [mpsGraph sliceTensor:gradCellStateBidirectional + dimension: 1 + start: 0 + length: hidden_size + name: nil]; + auto gradCellStateBack = [mpsGraph sliceTensor:gradCellStateBidirectional + dimension: 1 + start: hidden_size + length: hidden_size + name: nil]; + + [gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:gradCellStateBack axis:0 name:nil] atIndex:0]; + [gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:gradCellStateFwd axis:0 name:nil] atIndex:0]; + } else { + int outputIter = 1; + [gradRecWeightsArray insertObject:[outputs objectAtIndex:outputIter++] atIndex:0]; + [gradWeightsArray insertObject:[outputs objectAtIndex:outputIter++] atIndex:0]; + if (has_biases) { + [gradBiasArray insertObject: [outputs objectAtIndex:outputIter++] atIndex:0]; + } + [gradStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:outputIter++] axis:0 name:nil] atIndex:0]; + [gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:outputIter++] axis:0 name:nil] atIndex:0]; + } + } if (batch_first) { MPSGraphTensor* gradientTensorTransposed = [mpsGraph transposeTensor:gradientTensor_ dimension: 0 @@ -477,7 +669,6 @@ newCachedGraph->gradOutput_ = gradientTensor_; } - newCachedGraph->outputTensors_ = outputTensors; newCachedGraph->gradRecWeights_ = gradRecWeightsArray; newCachedGraph->gradWeights_ = gradWeightsArray; newCachedGraph->gradBias_ = gradBiasArray; @@ -514,18 +705,15 @@ NSMutableArray *recurrentKernelWeightsList = cachedGraph->recurrentKernelWeightsList_; NSMutableArray *biasList = cachedGraph->biasList_; NSMutableArray *recurrentBiasList = cachedGraph->recurrentBiasList_; - Placeholder kernelWeight; - Placeholder recurrentKernelWeight; - Placeholder bias; - Placeholder recurrentBias; - for (size_t i = 0; i < num_layers; i+=1) { - kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]); - recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]); + + for (const auto i : c10::irange(total_layers)) { + Placeholder kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]); + Placeholder recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]); [feeds setObject:kernelWeight.getMPSGraphTensorData() forKey:kernelWeight.getMPSGraphTensor()]; [feeds setObject:recurrentKernelWeight.getMPSGraphTensorData() forKey:recurrentKernelWeight.getMPSGraphTensor()]; - if(has_biases) { - bias = Placeholder([biasList objectAtIndex:i], biases[i]); - recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]); + if (has_biases) { + Placeholder bias = Placeholder([biasList objectAtIndex:i], biases[i]); + Placeholder recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]); [feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()]; [feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()]; } @@ -556,25 +744,32 @@ Placeholder gradRecWeightsPlaceholder, gradWeightsPlaceholder, gradBiasPlaceholder; std::vector weights; - for (int i = 0; i < num_layers; i++) { + for (const auto i : c10::irange(total_layers)) { Tensor grad_rec_weights = at::empty_like(recurrent_kernel_weights[i]); Tensor grad_weights = at::empty_like(kernel_weights[i]); - Tensor grad_bias = at::empty((kernel_weights[i].size(0)), kernel_weights[i].options()); + weights.push_back(grad_weights); weights.push_back(grad_rec_weights); - if(has_biases) { - weights.push_back(grad_bias); - weights.push_back(grad_bias); - } - gradRecWeightsPlaceholder = Placeholder([gradRecWeightsArray objectAtIndex: i], grad_rec_weights); gradWeightsPlaceholder = Placeholder([gradWeightsArray objectAtIndex: i], grad_weights); - gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex: i], grad_bias); - [results setObject:gradBiasPlaceholder.getMPSGraphTensorData() forKey:gradBiasPlaceholder.getMPSGraphTensor()]; [results setObject:gradRecWeightsPlaceholder.getMPSGraphTensorData() forKey:gradRecWeightsPlaceholder.getMPSGraphTensor()]; [results setObject:gradWeightsPlaceholder.getMPSGraphTensorData() forKey:gradWeightsPlaceholder.getMPSGraphTensor()]; + + if (has_biases) { + Tensor grad_bias = at::empty((kernel_weights[i].size(0)), kernel_weights[i].options()); + + // In PyTorch LSTM API there are two biases. The second bias is included for CuDNN compatibility. + // In this implementation these two biases are added together and used further. + // Therefore, they have equal gradient, and it is pushed + // twice for each of two bias vectors. + weights.push_back(grad_bias); + weights.push_back(grad_bias); + + gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex: i], grad_bias); + [results setObject:gradBiasPlaceholder.getMPSGraphTensorData() forKey:gradBiasPlaceholder.getMPSGraphTensor()]; + } } runMPSGraph(stream, cachedGraph->graph(), feeds, results); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e4d718705a42..3c6d2e2dfc86 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -171,7 +171,7 @@ CUDA: _assert_async_cuda -- func: _assert_tensor_metadata(Tensor a, int[]? size=None, int[]? stride=None, ScalarType? dtype=None) -> () +- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> () - func: refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a) variants: method @@ -610,13 +610,13 @@ MPS: addr_out_mps CompositeExplicitAutograd: math_addr_out -- func: affine_grid_generator(Tensor theta, int[] size, bool align_corners) -> Tensor +- func: affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor variants: function dispatch: CompositeExplicitAutograd: affine_grid_generator autogen: affine_grid_generator.out -- func: affine_grid_generator_backward(Tensor grad, int[] size, bool align_corners) -> Tensor +- func: affine_grid_generator_backward(Tensor grad, SymInt[] size, bool align_corners) -> Tensor variants: function - func: _is_all_true(Tensor self) -> Tensor @@ -2285,7 +2285,7 @@ autogen: new_ones.out # other overrides are to provide a more helpful error message that dtype is required -- func: _empty_affine_quantized(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor +- func: _empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor dispatch: CPU: empty_affine_quantized_other_backends_stub QuantizedCPU, QuantizedCUDA: empty_affine_quantized @@ -2293,7 +2293,7 @@ # it's a factory function receiving a tensor argument, thus overriding explicitly # other overrides are to provide a more helpful error message that dtype is required -- func: _empty_per_channel_affine_quantized(int[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor +- func: _empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor category_override: factory dispatch: CPU: empty_per_channel_affine_quantized_other_backends_stub @@ -2318,7 +2318,7 @@ # This is a utility function to enable users to resize out tensor while registering kernels for out variants. # Eventually, we can consider exposing `resize_output` as a public API to ship it with python op registration # to make it easy to register out variants for ops. -- func: _resize_output_(Tensor(a!) self, int[] size, Device device) -> Tensor(a!) +- func: _resize_output_(Tensor(a!) self, SymInt[] size, Device device) -> Tensor(a!) use_const_ref_for_mutable_tensors: True variants: function dispatch: @@ -2520,11 +2520,15 @@ - func: flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a) variants: function, method -- func: unflatten.int(Tensor(a) self, int dim, int[] sizes) -> Tensor(a) +- func: unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a) variants: function, method + dispatch: + CompositeImplicitAutograd: unflatten_symint -- func: unflatten.Dimname(Tensor(a) self, Dimname dim, int[] sizes, Dimname[] names) -> Tensor(a) +- func: unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a) variants: function, method + dispatch: + CompositeImplicitAutograd: unflatten_dimname_symint - func: fill.Scalar(Tensor self, Scalar value) -> Tensor variants: function @@ -5389,10 +5393,12 @@ CPU, CUDA: nansum_out MPS: nansum_out_mps -- func: sum_to_size(Tensor self, int[] size) -> Tensor +- func: sum_to_size(Tensor self, SymInt[] size) -> Tensor variants: method device_check: NoCheck device_guard: False + dispatch: + CompositeImplicitAutograd: sum_to_size_symint - func: sqrt(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -6061,7 +6067,7 @@ CompositeExplicitAutograd: zeros autogen: zeros.names_out -- func: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +- func: _efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor dispatch: CPU: _efficientzerotensor CUDA: _efficientzerotensor_cuda diff --git a/aten/src/ATen/native/quantized/cpu/LinearUnpackImpl.cpp b/aten/src/ATen/native/quantized/cpu/LinearUnpackImpl.cpp index 8e3739a78d6f..f1375b17b4d4 100644 --- a/aten/src/ATen/native/quantized/cpu/LinearUnpackImpl.cpp +++ b/aten/src/ATen/native/quantized/cpu/LinearUnpackImpl.cpp @@ -15,6 +15,7 @@ #else #include #include +#include #include #include #endif @@ -65,13 +66,20 @@ std::tuple> PackedLinearWeightsQnnp:: return std::tuple>(orig_weight, bias_); } else{ - TORCH_WARN( - "Original weight is freed, we are converting pre-packed weight to original weight."); - uint8_t* kernel = w->unpackWeights(w_zero_points.data(), n_elements); - at::Tensor original_tensor = at::from_blob(kernel, weight_sizes, c10::kByte).clone().toType(c10::kQInt8); - original_tensor.sub_(128); - free(kernel); - return std::tuple>(original_tensor, bias_); + float* weight_scales_data = w_scales.data_ptr(); + at::Tensor weight_origin; + weight_origin = at::empty(weight_sizes, at::device(c10::kCPU).dtype(at::kChar)); + int8_t* weight_ptr_int8 = + reinterpret_cast(weight_origin.data_ptr()); + w->unpackWeights(w_zero_points.data(), weight_ptr_int8); + // See for the subtraction 128 + // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp#L319 + weight_origin.sub_(128); + // As of now, we are supporting only per tensor quantizer + // TO-DO : Support a per channel as well. + at::Tensor original_quantized_tensor = at::_make_per_tensor_quantized_tensor(weight_origin, weight_scales_data[0], w_zero_points[0]); + TORCH_CHECK(original_quantized_tensor.qscheme() == c10::kPerTensorAffine); + return std::tuple>(original_quantized_tensor, bias_); } } #endif // USE_PYTORCH_QNNPACK diff --git a/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h b/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h index cfa9dcdb7028..39eb047edc17 100644 --- a/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h +++ b/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h @@ -50,7 +50,6 @@ struct PackedLinearWeightsQnnp : public LinearPackedParamsBase { w_scales(std::move(w_scales)), w_zero_points(std::move(w_zps)) { weight_sizes = this->orig_weight.sizes().vec(); - n_elements = std::accumulate(std::begin(weight_sizes), std::end(weight_sizes), 1, std::multiplies()); } std::unique_ptr w; @@ -62,7 +61,6 @@ struct PackedLinearWeightsQnnp : public LinearPackedParamsBase { std::vector w_zero_points; std::vector requantization_scales; std::vector weight_sizes; - int n_elements; at::Tensor apply( at::Tensor input, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/include/qnnpack_func.h b/aten/src/ATen/native/quantized/cpu/qnnpack/include/qnnpack_func.h index eeadbaf91181..10bbc000192d 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/include/qnnpack_func.h +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/include/qnnpack_func.h @@ -66,9 +66,9 @@ class PackBMatrix final { return packed_weights_; } - uint8_t* unpackWeights( + void unpackWeights( const uint8_t* kernel_zero_points, - int n_elements + int8_t* kernel ) const; size_t getInputChannels() const diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-prepack.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-prepack.cc index 2b2922d2bf37..ce5e1fec7d4e 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-prepack.cc +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-prepack.cc @@ -32,7 +32,6 @@ PackBMatrix::PackBMatrix( const uint32_t n_stride = (output_channels + (nr - 1)) & -nr; const uint32_t k_stride = (input_channels + (kr - 1)) & -kr; - input_channels_ = input_channels; output_channels_ = output_channels; packed_weights_ = diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-unpack.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-unpack.cc index d142567b90ef..02610c42c7b3 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-unpack.cc +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-unpack.cc @@ -8,9 +8,9 @@ namespace qnnpack { // For runtime quantization unpacking. -uint8_t* PackBMatrix::unpackWeights( +void PackBMatrix::unpackWeights( const uint8_t* kernel_zero_points, - int n_elements + int8_t* kernel ) const { union { void* const as_void_ptr; @@ -18,8 +18,6 @@ uint8_t* PackBMatrix::unpackWeights( int32_t* as_int32_ptr; } packed = {packed_weights_}; - uint8_t* kernel = (uint8_t*)malloc(n_elements * sizeof(uint8_t));; - // C = A * B // A = M*K // B = K*N @@ -67,7 +65,6 @@ uint8_t* PackBMatrix::unpackWeights( } } - return kernel; } } // namespace qnnpack diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h index 9b2a8be7ef9a..88677aa5d50f 100644 --- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h +++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h @@ -13,9 +13,7 @@ #else #include #include -#include #include -#include #include #endif @@ -25,23 +23,6 @@ #define NAME "sparse_binary_op_intersection_cpu" #endif -#define CALL(...) __VA_ARGS__(); -#define EXPAND(b, n, ...) \ - if (b) { \ - using index_t ## n = int32_t; \ - __VA_ARGS__ \ - } \ - else { \ - using index_t ## n = int64_t; \ - __VA_ARGS__ \ - } -#define BOOL_TO_INDEX_TYPE1(b0, ...) \ - EXPAND(b0, 0, CALL(__VA_ARGS__)) -#define BOOL_TO_INDEX_TYPE2(b1, b0, ...) \ - EXPAND(b1, 1, BOOL_TO_INDEX_TYPE1(b0, __VA_ARGS__)) -#define BOOL_TO_INDEX_TYPE3(b2, b1, b0, ...) \ - EXPAND(b2, 2, BOOL_TO_INDEX_TYPE2(b1, b0, __VA_ARGS__)) - namespace at { namespace native { @@ -99,7 +80,7 @@ TensorIterator make_value_selection_intersection_iter( const Tensor& lhs_select_idx, const Tensor& rhs_values, const Tensor& rhs_select_idx, - const c10::optional& match_mask_opt = c10::nullopt) { + const Tensor& intersection_counts) { const auto res_values_sizes = [&]() -> std::vector { auto sizes = infer_size( // keep nnz dim @@ -128,14 +109,6 @@ TensorIterator make_value_selection_intersection_iter( return values.as_strided(values_sizes, values_strides); }; - const auto match_mask = [&match_mask_opt, &lhs_select_idx]() -> Tensor { - if (match_mask_opt.has_value()) { - return *match_mask_opt; - } else { - return at::ones_like(lhs_select_idx); - } - }(); - auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) .check_all_same_dtype(false) @@ -145,7 +118,7 @@ TensorIterator make_value_selection_intersection_iter( .add_owned_input(restride_idx(lhs_select_idx)) .add_owned_input(restride_values(rhs_values)) .add_owned_input(restride_idx(rhs_select_idx)) - .add_owned_input(restride_idx(match_mask)) + .add_owned_input(restride_idx(intersection_counts)) .build(); return iter; @@ -155,15 +128,14 @@ template < template class kernel_t, typename value_selection_intersection_kernel_t, typename index_t = int64_t, - typename hash_t = int64_t, - typename offset_t = int64_t> + int64_t max_static_len = 0> void _sparse_binary_op_intersection_kernel_impl( Tensor& res, const Tensor& x_, const Tensor& y_, const std::vector broadcasted_shape, const bool restrict_indices_to_rhs = false, - const bool commutes_with_sum = true + const bool distributive_with_sum = true ) { // The common dtype check is relevant when op is done in-place. // This is because binary_of_t produces new values and it could be that @@ -176,12 +148,12 @@ void _sparse_binary_op_intersection_kernel_impl( using KernelLauncher = KernelLauncher; - // If the op and sum are not commutative, coalesce is required. + // If the op and sum are distributive, coalesce is required. // If restrict_indices_to_rhs is true, x needs to be coalesced so that // (x.coalesce() intersection y union y).indices().counts() == y.indices().counts(). - const Tensor x = (!commutes_with_sum || restrict_indices_to_rhs) ? x_.coalesce() : x_; + const Tensor x = (!distributive_with_sum || restrict_indices_to_rhs) ? x_.coalesce() : x_; const Tensor y = [&]() -> Tensor { - auto rhs = commutes_with_sum ? y_ : y_.coalesce(); + auto rhs = distributive_with_sum ? y_ : y_.coalesce(); if (restrict_indices_to_rhs) { // x is coalesced and y is marked as uncoalesced so that the intersection result // respects the order of indices in y. @@ -262,25 +234,19 @@ void _sparse_binary_op_intersection_kernel_impl( // which is implicit in the definition of hash_coeffs, // it could be shown that the hash function is actually bijective and, hence, // is a perfect hash function (no collisions ever). - const auto kHash = std::is_same::value ? kLong : kInt; - const auto hash_coeffs = [&]() -> Tensor { + + // Need owning storage in case of the Tensor class. + const auto hash_coeffs_storage = [&]() -> auto { const auto broadcasted_sparse_dim_shape = std::vector( broadcasted_shape.begin(), broadcasted_shape.begin() + probably_coalesced.sparse_dim() ); - auto strides = contiguous_strides(broadcasted_sparse_dim_shape); - auto strides_len = static_cast(strides.size()); - auto hash_coeffs = at::empty( - {strides_len}, - probably_coalesced._indices().options().device(kCPU).dtype(kHash)); - // Copy with a potential casting. Is there a nicer way? - for (const auto i : c10::irange(strides_len)) { - hash_coeffs[i] = strides[i]; - } - hash_coeffs = hash_coeffs.to(probably_coalesced.device()); - return hash_coeffs; + auto strides = c10::contiguous_strides(broadcasted_sparse_dim_shape); + return at::sparse::TensorGeometryHolder(strides, strides, probably_coalesced.options()); }(); + const auto hash_coeffs = std::get<0>(*hash_coeffs_storage); + const auto nnz_arange = at::arange( std::max(probably_coalesced._nnz(), source._nnz()), source._indices().options()); @@ -288,8 +254,6 @@ void _sparse_binary_op_intersection_kernel_impl( // non-const because of gcc-5/clang-5 issues auto sparse_dim = probably_coalesced.sparse_dim(); - // non-const because of gcc-5/clang-5 issues - auto sdim = static_cast(sparse_dim); // Apply the hash function to probably_coalesced.indices const auto probably_coalesced_indices_hash = [&]() -> Tensor { @@ -299,10 +263,9 @@ void _sparse_binary_op_intersection_kernel_impl( auto indices_nnz_stride = indices.stride(1); auto hash = at::empty({probably_coalesced._nnz()}, - indices.options().dtype(kHash)); + indices.options().dtype(kLong)); auto iter = TensorIteratorConfig() - // Hash has hash_t type while probably_coalesced_nnz_arange is index_t. .check_all_same_dtype(false) .add_output(hash) .add_input(probably_coalesced_nnz_arange) @@ -310,17 +273,15 @@ void _sparse_binary_op_intersection_kernel_impl( { const auto* RESTRICT ptr_indices = indices.data_ptr(); - const auto* RESTRICT ptr_hash_coeffs = hash_coeffs.template data_ptr(); KernelLauncher::launch(iter, // NOTE: capture by value required by CUDA - [=] FUNCAPI (index_t nnz_idx) -> hash_t { + [=] FUNCAPI (index_t nnz_idx) -> int64_t { const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride; - auto hash = hash_t {0}; - for (uint32_t dim = 0; dim < sdim; ++dim) { - // use only int32_t operations when hash_t == int32_t - const auto dim_hash_coeff = ptr_hash_coeffs[dim]; - const auto dim_index = static_cast(ptr_indices_dim[dim * indices_dim_stride]); + int64_t hash = 0; + for (int64_t dim = 0; dim < sparse_dim; ++dim) { + const auto dim_hash_coeff = hash_coeffs[dim]; + const auto dim_index = ptr_indices_dim[dim * indices_dim_stride]; hash += dim_index * dim_hash_coeff; } return hash; @@ -384,11 +345,10 @@ void _sparse_binary_op_intersection_kernel_impl( { const auto* RESTRICT ptr_indices = source_indices.data_ptr(); - const auto* RESTRICT ptr_sorted_hash = sorted_hash.data_ptr(); + const auto* RESTRICT ptr_sorted_hash = sorted_hash.data_ptr(); const auto sorted_hash_len = sorted_hash.numel(); - const auto* RESTRICT ptr_hash_coeffs = hash_coeffs.template data_ptr(); - auto* RESTRICT ptr_intersection_count = intersection_count.data_ptr(); - auto* RESTRICT ptr_intersection_first_idx = intersection_first_idx.data_ptr(); + auto* RESTRICT ptr_intersection_count = intersection_count.data_ptr(); + auto* RESTRICT ptr_intersection_first_idx = intersection_first_idx.data_ptr(); // Fusing hash computation with hash intersection. KernelLauncher::launch(iter, @@ -396,22 +356,21 @@ void _sparse_binary_op_intersection_kernel_impl( [=] FUNCAPI (index_t nnz_idx) -> index_t { // Compute hash value const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride; - auto hash = hash_t {0}; - for (uint32_t dim = 0; dim < sdim; ++dim) { - // Use only int32_t operations when hash_t == int32_t. - const auto dim_hash_coeff = ptr_hash_coeffs[dim]; - const auto dim_index = static_cast(ptr_indices_dim[dim * indices_dim_stride]); + int64_t hash = 0; + for (int64_t dim = 0; dim < sparse_dim; ++dim) { + const auto dim_hash_coeff = hash_coeffs[dim]; + const auto dim_index = ptr_indices_dim[dim * indices_dim_stride]; hash += dim_index * dim_hash_coeff; } // Perform hash values intersection - const auto* RESTRICT lb = find_bound( + const auto* RESTRICT lb = find_bound( ptr_sorted_hash, ptr_sorted_hash + sorted_hash_len, hash ); - const auto* RESTRICT ub = find_bound( + const auto* RESTRICT ub = find_bound( ptr_sorted_hash, ptr_sorted_hash + sorted_hash_len, hash @@ -427,165 +386,25 @@ void _sparse_binary_op_intersection_kernel_impl( return std::make_tuple(intersection_count, intersection_first_idx); }(); - // Intersection is all we need in such a case. - if (restrict_indices_to_rhs) { - const auto res_indices = source._indices().clone(); - const auto res_values = value_selection_intersection_kernel_t::apply( - probably_coalesced._values(), - intersection_first_idx.to(nnz_arange.scalar_type()), - source._values(), - nnz_arange.narrow(-1, 0, source._nnz()), - intersection_count.ge(1)); - const auto res_sparse_dim = source.sparse_dim(); - const auto res_dense_dim = source.dense_dim(); - const auto& res_shape = broadcasted_shape; - const auto res_nnz = source._nnz(); - - auto* res_sparse_impl = get_sparse_impl(res); - res_sparse_impl->raw_resize_(res_sparse_dim, res_dense_dim, res_shape); - res_sparse_impl->set_indices_and_values_unsafe(res_indices, res_values); - res_sparse_impl->set_nnz_and_narrow(res_nnz); - res._coalesced_(y_.is_coalesced() || !commutes_with_sum); - return; - } - - // Using intersection_count and intersection_first_idx, - // form indices selected_source and selected_probably_coalesced such that - // res.values = op( - // source.values.index_select(0, selected_source), - // probably_coalesced.values.index_select(0, selected_probably_coalesced)) and - // res.indices = selected_source_sparse_indices, which is also equivalent to - // res.indices = source.indices.index_select(1, selected_source). - Tensor selected_source, selected_source_sparse_indices, selected_probably_coalesced; - std::tie(selected_source, selected_source_sparse_indices, selected_probably_coalesced) - = [&]() -> std::tuple { - // Thread offset = shifted_offset - shift. - // This computation is fused in kernels below. - - // hash_t might not be enough to store offset values, so we use - // offset_t which is at least sizeof(hash_t). - const auto kOffset = std::is_same::value ? kInt : kLong; - const auto shifted_offset = intersection_count.cumsum(-1, kOffset); - - // NOTE: unavoidable sync to get to know the result's shape. - const auto intersection_nnz = static_cast( - // shifted_offset is a 1-dim tensor, potentially empty - shifted_offset.size(0) - ? shifted_offset.select(-1, -1).template item() - : 0); - - auto selected_buffer = at::empty({2, intersection_nnz}, intersection_count.options()); - auto selected_source = selected_buffer.select(0, 0); - auto selected_probably_coalesced = selected_buffer.select(0, 1); - const auto source_sparse_indices = source._indices(); - auto selected_source_sparse_indices = at::empty({source.sparse_dim(), intersection_nnz}, - source_sparse_indices.options().memory_format(at::MemoryFormat::Contiguous)); - const auto source_idx = nnz_arange.narrow(-1, 0, source._nnz()); - auto dummy = at::empty({1}, source_idx.options()); - - auto iter = TensorIteratorConfig() - .set_check_mem_overlap(false) - .check_all_same_dtype(false) - .add_owned_output(dummy.expand_as(source_idx)) - .add_input(source_idx) // index_t - .add_input(intersection_count) // hash_t - .add_input(intersection_first_idx) // hash_t - .add_input(shifted_offset) // offset_t - .build(); - - { - auto* RESTRICT ptr_selected_source = selected_source.data_ptr(); - auto* RESTRICT ptr_selected_probably_coalesced = selected_probably_coalesced.data_ptr(); - const auto* RESTRICT ptr_argsort = argsort_hash.data_ptr(); - - auto* RESTRICT ptr_selected_source_sparse_indices = selected_source_sparse_indices.data_ptr(); - // Non-const because of Gcc5/Clang5 issues - auto selected_source_sparse_indices_nnz_stride = static_cast( - selected_source_sparse_indices.stride(1)); - auto selected_source_sparse_indices_dim_stride = static_cast( - selected_source_sparse_indices.stride(0)); - - const auto* RESTRICT ptr_source_sparse_indices = source_sparse_indices.data_ptr(); - // Non-const because of Gcc5/Clang5 issues - auto source_sparse_indices_nnz_stride = static_cast( - source_sparse_indices.stride(1)); - auto source_sparse_indices_dim_stride = static_cast( - source_sparse_indices.stride(0)); - - KernelLauncher::launch(iter, - // NOTE: capture by value required by CUDA - [=] FUNCAPI ( - index_t idx, - hash_t count, - hash_t first_match_idx, - offset_t shifted_offset) -> index_t { - const auto offset = shifted_offset - static_cast(count); - auto* RESTRICT ptr_selected_source_idx_out = ptr_selected_source + offset; - auto* RESTRICT ptr_selected_probably_coalesced_idx_out = ptr_selected_probably_coalesced + offset; - const auto* RESTRICT ptr_argsort_idx = ptr_argsort + first_match_idx; - - auto* RESTRICT ptr_selected_source_sparse_indices_out = - ptr_selected_source_sparse_indices + offset * selected_source_sparse_indices_nnz_stride; - const auto* RESTRICT ptr_source_sparse_indices_in = - ptr_source_sparse_indices + idx * source_sparse_indices_nnz_stride; - - for (hash_t i = 0; i < count; ++i) { - *ptr_selected_source_idx_out++ = idx; - *ptr_selected_probably_coalesced_idx_out++ = *ptr_argsort_idx++; - - // res_indices = source._indices().index_select(1, selected_source) - // The code below fuses this computation with forming - // selected_source and selected_probably_coalesced. - for (uint32_t d = 0; d < sdim; ++d) { - ptr_selected_source_sparse_indices_out[d * selected_source_sparse_indices_dim_stride] - = ptr_source_sparse_indices_in[d * source_sparse_indices_dim_stride]; - } - ptr_selected_source_sparse_indices_out += selected_source_sparse_indices_nnz_stride; - } - - return 0; - }); - } - - return std::make_tuple(selected_source, selected_source_sparse_indices, selected_probably_coalesced); - }(); - - const auto res_indices = selected_source_sparse_indices; - - // Value intersection - const auto binary_op_res_dtype = at::result_type( - source._values(), - probably_coalesced._values()); - auto res_values = value_selection_intersection_kernel_t::apply( - source._values().to(binary_op_res_dtype), // promote for better accuracy - selected_source, - probably_coalesced._values().to(binary_op_res_dtype), // promote for better accuracy - selected_probably_coalesced); - // Convert back if the promoted dtype is different from res.dtype. - // This could happen for in-place usage cases. - res_values = res_values.to(res.scalar_type()); - + const auto res_indices = source._indices().clone(); + const auto binary_op_res_dtype = at::result_type(source._values(), probably_coalesced._values()); + const auto res_values = value_selection_intersection_kernel_t::apply( + source._values().to(binary_op_res_dtype), + nnz_arange.narrow(-1, 0, source._nnz()), + probably_coalesced._values().to(binary_op_res_dtype), + intersection_first_idx.to(nnz_arange.scalar_type()), + intersection_count, + argsort_hash).to(res.scalar_type()); const auto res_sparse_dim = source.sparse_dim(); - const auto res_dense_dim = res_values.dim() - 1; + const auto res_dense_dim = source.dense_dim(); const auto& res_shape = broadcasted_shape; - const auto res_nnz = selected_source.numel(); + const auto res_nnz = source._nnz(); auto* res_sparse_impl = get_sparse_impl(res); res_sparse_impl->raw_resize_(res_sparse_dim, res_dense_dim, res_shape); res_sparse_impl->set_indices_and_values_unsafe(res_indices, res_values); res_sparse_impl->set_nnz_and_narrow(res_nnz); - // Result is coalesced iff arguments are coalesced, conditioned on the fact - // that we do not check that intersection hash values are sorted and unique. - // <= : intersection contains only unique indices (or empty), and the algorithm's - // behavior is order-preserving. So, the result has only unique indices (or empty) which are sorted. - // => : proof by contraposition. The contrapositive statement reads - // `there is an uncoalesced argument => result is not coalesced`. - // If both arguments are uncoalesced, the result is clearly uncoalesced again - // thanks to the order-preserving behavior of the algorithm. - // Otherwise we have a coalesced argument `probably_coalesced` and an uncoalesced `source`. - // Since the matching beahavior of the algorithm respects the order of `source`, the result - // will be as coalesced as `source` is, which is uncoalesced. - res._coalesced_(source.is_coalesced() && probably_coalesced.is_coalesced()); + res._coalesced_(source.is_coalesced()); } template < @@ -601,9 +420,9 @@ void _sparse_binary_op_intersection_kernel_out( // and it also requires less kernel calls compared to // a generic intersection. const bool restrict_indices_to_rhs = false, - // If op commutes with the sum, the arguments are processed as is, + // If op distributes with the sum, the arguments are processed as is, // without the calls to coalesce(). - const bool commutes_with_sum = true + const bool distributive_with_sum = true ) { TORCH_CHECK( (x.is_sparse() && y.is_sparse()) @@ -617,33 +436,23 @@ void _sparse_binary_op_intersection_kernel_out( const auto broadcasted_shape = infer_size(x.sizes(), y.sizes()); - int64_t max_hash_val = 1; - for (const auto d : c10::irange(x.sparse_dim())) { - max_hash_val *= broadcasted_shape[d]; + // 8 sparse dims should be more than enough? + constexpr int64_t max_sparse_dims = 8; + + // COO indices are only 64-bit integers for now. + using index_t = int64_t; + + if (max_sparse_dims > x.sparse_dim()) { + _sparse_binary_op_intersection_kernel_impl< + // For some reason MSVC complaints about passing constexpr max_sparse_dims + // as a template parameter claiming as if it is not know at compile time. + kernel_t, value_selection_intersection_kernel_t, index_t, 8>( + res, x, y, broadcasted_shape, restrict_indices_to_rhs, distributive_with_sum); + } else { + _sparse_binary_op_intersection_kernel_impl< + kernel_t, value_selection_intersection_kernel_t, index_t>( + res, x, y, broadcasted_shape, restrict_indices_to_rhs, distributive_with_sum); } - - const auto is_32bit_indexing = x._indices().scalar_type() == at::kInt; - // Optimization: use 32-bit hash values when possible. - const auto is_max_hash_32bits = max_hash_val <= std::numeric_limits::max(); - // Intersection nnz could get larger than nnz of either arguments. - // Example: probably_coalesced and source have only one unique and shared index, - // then the size of intersection is exactly the product of their nnzs. - // This nnz defines offsets per thread which are computed using cumsum on values - // of hash dtype. This becomes a problem when hash_t=int32_t and res_nnz > max(int32_t). - const auto is_max_offset_32bits = (x._nnz() * y._nnz()) <= std::numeric_limits::max(); - - BOOL_TO_INDEX_TYPE3(is_32bit_indexing, is_max_hash_32bits, is_max_offset_32bits, [&]() { - // Given 3 booleans b0, b1, b2, index_t is defined as - // index_t = int32_t if b<2 - i> is true else int64_t. - // The goal is to use int32_t whenever possible for better - // performance. - // NOTE: order of types given booleans is reversed. - using index_t = index_t2; - using hash_t = index_t1; - using offset_t = index_t0; - _sparse_binary_op_intersection_kernel_impl( - res, x, y, broadcasted_shape, restrict_indices_to_rhs, commutes_with_sum); - }); } } // anonymous namespace diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp index 32bb075da504..d48d31168d02 100644 --- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp +++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp @@ -3,6 +3,7 @@ #include #include #include +#include namespace at { namespace native { @@ -28,10 +29,10 @@ bool MulOp::apply(bool a, bool b) { return a && b; } -struct LhsProjOp { +struct RhsProjOp { template static scalar_t apply(scalar_t a, scalar_t b) { - return a; + return b; } }; @@ -42,13 +43,14 @@ struct CPUValueSelectionIntersectionKernel { const Tensor& lhs_select_idx, const Tensor& rhs_values, const Tensor& rhs_select_idx, - const c10::optional& match_mask = c10::nullopt) { + const Tensor& intersection_counts, + const Tensor& argsort) { auto iter = make_value_selection_intersection_iter( lhs_values, lhs_select_idx, rhs_values, rhs_select_idx, - match_mask); + intersection_counts); auto res_values = iter.tensor(0); auto lhs_nnz_stride = lhs_values.stride(0); @@ -57,45 +59,51 @@ struct CPUValueSelectionIntersectionKernel { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, res_values.scalar_type(), "binary_op_intersection_cpu", [&] { - AT_DISPATCH_INDEX_TYPES(lhs_select_idx.scalar_type(), - "binary_op_intersection_cpu", [&] { - auto loop = [&](char** data, const int64_t* strides, int64_t n) { - auto* ptr_res_values_bytes = data[0]; - const auto* ptr_lhs_values_bytes = data[1]; - const auto* ptr_lhs_select_idx_bytes = data[2]; - const auto* ptr_rhs_values_bytes = data[3]; - const auto* ptr_rhs_select_idx_bytes = data[4]; - const auto* ptr_match_bytes = data[5]; - - for (int64_t i = 0; i < n; ++i) { - // Exctract data - auto* RESTRICT ptr_res_values = reinterpret_cast(ptr_res_values_bytes); - const auto* ptr_lhs_values = reinterpret_cast(ptr_lhs_values_bytes); - const auto lhs_nnz_idx = *reinterpret_cast(ptr_lhs_select_idx_bytes); - const auto* ptr_rhs_values = reinterpret_cast(ptr_rhs_values_bytes); - const auto rhs_nnz_idx = *reinterpret_cast(ptr_rhs_select_idx_bytes); - const auto match = *reinterpret_cast(ptr_match_bytes); - - // Apply op - if (match) { - *ptr_res_values = binary_op_t::apply( - *(ptr_lhs_values + lhs_nnz_idx * lhs_nnz_stride), - *(ptr_rhs_values + rhs_nnz_idx * rhs_nnz_stride)); - } else { - *ptr_res_values = 0; - } - - // Advance - ptr_res_values_bytes += strides[0]; - ptr_lhs_values_bytes += strides[1]; - ptr_lhs_select_idx_bytes += strides[2]; - ptr_rhs_values_bytes += strides[3]; - ptr_rhs_select_idx_bytes += strides[4]; - ptr_match_bytes += strides[5]; - } - }; - iter.for_each(loop, at::internal::GRAIN_SIZE); - }); + // COO indices are only 64-bit for now. + using index_t = int64_t; + auto loop = [&](char** data, const int64_t* strides, int64_t n) { + auto* ptr_res_values_bytes = data[0]; + const auto* ptr_lhs_values_bytes = data[1]; + const auto* ptr_lhs_select_idx_bytes = data[2]; + const auto* ptr_rhs_values_bytes = data[3]; + const auto* ptr_rhs_select_idx_bytes = data[4]; + const auto* ptr_intersection_counts_bytes = data[5]; + const auto* ptr_argsort = argsort.data_ptr(); + + for (int64_t i = 0; i < n; ++i) { + // Exctract data + auto* ptr_res_values = reinterpret_cast(ptr_res_values_bytes); + const auto* ptr_lhs_values = reinterpret_cast(ptr_lhs_values_bytes); + const auto lhs_nnz_idx = *reinterpret_cast(ptr_lhs_select_idx_bytes); + const auto* ptr_rhs_values = reinterpret_cast(ptr_rhs_values_bytes); + const auto rhs_nnz_idx = *reinterpret_cast(ptr_rhs_select_idx_bytes); + const auto count = *reinterpret_cast(ptr_intersection_counts_bytes); + + const auto* ptr_lhs_begin = ptr_lhs_values + lhs_nnz_idx * lhs_nnz_stride; + const auto* ptr_rhs_sorted_nnz_idx = ptr_argsort + rhs_nnz_idx; + + using accscalar_t = at::acc_type; + accscalar_t res_values = 0; + accscalar_t lhs_values = static_cast(*ptr_lhs_begin); + accscalar_t rhs_values; + index_t rhs_sorted_nnz_idx; + for (int64_t c = 0; c < count; ++c) { + rhs_sorted_nnz_idx = *ptr_rhs_sorted_nnz_idx++; + rhs_values = static_cast(*(ptr_rhs_values + rhs_sorted_nnz_idx * rhs_nnz_stride)); + res_values += binary_op_t::apply(lhs_values, rhs_values); + } + *ptr_res_values = static_cast(res_values); + + // Advance + ptr_res_values_bytes += strides[0]; + ptr_lhs_values_bytes += strides[1]; + ptr_lhs_select_idx_bytes += strides[2]; + ptr_rhs_values_bytes += strides[3]; + ptr_rhs_select_idx_bytes += strides[4]; + ptr_intersection_counts_bytes += strides[5]; + } + }; + iter.for_each(loop, at::internal::GRAIN_SIZE); }); return res_values; @@ -116,8 +124,8 @@ void sparse_mask_intersection_out_cpu_kernel( Tensor& result, const Tensor& x, const Tensor& y) { - using CPUValueLhsProjKernel = CPUValueSelectionIntersectionKernel; - _sparse_binary_op_intersection_kernel_out( + using CPUValueRhsProjKernel = CPUValueSelectionIntersectionKernel; + _sparse_binary_op_intersection_kernel_out( result, x, y, true ); } diff --git a/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h b/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h index 9b2ef61df5fe..3778ca7c3220 100644 --- a/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h +++ b/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h @@ -4,6 +4,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -190,7 +191,8 @@ template < class kernel_t, template class vec_kernel_t = EmptyVecKernel, - template class Vec = DummyVec> + template class Vec = DummyVec, + size_t static_shape_max_len = 0> void _validate_compressed_sparse_indices_kernel( const Tensor& cidx, const Tensor& idx, @@ -269,14 +271,10 @@ void _validate_compressed_sparse_indices_kernel( at::arange(batch_count, cidx.options()).view(batch_dims).unsqueeze_(-1); const auto idx_ndims = idx.dim(); - const auto cpu_options = idx.options().dtype(kLong).device(kCPU); - Tensor idx_sizes_and_strides_cpu = at::empty({2, idx_ndims}, cpu_options); - idx_sizes_and_strides_cpu.select(0, 0).copy_( - at::tensor(idx.sizes(), cpu_options)); - idx_sizes_and_strides_cpu.select(0, 1).copy_( - at::tensor(idx.strides(), cpu_options)); - const Tensor idx_sizes_and_strides = - idx_sizes_and_strides_cpu.to(idx.device()); + + const auto idx_geometry_holder = at::sparse::TensorGeometryHolder(idx); + const auto idx_sizes = std::get<0>(*idx_geometry_holder); + const auto idx_strides = std::get<1>(*idx_geometry_holder); auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) @@ -291,11 +289,8 @@ void _validate_compressed_sparse_indices_kernel( AT_DISPATCH_INDEX_TYPES( idx.scalar_type(), NAME, - [&iter, &idx, dim, nnz, idx_ndims, &idx_sizes_and_strides]() { + [&iter, &idx, dim, nnz, idx_ndims, &idx_sizes, &idx_strides]() { const auto* RESTRICT ptr_idx = idx.data_ptr(); - const int64_t* RESTRICT idx_sizes = - idx_sizes_and_strides.data_ptr(); - const int64_t* RESTRICT idx_strides = idx_sizes + idx_ndims; const auto zero = index_t{0}; KernelLauncher::launch( iter, @@ -348,18 +343,41 @@ void validate_compressed_sparse_indices_kernel( const int64_t cdim, const int64_t dim, const int64_t nnz) { + constexpr size_t idx_max_ndims = 8; // up to 7-dim batch. + const size_t idx_ndims = static_cast(idx.dim()); + if (is_crow) { - _validate_compressed_sparse_indices_kernel< - CDimName::CRow, - kernel_t, - vec_kernel_t, - Vec>(cidx, idx, cdim, dim, nnz); + if (idx_ndims <= idx_max_ndims) { + _validate_compressed_sparse_indices_kernel< + CDimName::CRow, + kernel_t, + vec_kernel_t, + Vec, + idx_max_ndims>(cidx, idx, cdim, dim, nnz); + } + else { + _validate_compressed_sparse_indices_kernel< + CDimName::CRow, + kernel_t, + vec_kernel_t, + Vec>(cidx, idx, cdim, dim, nnz); + } } else { - _validate_compressed_sparse_indices_kernel< - CDimName::CCol, - kernel_t, - vec_kernel_t, - Vec>(cidx, idx, cdim, dim, nnz); + if (idx_ndims <= idx_max_ndims) { + _validate_compressed_sparse_indices_kernel< + CDimName::CCol, + kernel_t, + vec_kernel_t, + Vec, + idx_max_ndims>(cidx, idx, cdim, dim, nnz); + } + else { + _validate_compressed_sparse_indices_kernel< + CDimName::CCol, + kernel_t, + vec_kernel_t, + Vec>(cidx, idx, cdim, dim, nnz); + } } } diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 984bcf058d10..63edba1632c3 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -204,12 +204,15 @@ class CI(NamedTuple): CI_SKIP[CI("aot_eager", training=False, dynamic=True)] = [ *CI_SKIP[CI("aot_eager", training=False)], # torchbench - "vision_maskrcnn", # 'literal' is an illegal expression for augmented assignment + "vision_maskrcnn", # sympy RecursionError ] CI_SKIP[CI("aot_eager", training=True, dynamic=True)] = [ *CI_SKIP[CI("aot_eager", training=True)], *CI_SKIP[CI("aot_eager", training=False, dynamic=True)], + # timm_models + "botnet26t_256", # sympy RecursionError + "eca_botnext26ts_256", # sympy RecursionError ] CI_SKIP[CI("inductor", training=False, dynamic=True)] = [ @@ -218,7 +221,11 @@ class CI(NamedTuple): # torchbench "functorch_dp_cifar10", # timeout "opacus_cifar10", # timeout + "fastNLP_Bert", # AssertionError: 1900: , 256: + "speech_transformer", # AssertionError: 2040: , 256: + "yolov3", # AssertionError: 2304: , 32: # timm_models + "convit_base", # TypeError: Cannot convert symbols to int "pnasnet5large", # ceiling is not defined "volo_d1_224", # ceiling is not defined ] @@ -228,7 +235,23 @@ class CI(NamedTuple): # *CI_SKIP[CI("aot_eager", training=True, dynamic=True)], *CI_SKIP[CI("inductor", training=False, dynamic=True)], *CI_SKIP[CI("inductor", training=True)], - # TODO: Fill this in + # torchbench + "drq", # 'NoneType' object has no attribute '_has_symbolic_sizes_strides' + "pytorch_unet", # TypeError: unhashable type: 'SymInt' + "soft_actor_critic", # 'NoneType' object has no attribute '_has_symbolic_sizes_strides' + # huggingface + "PegasusForCausalLM", # 'NoneType' object has no attribute '_has_symbolic_sizes_strides' + "PegasusForConditionalGeneration", # 'NoneType' object has no attribute '_has_symbolic_sizes_strides' + "T5ForConditionalGeneration", # 'NoneType' object has no attribute '_has_symbolic_sizes_strides' + "T5Small", # 'NoneType' object has no attribute '_has_symbolic_sizes_strides' + "XLNetLMHeadModel", # 'NoneType' object has no attribute '_has_symbolic_sizes_strides' + # timm_models + "eca_botnext26ts_256", # 'float' object has no attribute '_has_symbolic_sizes_strides' + "dla102", # Accuracy failed for key name base_layer.1.bias.grad + "mixnet_l", # 'float' object has no attribute '_has_symbolic_sizes_strides' + "tf_efficientnet_b0", # 'float' object has no attribute '_has_symbolic_sizes_strides' + "tf_mixnet_l", # 'float' object has no attribute '_has_symbolic_sizes_strides' + "visformer_small", # 'float' object has no attribute '_has_symbolic_sizes_strides' ] @@ -1350,8 +1373,8 @@ def warmup(fn, model, example_inputs, mode, niters=5): total = psutil.virtual_memory().total percentage = psutil.Process(os.getpid()).memory_percent() peak_mem = percentage * total / 10**9 - except Exception as e: - log.exception(f"Failed for {mode} {e}") + except Exception: + log.exception(f"Backend {mode} failed in warmup()") return sys.exit(-1) dynamo_stats = get_dynamo_stats() dynamo_stats.subtract(start_stats) @@ -1958,7 +1981,13 @@ def run(runner, args, original_dir=None): # TODO - Using train mode for timm_models. Move to train mode for HF and Torchbench as well. args.use_eval_mode = True inductor_config.fallback_random = True - torch.use_deterministic_algorithms(True) + if args.only is not None and args.only not in { + "pytorch_CycleGAN_and_pix2pix", + "pytorch_unet", + "Super_SloMo", + }: + # some of the models do not support use_deterministic_algorithms + torch.use_deterministic_algorithms(True) os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" torch.backends.cudnn.deterministic = True torch.backends.cudnn.allow_tf32 = False diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py index 8db152daadc3..76131b080ebe 100755 --- a/benchmarks/dynamo/runner.py +++ b/benchmarks/dynamo/runner.py @@ -216,6 +216,21 @@ def parse_args(): "--training", action="store_true", help="Only run training related tasks" ) + parser.add_argument( + "--base-sha", + help="commit id for the tested pytorch", + ) + parser.add_argument( + "--total-partitions", + type=int, + help="Total number of partitions, to be passed to the actual benchmark script", + ) + parser.add_argument( + "--partition-id", + type=int, + help="ID of partition, to be passed to the actual benchmark script", + ) + parser.add_argument( "--update-dashboard", action="store_true", @@ -244,7 +259,7 @@ def parse_args(): "--update-dashboard-test", action="store_true", default=False, - help="does all of --no-graphs, --no-update-lookup, and --no-gh-comment", + help="does all of --no-graphs, --no-update-archive, and --no-gh-comment", ) parser.add_argument( "--dashboard-image-uploader", @@ -385,6 +400,12 @@ def generate_commands(args, dtypes, suites, devices, compilers, output_dir): if args.threads is not None: cmd = f"{cmd} --threads {args.threads}" + + if args.total_partitions is not None: + cmd = f"{cmd} --total-partitions {args.total_partitions}" + + if args.partition_id is not None: + cmd = f"{cmd} --partition-id {args.partition_id}" lines.append(cmd) lines.append("") runfile.writelines([line + "\n" for line in lines]) @@ -403,12 +424,15 @@ def generate_dropdown_comment(title, body): def build_summary(args): - import git - out_io = io.StringIO() def print_commit_hash(path, name): - if exists(path): + if args.base_sha is not None: + if name == "pytorch": + out_io.write(f"{name} commit: {args.base_sha}\n") + elif exists(path): + import git + repo = git.Repo(path, search_parent_directories=True) sha = repo.head.object.hexsha date = repo.head.object.committed_datetime @@ -431,7 +455,6 @@ def env_var(name): out_io.write("\n") out_io.write("### Commit hashes ###\n") print_commit_hash("../pytorch", "pytorch") - print_commit_hash("../functorch", "functorch") print_commit_hash("../torchbenchmark", "torchbench") out_io.write("\n") diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py index 905ea324c255..799fe67e82d3 100755 --- a/benchmarks/dynamo/timm_models.py +++ b/benchmarks/dynamo/timm_models.py @@ -75,6 +75,10 @@ def pip_install(package): "levit_128", } +SKIP_TRAIN = { + # segfault: Internal Triton PTX codegen error + "eca_halonext26ts", +} MAX_BATCH_SIZE_FOR_ACCURACY_CHECK = { "cait_m36_384": 4, diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index 48a7da1d2d55..997027adceb1 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -84,6 +84,8 @@ def setup_torchbench_cwd(): # Unusual training setup "opacus_cifar10", "maml", + # segfault: Internal Triton PTX codegen error + "timm_efficientdet", } SKIP_TRAIN.update(DETECTRON2_MODELS) diff --git a/buckbuild.bzl b/buckbuild.bzl index dd12c242ecaa..0769ee527578 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -1463,6 +1463,7 @@ def define_buck_targets( "torch/csrc/jit/mobile/train/random.cpp", "torch/csrc/jit/mobile/train/sequential.cpp", ":gen_aten_libtorch[autograd/generated/Functions.cpp]", + "torch/csrc/quantized/quantized_backward.cpp", ], compiler_flags = get_pt_compiler_flags(), exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"], diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 854e365e9e0b..883db4ccb1d0 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -84,6 +84,29 @@ if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND NOT INTERN_BUILD_MOBILE) enable_ubsan() endif() +if(USE_ASAN OR USE_TSAN) + find_package(Sanitizer REQUIRED) + if(USE_ASAN) + if(TARGET Sanitizer::address) + list(APPEND Caffe2_PUBLIC_DEPENDENCY_LIBS Sanitizer::address) + else() + message(WARNING "Not ASAN found. Suppress this warning with -DUSE_ASAN=OFF.") + caffe2_update_option(USE_ASAN OFF) + endif() + if(TARGET Sanitizer::undefined) + list(APPEND Caffe2_PUBLIC_DEPENDENCY_LIBS Sanitizer::undefined) + endif() + endif() + if(USE_TSAN) + if(TARGET Sanitizer::thread) + list(APPEND Caffe2_PUBLIC_DEPENDENCY_LIBS Sanitizer::thread) + else() + message(WARNING "Not TSAN found. Suppress this warning with -DUSE_TSAN=OFF.") + caffe2_update_option(USE_TSAN OFF) + endif() + endif() +endif() + # ---[ Threads find_package(Threads REQUIRED) if(TARGET Threads::Threads) diff --git a/cmake/MiscCheck.cmake b/cmake/MiscCheck.cmake index 0f0fd3ff5bc7..d5b5cd3ddbce 100644 --- a/cmake/MiscCheck.cmake +++ b/cmake/MiscCheck.cmake @@ -10,37 +10,6 @@ include(CheckCXXSourceCompiles) include(CheckCXXCompilerFlag) include(CMakePushCheckState) -if(NOT INTERN_BUILD_MOBILE) - # ---[ Check that our programs run. This is different from the native CMake - # compiler check, which just tests if the program compiles and links. This is - # important because with ASAN you might need to help the compiled library find - # some dynamic libraries. - cmake_push_check_state(RESET) - if(CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64)$") - list(APPEND CMAKE_REQUIRED_FLAGS "-arch ${CMAKE_HOST_SYSTEM_PROCESSOR}") - endif() - if(CMAKE_CROSSCOMPILING) - CHECK_C_SOURCE_COMPILES(" - int main() { return 0; } - " COMPILER_WORKS) - else() - CHECK_C_SOURCE_RUNS(" - int main() { return 0; } - " COMPILER_WORKS) - endif() - if(NOT COMPILER_WORKS) - # Force cmake to retest next time around - unset(COMPILER_WORKS CACHE) - message(FATAL_ERROR - "Could not run a simple program built with your compiler. " - "If you are trying to use -fsanitize=address, make sure " - "libasan is properly installed on your system (you can confirm " - "if the problem is this by attempting to build and run a " - "small program.)") - endif() - cmake_pop_check_state() -endif() - set(CAFFE2_USE_EXCEPTION_PTR 1) # ---[ Check if we want to turn off deprecated warning due to glog. @@ -150,29 +119,6 @@ if(IOS AND (${IOS_ARCH} MATCHES "armv7*")) add_definitions("-Wno-deprecated-declarations") endif() -# ---[ If we use asan, turn on the flags. -# TODO: This only works with new style gcc and clang (not the old -faddress-sanitizer). -# Change if necessary on old platforms. -if(USE_ASAN) - set(CAFFE2_ASAN_COMPILER_FLAGS "-fsanitize=address -fPIE") - set(CAFFE2_ASAN_LINKER_FLAGS "-pie") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${CAFFE2_ASAN_COMPILER_FLAGS}") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CAFFE2_ASAN_COMPILER_FLAGS}") - set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${CAFFE2_ASAN_LINKER_FLAGS}") - set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} ${CAFFE2_ASAN_LINKER_FLAGS}") - set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${CAFFE2_ASAN_LINKER_FLAGS}") -endif() - -if(USE_TSAN) - set(CAFFE2_TSAN_COMPILER_FLAGS "-fsanitize=thread -fPIE") - set(CAFFE2_TSAN_LINKER_FLAGS "-pie") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${CAFFE2_TSAN_COMPILER_FLAGS}") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CAFFE2_TSAN_COMPILER_FLAGS}") - set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${CAFFE2_TSAN_LINKER_FLAGS}") - set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} ${CAFFE2_TSAN_LINKER_FLAGS}") - set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${CAFFE2_TSAN_LINKER_FLAGS}") -endif() - # ---[ Create CAFFE2_BUILD_SHARED_LIBS for macros.h.in usage. set(CAFFE2_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) diff --git a/cmake/Modules/FindSanitizer.cmake b/cmake/Modules/FindSanitizer.cmake new file mode 100644 index 000000000000..6858010d26b7 --- /dev/null +++ b/cmake/Modules/FindSanitizer.cmake @@ -0,0 +1,116 @@ +# Find sanitizers +# +# This module sets the following targets: +# Sanitizer::address +# Sanitizer::thread +# Sanitizer::undefined +# Sanitizer::leak +# Sanitizer::memory +include_guard(GLOBAL) + +option(UBSAN_FLAGS "additional UBSAN flags" OFF) + +get_property(languages GLOBAL PROPERTY ENABLED_LANGUAGES) + +set(_source_code + [==[ + #include + int main() { + printf("hello world!"); + return 0; + } + ]==]) + +include(CMakePushCheckState) +cmake_push_check_state(RESET) +foreach(sanitizer_name IN ITEMS address thread undefined leak memory) + if(TARGET Sanitizer::${sanitizer_name}) + continue() + endif() + + set(CMAKE_REQUIRED_FLAGS + "-fsanitize=${sanitizer_name};-fno-omit-frame-pointer") + if(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC" OR CMAKE_C_COMPILER_ID STREQUAL + "MSVC") + if(sanitizer_name STREQUAL "address") + set(CMAKE_REQUIRED_FLAGS "/fsanitize=${sanitizer_name}") + else() + continue() + endif() + endif() + if(sanitizer_name STREQUAL "address") + if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_C_COMPILER_ID STREQUAL + "Clang") + list(APPEND CMAKE_REQUIRED_FLAGS "-shared-libasan") + endif() + endif() + if(sanitizer_name STREQUAL "undefined" AND UBSAN_FLAGS) + list(APPEND CMAKE_REQUIRED_FLAGS "${UBSAN_FLAGS}") + endif() + if(sanitizer_name STREQUAL "memory") + list(APPEND CMAKE_REQUIRED_FLAGS "-fsanitize-memory-track-origins=2") + endif() + + set(CMAKE_REQUIRED_QUIET ON) + set(_run_res 0) + include(CheckSourceRuns) + foreach(lang IN LISTS languages) + if(lang STREQUAL CXX OR lang STREQUAL C) + check_source_runs(${lang} "${_source_code}" + __${lang}_${sanitizer_name}_res) + if(__${lang}_${sanitizer_name}_res) + set(_run_res 1) + endif() + endif() + endforeach() + if(_run_res) + add_library(Sanitizer::${sanitizer_name} INTERFACE IMPORTED GLOBAL) + target_compile_options( + Sanitizer::${sanitizer_name} + INTERFACE + $<$,$>:${CMAKE_REQUIRED_FLAGS}> + $<$,$>:${CMAKE_REQUIRED_FLAGS}> + ) + if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "MSVC" AND NOT CMAKE_C_COMPILER_ID + STREQUAL "MSVC") + target_link_options( + Sanitizer::${sanitizer_name} + INTERFACE + $<$,$>:${CMAKE_REQUIRED_FLAGS}> + $<$,$>:${CMAKE_REQUIRED_FLAGS}> + ) + else() + target_link_options( + Sanitizer::${sanitizer_name} + INTERFACE + $<$,$>:/INCREMENTAL:NO> + $<$,$>:/INCREMENTAL:NO> + ) + endif() + + if(sanitizer_name STREQUAL "address") + target_compile_definitions( + Sanitizer::${sanitizer_name} + INTERFACE + $<$,$>:_GLIBCXX_SANITIZE_VECTOR> + $<$,$>:_GLIBCXX_SANITIZE_STD_ALLOCATOR> + ) + target_link_options( + Sanitizer::${sanitizer_name} + INTERFACE + $<$,$,$>:-lasan> + $<$,$,$>:-lasan> + ) + endif() + if(sanitizer_name STREQUAL "undefined") + target_link_options( + Sanitizer::${sanitizer_name} + INTERFACE + $<$,$,$>:-lubsan> + $<$,$,$>:-lubsan> + ) + endif() + endif() +endforeach() + +cmake_pop_check_state() diff --git a/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake b/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake index 7f45cd098447..839c43ea0482 100644 --- a/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake +++ b/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake @@ -768,13 +768,17 @@ endif() # FAST_NVCC if(USE_FAST_NVCC AND CUDA_NVCC_EXECUTABLE AND NOT CUDA_NVCC_EXECUTABLE_ORIGIN) set(CUDA_NVCC_EXECUTABLE_ORIGIN "${CUDA_NVCC_EXECUTABLE}") + set(EXTENSION "sh") + if (MSVC) + set(EXTENSION "bat") + endif() set(FAST_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/fast_nvcc.py") - configure_file(${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.sh.in "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.sh") - file(COPY "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.sh" + configure_file(${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.${EXTENSION}.in "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.${EXTENSION}") + file(COPY "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.${EXTENSION}" DESTINATION "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/" FILE_PERMISSIONS OWNER_READ OWNER_WRITE OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE ) - set(CUDA_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.sh") + set(CUDA_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.${EXTENSION}") endif() mark_as_advanced(CUDA_NVCC_EXECUTABLE) diff --git a/docs/source/nested.rst b/docs/source/nested.rst index 69cf693ac6a1..2b4e3fad981b 100644 --- a/docs/source/nested.rst +++ b/docs/source/nested.rst @@ -212,3 +212,4 @@ NestedTensor and any constraints they have. :func:`torch.Tensor.reshape_as`; "Similar constraint as for ``reshape``." :func:`torch.transpose`; "Supports transposing of all dims except ``dim=0``." :func:`torch.Tensor.view`; "Rules for the new shape are similar to that of ``reshape``." + :func:`torch.empty_like`; "Behavior is analogous to that of regular tensors; returns a new empty nested tensor (i.e. with uninitialized values) matching the nested structure of the input." diff --git a/test/backends/xeon/test_launch.py b/test/backends/xeon/test_launch.py index c3585ba7429d..9e5f4def951a 100644 --- a/test/backends/xeon/test_launch.py +++ b/test/backends/xeon/test_launch.py @@ -53,7 +53,7 @@ def test_cpu_info(self): def test_multi_threads(self): num = 0 with subprocess.Popen(f"python -m torch.backends.xeon.run_cpu --ninstances 4 --use-default-allocator \ - --disable-iomp --disable-numactl --log-path {self._test_dir} --no-python pwd", + --disable-iomp --disable-numactl --disable-taskset --log-path {self._test_dir} --no-python pwd", shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as p: for line in p.stdout.readlines(): segs = str(line, "utf-8").strip().split("-") diff --git a/test/cpp_extensions/open_registration_extension.cpp b/test/cpp_extensions/open_registration_extension.cpp index ad036109903d..7a139a9bd871 100644 --- a/test/cpp_extensions/open_registration_extension.cpp +++ b/test/cpp_extensions/open_registration_extension.cpp @@ -7,7 +7,7 @@ #include #include #include - +#include static uint64_t add_counter = 0; static uint64_t last_saved_value = 0; @@ -108,6 +108,25 @@ bool custom_add_called() { return called; } +class PrivateGeneratorImpl : public at::CPUGeneratorImpl { +public: + // Constructors + PrivateGeneratorImpl(c10::DeviceIndex device_index) { + device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index); + key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1); + } + ~PrivateGeneratorImpl() override = default; +}; + +// this is used to register generator +at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) { + return at::make_generator(device_index); +} + +void register_genertor() { + REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1) +} + // Here, we're exposing a custom device object that corresponds to our custom backend. // We do this using pybind: exposing an "extension_name.custom_device()" function in python, // that's implemented in C++. @@ -115,4 +134,5 @@ bool custom_add_called() { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("custom_device", &get_custom_device, "get custom device object"); m.def("custom_add_called", &custom_add_called, "check if our custom add function was called"); + m.def("register_genertor", ®ister_genertor, "register generator for custom device"); } diff --git a/test/distributed/_shard/sharded_tensor/ops/test_linear.py b/test/distributed/_shard/sharded_tensor/ops/test_linear.py deleted file mode 100644 index 77d3b1035b47..000000000000 --- a/test/distributed/_shard/sharded_tensor/ops/test_linear.py +++ /dev/null @@ -1,274 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -import copy -import sys - -import torch -import torch.distributed as dist -from torch.distributed._shard.api import ( - shard_parameter, - _collect_local_shard, - _reshard_output, -) -from torch.distributed._shard.sharded_optim import ( - ShardedOptimizer, -) -from torch.distributed._shard.sharded_tensor import ( - empty, -) -from torch.distributed._shard.sharding_spec import ( - ChunkShardingSpec, - EnumerableShardingSpec, - ShardMetadata, -) -from torch.testing._internal.common_distributed import ( - requires_nccl, - skip_if_lt_x_gpu, -) -from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, - run_tests, -) -from torch.testing._internal.distributed._shard.sharded_tensor import ( - TEST_GPU_NUM, - ShardedTensorTestBase, - with_comms, -) -from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import ( - clone_module_parameter, - generate_chunk_sharding_specs_for_test, - generate_local_weight_sharding_params_for_test, -) - -if TEST_WITH_DEV_DBG_ASAN: - print( - "Skip dev-asan as torch + multiprocessing spawn have known issues", - file=sys.stderr, - ) - sys.exit(0) - - -class TestShardedTensorOpsLinear(ShardedTensorTestBase): - def _run_sharded_linear( - self, spec, input_size, linear_size, sharded_dim, dtype - ): - # Use same seed. - torch.manual_seed(0) - local_linear = torch.nn.Linear(*linear_size, dtype=dtype).cuda(self.rank) - sharded_linear = torch.nn.Linear(*linear_size, dtype=dtype) - - # Copy the weights and bias from local linear - sharded_linear.weight = clone_module_parameter(local_linear, "weight") - sharded_linear.bias = clone_module_parameter(local_linear, "bias") - - # Shard the parameter. - shard_parameter(sharded_linear, "weight", spec) - - # Run sharded computation - torch.manual_seed(self.rank) # inputs different on each rank - inp = torch.rand(*input_size, dtype=dtype).cuda(self.rank) - reshard_spec = copy.deepcopy(spec) - reshard_spec.dim = 0 - reshard_spec.placements.sort(key=lambda placement: placement.rank()) - sharded_linear = _collect_local_shard( - _reshard_output(sharded_linear, reshard_spec) - ) - sharded_output = sharded_linear(inp) - - # Run local computation - local_output = local_linear(inp) - - # Verify - self.assertEqual(local_output, sharded_output, atol=1e-3, rtol=1e-3) - - # Validate for torch.nn.functional.linear version. - local_output = torch.nn.functional.linear( - inp, local_linear.weight, local_linear.bias - ) - sharded_output = torch.nn.functional.linear( - inp, sharded_linear.weight, sharded_linear.bias - ) - sharded_output = sharded_output.reshard(reshard_spec).local_tensor() - # When local tensor only has one dimension, we increase one more dimension - # for reshard. We need to squeeze the # of dimensions manually. - if inp.dim() == 1: - sharded_output = sharded_output.squeeze(reshard_spec.dim) - self.assertEqual(local_output, sharded_output, atol=1e-3, rtol=1e-3) - - # Compute loss and run backward pass. - local_output.sum().backward() - sharded_output.sum().backward() - local_grad = local_linear.weight.grad - - # Verify that both weight and bias in the sharded linear has non-None grad. - sharded_weight = sharded_linear.weight.local_tensor() - self.assertNotEqual(sharded_linear.bias.grad, None) - self.assertNotEqual(sharded_weight.grad, None) - - # Shard the local linear's weight grad so that we can compare. - dist.all_reduce(local_grad) - (start_pos, chunk_size) = generate_local_weight_sharding_params_for_test( - local_linear.weight, sharded_dim, TEST_GPU_NUM, spec, self.rank - ) - local_grad_narrowed = local_grad.narrow(sharded_dim, start_pos, chunk_size) - local_bias_grad = local_linear.bias.grad - dist.all_reduce(local_bias_grad) - - # Test backward gradient calculation. - self.assertEqual(sharded_linear.bias.grad, local_bias_grad, atol=1e-3, rtol=1e-3) - self.assertEqual(sharded_weight.grad, local_grad_narrowed, atol=1e-3, rtol=1e-3) - - # Test optimizer. - previous = local_linear.weight.clone().detach() - optim = torch.optim.SGD(local_linear.parameters(), lr=0.1) - optim.step() - self.assertNotEqual(previous, local_linear.weight) - previous_sharded_weight = sharded_weight.clone() - previous_sharded_bias = sharded_linear.bias.clone() - sharded_optim = ShardedOptimizer( - dict(sharded_linear.named_parameters()), - torch.optim.SGD, - lr=0.1, - ) - sharded_optim.step() - sharded_weight = sharded_linear.weight.local_tensor() - local_weight_narrowed = local_linear.weight.narrow( - sharded_dim, start_pos, chunk_size - ) - self.assertEqual(sharded_weight.size(), local_weight_narrowed.size()) - self.assertNotEqual(previous_sharded_weight, sharded_weight) - self.assertEqual(sharded_weight, local_weight_narrowed, atol=1e-3, rtol=1e-3) - self.assertNotEqual(previous_sharded_bias, sharded_linear.bias) - self.assertEqual(sharded_linear.bias, local_linear.bias, atol=1e-3, rtol=1e-3) - - @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() - def test_sharded_linear_colwise(self): - for spec in generate_chunk_sharding_specs_for_test(0): - self._run_sharded_linear(spec, [2, 17], [17, 12], 0, torch.float16) - self._run_sharded_linear(spec, [8, 21], [21, 11], 0, torch.float32) - self._run_sharded_linear(spec, [7, 23], [23, 13], 0, torch.float64) - self._run_sharded_linear(spec, [4, 15], [15, 14], 0, torch.float16) - - # Test multiple input dims - self._run_sharded_linear(spec, [10, 2, 17], [17, 12], 0, torch.float32) - self._run_sharded_linear(spec, [13, 8, 21], [21, 11], 0, torch.float64) - self._run_sharded_linear(spec, [27, 7, 23], [23, 13], 0, torch.float16) - self._run_sharded_linear(spec, [100, 12, 4, 15], [15, 14], 0, torch.float32) - - # Test single input dim - self._run_sharded_linear(spec, [17], [17, 12], 0, torch.float64) - self._run_sharded_linear(spec, [21], [21, 11], 0, torch.float16) - self._run_sharded_linear(spec, [23], [23, 13], 0, torch.float32) - self._run_sharded_linear(spec, [15], [15, 14], 0, torch.float64) - - @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() - def test_sharded_linear_rowwise(self): - for spec in generate_chunk_sharding_specs_for_test(1): - # Test even split. - self._run_sharded_linear(spec, [8, 16], [16, 11], 1, torch.float16) - - # Test uneven split. - self._run_sharded_linear(spec, [5, 19], [19, 11], 1, torch.float32) - self._run_sharded_linear(spec, [10, 21], [21, 11], 1, torch.float64) - - # Test multiple input dims - self._run_sharded_linear(spec, [13, 8, 16], [16, 11], 1, torch.float16) - self._run_sharded_linear(spec, [10, 5, 19], [19, 11], 1, torch.float32) - self._run_sharded_linear(spec, [12, 15, 10, 21], [21, 11], 1, torch.float64) - - # Test single input dim - self._run_sharded_linear(spec, [16], [16, 11], 1, torch.float16) - self._run_sharded_linear(spec, [19], [19, 11], 1, torch.float32) - self._run_sharded_linear(spec, [21], [21, 11], 1, torch.float64) - - @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() - def test_sharded_linear_errors(self): - for spec in generate_chunk_sharding_specs_for_test(0): - fc1 = torch.nn.Linear(10, 10).cuda(self.rank) - shard_parameter(fc1, "weight", spec) - shard_parameter(fc1, "bias", spec) - with self.assertRaisesRegex(TypeError, 'bias needs to be torch.Tensor'): - fc1(torch.rand(10, 10).cuda(self.rank)) - - fc2 = torch.nn.Linear(10, 10).cuda(self.rank) - shard_parameter(fc2, "weight", spec) - with self.assertRaisesRegex(ValueError, 'Input needs to have at least 1 dim'): - fc2(torch.tensor(1).cuda(self.rank)) - - fc3 = torch.nn.Linear(10, 10).cuda(self.rank) - fc3.weight = torch.nn.Parameter(torch.rand(10, 10, 10).cuda(self.rank)) - shard_parameter(fc3, "weight", spec) - with self.assertRaisesRegex(ValueError, 'Weight needs to have exactly 2 dims'): - fc3(torch.rand(10, 10).cuda(self.rank)) - - fc4 = torch.nn.Linear(10, 10).cuda(self.rank) - fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank)) - shard_parameter(fc4, "weight", spec) - with self.assertRaisesRegex(ValueError, 'Bias needs to have exactly 1 dim'): - fc4(torch.rand(10, 10).cuda(self.rank)) - - fc5 = torch.nn.Linear(7, 10).cuda(self.rank) - shard_parameter(fc5, "weight", spec) - with self.assertRaisesRegex(ValueError, 'Input dim: 13 does not match appropriate weight dim: 7'): - fc5(torch.rand(20, 10, 13).cuda(self.rank)) - - fc6 = torch.nn.Linear(10, 10).cuda(self.rank) - del fc6.weight - enumerable_spec = EnumerableShardingSpec([ - ShardMetadata( - shard_offsets=[0, 0], - shard_sizes=[5, 5], - placement="rank:0/cuda:0", - ), - ShardMetadata( - shard_offsets=[0, 5], - shard_sizes=[5, 5], - placement="rank:1/cuda:1", - ), - ShardMetadata( - shard_offsets=[5, 0], - shard_sizes=[5, 5], - placement="rank:2/cuda:2", - ), - ShardMetadata( - shard_offsets=[5, 5], - shard_sizes=[5, 5], - placement="rank:3/cuda:3", - ) - ]) - - fc6.weight = empty(enumerable_spec, 10, 10) - # Sharded Tensor metadata has parenthesis imbalance issue when using re.compile - error_msg = r"torch function 'linear', with args: (?s).* " - r"and kwargs: None not supported for ShardedTensor!" - with self.assertRaisesRegex(RuntimeError, error_msg): - fc6(torch.rand(10, 10).cuda(self.rank)) - - fc7 = torch.nn.Linear(10, 80).cuda(self.rank) - multiple_local_shard_spec = ChunkShardingSpec( - dim=0, - placements=[ - "rank:0/cuda:0", - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:2/cuda:2", - "rank:3/cuda:3", - "rank:3/cuda:3", - ], - ) - del fc7.weight - fc7.weight = empty(multiple_local_shard_spec, 80, 10) - with self.assertRaisesRegex(ValueError, 'Only one local shard supported!'): - fc7(torch.rand(10, 10).cuda(self.rank)) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/_shard/test_partial_tensor.py b/test/distributed/_shard/test_partial_tensor.py deleted file mode 100644 index 24ea79651367..000000000000 --- a/test/distributed/_shard/test_partial_tensor.py +++ /dev/null @@ -1,198 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -import sys - -import torch -import torch.distributed as dist -from torch.distributed._shard.partial_tensor import ( - _PartialTensor, -) -from torch.distributed._shard.sharding_spec import ( - ChunkShardingSpec, - EnumerableShardingSpec, - ShardMetadata, -) -from torch.testing._internal.common_distributed import ( - requires_nccl, - skip_if_lt_x_gpu, -) -from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, - run_tests, -) -from torch.testing._internal.distributed._shard.sharded_tensor import ( - ShardedTensorTestBase, - with_comms, - TEST_GPU_NUM -) -from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import ( - _chunk_sharding_specs_list_for_test, -) - -if TEST_WITH_DEV_DBG_ASAN: - print( - "Skip dev-asan as torch + multiprocessing spawn have known issues", - file=sys.stderr, - ) - sys.exit(0) - - -class TestPartialTensorReshard(ShardedTensorTestBase): - def _run_partial_tensor_n_reshard( - self, reshard_spec, input_size, world_size, reduce_op, dtype=torch.float, pg=None - ): - results_compare = [] - local_result = [] - pg = pg if pg is not None else dist.distributed_c10d._get_default_group() - for rank in range(pg.size()): - torch.manual_seed(rank) - results = [] - for _ in range(world_size): - tensor = torch.rand(*input_size, dtype=dtype).cuda(self.rank) - results.append(tensor) - if self.rank % pg.size() == rank: - local_result.append(tensor.clone().detach()) - results_compare.append(torch.cat(results)) - parital_tensor = _PartialTensor( - torch.cat(local_result), pg, reduce_op=reduce_op - ) - local_sharded_result = parital_tensor.reshard(reshard_spec) - local_shards = local_sharded_result.local_shards() - results_compare = torch.stack(results_compare) - if reduce_op == dist.ReduceOp.SUM: - results_compare = torch.sum(results_compare, dim=0) - else: - results_compare = torch.max(results_compare, dim=0).values - rank_idx = None - for idx, placement in enumerate(reshard_spec.placements): - if placement.rank() == self.rank % pg.size(): - rank_idx = idx - local_result_compare = results_compare.chunk(pg.size())[rank_idx] - self.assertEqual(1, len(local_shards)) - self.assertEqual(local_shards[0].tensor, local_result_compare) - - def _reshard_spec_for_subgroup(self, rank): - if rank in [0, 1]: - return ChunkShardingSpec( - dim=0, - placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - ], - ) - else: - return ChunkShardingSpec( - dim=0, - placements=[ - "rank:0/cuda:2", - "rank:1/cuda:3", - ], - ) - - @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() - def test_partial_tensor_reshard(self): - specs = _chunk_sharding_specs_list_for_test([0], seed=7) - spec = specs[0] - self._run_partial_tensor_n_reshard(spec, [13, 21], 4, dist.ReduceOp.SUM) - self._run_partial_tensor_n_reshard(spec, [12, 22], 4, dist.ReduceOp.MAX) - self._run_partial_tensor_n_reshard(spec, [13, 21], 3, dist.ReduceOp.SUM) - self._run_partial_tensor_n_reshard(spec, [17, 21], 2, dist.ReduceOp.MAX) - sub_pgs = [dist.new_group([0, 1]), dist.new_group([2, 3])] - pg = sub_pgs[self.rank // 2] - spec = self._reshard_spec_for_subgroup(self.rank) - self._run_partial_tensor_n_reshard(spec, [12, 22], 4, dist.ReduceOp.MAX, pg=pg) - self._run_partial_tensor_n_reshard(spec, [13, 22], 3, dist.ReduceOp.SUM, pg=pg) - - @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() - def test_partial_tensor_reshard_errors(self): - enumerable_sharding_spec = EnumerableShardingSpec( - [ - ShardMetadata( - shard_offsets=[0, 0], - shard_sizes=[5, 5], - placement="rank:0/cuda:0", - ), - ShardMetadata( - shard_offsets=[5, 0], - shard_sizes=[5, 5], - placement="rank:1/cuda:1", - ), - ] - ) - with self.assertRaisesRegex( - NotImplementedError, "Only ChunkShardingSpec supported for reshard." - ): - self._run_partial_tensor_n_reshard( - enumerable_sharding_spec, [13, 21], 4, dist.ReduceOp.SUM - ) - self._run_partial_tensor_n_reshard( - enumerable_sharding_spec, [12, 22], 4, dist.ReduceOp.MAX - ) - specs = _chunk_sharding_specs_list_for_test([0], seed=7) - spec = specs[0] - with self.assertRaisesRegex( - NotImplementedError, "Only real partial tensor supported for reshard." - ): - self._run_partial_tensor_n_reshard( - spec, [13, 21], 4, dist.ReduceOp.SUM, dtype=torch.cfloat - ) - self._run_partial_tensor_n_reshard( - spec, [12, 22], 4, dist.ReduceOp.MAX, dtype=torch.cfloat - ) - -class TestPartialTensorOps(ShardedTensorTestBase): - @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() - def test_transpose(self): - partial_tensor = _PartialTensor(torch.rand(5, 10)) - partial_tensor = partial_tensor.transpose(0, 1) - self.assertEqual(partial_tensor.size(), torch.Size((10, 5))) - - @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() - def test_cat(self): - t1 = torch.rand(5, 10) - t2 = torch.rand(3, 10) - t3 = torch.rand(4, 10) - partial_tensors = [_PartialTensor(t1), _PartialTensor(t2), _PartialTensor(t3)] - partial_concat = torch.cat(partial_tensors) - local_concat = torch.cat([t1, t2, t3]) - self.assertEqual(local_concat.size(), partial_concat.size()) - - # Test dim kwarg - t1 = torch.rand(5, 10) - t2 = torch.rand(5, 12) - t3 = torch.rand(5, 11) - partial_tensors = [_PartialTensor(t1), _PartialTensor(t2), _PartialTensor(t3)] - partial_concat = torch.cat(partial_tensors, dim=1) - local_concat = torch.cat([t1, t2, t3], dim=1) - self.assertEqual(local_concat.size(), partial_concat.size()) - - @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() - def test_cat_errors(self): - with self.assertRaisesRegex( - RuntimeError, 'All inputs need to be an instance of _PartialTensor' - ): - torch.cat([_PartialTensor(torch.rand(10)), torch.rand(10)]) - - with self.assertRaisesRegex( - RuntimeError, 'reduce_ops need to be the same' - ): - torch.cat([_PartialTensor(torch.rand(10)), _PartialTensor(torch.rand(10), reduce_op=dist.ReduceOp.MAX)]) - - with self.assertRaisesRegex( - RuntimeError, '"out" kwarg is not supported' - ): - torch.cat([_PartialTensor(torch.rand(10)), _PartialTensor(torch.rand(10))], out=torch.rand(10)) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py index a694db767ec2..afbd15b265ea 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -409,11 +409,14 @@ def test_dtensor_spec_local_shard_offset(self): ), ] + from torch.distributed._tensor._utils import compute_local_offset + # loop through all sharding specs and check local shard offsets logical_tensor = torch.randn(tensor_shape) for shard_spec, expected_shard_offsets in shard_spec_and_offsets: dtensor = distribute_tensor(logical_tensor, device_mesh, shard_spec) - self.assertEqual(expected_shard_offsets, dtensor._spec.local_offsets) + offset = compute_local_offset(dtensor.shape, device_mesh, dtensor.placements) + self.assertEqual(expected_shard_offsets, offset) @with_comms def test_from_local_sub_mesh(self): diff --git a/test/distributed/_tensor/test_utils.py b/test/distributed/_tensor/test_utils.py index 50e01a5ad538..cc4cbffa507e 100644 --- a/test/distributed/_tensor/test_utils.py +++ b/test/distributed/_tensor/test_utils.py @@ -3,7 +3,7 @@ import torch from torch.distributed._tensor.device_mesh import DeviceMesh from torch.distributed._tensor.placement_types import Replicate, Shard -from torch.distributed._tensor.utils import compute_local_tensor_size +from torch.distributed._tensor._utils import compute_local_shape from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -18,7 +18,7 @@ def world_size(self): return 8 @with_comms - def test_compute_local_tensor_size_2d(self): + def test_compute_local_shape_2d(self): # mesh: 4 * 2 mesh_tensor = torch.arange(self.world_size).reshape(4, 2) mesh = DeviceMesh(self.device_type, mesh_tensor) @@ -26,21 +26,21 @@ def test_compute_local_tensor_size_2d(self): # replicate, replicate placements1 = [Replicate(), Replicate()] - local_size1 = compute_local_tensor_size(size, mesh, placements1) + local_size1 = compute_local_shape(size, mesh, placements1) self.assertEqual(local_size1, torch.Size([8, 6])) # replicate, shard placements2 = [Replicate(), Shard(0)] - local_size2 = compute_local_tensor_size(size, mesh, placements2) + local_size2 = compute_local_shape(size, mesh, placements2) self.assertEqual(local_size2, torch.Size([4, 6])) # shard, shard placements3 = [Shard(0), Shard(1)] - local_size3 = compute_local_tensor_size(size, mesh, placements3) + local_size3 = compute_local_shape(size, mesh, placements3) self.assertEqual(local_size3, torch.Size([2, 3])) @with_comms - def test_compute_local_tensor_size_2d_not_evenly(self): + def test_compute_local_shape_2d_uneven(self): # mesh: 4 * 2 mesh_tensor = torch.arange(self.world_size).reshape(4, 2) mesh = DeviceMesh(self.device_type, mesh_tensor) @@ -49,7 +49,7 @@ def test_compute_local_tensor_size_2d_not_evenly(self): # replicate, shard placements2 = [Replicate(), Shard(0)] - local_size2 = compute_local_tensor_size(size, mesh, placements2) + local_size2 = compute_local_shape(size, mesh, placements2) if rank_coordinates[1] < 1: self.assertEqual(local_size2, torch.Size([4, 7])) else: @@ -57,7 +57,7 @@ def test_compute_local_tensor_size_2d_not_evenly(self): # shard, shard placements3 = [Shard(0), Shard(1)] - local_size3 = compute_local_tensor_size(size, mesh, placements3) + local_size3 = compute_local_shape(size, mesh, placements3) # first dim if rank_coordinates[0] < 3: self.assertEqual(local_size3[0], 2) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 2d12f19a58bf..b91e86d21050 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -1029,6 +1029,45 @@ def test_nccl_dist_backend_error(self): self.assertIsInstance(cm.exception, RuntimeError) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs") + def test_abort_pg(self): + # Disable ASYNC_ERROR_HANDLING for this test to ensure we can programmatically + # abort the process group. + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + + store = c10d.FileStore(self.file_name, self.world_size) + self._create_process_group_nccl(store, self.opts()) + device = self.rank_to_GPU[self.rank][0] + + t = torch.rand(10, 10, device=device) + # First allreduce to initialize state. + dist.all_reduce(t) + + def abortpg(): + c10d.distributed_c10d._get_default_group()._get_backend(torch.device(device))._abort() + + # Initialize DDP to ensure "destroy_process_group" will not call + # ProcessGroupNCCL destructor since DDP holds a reference to process group. + # Run a single iteration of DDP to initialize state. + model = DistributedDataParallel( + torch.nn.Linear(10, 10).to(device), device_ids=[device] + ) + model(t).sum().backward() + + # Now simulate collective getting stuck and abort gets us unstuck + if self.rank == 0: + dist.all_reduce(t) + + # Schedule thread before we get stuck to abort pg. + thread = threading.Thread(target=abortpg) + thread.start() + + # We would get stuck here due to d2h if we didn't abort. + t_cpu = t.cpu() + + thread.join() + class DistributedDataParallelTest( test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase ): diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index b5db3ba7eaa8..572024ec7e02 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -283,6 +283,7 @@ def func(inp, *, tag, ranks, group_size): assert same(out, correct) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @patch.object(torch._inductor.config.triton, "descriptive_names", False) def test_inductor_doesnt_mutate_shared(self): """ make sure that an intermediate that's going to be reuse isn't mutated unless copied @@ -344,7 +345,7 @@ def func(inp, *, tag, ranks, group_size): input = torch.ones(4, 4, device="cuda", requires_grad=True) # TODO implement backwards - with self.assertRaisesRegex(RuntimeError, "derivative for aten::all_reduce is not implemented"): + with self.assertRaisesRegex(RuntimeError, "element 0 of tensors does not require grad and does not have a grad_fn"): compiled = torch.compile(func, backend="aot_eager") # inductor bug with single-op allreduce graph out = compiled(input, **self.get_world_trs()) out.sum().backward() diff --git a/test/distributed/test_nccl.py b/test/distributed/test_nccl.py index a90288bd05fe..6991b77a4ee6 100644 --- a/test/distributed/test_nccl.py +++ b/test/distributed/test_nccl.py @@ -13,6 +13,7 @@ load_tests, TEST_WITH_ROCM, skip_but_pass_in_sandcastle_if, + NoTest, ) from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU from torch.testing._internal.common_device_type import ( @@ -34,7 +35,7 @@ nGPUs = torch.cuda.device_count() if not TEST_CUDA: print("CUDA not available, skipping tests", file=sys.stderr) - TestCase = object # noqa: F811 + TestCase = NoTest # noqa: F811 datatypes = [torch.float] diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index db364296e3b7..7fbf4e88e0e1 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -2589,6 +2589,18 @@ def test_gumbel(self): self.assertEqual(Gumbel(loc_1d, scale_1d).sample((1,)).size(), (1, 1)) self.assertEqual(Gumbel(1.0, 1.0).sample().size(), ()) self.assertEqual(Gumbel(1.0, 1.0).sample((1,)).size(), (1,)) + self.assertEqual(Gumbel(torch.tensor(0.0, dtype=torch.float32), + torch.tensor(1.0, dtype=torch.float32), + validate_args=False).cdf(20.0), 1.0, atol=1e-4, rtol=0) + self.assertEqual(Gumbel(torch.tensor(0.0, dtype=torch.float64), + torch.tensor(1.0, dtype=torch.float64), + validate_args=False).cdf(50.0), 1.0, atol=1e-4, rtol=0) + self.assertEqual(Gumbel(torch.tensor(0.0, dtype=torch.float32), + torch.tensor(1.0, dtype=torch.float32), + validate_args=False).cdf(-5.0), 0.0, atol=1e-4, rtol=0) + self.assertEqual(Gumbel(torch.tensor(0.0, dtype=torch.float64), + torch.tensor(1.0, dtype=torch.float64), + validate_args=False).cdf(-10.0), 0.0, atol=1e-8, rtol=0) def ref_log_prob(idx, x, log_prob): l = loc.view(-1)[idx].detach() diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 4bd69fdcf489..4f3356597914 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -1,17 +1,19 @@ # Owner(s): ["module: dynamo"] +import inspect import operator import unittest from enum import Enum -from typing import Dict, List +from typing import Dict, List, Sequence from unittest.mock import patch import torch - +import torch._dynamo import torch._dynamo.test_case import torch._dynamo.testing from functorch.experimental.control_flow import cond from torch._dynamo import config from torch.fx.experimental.proxy_tensor import make_fx +from torch.testing._internal import common_utils class ExportTests(torch._dynamo.test_case.TestCase): @@ -935,6 +937,7 @@ def forward(self, x): self.assertTrue(node.meta["nn_module_stack"] is not None) self.assertTrue(node.meta["source_fn"] is not None) self.assertTrue(node.meta["val"] is not None) + self.assertTrue(node.meta["original_aten"] is not None) def test_export_preserves_nn_module_stack_for_get_attr(self): inp = torch.randn(4, 4) @@ -1750,22 +1753,25 @@ def fn_with_kwargs(pos0, tuple0, *myargs, mykw0=None, **mykwargs): pos0 = torch.randn(4) myargs = [torch.randn(4), torch.randn(4)] - torch._dynamo.reset() - exported = torch._dynamo.export( + expected_argument_names = [ + "pos0", + "tuple0", + "myargs_0", + "myargs_1", + "mykw0", + "input0", + "input1", + ] + self._test_export_preserving_original_signature( fn_with_kwargs, + expected_argument_names, pos0, tuple0, *myargs, - aten_graph=False, mykw0=mykw0, **mykwargs, ) - out_graph = exported[0] - dynamo_result = out_graph(pos0, tuple0, *myargs, mykw0=mykw0, **mykwargs) - real_result = fn_with_kwargs(pos0, tuple0, *myargs, mykw0=mykw0, **mykwargs) - self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) - def test_export_with_kwargs_and_empty_args(self): def fn_with_kwargs(mykw0=None, **mykwargs): out = mykw0 @@ -1775,19 +1781,11 @@ def fn_with_kwargs(mykw0=None, **mykwargs): mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)} mykw0 = torch.randn(4) - torch._dynamo.reset() - exported = torch._dynamo.export( - fn_with_kwargs, - aten_graph=False, - mykw0=mykw0, - **mykwargs, + expected_argument_names = ["mykw0"] + list(mykwargs.keys()) + self._test_export_preserving_original_signature( + fn_with_kwargs, expected_argument_names, mykw0, **mykwargs ) - out_graph = exported[0] - dynamo_result = out_graph(mykw0=mykw0, **mykwargs) - real_result = fn_with_kwargs(mykw0=mykw0, **mykwargs) - self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) - def test_export_with_args_and_empty_kwargs(self): def fn_with_kwargs(pos0, tuple0, *myargs): out = pos0 @@ -1801,16 +1799,139 @@ def fn_with_kwargs(pos0, tuple0, *myargs): pos0 = torch.randn(4) myargs = [torch.randn(4), torch.randn(4)] + expected_argument_names = ["pos0", "tuple0", "myargs_0", "myargs_1"] + self._test_export_preserving_original_signature( + fn_with_kwargs, expected_argument_names, pos0, tuple0, *myargs + ) + + @common_utils.parametrize( + "default_value", + [ + common_utils.subtest(None, name="None"), + common_utils.subtest(42.0, name="float"), + common_utils.subtest( + # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output + torch.randn(4), + name="tensor", + decorators=[unittest.expectedFailure], + ), + common_utils.subtest( + # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output + (torch.randn(4),), + name="tuple", + decorators=[unittest.expectedFailure], + ), + ], + ) + def test_export_with_args_with_default(self, default_value): + def fn(pos0, pos1_default=default_value): + out = pos0 + if pos1_default is None: + pos1_default = torch.randn(4) + if isinstance(pos1_default, tuple): + pos1_default = pos1_default[0] + out *= pos1_default + return out + + pos0 = torch.randn(4) + expected_argument_names = ["pos0"] + self._test_export_preserving_original_signature( + fn, expected_argument_names, pos0 + ) + + @common_utils.parametrize( + "default_value", + [ + common_utils.subtest(None, name="None"), + common_utils.subtest(42.0, name="float"), + common_utils.subtest( + # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output + torch.randn(4), + name="tensor", + decorators=[unittest.expectedFailure], + ), + common_utils.subtest( + # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output + (torch.randn(4),), + name="tuple", + decorators=[unittest.expectedFailure], + ), + ], + ) + def test_export_with_kwargs_with_default(self, default_value): + def fn(pos0, *, kw0, kw1_default=default_value, **kwargs): + out = pos0 + out += kw0 + if kw1_default is None: + kw1_default = torch.randn(4) + elif isinstance(kw1_default, tuple): + kw1_default = kw1_default[0] + out += kw1_default + out += kwargs["kw2"] + return out + + pos0 = torch.randn(4) + kw0 = torch.randn(4) + kw2 = torch.randn(4) + + args = (pos0,) + kwargs = {"kw0": kw0, "kw2": kw2} + expected_argument_names = ["pos0", "kw0", "kw2"] + self._test_export_preserving_original_signature( + fn, expected_argument_names, *args, **kwargs + ) + + def test_export_with_wrapped_fn(self): + # To ensure dynamo.export is robust to wrapped functions + # when it cannot use `inspect` to retrieve original signature + # info. + def _fn(pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs): + out = pos0 + out += pos1 + out += kw0 + out += kw1 + for arg in args: + out += arg + for kwarg in kwargs.values(): + out += kwarg + return out + + def wrapped_fn(*args, **kwargs): + return _fn(*args, **kwargs) + + pos0 = torch.randn(4) + kw0 = torch.randn(4) + args = (pos0, torch.randn(4), torch.randn(4)) + kwargs = {"kw0": kw0, "kw2": torch.randn(4)} + expected_argument_names = [f"args_{i}" for i in range(len(args))] + list( + kwargs.keys() + ) + + self._test_export_preserving_original_signature( + wrapped_fn, expected_argument_names, *args, **kwargs + ) + + def _test_export_preserving_original_signature( + self, fn, expected_argument_names: Sequence[str], *args, **kwargs + ): torch._dynamo.reset() exported = torch._dynamo.export( - fn_with_kwargs, pos0, tuple0, *myargs, aten_graph=False + fn, + *args, + **kwargs, + aten_graph=False, ) out_graph = exported[0] - dynamo_result = out_graph(pos0, tuple0, *myargs) - real_result = fn_with_kwargs(pos0, tuple0, *myargs) + dynamo_result = out_graph(*args, **kwargs) + real_result = fn(*args, **kwargs) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + # Check that the exported graph preserves same argument names. + self.assertEqual( + inspect.getfullargspec(out_graph.forward).args[1:], expected_argument_names + ) + def test_export_meta(self): class MyModule(torch.nn.Module): def __init__(self): @@ -2033,6 +2154,44 @@ def forward(self, input): count += 1 self.assertEqual(count, 1) + def test_export_pass_arg_by_name(self): + class BasicModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.my_lin = torch.nn.Linear(3, 4, bias=True) + + def forward(self, x): + return self.my_lin(x) + + mod, input_tensor = BasicModule(), torch.randn(2, 3) + gm, guard = torch._dynamo.export(mod, input_tensor, aten_graph=True) + ref = mod(x=input_tensor) + res = gm(x=input_tensor) + self.assertTrue(torch._dynamo.utils.same(ref, res)) + + def test_export_pass_arg_by_name_star_args(self): + class BasicModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.my_lin = torch.nn.Linear(3, 4, bias=True) + + def forward(self, *args): + return self.my_lin(args[0]) * self.my_lin(args[1]) + + mod, input_tensor, input_tensor2 = ( + BasicModule(), + torch.randn(2, 3), + torch.randn(2, 3), + ) + gm, guard = torch._dynamo.export( + mod, input_tensor, input_tensor2, aten_graph=True + ) + ref = mod(input_tensor, input_tensor2) + res = gm(input_tensor, input_tensor2) + self.assertTrue(torch._dynamo.utils.same(ref, res)) + + +common_utils.instantiate_parametrized_tests(ExportTests) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 745b39ed0b0b..1ee9150fa3e3 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -732,6 +732,12 @@ def test_islice_chain(a, b): c = next(itertools.islice(tmp1, 1, None)) return a - b / c + @make_test + def test_namedtuple(a, b): + mytuple = collections.namedtuple("mytuple", ["x", "y", "xy"]) + tmp = mytuple(a, b, a + b) + return mytuple(tmp.x, tmp[1], tmp.xy + b) + @make_test def test_is_quantized(a, b): if not a.is_quantized: diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 8ffdba9565b4..a2202f6c1c0f 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -11,6 +11,7 @@ import sys import typing import unittest +import unittest.mock as mock import weakref from unittest.mock import patch @@ -21,6 +22,7 @@ import torch._dynamo.testing import torch.onnx.operators from torch._dynamo import bytecode_transformation, graph_break +from torch._dynamo.eval_frame import enable_cache_lookup_profiler from torch._dynamo.output_graph import OutputGraph from torch._dynamo.testing import ( CompileCounter, @@ -910,6 +912,28 @@ def fn(a, b): opt_fn(a, b) self.assertEqual(cnts.frame_count, 2) + def test_nested_grad_mode_graph_break(self): + def fn(x): + before = torch.is_grad_enabled() + with torch.set_grad_enabled(False): + torch._dynamo.graph_break() + with torch.set_grad_enabled(True): + x = torch.mul(x, 5) + torch._dynamo.graph_break() + x = torch.sqrt(x) + assert torch.is_grad_enabled() + assert not torch.is_grad_enabled() + assert torch.is_grad_enabled() == before + return x + + a = torch.randn([3, 4]) + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnts)(fn) + + for _ in range(10): + opt_fn(a) + self.assertEqual(cnts.frame_count, 3) + def test_build_tuple_unpack(self): def fn1(a, b, c): return a - b / c @@ -1981,6 +2005,48 @@ def fn(x): self.assertTrue(same(ref, res)) self.assertEqual(cnts.frame_count, 2) + def test_profiler_cache_lookup(self): + def fn(x): + y = x**2 + y = y + 2 + z = y**3 + return z + + x = torch.randn((2, 2), requires_grad=True) + ref = fn(x) + opt_fn = torch.compile(fn, backend="aot_eager") + + # warmup + opt_fn(x) + + enable_cache_lookup_profiler(True) + with torch.autograd.profiler.profile() as prof: + res = opt_fn(x) + events = list( + filter( + lambda event: event.name == "TorchDynamo Cache Lookup", + prof.function_events, + ) + ) + + self.assertTrue(same(ref, res)) + self.assertTrue( + len(events) == 1, "Expected one lookup profiler event for one opt_fn run" + ) + + enable_cache_lookup_profiler(False) + with torch.autograd.profiler.profile() as prof: + res = opt_fn(x) + events = list( + filter( + lambda event: event.name == "TorchDynamo Cache Lookup", + prof.function_events, + ) + ) + + self.assertTrue(same(ref, res)) + self.assertTrue(len(events) == 0, "Expected disabled profiling") + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") def test_cuda_stream_context_manager1(self): def fn(x): @@ -3382,6 +3448,109 @@ def fn(x): self.assertEqual(res.dtype, torch.bfloat16) self.assertEqual(opt_res.dtype, torch.bfloat16) + def test_autocast_cpu_graph_break_inner_fn(self): + class MyModule(torch.nn.Module): + @staticmethod + def mm_breaks(x, y): + torch._dynamo.graph_break() + return torch.mm(x, y) + + def forward(self, x): + a_float32 = torch.rand((8, 8), device="cpu") + b_float32 = torch.rand((8, 8), device="cpu") + + with torch.autocast(device_type="cpu", dtype=torch.bfloat16): + torch._dynamo.graph_break() + with torch.autocast( + device_type="cpu", dtype=torch.bfloat16, enabled=False + ): + torch._dynamo.graph_break() + g_float32 = torch.mm(a_float32, b_float32) + with torch.autocast(device_type="cpu", dtype=torch.bfloat16): + # Check that nested with non-inlineable function with graph break + torch._dynamo.graph_break() + f_float16_1 = self.mm_breaks(a_float32, b_float32) + # We remember to exit the inner autocast correctly to outer + # even after graph breaks + f_float16 = self.mm_breaks(a_float32, b_float32) + assert f_float16.dtype == f_float16_1.dtype + return f_float16, g_float32 + + module = MyModule() + real_16, real_32 = module(torch.tensor([0.5])) + real_device_16 = real_16.device + real_dtype_16 = real_16.dtype + real_device_32 = real_32.device + real_dtype_32 = real_32.dtype + + graph = torch._dynamo.optimize("eager")(module) + out_16, out_32 = graph(torch.tensor([0.5])) + self.assertEqual(out_16.device, real_device_16) + self.assertEqual(out_16.dtype, real_dtype_16) + self.assertEqual(out_32.device, real_device_32) + self.assertEqual(out_32.dtype, real_dtype_32) + + self.assertEqual(out_16.device.type, "cpu") + self.assertEqual(out_16.dtype, torch.bfloat16) + self.assertEqual(out_32.device.type, "cpu") + self.assertEqual(out_32.dtype, torch.float32) + + def test_autocast_graph_break_method(self): + class MyModule(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.bias = bias + + def mm_not_break(self, x, y): + return torch.mm(x, y) + self.bias + + def mm_breaks(self, x, y): + torch._dynamo.graph_break() + return torch.mm(x, y) + self.bias + + def forward(self, x): + a_float32 = torch.rand((8, 8), device="cpu") + b_float32 = torch.rand((8, 8), device="cpu") + + with torch.autocast(device_type="cpu", dtype=torch.bfloat16): + with torch.autocast( + device_type="cpu", dtype=torch.bfloat16, enabled=False + ): + g_float32 = torch.mm(a_float32, b_float32) + f_float16 = self.mm_breaks(a_float32, b_float32) + + assert ( + f_float16[0][0] == self.mm_not_break(a_float32, b_float32)[0][0] + ) + return f_float16, g_float32 + + module = MyModule(bias=torch.rand((8, 8), device="cpu", dtype=torch.bfloat16)) + + with torch.autocast(device_type="cpu", dtype=torch.bfloat16): + # Autocast doesn't work on addition, so we need the bias to be `bfloat16` + res = torch.rand((8, 8), device="cpu", dtype=torch.float32) + torch.rand( + (8, 8), device="cpu", dtype=torch.bfloat16 + ) + self.assertEqual(res.dtype, torch.float32) + + real_16, real_32 = module(torch.tensor([0.5])) + real_device_16 = real_16.device + real_dtype_16 = real_16.dtype + real_device_32 = real_32.device + real_dtype_32 = real_32.dtype + + graph = torch._dynamo.optimize("eager")(module) + out_16, out_32 = graph(torch.tensor([0.5])) + self.assertEqual(out_16.device, real_device_16) + self.assertEqual(out_16.dtype, real_dtype_16) + self.assertEqual(out_32.device, real_device_32) + self.assertEqual(out_32.dtype, real_dtype_32) + + self.assertEqual(out_16.device.type, "cpu") + self.assertEqual(out_16.dtype, torch.bfloat16) + self.assertEqual(out_32.device.type, "cpu") + self.assertEqual(out_32.dtype, torch.float32) + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") def test_autocast_float64(self): class MyModule(torch.nn.Module): @@ -4726,6 +4895,74 @@ def training_step(self): model.training_step() + def test_torch_guards_stack_frame_register(self): + y = torch.nn.Parameter(torch.tensor([0.25, 0.25])) + x = torch.tensor([0.5, 0.5]) + + class encoder(torch.nn.Module): + def __init__(self, y): + super().__init__() + self.register_parameter("param", y) + + @torch._dynamo.disable + def helper(self, x, y): + return x * y + + def forward(self, a, *args): + x = a + a + return self.helper(x, self.param) + + e = encoder(y) + + seen_frames = [] + import contextlib + + @contextlib.contextmanager + def global_context_capture_fn(frame_summary): + seen_frames.append(frame_summary) + yield + + with mock.patch( + "torch._guards.TracingContext.current_frame", + side_effect=global_context_capture_fn, + ): + torch._dynamo.optimize("eager")(e)(x) + + self.assertEqual(len(seen_frames), 1) + self.assertEqual(seen_frames[0].line, "def forward(self, a, *args):") + + def test_torch_guards_stack_frame_register_inlining(self): + x = torch.tensor([0.5, 0.5]) + y = torch.tensor([0.75, 0.75, 0.75, 0.75]) + z = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]) + + def uwu_inline_me(x, y, z): + r = torch.cat((x, x)) + y + r2 = torch.cat((y, y)) + z + return r, r2 + + def fn(x, y, z): + r, r2 = uwu_inline_me(x, y, z) + return torch.mul(r, r), torch.mul(r2, r2) + + seen_frames = [] + import contextlib + + @contextlib.contextmanager + def global_context_capture_fn(frame_summary): + seen_frames.append(frame_summary) + yield + + with mock.patch( + "torch._guards.TracingContext.current_frame", + side_effect=global_context_capture_fn, + ): + torch._dynamo.optimize("eager")(fn)(x, y, z) + + self.assertEqual(len(seen_frames), 2) + self.assertEqual(seen_frames[0].name, "fn") + self.assertEqual(seen_frames[1].line, "def uwu_inline_me(x, y, z):") + class CustomFunc1(torch.autograd.Function): @staticmethod diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index ee7902bef512..3fbba176db2a 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -5,6 +5,7 @@ from typing import Tuple from unittest.mock import patch +import pytest import torch import torch._dynamo.test_case @@ -682,6 +683,22 @@ def forward(self, x): return x * self.scale +class ModuleGuardNameIsValid(torch.nn.ModuleDict): + # Guard names should be valid python identifier as we use eval() to get + # corresponding guard value. Some guard names come from source(module path) + # where special symbols are valid. But they are not valid python identifier, + # we should identify these pattern and rewrite them with getattr. + def __init__(self): + super().__init__() + for i in range(2): + self.add_module("l@yer-%d" % (i + 1), BasicModule()) + + def forward(self, x): + for _, layer in self.items(): + x = layer(x) + return x + + class ModulePatch1(torch.nn.Module): pass @@ -746,6 +763,7 @@ class NNModuleTests(torch._dynamo.test_case.TestCase): test_forward_directly = make_test(CallForwardDirectly()) test_module_name_string = make_test(ModuleNameString()) test_module_attribute_precedence = make_test(ModuleAttributePrecedence()) + test_module_guard_name_is_valid = make_test(ModuleGuardNameIsValid()) def test_module_forward_has_graph_break(self): m = ModuleForwardHasGraphBreak() @@ -1151,6 +1169,20 @@ def forward(self, x): # There will be a graph break for the inner mod being OptimizedModule self.assertEqual(cnt.frame_count, 2) + def test_torchscript_failure(self): + model = BasicModule() + compile_model = torch.compile(model) + example_forward_input = torch.rand(10, 10) + with pytest.raises(AttributeError): + c_model_scripted = torch.jit.script(compile_model, example_forward_input) + + def test_torchtrace_failure(self): + model = BasicModule() + compile_model = torch.compile(model) + example_forward_input = torch.rand(10, 10) + with pytest.raises(AttributeError): + c_model_traced = torch.jit.trace(compile_model, example_forward_input) + def test_module_patch(self): mod = ModulePatch1() mod.forward = types.MethodType(ModulePatch2.forward, mod) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 28fa92085806..82ad5d8af309 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1298,6 +1298,32 @@ def fn(x): res = opt_fn(x) self.assertTrue(same(ref, res)) + def test_with_on_graph_break_nested(self): + def reversible(x): + torch._dynamo.graph_break() # Cause graph break so inline fails + return torch.sin(torch.cos(x)) + + def fn(x): + # nested context manager failed previously + with torch.no_grad(): + with torch.enable_grad(): + a = torch.sin(x) + b = reversible(a) + c = torch.sigmoid(b) + c.sum().backward() + return x.grad + + x = torch.randn(3, requires_grad=True) + x.grad = None + with torch.no_grad(): + ref = fn(x) + + x.grad = None + opt_fn = torch._dynamo.optimize("eager")(fn) + with torch.no_grad(): + res = opt_fn(x) + self.assertTrue(same(ref, res)) + # https://github.com/pytorch/torchdynamo/issues/1446 def test_grad_mode_carrying_correct_state_after_graph_break(self): def fn(x): diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 6bfae5b29d7d..ed3a344a1f6b 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -578,8 +578,6 @@ aten::alias_copy.out aten::all_gather_into_tensor aten::all_reduce aten::allclose -aten::aminmax -aten::aminmax.out aten::angle aten::angle.out aten::argmax @@ -960,8 +958,6 @@ aten::nanmedian aten::nanmedian.dim aten::nanmedian.dim_values aten::nanmedian.out -aten::nansum -aten::nansum.out aten::native_group_norm.out aten::native_norm aten::native_norm.ScalarOpt_dim_dtype diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 65938f28e1c9..4b9c8bd664ac 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -23,7 +23,7 @@ import itertools from functools import partial from torch.nn.utils.rnn import PackedSequence -from torch.testing._internal.common_device_type import instantiate_device_type_tests, toleranceOverride, tol +from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed from torch.testing._internal.common_modules import module_db, modules from functorch import ( @@ -245,6 +245,7 @@ def verify_aot_autograd( *, test_mutation: bool = False, decompositions: Optional[Dict] = None, + dynamic: bool = False, ): for keep_input_mutations in [True, False]: # Some tests pass in a callable for inp, to generate the inputs @@ -294,7 +295,8 @@ def verify_aot_autograd( fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), bw_compiler=nop, decompositions=decompositions, - keep_inference_input_mutations=keep_input_mutations + keep_inference_input_mutations=keep_input_mutations, + dynamic=dynamic ) else: compiled_f = aot_function( @@ -302,7 +304,8 @@ def verify_aot_autograd( fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), bw_compiler=nop, decompositions=decompositions, - keep_inference_input_mutations=keep_input_mutations + keep_inference_input_mutations=keep_input_mutations, + dynamic=dynamic ) ref_out, ref_grad = outs_and_grads(f, graph_inps, inp) test_out, test_grad = outs_and_grads(compiled_f, graph_inps_copy, inp_copy) @@ -366,8 +369,6 @@ def f(a, b): self.verify_aot_autograd(f, inp) # Test for bug occurring at the intersection of fake tensors & functionalization. - @patch("torch._functorch.config.use_dynamic_shapes", True) - @patch("torch._functorch.config.use_fake_tensor", True) def test_squeeze_mutation(self): def f(a): b = a.clone().squeeze(-1) @@ -375,12 +376,10 @@ def f(a): return a + b inp = [torch.randn(3, 1, requires_grad=True)] - self.verify_aot_autograd(f, inp) + self.verify_aot_autograd(f, inp, dynamic=True) inp = [torch.randn(3, 1, requires_grad=False)] - self.verify_aot_autograd(f, inp) + self.verify_aot_autograd(f, inp, dynamic=True) - @patch("torch._functorch.config.use_dynamic_shapes", True) - @patch("torch._functorch.config.use_fake_tensor", True) def test_embedding_bag_view(self): # Backwards pass tries to wrap a sparse tensor in a FunctionalTensorWrapper; # test that this works even though the sparse tensor has no storage. @@ -395,7 +394,7 @@ def forward(self, x, y): x = torch.arange(3) y = torch.arange(3) - self.verify_aot_autograd(F(), [x, y]) + self.verify_aot_autograd(F(), [x, y], dynamic=True) @patch("functorch.compile.config.use_fake_tensor", True) def test_input_mutation_simple(self): @@ -1331,6 +1330,31 @@ def forward(self, primals_1, primals_2): unsqueeze = torch.ops.aten.unsqueeze.default(view_1, 0) return [t, as_strided_2, view_1, t_1, unsqueeze, add]""") # noqa: B950 + @patch("functorch.compile.config.use_fake_tensor", True) + def test_dynamic_shape_output_not_in_bw_graph(self): + def f(x): + return [x + 1, x.shape[0]] + inp = torch.ones(5, requires_grad=True) + bw_graph_cell = [None] + compiled_f = aot_function( + f, + fw_compiler=nop, + bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), + decompositions={}, + keep_inference_input_mutations=False, + dynamic=True, + ) + out = compiled_f(inp) + out[0].sum().backward() + # The important bit: the forward fn returns 2 outputs, + # but one of them is a symint so we should only see + # 1 grad_output as an input to the backward graph. + # (Otherwise, autograd will plumb a None as the value of the grad_output, + # which causes inductor to complain). + self.assertExpectedInline(bw_graph_cell[0].code.strip(), """\ +def forward(self, arg0_1): + return (arg0_1,)""") + @patch("functorch.compile.config.use_fake_tensor", True) def test_no_grad_input_output(self): def f(a, b): @@ -1714,8 +1738,6 @@ def bn(x): for a, b in zip(ref, res): assert torch.allclose(a, b) - @patch("functorch.compile.config.use_dynamic_shapes", True) - @patch("functorch.compile.config.use_fake_tensor", True) def test_output_op_depending_on_symint(self): """ It won't be obvious from reading this test what it's testing for. We should probably make it into a more @@ -1738,12 +1760,10 @@ def f(x): # TODO: assert outputs of fwd graph trace to correct symint # e2e test that fails without symint clone fix - af = aot_function(f, nop, partition_fn=partial(min_cut_rematerialization_partition, compiler="inductor")) + af = aot_function(f, nop, partition_fn=partial(min_cut_rematerialization_partition, compiler="inductor"), dynamic=True) out = af(inp) self.assertEqual(out, f(inp)) - @patch("functorch.compile.config.use_dynamic_shapes", True) - @patch("functorch.compile.config.use_fake_tensor", True) def test_default_partitioner_saves_symints_not_tensors_for_bw(self): """ In this test, the important thing is that primals_1 is **only** needed in the backward @@ -1764,7 +1784,7 @@ def f(a): d = b.masked_fill_(c, 0) return d - compiled_f = aot_function(f, nop) + compiled_f = aot_function(f, nop, dynamic=True) inp_ref = torch.ones(2, 2, requires_grad=True) inp_test = torch.ones(2, 2, requires_grad=True) @@ -1859,14 +1879,15 @@ def get_num_ins_outs(fx_g): return tuple(len(i) for i in get_ins_outs(fx_g)) -def get_fw_bw_graph(f, inps, partitioner=min_cut_rematerialization_partition): +def get_fw_bw_graph(f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False): fw_graph_cell = [None] bw_graph_cell = [None] aot_function(f, fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), partition_fn=partitioner, - decompositions=default_decompositions)(*inps).sum().backward() + decompositions=default_decompositions, + dynamic=dynamic)(*inps).sum().backward() return (fw_graph_cell[0], bw_graph_cell[0]) @@ -1933,8 +1954,6 @@ def f(x, mod_weight, mod_bias): self.assertEqual(get_num_ins_outs(fw_graph), (3, 6)) self.assertEqual(get_num_ins_outs(bw_graph), (6, 3)) - @patch("functorch.compile.config.use_dynamic_shapes", True) - @patch("functorch.compile.config.use_fake_tensor", True) @unittest.skipIf(not USE_NETWORKX, "networkx not available") def test_min_cut_partitioner_save_shape(self): @@ -1943,7 +1962,7 @@ def f(x): return s inp = [torch.ones([10, 10], requires_grad=True)] - fw_graph, bw_graph = get_fw_bw_graph(f, inp) + fw_graph, bw_graph = get_fw_bw_graph(f, inp, dynamic=True) _, fw_output = get_ins_outs(fw_graph) self.assertEqual(get_num_ins_outs(fw_graph), (1, 3)) self.assertEqual(get_num_ins_outs(bw_graph), (3, 1)) @@ -1968,14 +1987,12 @@ def f(a, b, c): x = sb[0] + sc[0] a_sz = (x, a.size(0)) return torch.cat([a.expand(a_sz), b, c]) - fw_graph, bw_graph = get_fw_bw_graph(f, inp) + fw_graph, bw_graph = get_fw_bw_graph(f, inp, dynamic=True) self.assertEqual(get_num_ins_outs(fw_graph), (3, 5)) self.assertEqual(get_num_ins_outs(bw_graph), (5, 3)) _, outs = get_ins_outs(fw_graph) self.assertTrue(all([is_sym_node(n) for n in outs[1:]])) - @patch("functorch.compile.config.use_dynamic_shapes", True) - @patch("functorch.compile.config.use_fake_tensor", True) def test_default_partitioner_output_tensor_shape_tensor(self): inp = [ @@ -2004,7 +2021,8 @@ def f(a, b, c, d): fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), partition_fn=default_partition, - decompositions=default_decompositions)(*inp) + decompositions=default_decompositions, + dynamic=True)(*inp) fw_graph = fw_graph_cell[0] (compiled_outs[0].sum() + compiled_outs[2].sum()).backward() bw_graph = bw_graph_cell[0] @@ -2013,11 +2031,12 @@ def f(a, b, c, d): # - 5 original outputs (sb is a tuple, gets expanded to 2 symints) # - 8 saved outputs for backward: 5 tensors, 3 symints self.assertEqual(get_num_ins_outs(fw_graph), (4, 13)) - # in the bwd graph, 12 inputs (grad outs) because: + # in the bwd graph, 10 inputs (grad outs) because: # - The fwd graph had 13 outputs # - 1 was a view of an input, which gets regenerated outside of the graph # and doesn't participate in the backward - self.assertEqual(get_num_ins_outs(bw_graph), (12, 4)) + # - 2 user outs were symints (b.size()), which don't get tangents in the backward + self.assertEqual(get_num_ins_outs(bw_graph), (10, 4)) _, fw_graph_out_nodes = get_ins_outs(fw_graph) self.assertEqual( # fw outputs include b.size() which expands to 2 symints, @@ -2037,8 +2056,6 @@ def f(a, b, c, d): # TODO(whc) we should learn to return torch.Sizes self.assertFalse(isinstance(compiled_outs[1], torch.Size)) - @patch("functorch.compile.config.use_dynamic_shapes", True) - @patch("functorch.compile.config.use_fake_tensor", True) @unittest.skipIf(not USE_NETWORKX, "networkx not available") def test_min_cut_partitioner_output_tensor_shape_tensor(self): @@ -2068,13 +2085,14 @@ def f(a, b, c, d): fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), partition_fn=min_cut_rematerialization_partition, - decompositions=default_decompositions)(*inp) + decompositions=default_decompositions, + dynamic=True)(*inp) fw_graph = fw_graph_cell[0] (compiled_outs[0].sum() + compiled_outs[2].sum()).backward() bw_graph = bw_graph_cell[0] self.assertEqual(get_num_ins_outs(fw_graph), (4, 13)) - self.assertEqual(get_num_ins_outs(bw_graph), (12, 4)) + self.assertEqual(get_num_ins_outs(bw_graph), (10, 4)) _, fw_graph_out_nodes = get_ins_outs(fw_graph) self.assertEqual( # fw outputs include b.size() which expands to 2 symints, @@ -2394,8 +2412,6 @@ def forward(self, x): skip('linalg.householder_product'), # flaky decorate('matmul', decorator=unittest.skipIf(IS_ARM64, 'flaky')), decorate('__rmatmul__', decorator=unittest.skipIf(IS_ARM64, 'flaky')), - # overrides atol=1e-4, rtol=1e-5 would do as well - decorate('svd_lowrank', decorator=toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-05)})), } symbolic_aot_autograd_failures = { @@ -2441,6 +2457,7 @@ def forward(self, x): xfail('gradient', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('hsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition + xfail('index_fill', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('inner', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides @@ -2552,7 +2569,6 @@ def forward(self, x): xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic ... xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('sum_to_size', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('svd', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('svd_lowrank', ''), # could not find kernel xfail('take_along_dim', ''), # Cannot call sizes() on tensor with symbolic sizes/strides @@ -2562,7 +2578,6 @@ def forward(self, x): xfail('trapezoid', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('trapz', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/de... - xfail('unflatten', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('_upsample_bilinear2d_aa'), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList xfail('vsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides } @@ -2617,7 +2632,7 @@ def create_new_arg(x): except DynamicOutputShapeException: self.skipTest("Dynamic output shape operation in trace") -def _test_aot_autograd_helper(self, device, dtype, op): +def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False): if not op.supports_autograd: self.skipTest("Op does not support autograd") @@ -2639,7 +2654,7 @@ def f(args): c_args, c_kwargs = pytree.tree_unflatten(cur_flat_args, args_spec) return op.op(*c_args, **c_kwargs) - compiled_f = compiled_function(f, nop, nop) + compiled_f = compiled_function(f, nop, nop, dynamic=dynamic) try: _test_aot_autograd_forwards_backwards_helper(self, f, compiled_f, args) except GuardOnDataDependentSymNode: @@ -2651,7 +2666,7 @@ def f(args): else: raise -def _test_aot_autograd_module_helper(self, device, dtype, training, module_info): +def _test_aot_autograd_module_helper(self, device, dtype, training, module_info, *, dynamic=False): module_cls = module_info.module_cls module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, requires_grad=True, training=training) @@ -2696,7 +2711,7 @@ def f(params_buffers_args): named_params = dict(m.named_parameters(remove_duplicate=False)) named_buffers = dict(m.named_buffers(remove_duplicate=False)) num_params_buffers = len(named_params) + len(named_buffers) - compiled_f = aot_function(f, nop, num_params_buffers=num_params_buffers) + compiled_f = aot_function(f, nop, num_params_buffers=num_params_buffers, dynamic=dynamic) params_buffers_args = [named_params, named_buffers, args] _test_aot_autograd_forwards_backwards_helper(self, f, compiled_f, params_buffers_args) @@ -2708,13 +2723,11 @@ def test_aot_autograd_exhaustive(self, device, dtype, op): _test_aot_autograd_helper(self, device, dtype, op) @ops(op_db, allowed_dtypes=(torch.float,)) - @patch("functorch.compile.config.use_dynamic_shapes", True) - @patch("functorch.compile.config.use_fake_tensor", True) @patch("functorch.compile.config.use_functionalize", True) @skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive', aot_autograd_failures | symbolic_aot_autograd_failures) def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op): - _test_aot_autograd_helper(self, device, dtype, op) + _test_aot_autograd_helper(self, device, dtype, op, dynamic=True) aot_autograd_module_failures = set({ @@ -2738,8 +2751,6 @@ def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op): symbolic_aot_autograd_module_failures = { torch.nn.Transformer, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool torch.nn.TransformerEncoder, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool - torch.nn.TransformerEncoderLayer, # RuntimeError: tried to get Double out of SymFloat - torch.nn.TransformerDecoderLayer, # RuntimeError: tried to get Double out of SymFloat torch.nn.GaussianNLLLoss, # NotImplementedError: local_scalar_dense/item NYI for torch.bool torch.nn.CrossEntropyLoss, # Cannot call sizes() on tensor with symbolic sizes/strides torch.nn.Bilinear, # Cannot call sizes() on tensor with symbolic sizes/strides @@ -2754,13 +2765,11 @@ def test_aot_autograd_module_exhaustive(self, device, dtype, training, module_in _test_aot_autograd_module_helper(self, device, dtype, training, module_info) @modules(module_db, allowed_dtypes=(torch.float,)) - @patch("functorch.compile.config.use_dynamic_shapes", True) - @patch("functorch.compile.config.use_fake_tensor", True) @patch("functorch.compile.config.use_functionalize", True) @decorateForModules(unittest.expectedFailure, aot_autograd_module_failures | symbolic_aot_autograd_module_failures) def test_aot_autograd_symbolic_module_exhaustive(self, device, dtype, training, module_info): - _test_aot_autograd_module_helper(self, device, dtype, training, module_info) + _test_aot_autograd_module_helper(self, device, dtype, training, module_info, dynamic=True) only_for = ("cpu") diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 0e4d80707234..38fc695c549a 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1195,6 +1195,7 @@ def test(): xfail("_native_batch_norm_legit"), xfail("native_dropout_backward"), xfail("_upsample_bilinear2d_aa"), # hit vmap fallback, which is disabled + xfail("index_fill"), # aten::_unique hit the vmap fallback which is currently disabled })) def test_vmapvjp_has_batch_rule(self, device, dtype, op): if not op.supports_autograd: diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index df00c89ee800..03cf3ea215bd 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -4259,6 +4259,7 @@ def f(e_): # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2) # https://github.com/pytorch/pytorch/runs/8110653462?check_suite_focus=true # but it passes locally + xfail('linalg.diagonal'), skip('linalg.matrix_norm', ''), skip('linalg.ldl_solve', ''), }) diff --git a/test/fx/test_verifier.py b/test/fx/test_verifier.py new file mode 100644 index 000000000000..71aee0fd69a9 --- /dev/null +++ b/test/fx/test_verifier.py @@ -0,0 +1,203 @@ +# Owner(s): ["module: fx"] +import os +import sys +import unittest +from torch.fx.verifier import ( + SpecViolationError, + check_valid_aten_dialect, + check_valid, + is_valid_aten_dialect, + is_valid, +) + + +from typing import Tuple + + +from torch.testing._internal.common_utils import TestCase +import torch # noqa: F401 +import torch.nn as nn +from torch import Tensor +import torch._dynamo as torchdynamo +import copy +from functorch import make_fx +from functorch.experimental import functionalize + + +@torch.no_grad() +def capture(f, args): + torchdynamo.config.capture_scalar_outputs = True + torchdynamo.config.guard_nn_modules = True + torchdynamo.config.dynamic_shapes = True + torchdynamo.config.allow_rnn = True + torchdynamo.config.verbose = True + torchdynamo.reset() + graphmodule, _ = torchdynamo.export( + f, + *copy.deepcopy(args), + aten_graph=True, + tracing_mode='fake', + ) + + def graph_with_interpreter(*args): + with torch.fx.traceback.preserve_node_meta(): + return torch.fx.Interpreter(graphmodule).run(*args) + + functionalized_callable = functionalize( + graph_with_interpreter, + remove='mutations_and_views', + ) + gm = make_fx(functionalized_callable, tracing_mode='fake', _allow_non_fake_inputs=True)(*args) + return gm + + +class Transpose(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: Tensor, dim0: int, dim1: int) -> Tensor: + return x.transpose(dim0, dim1) + + +class Mul(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input: Tensor, other: Tensor) -> Tensor: + # or return torch.mul(input, other) + return input * other + + def get_random_inputs(self) -> Tuple[Tensor, Tensor]: + return (torch.randn(3, 2), torch.randn(3, 2)) + + +class ElementwiseAdd(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: Tensor, y: Tensor) -> Tensor: + return x + y + + def get_random_inputs(self) -> Tuple[Tensor, Tensor]: + return (torch.randn(1, 3), torch.randn(1, 3)) + + +class Cat(nn.Module): + def __init__(self) -> None: + super().__init__() + + # def forward(self, tensors, dim=0): + def forward(self, *args: Tensor, dim: int) -> Tensor: + tensors = args[:-1] + return torch.cat(tensors, dim) + + +class FeedForwardBlock(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int) -> None: + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.layer_norm = nn.LayerNorm(input_dim) + + self.relu = nn.ReLU() + + self.linear1 = nn.Linear(input_dim, hidden_dim) + self.dropout1 = nn.Dropout() + + self.linear2 = nn.Linear(hidden_dim, input_dim) + self.dropout2 = nn.Dropout() + + def forward(self, x: Tensor) -> Tensor: + # LayerNorm -> Linear -> Dropout -> ReLU -> Linear -> Dropout + y = self.layer_norm(x) + y = self.linear1(y) + y = self.dropout1(y) + y = self.relu(y) + y = self.linear2(y) + y = self.dropout2(y) + return y + + +def skip_condition(): + return sys.version_info >= (3, 11) or os.name == 'nt' + +class VerifierTest(TestCase): + + @unittest.skipIf(skip_condition(), "dynamo doesnt support 3.11") + def test_verifier(self) -> None: + m = ElementwiseAdd() + egm = capture(m, (torch.randn(100), torch.randn(100))) + # assert not throw + check_valid(egm) + self.assertTrue(is_valid(egm)) + + @unittest.skipIf(skip_condition(), "dynamo doesnt support 3.11") + def testr_verifier_call_module(self) -> None: + m = FeedForwardBlock(10, 10) + gm = torch.fx.symbolic_trace(m) + # this would have modules that are not delegates + with self.assertRaises(SpecViolationError): + check_valid(gm) + self.assertFalse(is_valid(gm)) + + @unittest.skipIf(skip_condition(), "dynamo doesnt support 3.11") + def test_verifier_no_functional(self) -> None: + m = ElementwiseAdd() + egm = capture(m, (torch.randn(100), torch.randn(100))) + for node in egm.graph.nodes: + if node.target == torch.ops.aten.add.Tensor: + node.target = torch.ops.aten.add.out + with self.assertRaises(SpecViolationError): + check_valid(egm) + self.assertFalse(is_valid(egm)) + + @unittest.skipIf(skip_condition(), "dynamo doesnt support 3.11") + def test_aten_dialect(self) -> None: + m = ElementwiseAdd() + egm = capture(m, (torch.randn(100), torch.randn(100))) + check_valid_aten_dialect(egm) + self.assertTrue(is_valid_aten_dialect(egm)) + + @unittest.skipIf(skip_condition(), "dynamo doesnt support 3.11") + def test_aten_wrong_mem_format(self) -> None: + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.parameter.Parameter( + torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last) + ) + + def forward(self, x): + return self.a + x + + m = TestModel() + egm = capture(m, (torch.randn(1, 3, 100, 100),)) + egm._apply(lambda t: t.to(memory_format=torch.channels_last)) + with self.assertRaises(SpecViolationError): + check_valid_aten_dialect(egm) + self.assertFalse(is_valid_aten_dialect(egm)) + + @unittest.skipIf(skip_condition(), "dynamo doesnt support 3.11") + def test_aten_wrong_mem_format_buffer(self) -> None: + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "a", + torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last), + ) + + def forward(self, x): + return self.a + x + + m = TestModel() + egm = capture(m, (torch.randn(1, 3, 100, 100),)) + egm._apply(lambda t: t.to(memory_format=torch.channels_last)) + with self.assertRaises(SpecViolationError): + check_valid_aten_dialect(egm) + self.assertFalse(is_valid_aten_dialect(egm)) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 3422cbccc352..a5a2eb112a64 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -20,6 +20,7 @@ import torch import torch._dynamo +import torch.nn as nn from torch._dispatch.python import enable_python_dispatcher from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import rand_strided, same @@ -1009,6 +1010,7 @@ def fn(x): x.argmax(-1), x.amin(-1), x.amax(-1), + x.aminmax(), ) with config.patch(unroll_reductions_threshold=8): @@ -1043,9 +1045,35 @@ def fn(x, y): def test_min_max_reduction(self): def fn(a, b): - return ((a + b).max(), (a + b).min(), torch.amax(a + 1, keepdim=True)) + return ( + (a + b).max(), + (a + b).min(), + torch.amax(a + 1, keepdim=True), + torch.amin(b + 1, keepdim=True), + ) - self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) + for dtype in [torch.float, torch.bfloat16, torch.float16]: + self.common(fn, (torch.randn(8, 8).to(dtype), torch.randn(8, 8).to(dtype))) + + def test_fmin_fmax(self): + def fn(a, b): + return ( + torch.fmin(a, b), + torch.fmax(a, b), + torch.fmax(a + 1, torch.tensor(0.0)), + ) + + self.common( + fn, + ( + torch.tensor( + [-10.0, 10.0, float("nan"), float("nan"), float("nan"), 3, 4] + ), + torch.tensor( + [float("nan"), float("nan"), -10.0, 10.0, float("nan"), 4, 3] + ), + ), + ) def test_sum_int(self): def fn(x): @@ -2175,6 +2203,7 @@ def test_conv_transpose2d_packed(self): (v,), ) + @slow() def test_conv_transpose2d_unary(self): if self.device == "cuda": raise unittest.SkipTest("only support cpu conv_transpose2d unary test") @@ -4019,7 +4048,7 @@ def fn(x, y): self.assertTrue(same(out, out_eager)) @config.patch( - {"triton.ordered_kernel_names": True, "triton.descriptive_kernel_names": False} + {"triton.unique_kernel_names": True, "triton.descriptive_names": False} ) def test_kernel_names(self): @torch._dynamo.optimize("inductor") @@ -4262,22 +4291,25 @@ def fn(x): self.common(fn, [torch.randn(64) * 10]) def test_baddbmm(self): - def fn(a, b, c): - return aten.baddbmm(a, b, c) + def fn(a, b, c, beta): + return aten.baddbmm(a, b, c, beta=beta) - self.common( - fn, - [ - torch.randn(6, 1, 100), - torch.randn(6, 128, 64), - torch.randn(6, 64, 100), - ], - # Mismatched elements: 1212 / 76800 (1.6%) - # Greatest absolute difference: 0.001953125 at index (0, 0, 93) (up to 1e-05 allowed) - # Greatest relative difference: 1.0 at index (3, 19, 4) (up to 0.001 allowed) - atol=0.002, - rtol=0.001, + b = torch.randn(6, 128, 64) + c = torch.randn(6, 64, 100) + options = itertools.product( + [torch.randn(6, 1, 100), torch.randn(6, 1, 100).fill_(torch.nan)], + [0.0, 1.0], ) + for a, beta in options: + self.common( + fn, + [a, b, c, beta], + # Mismatched elements: 1212 / 76800 (1.6%) + # Greatest absolute difference: 0.001953125 at index (0, 0, 93) (up to 1e-05 allowed) + # Greatest relative difference: 1.0 at index (3, 19, 4) (up to 0.001 allowed) + atol=0.002, + rtol=0.001, + ) @config.patch({"triton.max_tiles": 2}) def test_fuse_tiled(self): @@ -4453,7 +4485,11 @@ def fn(a): def test_narrow(self): def fn(x): - return aten.narrow(x, 1, 10, 16), aten.narrow(x + 2, 0, 10, 16) + 1 + return ( + aten.narrow(x, 1, 10, 16), + aten.narrow(x + 2, 0, 10, 16) + 1, + aten.narrow_copy(x, 1, 10, 16), + ) self.common(fn, [torch.randn(64, 64)]) @@ -6626,6 +6662,22 @@ def fn(a): with self.assertRaises(RuntimeError): torch.compile(fn)(a) + def test_ir_node_str(self): + @torch.compile + def fn(x: torch.Tensor) -> torch.Tensor: + return x.sin(), torch.nn.Softmax(dim=1)(x.cos()) + + def run_node_alt(*args, **kwargs): + rv = run_node(*args, **kwargs) + strings.append(str(rv)) + return rv + + strings = [] + run_node = GraphLowering.run_node + with patch.object(GraphLowering, "run_node", run_node_alt): + fn(torch.randn([8, 128])) + self.assertGreater(len(strings), 3) + if HAS_CUDA and not TEST_WITH_ASAN: import triton @@ -7361,17 +7413,72 @@ def fn(): self.assertEqual(fn_opt(), fn()) - def test_split_op_with_sym(self): - for dynamic_shapes in [True, False]: - torch._dynamo.config.dynamic_shapes = dynamic_shapes + def test_kernel_names_descriptive(self): + @torch._dynamo.optimize("inductor") + def fn1(x): + return x.cos().sin() + + @torch._dynamo.optimize("inductor") + def fn2(x): + x = torch.mm(x, x) + x = torch.softmax(x, dim=1) + return x + + mod = nn.Sequential( + nn.Linear(4, 4), + nn.LayerNorm(4), + nn.ReLU(), + ).cuda() - def fn(x: torch.Tensor) -> torch.Tensor: - # split(tensor, sympy.Integer), split(tensor, sympy.Expr) - return torch.split(x, x.shape[0]), torch.split(x, x.shape[0] // 2) + @torch._dynamo.optimize("inductor") + def fn3(x): + return mod(x) - fn_opt = torch._dynamo.optimize("inductor", dynamic=dynamic_shapes)(fn) - inps = torch.randn([5, 5]) - fn_opt(inps) + func_and_kernel_aten = [ + (fn1, "triton_fused_cos_sin", (torch.randn(8, device="cuda"),)), + (fn2, "triton_fused__softmax", (torch.randn(4, 4, device="cuda"),)), + ( + fn3, + "triton_fused_native_layer_norm_relu", + (torch.randn(4, 4, device="cuda"),), + ), + ] + func_and_kernel_torch = [ + (fn1, "triton_fused_cos_sin", (torch.randn(8, device="cuda"),)), + (fn2, "triton_fused_softmax", (torch.randn(4, 4, device="cuda"),)), + ( + fn3, + "triton_fused_LayerNorm_ReLU", + (torch.randn(4, 4, device="cuda"),), + ), + ] + + def test_funcs(func_and_kernel): + with torch.no_grad(): + for fn, kernel_name, inps in func_and_kernel: + code = run_and_get_triton_code(fn, *inps) + if kernel_name not in code: + print(code) + self.assertTrue(kernel_name in code) + + test_funcs(func_and_kernel_aten) + patch.object(config.triton, "descriptive_names", "torch")(test_funcs)( + func_and_kernel_torch + ) + + def test_split_op_with_sym(self): + def fn(x: torch.Tensor) -> torch.Tensor: + # split(tensor, sympy.Integer), split(tensor, sympy.Expr) + return torch.split(x, x.shape[0]), torch.split(x, x.shape[0] // 2) + + for dynamic_shapes in [True, False]: + with torch._dynamo.config.patch(dynamic_shapes=dynamic_shapes): + torch._dynamo.reset() + fn_opt = torch._dynamo.optimize("inductor", dynamic=dynamic_shapes)( + fn + ) + inps = torch.randn([5, 5]) + fn_opt(inps) class ExprPrinterTests(TestCase): diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 95df82d322ab..fb7edeb21446 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -1,11 +1,15 @@ # Owner(s): ["module: inductor"] +import contextlib import importlib import os import sys import unittest +from functools import partial +from unittest.mock import patch import torch from torch._dynamo.testing import make_test_cls_with_patches +from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import ( IS_CI, IS_WINDOWS, @@ -41,8 +45,6 @@ "test_cudnn_rnn_dynamic_shapes": ("cuda",), "test_gather3_dynamic_shapes": ("cpu", "cuda"), "test_kwargs_dynamic_shapes": ("cpu",), - "test_lowmem_dropout2_dynamic_shapes": ("cpu", "cuda"), - "test_rand_like_deterministic_dynamic_shapes": ("cpu", "cuda"), "test_randn_like_empty_dynamic_shapes": ("cpu", "cuda"), # test_roi_align uses torchvision, which doesn't work with dynamic shapes "test_roi_align_dynamic_shapes": ("cpu", "cuda"), @@ -81,6 +83,59 @@ class DynamicShapesCudaTests(TestCase): copy_tests(DynamicShapesCommonTemplate, DynamicShapesCudaTests, "cuda", test_skips) +class TestInductorDynamic(TestCase): + + compile_fn = partial(torch.compile, dynamic=True) + + def setUp(self): + # HAS_CUDA also checks compute capability to skip tests + # on older devices + if self.device_type == "cuda" and not HAS_CUDA: + self.skipTest("Triton not available") + torch._dynamo.reset() + super(TestCase, self).setUp() + # this should be in setUpClass, but device-generic tests + # don't work with setUpClass well (non-deterministically the wrong setUpClass is resolved), + # so put it in test setUp, it's cheap + self._stack = contextlib.ExitStack() + self._stack.enter_context( + torch._inductor.config.patch( + { + "debug": False, + "cpp.min_chunk_size": 1, + "triton.autotune_pointwise": False, # too slow + "implicit_fallbacks": False, + } + ) + ) + + def tearDown(self): + self._stack.close() + super(TestCase, self).tearDown() + torch._dynamo.reset() + + @patch.object(torch._dynamo.config, "specialize_int", False) + def test_arange_dynamic(self, device): + def fn(a): + batch_size = a.numel() + max_len = a.max() + return ~( + torch.arange(0, max_len, device=a.device) + .type_as(a) + .repeat(batch_size, 1) + .lt(a.unsqueeze(1)) + ) + + a = torch.randint(10, 30, (10,), device=device) + a[0] = 29 # fix max_len + opt = self.compile_fn(fn) + res = opt(a) + ref = fn(a) + self.assertEqual(res, ref) + + +instantiate_device_type_tests(TestInductorDynamic, globals()) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 8d9dff20780b..01495f436a12 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -178,9 +178,7 @@ def process(device_type): "masked.var": {f16}, "masked_scatter": {f16, f32, f64}, "masked_select": {b8, f16, f32, f64, i32, i64}, - ("max", "reduction_no_dim"): {f16}, ("max", "reduction_with_dim"): {b8}, - ("min", "reduction_no_dim"): {f16}, ("min", "reduction_with_dim"): {b8}, "multinomial": {f32, f64}, "nanquantile": {f32, f64}, @@ -224,7 +222,6 @@ def process(device_type): "var": {f16}, "var_mean": {f16}, "view_as_complex": {f16}, - ("norm", "inf"): {f16}, "fft.fft": {b8, f16, f32, f64, i32, i64}, "fft.fft2": {b8, f16, f32, f64, i32, i64}, "fft.fftn": {b8, f16, f32, f64, i32, i64}, @@ -472,6 +469,7 @@ def wrapper_set_seed(op, *args, **kwargs): "mT", "mH", "rsub", + "triu", } diff --git a/test/jit/test_cuda.py b/test/jit/test_cuda.py index 6937af9f2927..40603a734621 100644 --- a/test/jit/test_cuda.py +++ b/test/jit/test_cuda.py @@ -9,7 +9,7 @@ from typing import NamedTuple from torch.testing import FileCheck from torch.testing._internal.jit_utils import JitTestCase -from torch.testing._internal.common_utils import skipIfRocm, skipCUDANonDefaultStreamIf +from torch.testing._internal.common_utils import skipIfRocm, skipCUDANonDefaultStreamIf, NoTest # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) @@ -23,7 +23,7 @@ # If GPU is not available, then do not run the tests if not TEST_CUDA: print('CUDA not available, skipping tests', file=sys.stderr) - JitTestCase = object # noqa: F811 + JitTestCase = NoTest # noqa: F811 TEST_LARGE_TENSOR = TEST_CUDA diff --git a/test/jit/test_pdt.py b/test/jit/test_pdt.py index dd8c00685114..5fa39b8cac35 100644 --- a/test/jit/test_pdt.py +++ b/test/jit/test_pdt.py @@ -6,6 +6,7 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global from torch.jit._monkeytype_config import _IS_MONKEYTYPE_INSTALLED from typing import List, Dict, Tuple, Any, Optional, NamedTuple # noqa: F401 +from torch.testing._internal.common_utils import NoTest # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) @@ -13,7 +14,7 @@ if not _IS_MONKEYTYPE_INSTALLED: print("monkeytype is not installed. Skipping tests for Profile-Directed Typing", file=sys.stderr) - JitTestCase = object # type: ignore[misc, assignment] # noqa: F811 + JitTestCase = NoTest # type: ignore[misc, assignment] # noqa: F811 if __name__ == "__main__": raise RuntimeError( diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index b35e66182e7c..dd1b10ecafb3 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -4144,6 +4144,21 @@ def forward(self, x): additional_test_inputs=[x2], ) + class Model(torch.nn.Module): + def forward(self, x): + aa = torch.tensor([[0], [1], [2]]) + return aa.expand_as(x) + + x = torch.ones(3, 2) + x2 = torch.randn(3, 5) + self.run_test( + Model(), + (x,), + input_names=["x"], + dynamic_axes={"x": [0, 1]}, + additional_test_inputs=[x2], + ) + def test_multinomial(self): class Multinomial(torch.nn.Module): def forward(self, weight): diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 19f2d12337f3..c48efeedd9a7 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -78,6 +78,7 @@ ) from torch.ao.quantization.backend_config import ( + get_fbgemm_backend_config, get_qnnpack_backend_config, BackendConfig, BackendPatternConfig, @@ -8348,6 +8349,84 @@ def forward(self, x): } self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) + def test_pixel_unshuffle(self): + class MyBias(nn.Module): + def __init__(self): + super().__init__() + self.bias = nn.Parameter(torch.randn(64)) + + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(8, 8, 1, bias=False) + self.bias = MyBias() + + def forward(self, x): + x = self.conv(x) + x = nn.functional.pixel_unshuffle(x, 2) + bias = self.bias.bias + return x + bias + + for backend in ["fbgemm", "qnnpack"]: + if backend == "fbgemm": + backend_config = get_fbgemm_backend_config() + else: + backend_config = get_qnnpack_backend_config() + qconfig_mapping = get_default_qconfig_mapping(backend) + model = MyModel() + m = prepare_fx( + model, + qconfig_mapping=qconfig_mapping, + example_inputs=(torch.randn(1, 8, 6, 6),), + backend_config=backend_config + ) + m = convert_fx(m) + expected_occurrence = { + ns.call_function(torch.quantize_per_tensor): 2, + ns.call_method("dequantize"): 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) + + + + def test_narrow(self): + class MyBias(nn.Module): + def __init__(self): + super().__init__() + self.bias = nn.Parameter(torch.randn(4)) + + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(8, 8, 1, bias=False) + self.bias = MyBias() + + def forward(self, x): + x = self.conv(x) + x = torch.narrow(x, 1, 0, 4) + bias = self.bias.bias + return x + bias + + for backend in ["fbgemm", "qnnpack"]: + if backend == "fbgemm": + backend_config = get_fbgemm_backend_config() + else: + backend_config = get_qnnpack_backend_config() + qconfig_mapping = get_default_qconfig_mapping(backend) + model = MyModel() + m = prepare_fx( + model, + qconfig_mapping=qconfig_mapping, + example_inputs=(torch.randn(1, 8, 3, 3),), + backend_config=backend_config + ) + m = convert_fx(m) + expected_occurrence = { + ns.call_function(torch.quantize_per_tensor): 2, + ns.call_method("dequantize"): 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) + class TestQuantizeFxModels(QuantizationTestCase): @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDA, "gpu is not available.") diff --git a/test/run_test.py b/test/run_test.py index 708d22c9a37b..53d7245d7208 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -673,6 +673,7 @@ def run_doctests(test_module, test_directory, options): import pathlib pkgpath = pathlib.Path(torch.__file__).parent + exclude_module_list = [] enabled = { # TODO: expose these options to the user # For now disable all feature-conditional tests @@ -687,6 +688,7 @@ def run_doctests(test_module, test_directory, options): 'autograd_profiler': 0, 'cpp_ext': 0, 'monitor': 0, + "onnx": "auto", } # Resolve "auto" based on a test to determine if the feature is available. @@ -710,6 +712,17 @@ def run_doctests(test_module, test_directory, options): else: enabled['qengine'] = True + if enabled["onnx"] == "auto": + try: + import onnx # NOQA + import onnxscript # NOQA + import onnxruntime # NOQA + except ImportError: + exclude_module_list.append("torch.onnx._internal.fx.*") + enabled["onnx"] = False + else: + enabled["onnx"] = True + # Set doctest environment variables if enabled['cuda']: os.environ['TORCH_DOCTEST_CUDA'] = '1' @@ -732,6 +745,9 @@ def run_doctests(test_module, test_directory, options): if enabled['monitor']: os.environ['TORCH_DOCTEST_MONITOR'] = '1' + if enabled["onnx"]: + os.environ['TORCH_DOCTEST_ONNX'] = '1' + if 0: # TODO: could try to enable some of these os.environ['TORCH_DOCTEST_QUANTIZED_DYNAMIC'] = '1' @@ -739,7 +755,6 @@ def run_doctests(test_module, test_directory, options): os.environ['TORCH_DOCTEST_AUTOGRAD'] = '1' os.environ['TORCH_DOCTEST_HUB'] = '1' os.environ['TORCH_DOCTEST_DATALOADER'] = '1' - os.environ['TORCH_DOCTEST_ONNX'] = '1' os.environ['TORCH_DOCTEST_FUTURES'] = '1' pkgpath = os.path.dirname(torch.__file__) @@ -757,7 +772,8 @@ def run_doctests(test_module, test_directory, options): xdoctest_verbose = max(1, options.verbose) run_summary = xdoctest.runner.doctest_module( os.fspath(pkgpath), config=xdoctest_config, verbose=xdoctest_verbose, - command=options.xdoctest_command, argv=[]) + command=options.xdoctest_command, argv=[], + exclude=exclude_module_list) result = 1 if run_summary.get('n_failed', 0) else 0 return result @@ -887,21 +903,10 @@ def run_test_ops(test_module, test_directory, options): PYTEST_BLOCKLIST = [ "test_package", - "test_nccl", "inductor/test_torchinductor", - "test_cuda", "test_quantization", - "test_cuda_nvml_based_avail", - "test_cuda_primary_ctx", - "test_cuda_sanitizer", - "test_cuda_trace", "test_fx", - "test_jiterator", - "test_mps", - "test_cuda_trace", "profiler/test_profiler", - "test_jit", - "test_jit_legacy", "dynamo/test_repros", # skip_if_pytest "dynamo/test_optimizers", # skip_if_pytest "dynamo/test_dynamic_shapes", # needs change to check_if_enable for disabled test issues diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 61dda497cb64..00459eb202a4 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -100,5 +100,20 @@ def test_open_device_registration(self): # None of our CPU operations should call the custom add function. self.assertFalse(module.custom_add_called()) + # check generator registered befor use + with self.assertRaisesRegex(RuntimeError, + "Please register a generator to the PrivateUse1 dispatch key"): + gen_ = torch.Generator(device=device) + + module.register_genertor() + + gen = torch.Generator(device=device) + self.assertTrue(gen.device == device) + + # generator can be registered only once + with self.assertRaisesRegex(RuntimeError, + "Only can register a generator to the PrivateUse1 dispatch key once"): + module.register_genertor() + if __name__ == "__main__": common.run_tests() diff --git a/test/test_cuda.py b/test/test_cuda.py index 72862c118f34..072577583b25 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -24,6 +24,7 @@ import torch.cuda.comm as comm from torch.cuda._memory_viz import profile_plot from torch.cuda._memory_viz import trace_plot +from torch.cuda._memory_viz import segment_plot from torch import inf, nan from torch.nn.parallel import scatter_gather @@ -31,7 +32,7 @@ from torch.testing._internal.common_utils import TestCase, freeze_rng_state, run_tests, \ NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_REMOTE_GPU, IS_SANDCASTLE, IS_WINDOWS, \ slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, TEST_WITH_ROCM, TEST_NUMPY, \ - get_cycles_per_ms, parametrize, instantiate_parametrized_tests, subtest, IS_JETSON, gcIfJetson + get_cycles_per_ms, parametrize, instantiate_parametrized_tests, subtest, IS_JETSON, gcIfJetson, NoTest from torch.testing._internal.autocast_test_lists import AutocastTestLists # load_tests from common_utils is used to automatically filter tests for @@ -47,7 +48,7 @@ if not TEST_CUDA: print('CUDA not available, skipping tests', file=sys.stderr) - TestCase = object # noqa: F811 + TestCase = NoTest # noqa: F811 TEST_CUDAMALLOCASYNC = TEST_CUDA and (torch.cuda.get_allocator_backend() == "cudaMallocAsync") TEST_LARGE_TENSOR = TEST_CUDA @@ -5009,7 +5010,7 @@ def test_memory_profiler_viz(self): self.assertTrue('"elements_category": [' in plot) @unittest.skipIf(TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync") - def test_memory_trace_plot(self): + def test_memory_plots(self): for record_context in (True, False): try: torch.cuda.memory.empty_cache() @@ -5025,9 +5026,12 @@ def run(): run() ss = torch.cuda.memory._snapshot() - plot = trace_plot(ss) - self.assertTrue(record_context == ("test_memory_trace_plot" in plot)) - self.assertTrue(str(128 * 128 * 4) in plot) + tplot = trace_plot(ss) + self.assertTrue(record_context == ("test_memory_plots" in tplot)) + self.assertTrue(str(128 * 128 * 4) in tplot) + splot = segment_plot(ss) + self.assertTrue(record_context == ("test_memory_plots" in splot)) + self.assertTrue(str(128 * 128 * 4) in splot) torch.cuda.memory._record_memory_history(False) finally: torch.cuda.memory._record_memory_history(False) diff --git a/test/test_cuda_nvml_based_avail.py b/test/test_cuda_nvml_based_avail.py index 04bad0ff86af..012c047563ec 100644 --- a/test/test_cuda_nvml_based_avail.py +++ b/test/test_cuda_nvml_based_avail.py @@ -13,7 +13,7 @@ # Before executing the desired tests, we need to disable CUDA initialization and fork_handler additions that would # otherwise be triggered by the `torch.testing._internal.common_utils` module import from torch.testing._internal.common_utils import (parametrize, instantiate_parametrized_tests, run_tests, TestCase, - IS_WINDOWS, IS_JETSON) + IS_WINDOWS, IS_JETSON, NoTest) # NOTE: Because `remove_device_and_dtype_suffixes` initializes CUDA context (triggered via the import of # `torch.testing._internal.common_device_type` which imports `torch.testing._internal.common_cuda`) we need # to bypass that method here which should be irrelevant to the parameterized tests in this module. @@ -22,7 +22,7 @@ TEST_CUDA = torch.cuda.is_available() if not TEST_CUDA: print('CUDA not available, skipping tests', file=sys.stderr) - TestCase = object # type: ignore[misc, assignment] # noqa: F811 + TestCase = NoTest # type: ignore[misc, assignment] # noqa: F811 class TestExtendedCUDAIsAvail(TestCase): diff --git a/test/test_cuda_primary_ctx.py b/test/test_cuda_primary_ctx.py index 74005515a495..b0fa0e14c792 100644 --- a/test/test_cuda_primary_ctx.py +++ b/test/test_cuda_primary_ctx.py @@ -1,7 +1,7 @@ # Owner(s): ["module: cuda"] import torch -from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocmVersionLessThan +from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocmVersionLessThan, NoTest import sys import unittest @@ -16,7 +16,7 @@ if not TEST_CUDA: print('CUDA not available, skipping tests', file=sys.stderr) - TestCase = object # noqa: F811 + TestCase = NoTest # noqa: F811 class TestCudaPrimaryCtx(TestCase): diff --git a/test/test_cuda_sanitizer.py b/test/test_cuda_sanitizer.py index f8733ba43a42..b52fafb94aed 100644 --- a/test/test_cuda_sanitizer.py +++ b/test/test_cuda_sanitizer.py @@ -8,7 +8,7 @@ import torch import torch.cuda._sanitizer as csan from torch.cuda._sanitizer import StreamId, DataPtr, EventId -from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing._internal.common_utils import TestCase, run_tests, NoTest # We cannot import TEST_CUDA from torch.testing._internal.common_cuda here, @@ -19,7 +19,7 @@ if not TEST_CUDA: print("CUDA not available, skipping tests", file=sys.stderr) - TestCase = object # noqa: F811 + TestCase = NoTest # noqa: F811 class TestArgumentHandler(TestCase): diff --git a/test/test_cuda_trace.py b/test/test_cuda_trace.py index 07ba30d27f41..6581ffffec06 100644 --- a/test/test_cuda_trace.py +++ b/test/test_cuda_trace.py @@ -6,7 +6,7 @@ import torch import torch.utils._cuda_trace as cuda_trace -from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing._internal.common_utils import TestCase, run_tests, NoTest # NOTE: Each test needs to be run in a brand new process, to reset the registered hooks # and make sure the CUDA streams are initialized for each test that uses them. @@ -19,7 +19,7 @@ if not TEST_CUDA: print("CUDA not available, skipping tests", file=sys.stderr) - TestCase = object # noqa: F811 + TestCase = NoTest # noqa: F811 class TestCudaTrace(TestCase): diff --git a/test/test_datapipe.py b/test/test_datapipe.py index bc03000542d2..a77d6adba6b0 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -30,6 +30,7 @@ import numpy as np import torch +import torch.nn as nn import torch.utils.data.datapipes as dp import torch.utils.data.graph import torch.utils.data.graph_settings @@ -663,6 +664,16 @@ def _mod_3_test(x): lambda_fn3 = lambda x: x >= 5 # noqa: E731 +class Add1Module(nn.Module): + def forward(self, x): + return x + 1 + + +class Add1Callable: + def __call__(self, x): + return x + 1 + + class TestFunctionalIterDataPipe(TestCase): def _serialization_test_helper(self, datapipe, use_dill): @@ -1326,6 +1337,10 @@ def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): _helper(lambda data: (str(data[0]), data[1], data[2]), str, 0) _helper(lambda data: (data[0], data[1], int(data[2])), int, 2) + # Handle nn.Module and Callable (without __name__ implemented) + _helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Module(), 0) + _helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Callable(), 0) + @suppress_warnings # Suppress warning for lambda fn def test_map_dict_with_col_iterdatapipe(self): def fn_11(d): diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index fce82fc3d9cc..28d5e7406243 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -422,6 +422,13 @@ def test_duck_shape(self): assert a0 == a1 self.assertEqual(len(shape_env.guards), 1) + def test_int_bool(self): + # See https://github.com/pytorch/pytorch/issues/95981 + shape_env = ShapeEnv(duck_shape=True) + a0 = create_symint(shape_env, 5) + assert a0 + self.assertEqual(len(shape_env.guards), 0) + def test_symint_as_scalar(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) diff --git a/test/test_fx.py b/test/test_fx.py index 49ea19a88a12..1ac205d5e53d 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -46,6 +46,7 @@ from fx.test_common_passes import TestCommonPass # noqa: F401 from fx.test_cse_pass import TestCSEPass # noqa: F401 from fx.test_matcher_utils import TestMatcher # noqa: F401 +from fx.test_verifier import VerifierTest # noqa: F401 from fx.test_gradual_type import AnnotationsTest # noqa: F401 from fx.test_gradual_type import TypeCheckerTest # noqa: F401 diff --git a/test/test_jit.py b/test/test_jit.py index 0fbcf5b20d78..2d1161d74669 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -343,6 +343,72 @@ def __init__(self): self.bar = torch.jit.ScriptModule() +@skipIfTorchDynamo() +class TestJitProfiler(JitTestCase): + """ + This runs tests that requires setting some global states like torch._C._set_graph_executor_optimize + and restore the values afterward, i.e. test_profiler. This is to address the flaky issue in + https://github.com/pytorch/pytorch/issues/91483 in which test_profiler was flaky and failed in the + middle without the chance to restore torch._C._set_graph_executor_optimize to its original value. + This causes issues for all future tests running after. + + Using a separate test class here, so that there is no need to run setup and teardown for all tests + in TestJit. + """ + + def setUp(self): + super().setUp() + self.graph_executor_optimize_opt = torch._C._get_graph_executor_optimize() + + def tearDown(self): + super().tearDown() + # Resetting + torch._C._set_graph_executor_optimize( + self.graph_executor_optimize_opt + ) + + @unittest.skipIf(IS_WINDOWS, 'TODO: fix occasional windows failure') + def test_profiler(self): + torch._C._set_graph_executor_optimize(False) + + def other_fn(x): + return x * 2 + + x = torch.rand(3, 4) + traced_other_fn = torch.jit.trace(other_fn, x) + + def fn(x): + y = traced_other_fn(x) + fut = torch.jit._fork(traced_other_fn, x) + y = torch.jit._wait(fut) + return y + + traced_fn = torch.jit.trace(fn, x) + with torch.autograd.profiler.profile() as prof: + traced_fn(x) + + # expecting to see other_fn TS function call + # with cpu time >= mul cpu time and + # a forked other_fn + + mul_events = defaultdict(int) + other_fn_events = defaultdict(int) + for e in prof.function_events: + if e.name == "aten::mul": + self.assertTrue(e.thread not in mul_events) + mul_events[e.thread] = e.time_range.elapsed_us() + elif e.name == "other_fn": + self.assertTrue(e.thread not in other_fn_events) + other_fn_events[e.thread] = e.time_range.elapsed_us() + + self.assertTrue(len(mul_events) == 2) + self.assertTrue(len(other_fn_events) == 2) + + for thread, mul_time in mul_events.items(): + self.assertTrue(thread in other_fn_events) + self.assertTrue(other_fn_events[thread] >= mul_time) + + @skipIfTorchDynamo() class TestJit(JitTestCase): @unittest.skip("Requires a lot of RAM") @@ -2926,50 +2992,6 @@ def test_print_torch_ops_modules(self): s = str(torch._ops.ops.atan) self.assertRegex(s, r'torch.ops') - @unittest.skipIf(IS_WINDOWS, 'TODO: fix occasional windows failure') - def test_profiler(self): - prev_opt = torch._C._get_graph_executor_optimize() - torch._C._set_graph_executor_optimize(False) - - def other_fn(x): - return x * 2 - - x = torch.rand(3, 4) - traced_other_fn = torch.jit.trace(other_fn, x) - - def fn(x): - y = traced_other_fn(x) - fut = torch.jit._fork(traced_other_fn, x) - y = torch.jit._wait(fut) - return y - - traced_fn = torch.jit.trace(fn, x) - with torch.autograd.profiler.profile() as prof: - traced_fn(x) - - # expecting to see other_fn TS function call - # with cpu time >= mul cpu time and - # a forked other_fn - - mul_events = defaultdict(int) - other_fn_events = defaultdict(int) - for e in prof.function_events: - if e.name == "aten::mul": - self.assertTrue(e.thread not in mul_events) - mul_events[e.thread] = e.time_range.elapsed_us() - elif e.name == "other_fn": - self.assertTrue(e.thread not in other_fn_events) - other_fn_events[e.thread] = e.time_range.elapsed_us() - - self.assertTrue(len(mul_events) == 2) - self.assertTrue(len(other_fn_events) == 2) - - for thread, mul_time in mul_events.items(): - self.assertTrue(thread in other_fn_events) - self.assertTrue(other_fn_events[thread] >= mul_time) - - torch._C._set_graph_executor_optimize(prev_opt) - def test_hide_source_ranges_context_manager(self): @torch.jit.script def foo(x): diff --git a/test/test_jiterator.py b/test/test_jiterator.py index f995e5408873..1acba982f3af 100644 --- a/test/test_jiterator.py +++ b/test/test_jiterator.py @@ -5,7 +5,7 @@ from torch.cuda.jiterator import _create_multi_output_jit_fn as create_multi_output_jit_fn import sys from itertools import product -from torch.testing._internal.common_utils import TestCase, parametrize, run_tests, TEST_CUDA +from torch.testing._internal.common_utils import TestCase, parametrize, run_tests, TEST_CUDA, NoTest from torch.testing._internal.common_dtype import all_types_and_complex_and from torch.testing._internal.common_device_type import ( skipCUDAIfRocm, skipCUDAIf, instantiate_device_type_tests, dtypes, toleranceOverride, tol) @@ -13,7 +13,7 @@ if not TEST_CUDA: print('CUDA not available, skipping tests', file=sys.stderr) - TestCase = object # noqa: F811 + TestCase = NoTest # noqa: F811 code_string = "template T my_fused_kernel(T x, T y, T alpha, T beta) { return alpha * x + beta * y; }" diff --git a/test/test_linalg.py b/test/test_linalg.py index d1e1e76762d3..4e66073dd050 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -5565,6 +5565,21 @@ def test_addmm_baddbmm_overflow(self, device, dtype): self.assertTrue((out == 10000.).all()) torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig + @dtypes(torch.float) + def test_baddbmm_nan_input_with_zero_beta(self, device, dtype): + for shape in [[3, 2, 2], [2, 20, 20]]: + mat1, mat2 = [torch.randn(shape, dtype=dtype, device=device) for _ in range(2)] + inputs = [torch.randn(shape, dtype=dtype, device=device), + torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)] + outs = [None, torch.randn(shape, dtype=dtype, device=device), + torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)] + options = itertools.product(inputs, outs) + for input, out in options: + y_ref = torch.bmm(mat1, mat2) + y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out) + self.assertEqual(y_ref, y) + + @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") @onlyCUDA def test_matmul_45724(self, device): diff --git a/test/test_mps.py b/test/test_mps.py index 062a51329efb..b1138dcf4c8d 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -10,7 +10,6 @@ import subprocess import tempfile import os -import pprint import copy import gc import torch @@ -22,8 +21,8 @@ from torch.nn import Parameter from torch.testing._internal import opinfo from torch.testing._internal.common_utils import \ - (gradcheck, gradgradcheck, run_tests, TestCase, download_file, IS_CI, - TEST_WITH_UBSAN, dtype_abbrs, skipIfSlowGradcheckEnv, TEST_WITH_ASAN, suppress_warnings) + (gradcheck, gradgradcheck, run_tests, TestCase, download_file, IS_CI, NoTest, + TEST_WITH_UBSAN, skipIfSlowGradcheckEnv, TEST_WITH_ASAN, suppress_warnings) from torch.testing import make_tensor from torch.testing._comparison import TensorLikePair from torch.testing._internal.common_dtype import get_all_dtypes, integral_types @@ -58,18 +57,174 @@ ) ) -def mps_ops_modifier(ops): - # Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758 - MACOS_13_X_XFAILLIST = { - 'masked.softmax': [torch.float32], +def mps_ops_grad_modifier(ops): + XFAILLIST_GRAD = { + # Top 60 + # CPU: empty is returning all 0's and there is a mismatch with MPS + # allocation (MacOS 13). According to + # https://pytorch.org/docs/2.0/generated/torch.empty.html + # PyTorch `empty`, Returns a tensor filled with uninitialized data. + 'empty': [torch.float16, torch.float32], + + # CPU Error: RuntimeError: "addmv_impl_cpu" not implemented for 'Half' + 'addr': [torch.float16], + + # Unimplemented ops + '__getitem__': [torch.float16], + 'prod': [torch.float32], # The operator 'aten::cumprod.out' + 'sgn': [torch.float16, torch.float32], + '_segment_reduce': [torch.float16, torch.float32], + 'unfold_copy': [torch.float16, torch.float32], # unfold_backward is not implemented + 'unfold': [torch.float16, torch.float32], + 'trace': [torch.float32], # missing in place aten::index_fill_.int_Tensor + 'sparse.mmreduce': [torch.float32], # csr not supported + 'unique_consecutive': [torch.float16, torch.float32], + 'special_modified_bessel_i0': [torch.float16, torch.float32], + 'scalar_tensor': [torch.float16, torch.float32], + 'cdist': [torch.float32], + 'masked.scatter': [torch.float16, torch.float32], + + # Correctness issues + 'atanh': [torch.float32], + + # Random output + 'exponential': [torch.float16, torch.float32], + + # CPU errors + # derivative for aten::floor_divide is not implemented on CPU + 'floor_divide': [torch.float16, torch.float32], + # derivative for aten::narrow_copy is not implemented on CPU + 'narrow_copy': [torch.float16, torch.float32], + # RuntimeError: "log_vml_cpu" not implemented for 'Half' + '__rpow__': [torch.float16], + 'pow': [torch.float16], + # 'bool' object is not iterable + 'allclose': [torch.float16, torch.float32], + 'equal': [torch.float16, torch.float32], + # "mse_backward_cpu_out" not implemented for 'Half' + 'nn.functional.mse_loss': [torch.float16], + # "smooth_l1_backward_cpu_out" not implemented for 'Half' + 'nn.functional.smooth_l1_loss': [torch.float16], + # cpu error: grad requires non-empty inputs + 'randn': [torch.float16, torch.float32], + 'signal.windows.bartlett': [torch.float32], + 'signal.windows.blackman': [torch.float32], + 'signal.windows.cosine': [torch.float32], + 'signal.windows.exponential': [torch.float32], + 'signal.windows.gaussian': [torch.float32], + 'signal.windows.general_cosine': [torch.float32], + 'signal.windows.general_hamming': [torch.float32], + 'signal.windows.hamming': [torch.float32], + 'signal.windows.hann': [torch.float32], + 'signal.windows.kaiser': [torch.float32], + 'signal.windows.nuttall': [torch.float32], + 'empty_permuted': [torch.float16, torch.float32], + 'eye': [torch.float16, torch.float32], + + # trunc_tensor not working properly for float16 + 'divtrunc_rounding': [torch.float16], + 'fmod': [torch.float16], + } + + MACOS_12_3_XFAILLIST_GRAD = { + # Unsupported Border padding mode, forward pass success as fallback to cpu + 'grid_sampler_2d': [torch.float32], + # Unimplemented + 'logaddexp2': [torch.float32], + + # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721. + # fixed in macOS 13. We are not raising error. + '__rpow__': [torch.float32], + 'pow': [torch.float32], + } + + MACOS_BEFORE_13_3_XFAILLIST_GRAD = { + # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ 'masked.softmin': [torch.float32], + 'masked.softmax': [torch.float32], 'masked.log_softmax': [torch.float32], + + # Unsupported Border padding mode, forward pass success as fallback to cpu + 'grid_sampler_2d': [torch.float32], + + # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour). + # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU. + # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS. + # Running `msort` with stable `sort` passes. + 'msort': [torch.float16], + + # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721. + # fixed in macOS 13. We are not raising error. + 'pow': [torch.float32], + '__rpow__': [torch.float32], + } + + XPASSLIST_GRAD = { + 'nn.functional.pairwise_distance': [torch.float16], + } + + MACOS_13_3_XFAILLIST_GRAD = { + # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour). + # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU. + # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS. + # Running `msort` with stable `sort` passes. + 'msort': [torch.float16], } - MACOS_12_X_XFAILLIST = { + + def addDecorator(op, d) -> None: + op.decorators = list(op.decorators) if op.decorators is not None else [] + op.decorators.append(d) + + for op in ops: + key = op.name + op.variant_test_name + if key in XFAILLIST_GRAD: + addDecorator(op, DecorateInfo( + unittest.expectedFailure, + dtypes=XFAILLIST_GRAD[key])) + + if key in XPASSLIST_GRAD: + addDecorator(op, DecorateInfo( + unittest.skip, + dtypes=XPASSLIST_GRAD[key])) + + if key in MACOS_12_3_XFAILLIST_GRAD and (not torch.backends.mps.is_macos13_or_newer()): + addDecorator(op, DecorateInfo( + unittest.expectedFailure, + dtypes=MACOS_12_3_XFAILLIST_GRAD[key])) + + if key in MACOS_BEFORE_13_3_XFAILLIST_GRAD and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3): + addDecorator(op, DecorateInfo( + unittest.expectedFailure, + dtypes=MACOS_BEFORE_13_3_XFAILLIST_GRAD[key])) + + if key in MACOS_13_3_XFAILLIST_GRAD and (product_version >= 13.3): + addDecorator(op, DecorateInfo( + unittest.expectedFailure, + dtypes=MACOS_13_3_XFAILLIST_GRAD[key])) + yield op + +def mps_ops_modifier(ops): + # Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758 + MACOS_12_3_XFAILLIST = { + # Top 60 + # expected failures + # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721. + # fixed in macOS 13.3. Currently error is not raised. + 'pow': [torch.int16, torch.int64, torch.uint8, torch.int8], + # expected failures + '__rpow__': [torch.uint8, torch.int8], + + # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ + 'cdist': [torch.float32], + 'tan': [torch.uint8, torch.float32], + + # Data type support starts from macOS 13 + 'nn.functional.avg_pool1d': [torch.int64], + 'nn.functional.avg_pool2d': [torch.int64], + 'nn.functional.local_response_norm': [torch.int64], '__radd__': [torch.uint8], '__rdiv__': [torch.uint8], '__rmul__': [torch.uint8], - '__rpow__': [torch.uint8], 'abs': [torch.uint8], 'acos': [torch.uint8], 'acosh': [torch.uint8], @@ -78,20 +233,30 @@ def mps_ops_modifier(ops): 'asinh': [torch.uint8], 'atan': [torch.uint8], 'atanh': [torch.uint8], + 'ceil': [torch.uint8], + 'corrcoef': [torch.uint8], 'cos': [torch.uint8], 'cosh': [torch.uint8], + 'cov': [torch.uint8], + 'cumulative_trapezoid': [torch.uint8], 'deg2rad': [torch.uint8], 'diff': [torch.uint8], + 'eq': [torch.uint8], 'equal': [torch.uint8], 'erf': [torch.uint8], 'exp2': [torch.uint8], 'exp': [torch.uint8], + 'expm1': [torch.uint8], + 'floor': [torch.uint8], 'fmax': [torch.uint8], 'fmin': [torch.uint8], 'fmod': [torch.uint8], + 'ge': [torch.uint8], + 'gt': [torch.uint8], 'isclose': [torch.uint8], 'isnan': [torch.uint8], 'kron': [torch.uint8], + 'le': [torch.uint8], 'log10': [torch.uint8], 'log1p': [torch.uint8], 'log2': [torch.uint8], @@ -100,55 +265,455 @@ def mps_ops_modifier(ops): 'logical_or': [torch.uint8], 'logical_xor': [torch.uint8], 'logit': [torch.uint8], + 'lt': [torch.uint8], 'masked.mean': [torch.uint8], 'masked.std': [torch.uint8], 'masked.var': [torch.uint8], - 'nn.functional.avg_pool1d': [torch.int64], - 'nn.functional.avg_pool2d': [torch.int64], + 'maximum': [torch.uint8], + 'minimum': [torch.uint8], + 'mul': [torch.uint8], + 'ne': [torch.uint8], + 'neg': [torch.uint8], 'nn.functional.cosine_embedding_loss': [torch.uint8], + 'nn.functional.margin_ranking_loss': [torch.uint8], 'nn.functional.poisson_nll_loss': [torch.uint8], 'nn.functional.softsign': [torch.uint8], 'nn.functional.tanhshrink': [torch.uint8], - 'pow': [torch.int16, torch.int64, torch.uint8], + 'nn.functional.triplet_margin_loss': [torch.uint8], + 'nn.functional.triplet_margin_with_distance_loss': [torch.uint8], + 'nn.functional.pairwise_distance': [torch.uint8, torch.float16], + 'outer': [torch.uint8], 'rad2deg': [torch.uint8], 'reciprocal': [torch.uint8], 'remainder': [torch.uint8], + 'round': [torch.uint8], 'rsqrt': [torch.uint8], 'sigmoid': [torch.uint8], 'sign': [torch.uint8], + 'signbit': [torch.uint8], 'sin': [torch.uint8], 'sinh': [torch.uint8], 'special.ndtr': [torch.uint8], 'sqrt': [torch.uint8], 'sub': [torch.uint8], - 'tan': [torch.uint8], 'tanh': [torch.uint8], + 'trapezoid': [torch.uint8], + 'trapz': [torch.uint8], 'true_divide': [torch.uint8], + 'trunc': [torch.uint8], 'xlogy': [torch.uint8], - # Weird - 'square': [torch.uint8, torch.bool, torch.int16, torch.int32, torch.int64], + 'minbinary': [torch.uint8], + 'maxbinary': [torch.uint8], + 'divtrunc_rounding': [torch.uint8], + 'divfloor_rounding': [torch.uint8], + 'divno_rounding_mode': [torch.uint8], + 'floor_divide': [torch.uint8], + 'ldexp': [torch.uint8], + # square internally calls into power, and will type cast to int64, which supports starting from macOS 13 + 'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + + # cpu not giving nan for x/0.0 + 'atan2': [torch.bool, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + # fill tensors with uninitialized data, causing mismatch with CPU + 'empty_permuted': [torch.bool, torch.float16, torch.float32, torch.int16, + torch.int32, torch.int64, torch.uint8, torch.int8], + 'empty': [torch.bool, torch.float16, torch.float32, torch.int16, + torch.int32, torch.int64, torch.uint8, torch.int8], + 'dist': [torch.float16], # cpu result off, showing inf values } + MACOS_BEFORE_13_3_XFAILLIST = { + # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ + 'tan': [torch.float32], + 'cdist': [torch.float32], + + # CPU Error: cpu not giving nan for x/0.0 + 'atan2': [torch.bool, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + + # test blow pass on macOS 12 as it falls back to cpu + # Argsort case using duplicate indices (undefined behaviour): + # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], devuce='cpu') + # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') + # Elements from index 30 and 5133 are both equal. + # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. + 'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool], + # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. + # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour. + 'sort': [torch.int8, torch.uint8, torch.bool, torch.float16], + # Unsupported dtypes + 'cumsum': [torch.int64], + 'cumulative_trapezoid': [torch.int64], + 'masked.cumsum': [torch.int64], + } + + MACOS_13_3_XFAILLIST = { + # before macOS 13.3 it falls back to cpu and pass the forward pass + 'grid_sampler_2d': [torch.float32], # Unsupported Border padding mode + + # Failure due to precision issue for fp16 + # on both cpu and mps there are test cases that might produce inf result + # 'nn.functional.pairwise_distance': [torch.float16], + + # test blow pass on macOS 12 as it falls back to cpu + # Argsort case using duplicate indices (undefined behaviour): + # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], devuce='cpu') + # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') + # Elements from index 30 and 5133 are both equal. + # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. + 'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool], + # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. + # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour. + 'sort': [torch.int8, torch.uint8, torch.bool, torch.float16], + } # Those ops are not expected to work - XFAILLIST = { - '__rpow__': [torch.int16, torch.int32, torch.int64], + UNIMPLEMENTED_XFAILLIST = { + # Failures due to lack of op implementation on MPS backend + 'login': None, + 'log_sigmoid': None, + 'log_sigmoid_forward': None, + 'linalg.eig': None, + 'linalg.eigvals': None, + 'fft.fft': None, + 'fft.fft2': None, + 'fft.fftn': None, + 'fft.hfft': None, + 'fft.hfft2': None, + 'fft.hfftn': None, + 'fft.ifft': None, + 'fft.ifft2': None, + 'fft.ifftn': None, + 'fft.ihfft': None, + 'fft.ihfft2': None, + 'fft.ihfftn': None, + 'fft.irfft': None, + 'fft.irfft2': None, + 'fft.irfftn': None, + 'fft.rfft': None, + 'fft.rfft2': None, + 'fft.rfftn': None, + 'put': None, + 'stft': None, + 'nn.functional.conv_transpose3d': None, + 'rounddecimals_neg_3': None, + 'rounddecimals_3': None, + 'rounddecimals_0': None, + '__rsub__': None, + 'aminmax': None, + 'angle': None, + 'bucketize': None, + 'cauchy_': None, + 'cauchy': None, + 'cholesky': None, + 'cholesky_inverse': None, + 'cholesky_solve': None, + 'cummax': None, + 'cummin': None, + 'cumprod': None, + 'digamma': None, + 'erfc': None, + 'erfinv': None, + 'frexp': None, + 'gcd': None, + 'geqrf': None, + 'nn.functional.grid_sample': None, # Unsupported Border padding mode + 'heaviside': None, + 'histc': None, + 'histogram': None, + 'histogramdd': None, + 'i0': None, + 'igamma': None, + 'igammac': None, + 'index_copy': None, + 'index_fill': None, + 'index_reduce': None, + 'isin': None, + 'isneginf': None, + 'isposinf': None, + 'kthvalue': None, + 'lcm': None, + 'lerp': None, + 'lgamma': None, + 'linalg.cholesky': None, + 'linalg.cholesky_ex': None, + 'linalg.cond': None, + 'linalg.detsingular': None, + 'linalg.det': None, + 'linalg.eigh': None, + 'linalg.eigvalsh': None, + 'linalg.householder_product': None, + 'linalg.ldl_factor': None, + 'linalg.ldl_factor_ex': None, + 'linalg.ldl_solve': None, + 'linalg.lstsq': None, + 'linalg.lstsqgrad_oriented': None, + 'linalg.lu': None, + 'linalg.lu_factor': None, + 'linalg.lu_factor_ex': None, + 'linalg.lu_solve': None, + 'linalg.matrix_norm': [torch.float32], + 'linalg.norm': [torch.float32], + 'linalg.normsubgradients_at_zero': [torch.float32], + 'linalg.qr': None, + 'linalg.slogdet': None, + 'linalg.solve': None, + 'linalg.solve_ex': None, + 'linalg.svdvals': None, + 'linalg.tensorsolve': None, + 'linalg.vander': None, + 'linalg.vecdot': None, + 'logcumsumexp': None, + 'logdet': None, + 'lu': None, + 'lu_solve': None, + 'lu_unpack': None, + 'masked.cumprod': None, + 'masked.median': None, + 'matrix_exp': None, + 'mode': None, + 'mvlgamma': None, + 'mvlgammamvlgamma_p_1': None, + 'mvlgammamvlgamma_p_3': None, + 'mvlgammamvlgamma_p_5': None, + 'nanquantile': None, + 'nanmedian': None, + 'native_dropout_backward': None, + 'nextafter': None, + 'normnuc': None, + 'nn.functional.fractional_max_pool2d': None, + 'nn.functional.fractional_max_pool3d': None, + 'nn.functional.adaptive_avg_pool3d': None, + 'nn.functional.adaptive_max_pool3d': None, + 'nn.functional.interpolatearea': None, + 'nn.functional.interpolatebicubic': None, + 'nn.functional.interpolatelinear': None, + 'nn.functional.interpolatetrilinear': None, + 'nn.functional.max_unpool1dgrad': None, + 'nn.functional.max_unpool2dgrad': None, + 'nn.functional.max_unpool3dgrad': None, + 'nn.functional.avg_pool3d': None, + 'nn.functional.ctc_loss': None, + 'nn.functional.embedding_bag': None, + 'nn.functional.hardshrink': None, + 'nn.functional.max_pool3d': None, + 'nn.functional.max_unpool1d': None, + 'nn.functional.max_unpool2d': None, + 'nn.functional.max_unpool3d': None, + 'nn.functional.mish': None, + 'nn.functional.multi_margin_loss': None, + 'nn.functional.multilabel_margin_loss': None, + 'nn.functional.pdist': None, + 'nn.functional.rrelu': None, + 'nn.functional.softshrink': None, + 'nn.functional.norm': None, + 'ormqr': None, + 'pca_lowrank': None, + 'pinverse': None, + 'polar': None, + 'polygamma': None, + 'polygammapolygamma_n_0': None, + 'polygammapolygamma_n_1': None, + 'polygammapolygamma_n_2': None, + 'polygammapolygamma_n_3': None, + 'polygammapolygamma_n_4': None, + 'qr': None, + 'quantile': None, + 'renorm': None, + 'rsub': None, + 'scatter_reduceamax': None, + 'scatter_reduceamin': None, + 'scatter_reducemin': None, + 'scatter_reducemean': None, + 'scatter_reduceprod': None, + 'scatter_reducesum': None, + 'searchsorted': None, + 'segment_reduce': None, + '_segment.reduce': None, + 'segment.reduce': None, + 'segment_reduce_offsets': None, + '_segment_reduce_offsets': None, + '_segment_reduce_lengths': None, + '_segment_reducelengths': None, + '_segment_reduceoffsets': None, + 'sinc': None, + 'sparse.mm': None, + 'sparse.mmreduce': None, + 'special.airy_ai': None, + 'special.bessel_j0': None, + 'special.bessel_j1': None, + 'special.bessel_y0': None, + 'special.bessel_y1': None, + 'special.chebyshev_polynomial_t': None, + 'special.chebyshev_polynomial_u': None, + 'special.entr': None, + 'special.erfcx': None, + 'special.hermite_polynomial_h': None, + 'special.hermite_polynomial_he': None, + 'special.i0e': None, + 'special.i1': None, + 'special.i1e': None, + 'special.laguerre_polynomial_l': None, + 'special.log_ndtr': None, + 'special.modified_bessel_i0': None, + 'special.modified_bessel_i1': None, + 'special.modified_bessel_k0': None, + 'special.modified_bessel_k1': None, + 'special.ndtri': None, + 'special.polygamma': None, + 'special.polygammaspecial_polygamma_n_0': None, + 'special.scaled_modified_bessel_k0': None, + 'special.scaled_modified_bessel_k1': None, + 'special.spherical_bessel_j0': None, + 'special.xlog1py': None, + 'special.zeta': None, + 'std_mean': None, + 'std_meanunbiased': None, + 'svd_lowrank': None, + 'symeig': None, + 'take': None, + 'to': None, + 'to_sparse': None, + 'unique': None, + 'vdot': None, + 'view_as_complex': None, + 'segment_reduce': None, + 'segment_reduce_': None, + '_segment_reduce_lengths': None, + '_upsample_bilinear2d_aa': None, + 'geometric' : None, + 'geometric_': None, + 'log_normal_': None, + 'log_normal': None, + 'bfloat16': None, + 'cdouble': None, + 'cfloat': None, + 'complex': None, + 'double': None, 'chalf': None, + 'nn.functional.softminwith_dtype': None, + 'log_softmaxwith_dtype': None, + 'softmaxwith_dtype': None, + 'float_power': None, + 'full_like': None, + 'linalg.matrix_rank': None, + 'linalg.matrix_rankhermitian': None, + 'linalg.pinv': None, + 'linalg.pinvhermitian': None, + + # MPS: input sizes must be divisible by output sizes + 'nn.functional.adaptive_avg_pool1d': None, + 'nn.functional.adaptive_avg_pool2d': None, + # Unsupported dtypes - 'dot': [torch.int64], - 'index_add': [torch.int64], + # bmm is not supported for integral types + 'nn.functional.bilinear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + # Cannot convert a MPS Tensor to float64 dtype. The tensors + # input data is created with double in common_methods_invocations.py + 'nn.functional.batch_norm': [torch.float32], + 'ones_like': None, + 'zeros_like': None, + + # Convolution for integral types is not supported on MPS 'nn.functional.conv1d': [torch.int64], 'nn.functional.conv2d': [torch.int64], 'nn.functional.conv_transpose1d': [torch.int64], 'nn.functional.conv_transpose2d': [torch.int64], - # 'remainder': [torch.int64], + + # Unsupported dtypes + 'dot': [torch.int64], + 'index_add': [torch.int64], + 'log1p': [torch.int64], 'sigmoid': [torch.int64], - # failures due to lack of op implementation on MPS backend - 'put': None, - # Weird + 'atan2': [torch.int64], + + # GEMM on MPS is not supported for integral types + 'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + '__rmatmul__': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'addmmdecomposed': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'addbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'addmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'addmv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'baddbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'mm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'bmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'einsum': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'inner': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'linalg.multi_dot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'matmul': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'mat': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'mv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'tensordot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + + # new_zeros/new_ones: Cannot convert a MPS Tensor to float64 dtype as + # the MPS framework doesn't support float64 + 'new_zeros': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'new_ones': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'new_full': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + # returned output on CPU is float64 + 'bincount': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + + # trunc_tensor not working properly for float16 + 'divtrunc_rounding': [torch.float16], + 'fmod': [torch.float16], + } + + UNDEFINED_XFAILLIST = { + # Top 60 operators + # topk fails with duplicate indices + 'topk': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + + # Failures due to random output that they generate using + # Philox engine causing mismatch with CPU results + 'multinomial': [torch.float32], # random results + 'uniform': [torch.float16, torch.float32], + 'rand_like': [torch.float16, torch.float32], + 'randint_like': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'randn_like': [torch.float16, torch.float32], + 'bernoulli': [torch.float32], + 'exponential': [torch.float16, torch.float32], + 'nn.functional.feature_alpha_dropoutwith_train': [torch.float32], + 'normal': [torch.float16, torch.float32, torch.float16, torch.float32], + 'normalin_place': [torch.float16, torch.float32], + 'normalnumber_mean': [torch.float16, torch.float32], + 'nn.functional.alpha_dropout': [torch.float32], + 'nn.functional.dropout': [torch.float32], + 'nn.functional.dropout2d': [torch.float32], + 'nn.functional.dropout3d': [torch.float32], + + # these fill tensors with uninitialized data, causing mismatch with CPU + 'new_empty': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'empty_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + # 'empty': [torch.int8], + 'new_empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16, + torch.int32, torch.int64, torch.uint8, torch.int8], + # duplicate indices are used in the testcase - undefined behaviour + 'index_put': None, + # zero to negative integer powers are undefined + '__rpow__': [torch.int8, torch.int16, torch.int32, torch.int64], + 'resize_': [torch.float16, torch.float32], + 'resize_as_': [torch.float16, torch.float32], + + # CPU Errors: + 'addr': [torch.bool, torch.int16, torch.int32, + torch.int64, torch.uint8, torch.int8], # "addmv_impl_cpu" not implemented for 'Half' + 'as_stridedpartial_views': [torch.bool, torch.float16, torch.float32, torch.int16, + torch.int32, torch.int64, torch.uint8, torch.int8], # cpu result off, showing random values + 'as_strided_partial_views': [torch.bool, torch.float16, torch.float32, torch.int16, + torch.int32, torch.int64, torch.uint8, torch.int8], # cpu result off, showing random values + + # random results + # mps vs cpu: + # Mismatched elements: 40 / 96 (41.7%) + # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed) + # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed) + # cuda(2.0.0.dev20230301+cu117) vs cpu: + # Mismatched elements: 56 / 96 (58.3%) + # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed) + # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed) + 'nn.functional.scaled_dot_product_attention': [torch.float32], + + # Failures due to casting negative float to uint8 is undefined 'byte': [torch.float16, torch.float32], - 'nn.functional.adaptive_avg_pool1d': [torch.float32], - 'nn.functional.adaptive_avg_pool2d': [torch.float32], } def addDecorator(op, d) -> None: @@ -157,26 +722,33 @@ def addDecorator(op, d) -> None: for op in ops: key = op.name + op.variant_test_name - if key in XFAILLIST: + for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST]: + if key in xfaillist: + addDecorator(op, DecorateInfo( + unittest.expectedFailure, + dtypes=xfaillist[key])) + + if key in MACOS_BEFORE_13_3_XFAILLIST and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3): addDecorator(op, DecorateInfo( unittest.expectedFailure, - dtypes=XFAILLIST[key])) + dtypes=MACOS_BEFORE_13_3_XFAILLIST[key])) - if key in MACOS_13_X_XFAILLIST and torch.backends.mps.is_macos13_or_newer(): + if key in MACOS_13_3_XFAILLIST and (product_version >= 13.3): addDecorator(op, DecorateInfo( unittest.expectedFailure, - dtypes=MACOS_13_X_XFAILLIST[key])) - if key in MACOS_12_X_XFAILLIST and not torch.backends.mps.is_macos13_or_newer(): + dtypes=MACOS_13_3_XFAILLIST[key])) + + if key in MACOS_12_3_XFAILLIST and (not torch.backends.mps.is_macos13_or_newer()): addDecorator(op, DecorateInfo( unittest.expectedFailure, - dtypes=MACOS_12_X_XFAILLIST[key])) + dtypes=MACOS_12_3_XFAILLIST[key])) yield op # Same logic as test_cuda.py if not torch.backends.mps.is_available(): print('MPS not available, skipping tests', file=sys.stderr) - TestCase = object # noqa: F811 - NNTestCase = object # noqa: F811 + TestCase = NoTest # noqa: F811 + NNTestCase = NoTest # noqa: F811 product_version = float('.'.join(platform.mac_ver()[0].split('.')[:2])) @@ -1846,15 +2418,19 @@ def helper(shape): # Test addcmul def test_addcmul(self): - def helper(shape, value): + def helper(shape, value, xtype=torch.float32, ytype=None, ztype=None): + def rand_helper(dtype): + if dtype.is_floating_point: + return torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) + return torch.randint(10, shape, dtype=dtype, device='cpu', requires_grad=False) - cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) + cpu_x = rand_helper(xtype) x = cpu_x.detach().clone().to('mps') - cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) + cpu_y = rand_helper(ytype if ytype is not None else xtype) y = cpu_y.detach().clone().to('mps') - cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) + cpu_z = rand_helper(ztype if ztype is not None else xtype) z = cpu_z.detach().clone().to('mps') y = torch.addcmul(x, y, z, value=value) @@ -1866,6 +2442,16 @@ def helper(shape, value): helper((2, 8, 4, 5), 0.1) helper((2, 3, 4, 5), 0.2) helper((2, 8, 4, 5), 0.2) + # Integral types + helper((2, 2), 1.0, xtype=torch.int32) + helper((2, 2), 2.0, xtype=torch.int16) + + # Mixed types + helper((2, 2), 1.0, xtype=torch.float16, ytype=torch.float32) + helper((3, 2), 1.0, ytype=torch.float16) + helper((2, 3), 1.0, ztype=torch.float16) + helper((2, 2), 1.0, xtype=torch.int32, ytype=torch.int16, ztype=torch.uint8) + helper((2, 2), 1.0, ytype=torch.int16, ztype=torch.uint8) # Test addcdiv def test_addcdiv(self): @@ -6361,24 +6947,6 @@ def helper2(dim): for dim in [0, 1, 2, 3, -1, -2, -3]: helper(shape, dim, channels_last) - # Test sub - def test_sub(self): - def helper(shape, alpha): - cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) - x = cpu_x.detach().clone().to('mps') - - cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) - y = cpu_y.detach().clone().to('mps') - - cpu_out = torch.sub(cpu_x, cpu_y, alpha=alpha) - out = torch.sub(x, y, alpha=alpha) - - self.assertEqual(out, cpu_out) - - helper((2, 8, 4, 5), 0.1) - helper((2, 8, 3, 5), 0.1) - helper((2, 8, 3, 5), 0.2) - def test_nan_to_num(self): inputCPU = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14]) inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() @@ -6611,8 +7179,13 @@ def test_exponential_1(self): self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,)) # Test add - def test_add_binary_op(self): - def helper(shape, alpha): + def test_add_sub(self): + def helper(shape, alpha, op_name, inplace): + if op_name == "add": + op = torch.Tensor.add_ if inplace else torch.add + elif op_name == "sub": + op = torch.Tensor.sub_ if inplace else torch.sub + for dtype in [torch.float16, torch.float32]: cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) mps_x = cpu_x.detach().clone().to('mps') @@ -6620,25 +7193,32 @@ def helper(shape, alpha): cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) mps_y = cpu_y.detach().clone().to('mps') - cpu_out = torch.add(cpu_x, cpu_y, alpha=alpha) - mps_out = torch.add(mps_x, mps_y, alpha=alpha) + cpu_out = op(cpu_x, cpu_y, alpha=alpha) + mps_out = op(mps_x, mps_y, alpha=alpha) # fp16 isn't accurate when alpha is passed # TODO: remove or fix 'tol' when we fix problems with fp16 - tol = 1e-3 if dtype is torch.float16 else None + tol = 2e-3 if dtype is torch.float16 else None self.assertEqual(mps_out, cpu_out, rtol=tol, atol=tol) + if not (cpu_y.shape != () and inplace): # in-place output cannot be broadcasted. + # create a scalar tensor + cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False) + mps_s = cpu_s.detach().clone().to('mps') + # primary tensor is scalar + self.assertEqual(op(cpu_s, cpu_y), op(mps_s, mps_y)) # create a scalar tensor cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False) mps_s = cpu_s.detach().clone().to('mps') - # primary tensor is scalar - self.assertEqual(torch.add(cpu_s, cpu_y), torch.add(mps_s, mps_y)) # secondary tensor is scalar - self.assertEqual(torch.add(cpu_x, cpu_s), torch.add(mps_x, mps_s)) + self.assertEqual(op(cpu_x, cpu_s), op(mps_x, mps_s), rtol=tol, atol=tol) - helper((2, 8, 4, 5), 1.0) - helper((2, 8, 4, 5), 0.0) - helper((2, 8, 4, 5), 0.1) - helper((2, 8, 3, 5), 0.1) - helper((2, 8, 3, 5), 0.2) + + for op_name, inplace in product(["add", "sub"], [True, False]): + helper((), 0.0, op_name, inplace) + helper((2, 8, 4, 5), 0.0, op_name, inplace) + helper((2, 8, 4, 5), 0.1, op_name, inplace) + helper((2, 8, 4, 5), 1.0, op_name, inplace) + helper((2, 8, 3, 5), 0.1, op_name, inplace) + helper((2, 8, 3, 5), 0.2, op_name, inplace) # Test add def test_add_scalars(self): @@ -9255,92 +9835,91 @@ def test_cpu_indices(self, device="mps"): self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0) class TestRNNMPS(TestCaseMPS): - def test_lstm_1(self, device="mps", dtype=torch.float32): - for layers in [1] if product_version < 13.0 else [1, 2, 5]: - torch.random.manual_seed(42) - rnn = nn.LSTM(7, 4, layers, device="cpu") - input = torch.randn(2, 3, 7, device="cpu") - hx = torch.randn(layers, 3, 4, device="cpu") - cx = torch.randn(layers, 3, 4, device="cpu") - - cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx)) - - rnn = rnn.to(device) - input = input.to(device) - hx = hx.to(device) - cx = cx.to(device) - output, (hn, cn) = rnn(input, (hx, cx)) - - self.assertEqual(cpu_output, output) - self.assertEqual(cpu_hn, hn) - self.assertEqual(cpu_cn, cn) - - # test batch_first - rnn = nn.LSTM(7, 4, layers, device="cpu", batch_first=True) - input = torch.randn(3, 2, 7, device="cpu") - hx = torch.randn(layers, 3, 4, device="cpu") - cx = torch.randn(layers, 3, 4, device="cpu") - cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx)) - - rnn = rnn.to(device) - input = input.to(device) - hx = hx.to(device) - cx = cx.to(device) - output, (hn, cn) = rnn(input, (hx, cx)) - - self.assertEqual(cpu_output, output) - self.assertEqual(cpu_hn, hn) - self.assertEqual(cpu_cn, cn) - - def test_lstm_backward(self, device="mps", dtype=torch.float32): - for layers in [1] if product_version < 13.0 else [1, 2, 5]: - lstm = nn.LSTM(2, 4, layers) # initialized globally for consistent parameters init - lstm.train() - - def get_results(device, inp, hx, cx): - rnn = lstm.to(device) - inp, hx, cx = inp.to(device), hx.to(device), cx.to(device) - - output, _ = rnn(inp, (hx, cx)) - f = output.sum() + def _lstm_helper(self, num_layers, dtype, device, bidirectional=False, bias=True, batch_first=False, + seq_len=3, batch_size=5, hidden_size=7, input_size=11, backward=False): + rnn = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + bidirectional=bidirectional, + batch_first=batch_first, + device="cpu" + ) + bidirectional_mul = 2 if bidirectional else 1 + + if batch_first: + input = torch.randn(batch_size, seq_len, input_size, device="cpu", dtype=dtype, requires_grad=backward) + hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype, + requires_grad=backward) + cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype, + requires_grad=backward) + else: + input = torch.randn(seq_len, batch_size, input_size, device="cpu", dtype=dtype, requires_grad=backward) + hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype, + requires_grad=backward) + cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype, + requires_grad=backward) - param_names, params = zip(*rnn.named_parameters()) - param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True)) + cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx)) - input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx]) - return output, param_grads, input_grad, hx_grad, cx_grad + rnn = rnn.to(device) + input = input.to(device) + hx = hx.to(device) + cx = cx.to(device) + output, (hn, cn) = rnn(input, (hx, cx)) - inp = torch.randn((5, 3, 2), requires_grad=True, dtype=dtype, device=device) - hx = torch.randn((layers, 3, 4), requires_grad=True, dtype=dtype, device=device) - cx = torch.randn((layers, 3, 4), requires_grad=True, dtype=dtype, device=device) + self.assertEqual(cpu_output, output) + self.assertEqual(cpu_hn, hn) + self.assertEqual(cpu_cn, cn) - cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad = get_results("cpu", inp, hx, cx) - mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad = get_results(device, inp, hx, cx) + def get_backward_results(rnn, device, inp, hx, cx): + rnn = rnn.to(device) + inp, hx, cx = inp.to(device), hx.to(device), cx.to(device) - self.assertEqual(cpu_hx_grad, mps_hx_grad) - self.assertEqual(cpu_cx_grad, mps_cx_grad) - self.assertEqual(cpu_output, mps_output) - self.assertEqual(cpu_input_grad, mps_input_grad) - for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad): - self.assertEqual(cpu_weight_grad, mps_weight_grad, f"mismatch in cpu:{cpu_name} vs mps:{mps_name}") + output, _ = rnn(inp, (hx, cx)) + f = 3 * output.sum() + (hx * cx).sum() - # test batch_first backward - lstm = nn.LSTM(2, 4, layers, batch_first=True) - lstm.train() + param_names, params = zip(*rnn.named_parameters()) + param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True)) - hx = torch.randn((layers, 5, 4), requires_grad=True, dtype=dtype, device=device) - cx = torch.randn((layers, 5, 4), requires_grad=True, dtype=dtype, device=device) + input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx]) + return output, param_grads, input_grad, hx_grad, cx_grad - cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad = get_results("cpu", inp, hx, cx) - mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad = get_results(device, inp, hx, cx) + if backward: + cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\ + get_backward_results(rnn, "cpu", input, hx, cx) + mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\ + get_backward_results(rnn, device, input, hx, cx) self.assertEqual(cpu_hx_grad, mps_hx_grad) self.assertEqual(cpu_cx_grad, mps_cx_grad) self.assertEqual(cpu_output, mps_output) self.assertEqual(cpu_input_grad, mps_input_grad) for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad): - self.assertEqual(cpu_weight_grad, mps_weight_grad, f"mismatch in cpu:{cpu_name} vs mps:{mps_name}") + self.assertEqual(cpu_weight_grad, mps_weight_grad, + f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}") + + LSTM_TEST_CASES = [ + dict(), # default + dict(batch_first=True), + dict(bias=False), + dict(bidirectional=True), + dict(batch_first=True, bias=False), + dict(bidirectional=True, bias=False), + dict(bidirectional=True, batch_first=True), + dict(bidirectional=True, batch_first=True, bias=False) + ] + + def test_lstm_forward(self, device="mps", dtype=torch.float32): + for num_layers in [1] if product_version < 13.0 else [1, 2, 5]: + for test_options in self.LSTM_TEST_CASES: + self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, **test_options) + def test_lstm_backward(self, device="mps", dtype=torch.float32): + for num_layers in [1] if product_version < 13.0 else [1, 2, 5]: + for test_options in self.LSTM_TEST_CASES: + self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, backward=True, **test_options) def test_RNN_cell_no_broadcasting(self): def test(cell_module, input, hx, input_size, hidden_size): @@ -9547,6 +10126,8 @@ def test_serialization_map_location(self): for t in [torch.double, torch.cdouble, torch.cfloat, torch.bfloat16]: del MPS_DTYPES[MPS_DTYPES.index(t)] +MPS_GRAD_DTYPES = [torch.float32, torch.float16] + class TestConsistency(TestCaseMPS): # TODO: This is only used while some ops are being added. @@ -9554,629 +10135,34 @@ class TestConsistency(TestCaseMPS): # This can be generated automatically in the `new_mps_allowlist.txt` file # by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU` # You most likely do NOT want to modify this manually - ALLOWLIST_OP = { - '__getitem__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - '__radd__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - '__rand__': ['b8', 'i16', 'i32', 'i64', 'u8'], - '__rdiv__': ['f16', 'f32', 'i16', 'i32', 'u8'], - '__rmatmul__': ['f32'], - '__rmul__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - '__ror__': ['b8', 'i16', 'i32', 'i64', 'u8'], - '__rpow__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - '__rxor__': ['b8', 'i16', 'i32', 'i64', 'u8'], - 'masked.argmax': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.argmin': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.log_softmax': ['f32'], - 'masked.logaddexp': ['f32'], - 'masked.logsumexp': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.norm': ['f16', 'f32'], - 'masked.normalize': ['f16', 'f32'], - 'masked.softmax': ['f32'], - 'masked.softmin': ['f32'], - 'masked.std': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.var': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'abs': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'], - 'acos': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'acosh': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'addbmm': ['f32'], - 'addcdiv': ['f32'], - 'addcmul': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'addmm': ['f32'], - 'addmv': ['f32'], - 'addr': ['f32'], - 'all': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'allclose': ['f16', 'f32'], - 'any': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'arange': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'argmax': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'argmin': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'amix': ['f32'], - 'asin': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'asinh': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'atan': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'atan2': ['f32'], - 'atanh': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'atleast_1d': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'atleast_2d': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'atleast_3d': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'baddbmm': ['f32'], - 'bitwise_and': ['b8', 'i16', 'i32', 'i64', 'u8'], - 'bitwise_left_shift': ['i16', 'i32', 'i64', 'u8'], - 'bitwise_not': ['b8', 'i16', 'i32', 'i64', 'u8'], - 'bitwise_or': ['b8', 'i16', 'i32', 'i64', 'u8'], - 'bitwise_right_shift': ['i16', 'i32', 'i64', 'u8'], - 'bitwise_xor': ['b8', 'i16', 'i32', 'i64', 'u8'], - 'block_diag': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'], - 'bmm': ['f32'], - 'broadcast_shapes': ['f32'], - 'byte': None, - 'cat': None, - 'ceil': ['f32', 'int32', 'int64', 'f16'], - 'chalf': None, - 'char': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'chunk': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'clamp': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'clamp_max': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'clamp_min': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'clone': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'column_stack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'combinations': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'conj': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'conj_physical': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'constant_pad_nd': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'contiguous': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'copysign': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'corrcoef': ['f32'], - 'cos': ['b8', 'f32', 'i16', 'i32', 'u8', 'i64'], - 'cosh': ['b8', 'f32', 'i16', 'i32', 'u8', 'i64'], - 'cov': ['f32'], - 'cumsum': ['i8', 'b8', 'f16', 'f32', 'i16', 'i32', 'i64'], - 'deg2rad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'diag': ['f32', 'i32'], - 'diag_embed': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'], - 'diagflat': ['f32', 'i32'], - 'diagonal_scatter': ['b8', 'u8', 'f16', 'f32', 'i16', 'i32', 'i64'], - 'diff': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'dist': ['f32'], - 'dot': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'einsum': ['f32'], - 'equal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'erf': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'exp': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'exp2': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'], - 'eye': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'fill': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'flatten': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'flip': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'fliplr': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'flipud': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'float': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'floor': ['f32', 'f16', 'i16', 'i32', 'i64'], - 'floor_divide': ['f32', 'f16'], - 'fmax': ['b8', 'f32', 'f16', 'i16', 'i32', 'i64', 'u8'], - 'fmin': ['b8', 'f32', 'f16', 'i16', 'i32', 'i64', 'u8'], - 'fmod': ['f32', 'f16', 'i16', 'i32', 'i64', 'u8'], - 'frac': ['f16', 'f32'], - 'gather': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'gradient': ['f16', 'f32', 'i16'], - 'ge': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'gt': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'half': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'hstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'hypot': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'index_select': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'index_add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'int': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'isclose': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'isfinite': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'isinf': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'isnan': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'isreal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'kron': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'linalg.matrix_norm': ['f16'], - 'linalg.matrix_power': ['f32'], - 'linalg.svd': ['f32'], - 'linalg.vector_norm': ['f16', 'f32'], - 'linspace': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'log': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'log10': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'log1p': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'log2': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'log_softmax': ['f32'], - 'logaddexp': ['f16', 'f32'], - 'logaddexp2': ['f16', 'f32'], - 'logical_and': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'logical_not': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'logical_or': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'logical_xor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'logit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'logspace': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'logsumexp': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'long': None, - 'masked_fill': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked_select': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked_scatter': ['i8', 'b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'matmul': ['f32'], - 'mm': ['f32'], - 'mv': ['f32'], - 'nan_to_num': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'neg': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'], - 'nn.functional.adaptive_max_pool1d': ['f32'], - 'nn.functional.adaptive_max_pool2d': ['f32'], - 'nn.functional.adaptive_avg_pool1d': ['f32'], - 'nn.functional.adaptive_avg_pool2d': ['f32'], - 'nn.functional.avg_pool1d': ['f32', 'i64'], - 'nn.functional.avg_pool2d': ['f32', 'i64'], - 'nn.functional.binary_cross_entropy': ['f32'], - 'nn.functional.binary_cross_entropy_with_logits': ['f32'], - 'nn.functional.celu': ['f32'], - 'nn.functional.conv1d': ['f32'], - 'nn.functional.conv2d': ['f32'], - 'nn.functional.conv_transpose1d': ['f32'], - 'nn.functional.cosine_embedding_loss': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.cosine_similarity': ['f32'], - 'nn.functional.elu': ['f32'], - 'nn.functional.feature_alpha_dropout': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.embedding': ['f16', 'f32'], - 'nn.functional.gaussian_nll_loss': ['f32'], - 'nn.functional.glu': ['f32'], - 'nn.functional.group_norm': ['f32'], - 'nn.functional.hardsigmoid': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.hardtanh': ['f32', 'i16', 'i32', 'i64'], - 'nn.functional.hinge_embedding_loss': ['f32'], - 'nn.functional.huber_loss': ['f16', 'f32'], - 'nn.functional.instance_norm': ['f32'], - 'nn.functional.kl_div': ['f32', 'i16', 'i32', 'i64'], - 'nn.functional.l1_loss': ['f16', 'f32'], - 'nn.functional.leaky_relu': ['f32'], - 'nn.functional.linear': ['f32'], - 'nn.functional.local_response_norm': ['f32'], - 'nn.functional.logsigmoid': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.margin_ranking_loss': ['f32', 'i16', 'i32'], - 'nn.functional.max_pool1d': ['f32'], - 'nn.functional.max_pool2d': ['f32'], - 'max_pool2d_with_indices_backward': ['f32'], - 'nn.functional.mse_loss': ['f16', 'f32'], - 'nn.functional.nll_loss': ['f32'], - 'nn.functional.pad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.padconstant': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.padreflect': ['f32'], - 'nn.functional.padreplicate': ['f32'], - # TODO: add f16 test case after solve the accuracy issue, - # see https://github.com/pytorch/pytorch/pull/95166#issuecomment-1439359181. - 'nn.functional.pairwise_distance': ['f32', 'i16', 'i32', 'i64'], - 'nn.functional.poisson_nll_loss': ['f32', 'i16', 'i32', 'u8'], - 'nn.functional.prelu': ['f32'], - 'nn.functional.relu': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.relu6': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.selu': ['f32'], - 'nn.functional.silu': ['f32'], - 'nn.functional.smooth_l1_loss': ['f16', 'f32'], - 'nn.functional.soft_margin_loss': ['f32'], - 'nn.functional.softmin': ['f32'], - 'nn.functional.softplus': ['f32'], - 'nn.functional.softsign': ['f16', 'f32', 'i16', 'u8'], - 'nn.functional.tanhshrink': ['f32', 'i16', 'i32', 'u8'], - 'nn.functional.threshold': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.triplet_margin_loss': ['f32', 'i16', 'i32', 'i64'], - 'nn.functional.triplet_margin_with_distance_loss': ['f32', 'i16', 'i32', 'i64'], - 'nn.functional.upsample_bilinear': ['f32'], - 'nn.functional.upsample_nearest': ['f32'], - 'norm': ['f32', 'f16'], - 'positive': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'pow': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'put': None, - 'rad2deg': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'real': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'reciprocal': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'], - 'remainder' : None, - 'repeat': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'repeat_interleave': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'resize_': ['b8', 'i16', 'i32', 'i64', 'u8'], - 'resize_as_': ['b8', 'i16', 'i32', 'i64', 'u8'], - 'resolve_conj': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'resolve_neg': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'roll': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'rot90': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'round': ['f32', 'f16', 'i16', 'i32', 'i64'], - 'rsqrt': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'scatter_add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'select_scatter': None, - 'sgn': None, - 'short': None, - 'sigmoid': None, - 'sign': None, - 'sin': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'sinh': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'slice_scatter': ['b8', 'u8', 'f16', 'f32', 'i16', 'i32', 'i64'], - 'softmax': ['f32'], - 'special.ndtr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'split': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'sqrt': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'square': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'squeeze': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'stack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'sub': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'sum_to_size': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'svd': ['f32'], - 't': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'tan': ['b8', 'i16', 'i32', 'u8'], - 'tanh': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'tensordot': ['f32'], - 'tensor_split': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'tile': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'topk': ['f32', 'f16'], - 'trapz': ['f16', 'f32', 'i16', 'i32', 'i64'], - 'sort': ['f32', 'i16', 'i32', 'i64'], - 'argsort': ['f32', 'i16', 'i32', 'i64'], - 'tril': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'tril_indices': ['i32', 'i64'], - 'triu': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'triu_indices': ['i32', 'i64'], - 'true_divide': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'trunc': ['f32'], - 'unbind': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'unflatten': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'unsqueeze': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'view': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'view_as': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'vsplit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'vstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'zero_': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'where': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'nonzero': ['b8', 'u8', 'f16', 'f32', 'i16', 'i32', 'i64'], - 'cross': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'linalg.cross': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'unique_consecutive': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'std': ['f16', 'f32'], - 'var': ['f16', 'f32'], - 'amax': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'amin': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'sum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'prod': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'mean': ['f16', 'f32'], - 'count_nonzero': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'xlogy': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.amax': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.amin': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.mean': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.prod': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.sum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'native_layer_norm': ['torch.float32'], - 'nn.functional.layer_norm': ['torch.float32'], - 'nn.functional.bilinear': ['f32'], - 'linalg.solve_triangular': ['f32'], - 'triangular_solve': ['f32'], - 'trace': None, - '_native_batch_norm_legit': ['f32'], - 'native_batch_norm': ['f32'], - 'minreduction_with_dim': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'maxreduction_with_dim': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'linalg.inv': ['f32'], - 'linalg.inv_ex': ['f32'], - 'mH': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'mT': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'T': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'H': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - } - - - ALLOWLIST_OP_GRAD = { - '__radd__': ['f16', 'f32'], - '__rdiv__': ['f16', 'f32'], - '__rmatmul__': ['f32'], - '__rmul__': ['f16', 'f32'], - '__rpow__': ['f32'], - 'masked.log_softmax': ['f32'], - 'masked.logaddexp': ['f32'], - 'masked.softmax': ['f32'], - 'masked.softmin': ['f32'], - 'masked.std': ['f32'], - 'masked_scatter': ['f16', 'f32'], - 'abs': ['f16', 'f32'], - 'acos': ['f32'], - 'acosh': ['f32'], - 'add': ['f16', 'f32'], - 'addbmm': ['f32'], - 'addcdiv': ['f32'], - 'addcmul': ['f32'], - 'addmm': ['f32'], - 'addmv': ['f32'], - 'addr': ['f32'], - 'all': ['f16', 'f32'], - 'any': ['f16', 'f32'], - 'arange': ['f16', 'f32'], - 'argmax': ['f16', 'f32'], - 'argmin': ['f16', 'f32'], - 'asin': ['f32'], - 'asinh': ['f32'], - 'atan': ['f32'], - 'atan2': ['f32'], - 'atleast_1d': ['f16', 'f32'], - 'atleast_2d': ['f16', 'f32'], - 'atleast_3d': ['f16', 'f32'], - 'baddbmm': ['f32'], - 'block_diag': ['f16', 'f32'], - 'bmm': ['f32'], - 'broadcast_shapes': ['f32'], - 'ceil': ['f32'], - 'chunk': ['f16', 'f32'], - 'clone': ['f16', 'f32'], - 'column_stack': ['f16', 'f32'], - 'conj': ['f16', 'f32'], - 'conj_physical': ['f16', 'f32'], - 'contiguous': ['f16', 'f32'], - 'copysign': ['f16', 'f32'], - 'corrcoef': ['f32'], - 'cos': ['f32'], - 'cosh': ['f32'], - 'cumsum': ['f16', 'f32'], - 'deg2rad': ['f16', 'f32'], - 'diag': ['f32'], - 'diag_embed': ['f16', 'f32'], - 'diagflat': ['f32'], - 'diagonal_scatter': ['f16', 'f32'], - 'diff': ['f16', 'f32'], - 'dist': ['f32'], - 'dot': ['f32'], - 'einsum': ['f32'], - 'erf': ['f32'], - 'exp': ['f32'], - 'exp2': ['f16', 'f32'], - 'fill': ['f16', 'f32'], - 'flatten': ['f16', 'f32'], - 'flip': ['f16', 'f32'], - 'fliplr': ['f16', 'f32'], - 'flipud': ['f16', 'f32'], - 'float': ['f32'], - 'floor': ['f32'], - 'fmax': ['f16', 'f32'], - 'fmin': ['f16', 'f32'], - 'gradient': ['f32'], - 'half': ['f16'], - 'hstack': ['f16', 'f32'], - 'hypot': ['f16', 'f32'], - 'index_select': ['f16', 'f32'], - 'index_add': ['f16', 'f32'], - 'isclose': ['f16', 'f32'], - 'isfinite': ['f16', 'f32'], - 'isinf': ['f16', 'f32'], - 'isnan': ['f16', 'f32'], - 'isreal': ['f16', 'f32'], - 'kron': ['f32'], - 'linalg.matrix_norm': ['f16'], - 'linalg.svd': ['f32'], - 'linspace': ['f16', 'f32'], - 'log': ['f32'], - 'log10': ['f32'], - 'log1p': ['f32'], - 'log2': ['f32'], - 'log_softmax': ['f32'], - 'logaddexp': ['f32'], - 'logical_not': ['f16', 'f32'], - 'logit': ['f16', 'f32'], - 'logspace': ['f32'], - 'matmul': ['f32'], - 'mm': ['f32'], - 'mv': ['f32'], - 'neg': ['f16', 'f32'], - 'nn.functional.adaptive_max_pool1d': ['f32'], - 'nn.functional.adaptive_max_pool2d': ['f32'], - 'nn.functional.adaptive_avg_pool1d': ['f32'], - 'nn.functional.adaptive_avg_pool2d': ['f32'], - 'nn.functional.avg_pool1d': ['f32'], - 'nn.functional.avg_pool2d': ['f32'], - 'nn.functional.binary_cross_entropy': ['f32'], - 'nn.functional.celu': ['f32'], - 'nn.functional.conv1d': ['f32'], - 'nn.functional.conv2d': ['f32'], - 'nn.functional.conv_transpose1d': ['f32'], - 'nn.functional.cosine_embedding_loss': ['f32'], - 'nn.functional.elu': ['f32'], - 'nn.functional.feature_alpha_dropout': ['f16', 'f32'], - 'nn.functional.glu': ['f32'], - 'nn.functional.hardsigmoid': ['f16', 'f32'], - 'nn.functional.hardtanh': ['f32'], - 'nn.functional.hinge_embedding_loss': ['f32'], - 'nn.functional.huber_loss': ['f16', 'f32'], - 'nn.functional.instance_norm': ['f32'], - 'nn.functional.kl_div': ['f32'], - 'nn.functional.l1_loss': ['f16', 'f32'], - 'nn.functional.leaky_relu': ['f32'], - 'nn.functional.local_response_norm': ['f32'], - 'nn.functional.logsigmoid': ['f16', 'f32'], - 'nn.functional.margin_ranking_loss': ['f32'], - 'nn.functional.max_pool1d': ['f32'], - 'nn.functional.max_pool2d': ['f32'], - 'nn.functional.mse_loss': ['f32'], - 'nn.functional.nll_loss': ['f32'], - 'nn.functional.pad': ['f16', 'f32', 'i16', 'i32', 'i64'], - # TODO: add f16 test case after solve the accuracy issue, - # see https://github.com/pytorch/pytorch/pull/95166#issuecomment-1439359181. - 'nn.functional.pairwise_distance': ['f32'], - 'nn.functional.poisson_nll_loss': ['f32'], - 'nn.functional.relu': ['f32'], - 'nn.functional.relu6': ['f32'], - 'nn.functional.selu': ['f32'], - 'nn.functional.silu': ['f32'], - 'nn.functional.soft_margin_loss': ['f32'], - 'nn.functional.softmin': ['f32'], - 'nn.functional.softplus': ['f32'], - 'nn.functional.softsign': ['f16', 'f32'], - 'nn.functional.smooth_l1_loss': ['f32'], - 'nn.functional.threshold': ['f32'], - 'nn.functional.triplet_margin_loss': ['f32'], - 'nn.functional.triplet_margin_with_distance_loss': ['f32'], - 'nn.functional.upsample_bilinear': ['f32'], - 'norm': ['f32', 'f16'], - 'positive': ['f16', 'f32'], - 'pow': ['f32'], - 'rad2deg': ['f16', 'f32'], - 'real': ['f16', 'f32'], - 'reciprocal': ['f16', 'f32'], - 'repeat': ['f16', 'f32'], - 'repeat_interleave': ['f16', 'f32'], - 'resolve_conj': ['f16', 'f32'], - 'resolve_neg': ['f16', 'f32'], - 'roll': ['f16', 'f32'], - 'round': ['f32'], - 'rsqrt': ['f32'], - 'select_scatter': ['f16', 'f32'], - 'sign': ['f16', 'f32'], - 'sin': ['f32'], - 'sinh': ['f32'], - 'slice_scatter': ['f16', 'f32'], - 'softmax': ['f32'], - 'split': ['f16', 'f32'], - 'sqrt': ['f32'], - 'square': ['f16', 'f32'], - 'squeeze': ['f16', 'f32'], - 'stack': ['f16', 'f32'], - 'sub': ['f32'], - 'sum_to_size': ['f16', 'f32'], - 'svd': ['f32'], - 't': ['f16', 'f32'], - 'tanh': ['f32'], - 'tensordot': ['f32'], - 'tile': ['f16', 'f32'], - 'tril': ['f16', 'f32'], - 'triu': ['f16', 'f32'], - 'true_divide': ['f16', 'f32'], - 'trunc': ['f32'], - 'unbind': ['f16', 'f32'], - 'unflatten': ['f16', 'f32'], - 'unsqueeze': ['f16', 'f32'], - 'view': ['f16', 'f32'], - 'view_as': ['f16', 'f32'], - 'vsplit': ['f16', 'f32'], - 'vstack': ['f16', 'f32'], - 'xlogy': ['f16', 'f32'], - 'zero_': ['f16', 'f32'], - 'linalg.solve_triangular': ['f32'], - 'triangular_solve': ['f32'], - '_native_batch_norm_legit': ['f32'], - 'native_batch_norm': ['f32'], - 'native_layer_norm': ['f32'], - 'nn.functional.gelu': ['f32'], - 'nn.functional.bilinear': ['f32'], - 'nn.functional.prelu': ['f32'], - } - - # These ops that are problematic. So never run them even when - # generating the new allowlist. - # If the dtype list is None, all dtypes are excluded. - # All the entries in this list should be removed - BLOCKLIST = { - # Functions that hang - 'masked_fill': [torch.bool, torch.uint8, torch.float32], 'where': [torch.bool], - # + forward when requires_grad=True or running backward - 'masked.mean': [torch.bool, torch.float16], - 'masked.prod': [torch.bool], - 'masked.sum': [torch.bool], - - # Functions that hard crash - 'std': [torch.float16], - 'stft': [torch.float32], 'var': [torch.float16], - # + forward when requires_grad=True or running backward - 'nn.functional.embedding': [torch.float32, torch.float16], - - 'as_strided_scatter': [torch.uint8], - 'atan2': [torch.int64], - 'bfloat16': None, - 'block_diag': [torch.uint8], - 'diag_embed': [torch.uint8], - 'diagonal_scatter': [torch.uint8], - 'nn.functional.conv_transpose3d': [torch.int64, torch.float32], - 'nn.functional.local_response_norm': [torch.int64], - 'nn.functional.padcircular': [torch.uint8], - - - - # These were moved from ALLOWLIST to BLOCK as they are not working - # locally - 'tile': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - '__radd__': ['torch.bool', 'torch.uint8'], - '__rmul__': ['torch.uint8'], - 'neg': ['torch.uint8'], - 'add': ['torch.bool', 'torch.uint8'], - 'addr': ['torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'diag': ['torch.int64'], - 'diagflat': ['torch.int64'], - - # Functions that are flaky - # These are detected as "ok" by the expect case but actually fail to run sometimes - 'as_strided': None, - 'broadcast_tensors': None, - 'broadcast': None, - 'broadcast_to': None, - 'diagonal': None, - 'divfloor_rounding': None, - 'divno_rounding_mode': None, - 'divtrunc_rounding': None, - 'dsplit': None, - 'hsplit': None, - 'empty': None, - 'expand_as': None, - 'expand': None, - 'ge': None, - 'ne': None, - 'le': None, - 'lt': None, - 'gt': None, - 'transpose': None, - 'splitlist_args': None, - 'select': None, - 'reshape': None, - 'reshape_as': None, - 'permute': None, - 'norm': None, - 'nn.functional.pixel_unshuffle': None, - 'nn.functional.pixel_shuffle': None, - 'nn.functional.cross_entropy': None, - 'nn.functional.one_hot': None, - 'narrow': None, - 'movedim': None, - 'minreduction_with_dim': None, - 'minreduction_no_dim': None, - 'minbinary': None, - 'meshgridvariadic_tensors': None, - 'meshgridlist_of_tensors': None, - 'maxreduction_with_dim': None, - 'maxreduction_no_dim': None, - 'maxbinary': None, - 'maximum': None, - 'minimum': None, - 'outer': None, - 'softmaxwith_dtype': None, - 'rounddecimals_neg_3': None, - 'rounddecimals_3': None, - 'rounddecimals_0': None, - 'normnuc': None, - 'nn.functional.softminwith_dtype': None, - 'nn.functional.feature_alpha_dropoutwith_train': None, - 'log_softmaxwith_dtype': None, - 'split_with_sizes': None, - 'trapezoid': None, - 'eq': None, - 'mul': None, - 'cartesian_prod': None, - 'bool': None, - 'inner': None, - 'dstack': None, - 'take_along_dim': None, - } FP16_LOW_PRECISION_LIST = { 'add', 'sub', 'div', '__rdiv__', '__rmul__', 'nn.functional.huber_loss', 'true_divide', 'kron', - 'gradient', 'var', 'std', + 'gradient', 'var', 'std', 'ldexp', 'linalg.vector_norm', - 'masked.sum', 'masked.std', - 'masked.var', + 'addr', 'var_mean', + 'var_mean_unbiased', + + # for macOS 12 + 'masked.normalize', 'masked.sum', 'masked.var', + 'outer', + 'sum_to_size', 'sum', + 'mul', + 'nansum', 'nanmean', + 'norm', + } + + FP32_LOW_PRECISION_LIST = { + # conv2d and conv_transpose2d results have a very small + # difference compared to CPU/CUDA, so we use lower precision on FP32 + 'nn.functional.conv2d', + 'nn.functional.conv_transpose2d', + 'matmul', '__rmatmul__', + 'linalg.multi_dot', + 'addbmm', } # Used for accept mode only @@ -10187,29 +10173,60 @@ class TestConsistency(TestCaseMPS): def test_output_match(self, device, dtype, op): self.assertEqual(device, "cpu") key = op.name + op.variant_test_name + run_grad_test = True - if key in self.BLOCKLIST: - if self.BLOCKLIST[key] is None or dtype in self.BLOCKLIST[key]: - self.skipTest(f"Running test with {op.name} hangs so skipping") + def get_samples(): + return op.sample_inputs(device, dtype, requires_grad=(dtype.is_floating_point or dtype.is_complex)) + cpu_samples = get_samples() - # Make this an expecttest manually - # When this env variable is set, generate a new ALLOWLIST_OP - # that reflects the current state of what passes or not - if os.environ.get("EXPECTTEST_ACCEPT", None) == "1": - generate_new_truth = True - else: - generate_new_truth = False + all_backward_pass = True + for cpu_sample in cpu_samples: + # + # Forward check + # + mps_sample = cpu_sample.transform( + lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x) + + cpu_args = [cpu_sample.input] + list(cpu_sample.args) + cpu_kwargs = cpu_sample.kwargs + mps_args = [mps_sample.input] + list(mps_sample.args) + mps_kwargs = mps_sample.kwargs + + # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only + if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)): + mps_args[1] = cpu_args[1] + + cpu_out = op(*cpu_args, **cpu_kwargs) + mps_out = op(*mps_args, **mps_kwargs) + + if (op.name in self.FP32_LOW_PRECISION_LIST) and dtype == torch.float32: + atol = 1e-4 + rtol = 3e-5 + elif op.name in self.FP16_LOW_PRECISION_LIST and dtype == torch.float16: + atol = 1e-2 + rtol = 1e-2 + elif op.name == "masked.mean": + atol = 7e-4 + rtol = 2e-3 + elif op.name == "native_layer_norm": + atol = 1e-4 + rtol = 1.3e-5 + elif op.name in ["pow", "__rpow__"]: + atol = 1e-6 + rtol = 4e-6 + else: + atol = None + rtol = None - run_grad_test = True - if not generate_new_truth: - if op.name not in self.ALLOWLIST_OP: - self.skipTest(f"{op.name} is not in the allow list for test on MPS") - elif self.ALLOWLIST_OP[op.name] is not None: - if dtype_abbrs[dtype] not in self.ALLOWLIST_OP[op.name]: - self.skipTest(f"{op.name} is in the allow list for MPS but {dtype} is excluded") + self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol) - if op.name not in self.ALLOWLIST_OP_GRAD or dtype_abbrs[dtype] not in self.ALLOWLIST_OP_GRAD[op.name]: - run_grad_test = False + + @ops(mps_ops_grad_modifier(copy.deepcopy(op_db)), allowed_dtypes=MPS_GRAD_DTYPES) + def test_output_grad_match(self, device, dtype, op): + self.assertEqual(device, "cpu") + key = op.name + op.variant_test_name + + run_grad_test = True def get_samples(): return op.sample_inputs(device, dtype, requires_grad=(dtype.is_floating_point or dtype.is_complex)) @@ -10226,8 +10243,6 @@ def get_samples(): mps_sample = cpu_sample.transform( lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x) - # TODO: This checks only the function variant. We should also check the method and inplace version - # when they exist cpu_args = [cpu_sample.input] + list(cpu_sample.args) cpu_kwargs = cpu_sample.kwargs mps_args = [mps_sample.input] + list(mps_sample.args) @@ -10240,21 +10255,26 @@ def get_samples(): cpu_out = op(*cpu_args, **cpu_kwargs) mps_out = op(*mps_args, **mps_kwargs) - if op.name == "nn.functional.conv2d" and dtype == torch.float32: + if (op.name in self.FP32_LOW_PRECISION_LIST) and dtype == torch.float32: atol = 1e-4 rtol = 3e-5 - elif op.name in self.FP16_LOW_PRECISION_LIST and dtype == torch.float16: + elif op.name == "nn.functional.conv2d" or op.name == "linalg.multi_dot" and dtype == torch.float32: + atol = 1e-4 + rtol = 3e-5 + elif (op.name in self.FP16_LOW_PRECISION_LIST) and dtype == torch.float16: atol = 1e-2 rtol = 1e-2 - elif op.name == "masked.mean": + elif (op.name == "masked.mean"): atol = 7e-4 rtol = 2e-3 - elif op.name == "native_layer_norm": + elif (op.name == "native_layer_norm"): atol = 1e-4 rtol = 1.3e-5 - elif op.name in ["pow", "__rpow__"]: - atol = 1e-6 - rtol = 4e-6 + elif op.name == "norm" and dtype == torch.float16: + atol = 7e-4 + rtol = 1.5e-3 + elif op.name == "unique" and cpu_kwargs["sorted"] is False: + continue else: atol = None rtol = None @@ -10262,82 +10282,44 @@ def get_samples(): self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol) except Exception as e: - if any(s in str(e).lower() for s in ["int64", "float16", "div truc rounding"]): - self.skipTest(f"Expected Runtime Error: {str(e)}") - - if not generate_new_truth: - raise e + raise e forward_failed = True all_forward_pass = False - if not (dtype.is_floating_point or dtype.is_complex): - # Maybe we should error here instead? - continue - # # Backward check # + if forward_failed: + # We would've failed immediately anyway, but this error is clearer + # We error instead of continuing so that all_backward_pass would not be True + raise RuntimeError("Forward pass already failed") - # Skip the grad test if it is not part of the allow list - if not generate_new_truth and not run_grad_test: - # TODO: maybe there is a way to print only when we have -v - # if i == 0: - # print(f"Skipping gradient check because {op.name} is not on the allow list") - continue - - try: - if forward_failed: - # We would've failed immediately anyway, but this error is clearer - # We error instead of continuing so that all_backward_pass would not be True - raise RuntimeError("Forward pass already failed") - - cpu_out = (cpu_out,) if isinstance(cpu_out, torch.Tensor) else tuple(cpu_out) - mps_out = (mps_out,) if isinstance(mps_out, torch.Tensor) else tuple(mps_out) - - def req_grad(t): - return isinstance(t, torch.Tensor) and t.requires_grad + cpu_out = (cpu_out,) if isinstance(cpu_out, torch.Tensor) else tuple(cpu_out) + mps_out = (mps_out,) if isinstance(mps_out, torch.Tensor) else tuple(mps_out) - diff_cpu_out = tuple(t for t in cpu_out if req_grad(t)) - diff_mps_out = tuple(t for t in mps_out if req_grad(t)) - diff_cpu_arg = tuple(t for t in pytree.tree_flatten((cpu_args, cpu_kwargs))[0] if req_grad(t)) - diff_mps_arg = tuple(t for t in pytree.tree_flatten((mps_args, mps_kwargs))[0] if req_grad(t)) - self.assertEqual(len(diff_cpu_out), len(diff_mps_out)) - self.assertEqual(len(diff_cpu_arg), len(diff_mps_arg)) + def req_grad(t): + return isinstance(t, torch.Tensor) and t.requires_grad - if len(diff_cpu_out) == 0: - continue - # rand_like does not work with certain dtypes, so cast to double and cast back - cpu_grad_outputs = tuple(torch.rand_like(t.to(dtype=torch.double)).to(dtype=dtype) for t in diff_cpu_out) - mps_grad_outputs = tuple(t.to("mps") for t in cpu_grad_outputs) + diff_cpu_out = tuple(t for t in cpu_out if req_grad(t)) + diff_mps_out = tuple(t for t in mps_out if req_grad(t)) + diff_cpu_arg = tuple(t for t in pytree.tree_flatten((cpu_args, cpu_kwargs))[0] if req_grad(t)) + diff_mps_arg = tuple(t for t in pytree.tree_flatten((mps_args, mps_kwargs))[0] if req_grad(t)) + self.assertEqual(len(diff_cpu_out), len(diff_mps_out)) + self.assertEqual(len(diff_cpu_arg), len(diff_mps_arg)) - # Compare computed gradients with cpu given random grad_output vector - # Sometimes when the derivative is 0, we just don't bother creating the graph - # allow_unused is needed in those cases. - cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True) - mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True) + if len(diff_cpu_out) == 0: + continue + # rand_like does not work with certain dtypes, so cast to double and cast back + cpu_grad_outputs = tuple(torch.rand_like(t.to(dtype=torch.double)).to(dtype=dtype) for t in diff_cpu_out) + mps_grad_outputs = tuple(t.to("mps") for t in cpu_grad_outputs) - self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol) - except Exception as e: - if not generate_new_truth: - raise e - all_backward_pass = False - - if all_forward_pass and generate_new_truth: - if dtype_abbrs[dtype] not in self.NEW_ALLOW_LIST[op.name]: - self.NEW_ALLOW_LIST[op.name].append(dtype_abbrs[dtype]) - # We could write it only once. But I don't know how to detect that the current test is the last one - # So each test append to the dict and write it. - with open("new_mps_allowlist.txt", "w") as f: - pprint.pprint(self.NEW_ALLOW_LIST, stream=f) - - if all_backward_pass and generate_new_truth and dtype.is_floating_point: - if dtype_abbrs[dtype] not in self.NEW_ALLOW_LIST_GRAD[op.name]: - self.NEW_ALLOW_LIST_GRAD[op.name].append(dtype_abbrs[dtype]) - # We could write it only once. But I don't know how to detect that the current test is the last one - # So each test append to the dict and write it. - with open("new_mps_allowlist_grad.txt", "w") as f: - pprint.pprint(self.NEW_ALLOW_LIST_GRAD, stream=f) + # Compare computed gradients with cpu given random grad_output vector + # Sometimes when the derivative is 0, we just don't bother creating the graph + # allow_unused is needed in those cases. + cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True) + mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True) + self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol) # Copied from `TestCommon` in `test_ops.py`, just enough to duplicate the `test_numpy_ref` for MPS @skipIfSlowGradcheckEnv diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 52c53d3f244c..d8f2d6f1c0b4 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -970,6 +970,16 @@ def forward(self, crop_camera_1, mask_1): index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_2); crop_camera_1 = mask_1 = view_2 = None return None""") + def test_unbacked_slice(self): + def f(x, m): + x = x[m] + return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)] + + make_fx(f, tracing_mode="symbolic")( + torch.randn((12, 3, 3)), + torch.randint(0, 2, (12,), dtype=torch.bool) + ) + @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") def test_unbacked_batch_resnet(self): mod = torchvision.models.resnet18() @@ -1323,7 +1333,6 @@ def f(a, b, c, d, e): skip('masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition - xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back... @@ -1357,7 +1366,6 @@ def f(a, b, c, d, e): xfail('fft.rfft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('fft.rfft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('fft.rfftn', ''), # aten.size.default - couldn't find symbolic meta function/decomposition - xfail('unflatten', ''), # RuntimeError: Trying to call aten.size on a tensor with symbolic shapes... xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition xfail('geqrf', ''), # aten.geqrf.default - couldn't find symbolic meta function/decomposition xfail('gradient', ''), # aten.size.default - couldn't find symbolic meta function/decomposition @@ -1471,7 +1479,6 @@ def f(a, b, c, d, e): xfail('special.scaled_modified_bessel_k0', ''), # aten.special_scaled_modified_bessel_k0.default - couldn't find symbo... xfail('special.scaled_modified_bessel_k1', ''), # aten.special_scaled_modified_bessel_k1.default - couldn't find symbo... xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at... - xfail('sum_to_size', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('svd_lowrank', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition xfail('take_along_dim', ''), # dtype of indices should be Long but got Float xfail('take', ''), # aten.take.default - couldn't find symbolic meta function/decomposition diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 41a12c0932e1..a9e8ddce09ee 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -2099,14 +2099,14 @@ def _test_spadd_shape(fn, nnz, shape): res_sparse_sparse = fn(y, x) res_dense_sparse = fn(y.to_dense(), x) res_sparse_dense = fn(y, x.to_dense()) - expected = fn(y.to_dense(), x.to_dense()).to_sparse_csr() + expected = fn(y.to_dense(), x.to_dense()) self.assertEqual(res_sparse_sparse, expected) # TODO: While result of mul(dense, csr) is csr, it is not fully compressed. # That means it may contain materialized zeros, since the dense argument # is converted according to the sparsity pattern of csr. In the future # we might require the result to be fully compressed. - self.assertEqual(res_dense_sparse.to_dense(), expected.to_dense()) - self.assertEqual(res_sparse_dense.to_dense(), expected.to_dense()) + self.assertEqual(res_dense_sparse, expected) + self.assertEqual(res_sparse_dense, expected) # Grad comparison x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32) diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index 03b0bc99dba8..1200a6df8825 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -2,6 +2,7 @@ # Owner(s): ["oncall: pt2"] import itertools +import sys import sympy from torch.testing._internal.common_utils import ( @@ -50,6 +51,8 @@ 2**24, 2**32, 2**37 - 1, + sys.maxsize - 1, + sys.maxsize, ] # less constants for N^2 situations LESS_CONSTANTS = [-1, 0, 1, 2, 100] diff --git a/test/test_torch.py b/test/test_torch.py index 12dea3ba8433..076f764eb276 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -55,7 +55,7 @@ tf32_on_and_off, tf32_is_not_fp32, TEST_CUDNN) from torch.testing._internal.common_dtype import ( floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types, - all_types_and, floating_types, floating_and_complex_types, + all_types_and, floating_types, floating_and_complex_types, integral_types_and ) # Protects against includes accidentally setting the default dtype @@ -629,12 +629,6 @@ def test_scalar_check(self, device): zero_d_uint8 = torch.tensor(1, dtype=torch.uint8, device=device) one_d_uint8 = torch.tensor([1], dtype=torch.uint8, device=device) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - self.assertEqual((1,), torch.masked_select(zero_d_uint8, zero_d_uint8).shape) - self.assertEqual((1,), torch.masked_select(zero_d_uint8, one_d_uint8).shape) - self.assertEqual((1,), torch.masked_select(one_d_uint8, zero_d_uint8).shape) - # mode self.assertEqual([(), ()], [x.shape for x in torch.mode(zero_d, dim=0, keepdim=True)]) self.assertEqual([(), ()], [x.shape for x in torch.mode(zero_d, dim=0, keepdim=False)]) @@ -3688,16 +3682,17 @@ def test_masked_select(self, device, dtype): warn = 'masked_select received a mask with dtype torch.uint8,' else: warn = 'indexing with dtype torch.uint8 is now deprecated, pl' - for maskType in [torch.uint8, torch.bool]: + for maskType in integral_types_and(torch.bool): num_src = 10 src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dtype, device=device) mask = torch.randint(2, (num_src,), device=device, dtype=maskType) - with warnings.catch_warnings(record=True) as w: + if maskType is not torch.bool: + with self.assertRaisesRegex(RuntimeError, r'expected BoolTensor for mask'): + dst = src.masked_select(mask) + continue + else: dst = src.masked_select(mask) - if maskType is torch.uint8: - self.assertEqual(len(w), 1) - self.assertEqual(str(w[0].message)[0:53], str(warn)) dst2 = [] for i in range(num_src): if mask[i]: diff --git a/test/test_transformers.py b/test/test_transformers.py index 105222767905..b098bfd09ea1 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -510,6 +510,36 @@ def perm_fn(x): with cm: _test(batch_first, training, enable_nested_tensor) + + def test_padding_and_src_mask_bool(self): + encoder_layer = nn.TransformerEncoderLayer( + d_model=16, + nhead=2, + dim_feedforward=32, + dropout=0.1, + activation='relu', + batch_first=True, + ) + encoder_norm = nn.LayerNorm(16) + encoder = nn.TransformerEncoder( + encoder_layer, 2, encoder_norm + ) + + inputs = torch.randn(2, 3, 16) + + src_mask = torch.ones(3, 3, dtype=torch.bool).triu_(diagonal=1) + input_seq_len = torch.tensor([3, 2]) + padding_mask = ( + torch.arange(3)[None, :].cpu() >= input_seq_len[:, None] + ) + + encoder( + inputs, + mask=src_mask, + src_key_padding_mask=padding_mask, + ) + + @unittest.skipIf(not TEST_FAIRSEQ, "Fairseq not found") @unittest.skipIf(not TEST_CUDA, 'CUDA not available') def test_decoder_only_layer(self): diff --git a/third_party/kineto b/third_party/kineto index e121ba84c711..9380d6405513 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit e121ba84c71102656d011338bcb616419a241ad1 +Subproject commit 9380d64055137e609709b4b72230143848ca3465 diff --git a/tools/BUCK.bzl b/tools/BUCK.bzl index 3d685f4bab15..24b2ba562f6d 100644 --- a/tools/BUCK.bzl +++ b/tools/BUCK.bzl @@ -288,3 +288,20 @@ def define_tools_targets( ":autograd", ], ) + + python_test( + name = "test_torchgen_executorch", + srcs = [ + "test/test_executorch_custom_ops.py", + "test/test_executorch_gen.py", + "test/test_executorch_signatures.py", + "test/test_executorch_types.py", + "test/test_executorch_unboxing.py", + ], + contacts = contacts, + visibility = ["PUBLIC"], + deps = [ + torchgen_deps, + "fbsource//third-party/pypi/expecttest:expecttest", + ], + ) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index de13e63d75d0..cf5f7473dd49 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -259,8 +259,8 @@ vec2: maybe_multiply(grad.t().mv(vec1.conj()), alpha.conj()) result: maybe_multiply(self_t, beta) + maybe_multiply(vec1_t.outer(vec2_p), alpha) + maybe_multiply(vec1_p.outer(vec2_t), alpha) -- name: affine_grid_generator(Tensor theta, int[] size, bool align_corners) -> Tensor - theta: affine_grid_generator_backward(grad, size, align_corners) +- name: affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor + theta: affine_grid_generator_backward_symint(grad, size, align_corners) - name: alias(Tensor(a) self) -> Tensor(a) self: grad @@ -2800,7 +2800,7 @@ AutogradCUDA: self: grad.reshape_as(self) + 1 -- name: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +- name: _efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor output_differentiability: [False] - name: scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor @@ -2983,40 +2983,6 @@ - name: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[] exponent: pow_backward_exponent(grads[i], self, exponent[i], result[i]) -# Definitions below would be able to be generated by `torchgen` e.g. , but currently I see some weird numerical errors. -- name: _foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[] - self: mul_tensor_backward(grads[i], other[i], self[i].scalar_type()) - other: mul_tensor_backward(grads[i], self[i], other[i].scalar_type()) - -- name: _foreach_sub.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] - self: handle_r_to_c(self[i].scalar_type(), grads[i]) - other: handle_r_to_c(other[i].scalar_type(), maybe_multiply(-grads[i], alpha.conj())) - -- name: _foreach_clamp_min.List(Tensor[] self, Tensor[] other) -> Tensor[] - self: where(self[i] >= other[i], grads[i], at::scalar_tensor(0., grads[i].options())) - other: where(self[i] < other[i], grads[i], at::scalar_tensor(0., grads[i].options())) - -- name: _foreach_clamp_max.List(Tensor[] self, Tensor[] other) -> Tensor[] - self: where(self[i] <= other[i], grads[i], at::scalar_tensor(0., grads[i].options())) - other: where(self[i] > other[i], grads[i], at::scalar_tensor(0., grads[i].options())) - -- name: _foreach_minimum.List(Tensor[] self, Tensor[] other) -> Tensor[] - self: at::where(self[i] == other[i], grads[i] / 2, grads[i]).masked_fill_(self[i] > other[i], 0) - other: at::where(self[i] == other[i], grads[i] / 2, grads[i]).masked_fill_(self[i] < other[i], 0) - -- name: _foreach_maximum.List(Tensor[] self, Tensor[] other) -> Tensor[] - self: at::where(self[i] == other[i], grads[i] / 2, grads[i]).masked_fill_(self[i] < other[i], 0) - other: at::where(self[i] == other[i], grads[i] / 2, grads[i]).masked_fill_(self[i] > other[i], 0) - -- name: _foreach_lerp.List(Tensor[] self, Tensor[] tensors1, Tensor[] weights) -> Tensor[] - self: grads[i] * (1 - weights[i]).conj() - tensors1: grads[i] * weights[i].conj() - weights: grads[i] * (tensors1[i] - self[i]).conj() - -- name: _foreach_lerp.Scalar(Tensor[] self, Tensor[] tensors1, Scalar weight) -> Tensor[] - self: "weight.isComplex() ? grads[i] * (1 - weight.conj().toComplexDouble()) : grads[i] * (1 - weight.toDouble())" - tensors1: grads[i] * weight.conj() - # note(crcrpar): following definitions seem necessary because the reference native functions # of `maximum` and `minimum` don't have the overload def with Scalar as their second argument. - name: _foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -3030,3 +2996,12 @@ - name: _foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] < scalars[i], 0) + +- name: all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor + self: non_differentiable + +- name: all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor + shard: non_differentiable + +- name: wait_tensor(Tensor self) -> Tensor + self: non_differentiable diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 4c709d29068a..019f334c5268 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -550,6 +550,10 @@ # _nested_tensor_size() should never actually be called with requires_grad=True tensor "_nested_tensor_size", "_nested_tensor_strides", + # Functional collectives keep an internal ref through the Work object + "all_reduce", + "all_gather_into_tensor", + "wait_tensor", } DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = { diff --git a/tools/fast_nvcc/fast_nvcc.py b/tools/fast_nvcc/fast_nvcc.py index 659d91ae3c1f..285a2032dfbd 100755 --- a/tools/fast_nvcc/fast_nvcc.py +++ b/tools/fast_nvcc/fast_nvcc.py @@ -78,6 +78,8 @@ # regex for temporary file names re_tmp = r"(? None: @@ -141,7 +143,10 @@ def nvcc_dryrun_data(binary: str, args: List[str]) -> DryunData: print(result.stdout, end="") env = {} commands = [] - for line in result.stderr.splitlines(): + output = result.stderr + if os.name == "nt": + output = result.stdout + for line in output.splitlines(): match = re.match(r"^#\$ (.*)$", line) if match: (stripped,) = match.groups() @@ -213,9 +218,11 @@ def uniqueify(s: Match[str]) -> str: line = re.sub(r"\s*\-\-gen\_module\_id\_file\s*", " ", line) if arr: (filename,) = arr + if os.name == "nt": + filename = "%TEMP%\\" + filename if not module_id: module_id = module_id_contents(shlex.split(line)) - uniqueified.append(f"echo -n '{module_id}' > '{filename}'") + uniqueified.append(f"echo -n '{module_id}' > \"{filename}\"") uniqueified.append(line) return uniqueified @@ -261,6 +268,8 @@ def files_mentioned(command: str) -> List[str]: """ Return fully-qualified names of all tmp files referenced by command. """ + if os.name == "nt": + return [f"/%TEMP%/{match.group(1)}" for match in re.finditer(re_tmp, command)] return [f"/tmp/{match.group(1)}" for match in re.finditer(re_tmp, command)] @@ -294,7 +303,9 @@ def nvcc_data_dependencies(commands: List[str]) -> Graph: fatbins[i].add(tmp) else: tmp_files[tmp] = i - if line.startswith("rm ") and not deps: + if (line.startswith("rm ") or line.startswith("erase ")) and not deps: + if os.name == "nt": + commands[i] = line.replace("/", "\\") deps.add(i - 1) graph.append(deps) return graph @@ -421,6 +432,8 @@ async def run_graph( """ Return outputs/errors (and optionally time/file info) from commands. """ + if os.name == "nt": + env.update(os.environ.copy()) tasks: List[Awaitable[Result]] = [] for i, (command, indices) in enumerate(zip(commands, graph)): deps = {tasks[j] for j in indices} diff --git a/tools/fast_nvcc/wrap_nvcc.bat.in b/tools/fast_nvcc/wrap_nvcc.bat.in new file mode 100644 index 000000000000..f02a751e3a4f --- /dev/null +++ b/tools/fast_nvcc/wrap_nvcc.bat.in @@ -0,0 +1 @@ +python "@FAST_NVCC_EXECUTABLE@" --nvcc "@CUDA_NVCC_EXECUTABLE_ORIGIN@" -- %* diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index 22bf230865d9..6ce97e47c758 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -203,6 +203,7 @@ def generate( # CMakeLists.txt. var: var for var in ( + "UBSAN_FLAGS", "BLAS", "WITH_BLAS", "BUILDING_WITH_TORCH_LIBS", diff --git a/tools/stats/upload_dynamo_perf_stats.py b/tools/stats/upload_dynamo_perf_stats.py new file mode 100644 index 000000000000..52d2bfc1a49b --- /dev/null +++ b/tools/stats/upload_dynamo_perf_stats.py @@ -0,0 +1,108 @@ +import argparse +import csv +import os +import re +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Any, Dict, List + +from tools.stats.upload_stats_lib import download_s3_artifacts, unzip, upload_to_rockset + + +ARTIFACTS = [ + "test-reports", +] +ARTIFACT_REGEX = re.compile( + r"test-reports-test-(?P\w+)-\d+-\d+-(?P[\w\.]+)_(?P\d+).zip" +) + + +def upload_dynamo_perf_stats_to_rockset( + repo: str, workflow_run_id: int, workflow_run_attempt: int +) -> List[Dict[str, Any]]: + perf_stats = [] + with TemporaryDirectory() as temp_dir: + print("Using temporary directory:", temp_dir) + os.chdir(temp_dir) + + for artifact in ARTIFACTS: + artifact_paths = download_s3_artifacts( + artifact, workflow_run_id, workflow_run_attempt + ) + + # Unzip to get perf stats csv files + for path in artifact_paths: + m = ARTIFACT_REGEX.match(str(path)) + if not m: + print(f"Test report {path} has an invalid name. Skipping") + continue + + test_name = m.group("name") + runner = m.group("runner") + job_id = m.group("job") + + # Extract all files + unzip(path) + + for csv_file in Path(".").glob("**/*.csv"): + filename = os.path.splitext(os.path.basename(csv_file))[0] + print(f"Processing {filename} from {path}") + + with open(csv_file) as csvfile: + reader = csv.DictReader(csvfile, delimiter=",") + + for row in reader: + # If the row doesn't have a dev and a name column, it's not + # a torch dynamo perf stats csv file + if "dev" not in row or "name" not in row: + break + + row.update( + { + "workflow_id": workflow_run_id, # type: ignore[dict-item] + "run_attempt": workflow_run_attempt, # type: ignore[dict-item] + "test_name": test_name, + "runner": runner, + "job_id": job_id, + "filename": filename, + } + ) + perf_stats.append(row) + + # Done processing the file, removing it + os.remove(csv_file) + + return perf_stats + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Upload dynamo perf stats from S3 to Rockset" + ) + parser.add_argument( + "--workflow-run-id", + type=int, + required=True, + help="id of the workflow to get perf stats from", + ) + parser.add_argument( + "--workflow-run-attempt", + type=int, + required=True, + help="which retry of the workflow this is", + ) + parser.add_argument( + "--repo", + type=str, + required=True, + help="which GitHub repo this workflow run belongs to", + ) + args = parser.parse_args() + perf_stats = upload_dynamo_perf_stats_to_rockset( + args.repo, args.workflow_run_id, args.workflow_run_attempt + ) + upload_to_rockset( + collection="torch_dynamo_perf_stats", + docs=perf_stats, + workspace="inductor", + ) diff --git a/tools/stats/upload_stats_lib.py b/tools/stats/upload_stats_lib.py index e175b7edf365..f5bbf5f9965b 100644 --- a/tools/stats/upload_stats_lib.py +++ b/tools/stats/upload_stats_lib.py @@ -106,12 +106,18 @@ def download_gha_artifacts( return paths -def upload_to_rockset(collection: str, docs: List[Any]) -> None: +def upload_to_rockset( + collection: str, docs: List[Any], workspace: str = "commons" +) -> None: print(f"Writing {len(docs)} documents to Rockset") client = rockset.RocksetClient( host="api.usw2a1.rockset.com", api_key=os.environ["ROCKSET_API_KEY"] ) - client.Documents.add_documents(collection=collection, data=docs) + client.Documents.add_documents( + collection=collection, + data=docs, + workspace=workspace, + ) print("Done!") diff --git a/tools/test/test_executorch_gen.py b/tools/test/test_executorch_gen.py index 25bd01973475..5d10fe43f1e5 100644 --- a/tools/test/test_executorch_gen.py +++ b/tools/test/test_executorch_gen.py @@ -203,3 +203,27 @@ def test_operators_with_different_namespaces_are_grouped_correctly(self) -> None """ in declarations ) + + def test_aten_lib_has_context_arg(self) -> None: + declarations = gen_functions_declarations( + native_functions=[ + self.custom_1_native_function, + ], + static_dispatch_idx=self.static_dispatch_idx, + selector=SelectiveBuilder.get_nop_selector(), + use_aten_lib=True, + ) + print(declarations) + self.assertTrue( + """ +namespace custom_1 { + +// custom_1::op_1() -> bool +TORCH_API inline bool op_1(torch::executor::RuntimeContext & context) { + return at::op_1(); +} + +} // namespace custom_1 + """ + in declarations + ) diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index 462cf1cfa995..f3e064ad7d8d 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -1,6 +1,11 @@ import types -from torch._dynamo.types import DynamoCallback, DynamoGuardHook +from torch._dynamo.types import ( + DynamoCallback, + DynamoGuardHook, + ProfilerEndHook, + ProfilerStartHook, +) def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ... def reset_code(code: types.CodeType) -> None: ... @@ -8,3 +13,5 @@ def unsupported(obj1: object, obj2: object) -> object: ... def skip_code(code: types.CodeType) -> None: ... def set_guard_fail_hook(hook: DynamoGuardHook) -> None: ... def set_guard_error_hook(hook: DynamoGuardHook) -> None: ... +def set_profiler_hooks(start: ProfilerStartHook, end: ProfilerEndHook) -> None: ... +def clear_profiler_hooks() -> None: ... diff --git a/torch/__init__.py b/torch/__init__.py index 524785a4fa57..798f6dc1076d 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -246,7 +246,7 @@ def __init__(self, node): self.node = node def __bool__(self): - return self.node.bool_() + return builtins.bool(self != 0) def __int__(self): return self.node.int_() @@ -291,8 +291,6 @@ class SymFloat: """ def __init__(self, node): - from torch.fx.experimental.symbolic_shapes import SymNode - assert isinstance(node, SymNode) # This field MUST be named node; C++ binding code assumes that this # class has a field named node that stores SymNode self.node = node @@ -340,8 +338,6 @@ class SymBool: """ def __init__(self, node): - from torch.fx.experimental.symbolic_shapes import SymNode - assert isinstance(node, SymNode) # This field MUST be named node; C++ binding code assumes that this # class has a field named node that stores SymNode self.node = node @@ -349,6 +345,9 @@ def __init__(self, node): def __bool__(self): return self.node.bool_() + def __int__(self): + return builtins.int(self.node.bool_()) + # Magic methods installed by torch.fx.experimental.symbolic_shapes def __and__(self, other) -> "SymBool": raise AssertionError("type stub not overridden") diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index bb801139d918..0e26757200d0 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -182,6 +182,7 @@ def core_aten_decompositions() -> Dict[OpOverload, Callable]: aten.addcmul, aten.addcmul_, aten.addr, + aten.aminmax, aten.avg_pool2d_backward, aten.binary_cross_entropy, aten.binary_cross_entropy_backward, @@ -254,13 +255,14 @@ def core_aten_decompositions() -> Dict[OpOverload, Callable]: aten.mse_loss_backward, aten.mv, aten.mvlgamma, + aten.nansum, aten.nan_to_num, aten.narrow, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit, - aten._native_batch_norm_legit_no_training, aten._native_batch_norm_legit_functional, + aten._native_batch_norm_legit_no_training, aten.native_dropout_backward, aten.native_group_norm, aten.native_group_norm_backward, @@ -309,8 +311,10 @@ def core_aten_decompositions() -> Dict[OpOverload, Callable]: aten.trace, aten.transpose.int, aten.tril.default, + aten.triu.default, aten.unfold, aten.unfold_backward, + aten.unfold_copy, aten.upsample_bilinear2d, aten.upsample_bilinear2d.vec, aten.upsample_nearest2d_backward, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index f9a8d132477e..26043e78ce97 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -3332,6 +3332,30 @@ def upsample_bicubic2d_vec( return upsample_bicubic2d_default(a, output_size, align_corners, scale_h, scale_w) +@register_decomposition(aten.aminmax) +@out_wrapper("min", "max") +def aminmax(self, *, dim=None, keepdim=False): + amin = torch.amin(self, dim=dim, keepdim=keepdim) + amax = torch.amax(self, dim=dim, keepdim=keepdim) + if ( + keepdim + and dim is not None + and self.ndimension() == 0 + and self.device.type == "cpu" + ): + # the behavior of aminmax differs from amin/amax for 0D tensors on CPU + # https://github.com/pytorch/pytorch/issues/96042 + amin = amin.expand([1]) + amax = amax.expand([1]) + return amin, amax + + +@register_decomposition(aten.nansum) +@out_wrapper() +def nansum(self, dim=None, keepdim=False, *, dtype=None): + return aten.sum(torch.where(torch.isnan(self), 0, self), dim, keepdim, dtype=dtype) + + def register_inplace(aten_op, outplace_op): @register_decomposition(aten_op) def inplace_op(*args, **kwargs): diff --git a/torch/_dynamo/backends/onnxrt.py b/torch/_dynamo/backends/onnxrt.py index cd10d2610538..df0a0ef114d5 100644 --- a/torch/_dynamo/backends/onnxrt.py +++ b/torch/_dynamo/backends/onnxrt.py @@ -1,4 +1,5 @@ import importlib +import logging import os import tempfile @@ -25,6 +26,9 @@ _np_dtype = None +log = logging.getLogger(__name__) + + def default_provider(device_type): if "ONNXRT_PROVIDER" in os.environ: return os.environ["ONNXRT_PROVIDER"] @@ -78,8 +82,14 @@ def onnxrt(gm, example_inputs, *, filename=None, provider=None): def _call(*initial_args): binding = session.io_binding() + active_inputs = {inp.name for inp in session.get_inputs()} args = [a.contiguous() for a in initial_args] for name, value in zip(input_names, args): + if name not in active_inputs: + log.warning( + f"input {name} skipped as not found in onnx inference session" + ) + continue dev = value.device binding.bind_input( name, diff --git a/torch/_dynamo/backends/tvm.py b/torch/_dynamo/backends/tvm.py index e63a62a75905..61aa8e65ef8c 100644 --- a/torch/_dynamo/backends/tvm.py +++ b/torch/_dynamo/backends/tvm.py @@ -124,12 +124,20 @@ def to_tvm_tensor(torch_tensor): def exec_tvm(*i_args): args = [a.contiguous() for a in i_args] + shape_info, _ = m.get_input_info() + active_inputs = {name for name, _ in shape_info.items()} for idx, arg in enumerate(args, 0): if arg.dim() != 0: if arg.requires_grad: arg = arg.detach() + inp_name = f"inp_{idx}" + if inp_name not in active_inputs: + log.warning( + f"input {inp_name} skipped as not found in tvm's runtime library" + ) + continue m.set_input( - f"inp_{idx}", + inp_name, to_tvm_tensor(arg), ) m.run() diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 582983709a96..2d2919052152 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -369,9 +369,6 @@ def load_import_from(self, module_name, object_name): ) ) - def create_begin_finally(self): - return create_instruction("BEGIN_FINALLY") - def create_call_function_kw(self, nargs, kw_names, push_null): if sys.version_info >= (3, 11): output = create_call_function(nargs, push_null) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index b393a893b87f..d0896c7d5208 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -4,6 +4,7 @@ from os.path import abspath, dirname import torch + from . import external_utils from .logging import get_loggers_level, set_loggers_level @@ -24,7 +25,7 @@ log_file_name = None # Verbose will print full stack traces on warnings and errors -verbose = False +verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1" # If true, traced graph outputs will be outputted as Python GraphModule code. # If false, traced graph outputs will be outputted in tabular form. @@ -200,12 +201,21 @@ # root folder of the project base_dir = dirname(dirname(dirname(abspath(__file__)))) +# If True, record autograd profiler events for dynamo cache lookups (guards) +# TODO can we default this to True? +# and how can we cause registration/deregestration to be sensitive to runtime change of this flag? +profile_cache_lookup = False + def is_fbcode(): return not hasattr(torch.version, "git_version") -if is_fbcode(): +DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR" + +if DEBUG_DIR_VAR_NAME in os.environ: + debug_dir_root = os.path.join(os.environ[DEBUG_DIR_VAR_NAME], "torch_compile_debug") +elif is_fbcode(): debug_dir_root = os.path.join(tempfile.gettempdir(), "torch_compile_debug") else: debug_dir_root = os.path.join(os.getcwd(), "torch_compile_debug") diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 76ee5bb34590..680b44d1ed5a 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -7,6 +7,7 @@ from typing import Dict, Optional, Set import torch +from torch._guards import tracing from torch.fx.graph_module import _forward_from_src as original_forward_from_src from . import config, exc @@ -308,7 +309,8 @@ def transform(instructions, code_options): export, mutated_closure_cell_contents, ) - tracer.run() + with tracing(tracer.output.tracing_context): + tracer.run() output = tracer.output assert output is not None assert output.output_instructions diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index a311178b30bc..b597bd27a9d1 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -39,9 +39,12 @@ extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps]) +BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"] + + class BuckTargetWriter: def __init__(self, filename): - self.subdir, self.py_file = os.path.split(filename) + self.subdir, self.py_file = os.path.split(os.path.abspath(filename)) self.target = self.py_file.replace(".py", "") # Get main_module path from fbcode @@ -82,12 +85,12 @@ def write(self, print_msg=True): with open(target_file, "w") as fd: fd.write(self.build()) # log.warning(f"Wrote isolation TARGETS file at {target_file}") - cmd = ["buck2", "run", "@mode/dev-nosan", self.cmd_line_path] + cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path] if print_msg: log.warning( - f'Found an example that reproduces the error. Run this cmd to repro - {" ".join(cmd)}' + f"Found an example that reproduces the error. Run this cmd to repro - {' '.join(cmd_split)}" ) - return cmd + return cmd_split def minifier_dir(): @@ -1004,7 +1007,17 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str): def debug_wrapper(gm, example_inputs, **kwargs): compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs) assert config.repro_after in ("dynamo", "aot", None) + if config.repro_after == "dynamo": + + def add_paths(exc): + exc.minifier_path = os.path.join(minifier_dir(), "minifier_launcher.py") + if use_buck: + exc.buck_command = " ".join( + BUCK_CMD_PREFIX + + [BuckTargetWriter(exc.minifier_path).cmd_line_path] + ) + if config.repro_level == 3: dump_to_minify_after_dynamo(gm, example_inputs, compiler_name) @@ -1022,9 +1035,7 @@ def debug_wrapper(gm, example_inputs, **kwargs): compiler_name, ) exc = AccuracyError("Bad accuracy detected.") - exc.minifier_path = os.path.join( - minifier_dir(), "minifier_launcher.py" - ) + add_paths(exc) raise exc else: try: @@ -1047,9 +1058,7 @@ def debug_wrapper(gm, example_inputs, **kwargs): example_inputs, compiler_name, ) - exc.minifier_path = os.path.join( - minifier_dir(), "minifier_launcher.py" - ) + add_paths(exc) raise else: compiled_gm = compiler_fn(gm, example_inputs) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index d38cfb902415..546585325754 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import functools import inspect @@ -10,11 +12,13 @@ import types import warnings from enum import Enum -from typing import Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union from unittest.mock import patch import torch +import torch.fx import torch.utils._pytree as pytree +from torch import _guards from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo from torch.nn.parallel.distributed import DistributedDataParallel @@ -24,10 +28,12 @@ if TYPE_CHECKING: from torch._C._dynamo.eval_frame import ( # noqa: F401 + clear_profiler_hooks, reset_code, set_eval_frame, set_guard_error_hook, set_guard_fail_hook, + set_profiler_hooks, skip_code, unsupported, ) @@ -56,6 +62,23 @@ class Unset(Enum): token = 0 +def enable_cache_lookup_profiler(enable: bool): + if enable: + + def _profiler_start(name): + return torch.ops.profiler._record_function_enter_new(name, None) + + def _profiler_end(record): + torch.ops.profiler._record_function_exit._RecordFunction(record) + + set_profiler_hooks(_profiler_start, _profiler_end) + else: + clear_profiler_hooks() + + +# TODO can we enable by default? (check perf CI) otherwise, guard behind config +enable_cache_lookup_profiler(config.profile_cache_lookup) + unset = Unset.token compile_lock = threading.RLock() @@ -559,8 +582,15 @@ def guard_export_print(guards): def export( - f, *args, aten_graph=False, decomposition_table=None, tracing_mode="real", **kwargs -): + f: Callable[..., Any], + *args, + aten_graph: bool = False, + decomposition_table: Optional[ + Dict[torch._ops.OpOverload, Callable[..., Any]] + ] = None, + tracing_mode: str = "real", + **kwargs, +) -> Tuple[torch.fx.GraphModule, Set[_guards.Guard]]: """ Export an input function f to a format that can be executed outside of PyTorch using the FX graph. @@ -635,7 +665,7 @@ def produce_matching(source_args, candidate_args): return matched_elements_positions - def guard_export_print(guards): + def guard_export_print(guards: Set[_guards.Guard]): nonlocal out_guards assert out_guards is None, "whole graph export entails exactly one guard export" out_guards = guards @@ -644,7 +674,6 @@ def dynamo_normalization_capturing_compiler( gm: torch.fx.GraphModule, example_inputs ): nonlocal graph - assert ( graph is None ), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph." @@ -744,10 +773,50 @@ def graph_with_interpreter(*args): ).transform() # Make dynamo graph to have same input/output spec as user code - input_strs = [f"orig_arg_{i}" for i in range(len(args))] + list(kwargs.keys()) + def argument_names(f: Callable[..., Any], *args, **kwargs) -> List[str]: + call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f + fullargspec = inspect.getfullargspec(call_to_inspect) + if inspect.ismethod(call_to_inspect): + fullargspec.args.pop(0) + + # 1. Map `args` 1-to-1 to positional arguments in original signature. + input_strs = fullargspec.args[: len(args)] + + if len(args) > len(fullargspec.args): + # 2. If there are more arguments left in `args`, they map to varargs in original + # signature. Assign names as {varargs}_0, {varargs}_1, ... + assert fullargspec.varargs is not None, "More arguments than expected" + input_strs += [ + f"{fullargspec.varargs}_{i}" + for i in range(0, len(args) - len(input_strs)) + ] + elif len(args) < len(fullargspec.args): + # 3. If there are fewer arguments in `args` than `fullargspec.args`, + # it implies these are arguments either with default values, or provided in + # `kwargs`. The former can be safely ignored. Because Dynamo.export does not + # export them as part of the function signature. The latter will be handled + # in the next step. + for unprovided_arg in fullargspec.args[ + len(args) : -len(fullargspec.defaults or []) + ]: + assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}" + + # 4. Keyword arguments provided in `kwargs`. + input_strs += list(kwargs.keys()) + + # 5. Keyword-only arguments with default values if not provided are not exported + # as part of the function signature. + for kwonly_arg in fullargspec.kwonlyargs: + kwonlydefaults = fullargspec.kwonlydefaults or {} + assert ( + kwonly_arg in kwargs or kwonly_arg in kwonlydefaults + ), f"Missing keyword only argument {kwonly_arg}" + + return input_strs + new_graph.graph._codegen = _PyTreeCodeGen( _PyTreeInfo( - input_strs, + argument_names(f, *args, **kwargs), in_spec, out_spec_traced, ) diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 56c867e7acb0..01dd2940cf44 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -44,7 +44,7 @@ class BackendCompilerFailed(TorchDynamoException): def __init__(self, backend_fn, inner_exception): self.backend_name = getattr(backend_fn, "__name__", "?") self.inner_exception = inner_exception - msg = f"{self.backend_name} raised {type(inner_exception).__name__}: {inner_exception}" + msg = f"backend={self.backend_name!r} raised:\n{type(inner_exception).__name__}: {inner_exception}" super().__init__(msg) @@ -103,16 +103,23 @@ def augment_exc_message(exc, msg="\n"): msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\ torch._dynamo.replay('{exc.record_filename}').\n" - if not config.verbose: - msg += "\nSet torch._dynamo.config.verbose=True for more information\n" + if not config.verbose and hasattr(exc, "real_stack"): + msg += "\nSet torch._dynamo.config.verbose=True or TORCHDYNAMO_VERBOSE=1 for more information\n" if hasattr(exc, "inner_exception") and hasattr( exc.inner_exception, "minifier_path" ): - msg += ( - f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " - "this script to find the smallest traced graph which reproduces this error.\n" - ) + if hasattr(exc.inner_exception, "buck_command"): + msg += ( + f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " + f"this buck command to find the smallest traced graph " + f"which reproduces this error: {exc.inner_exception.buck_command}\n" + ) + else: + msg += ( + f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " + "this script to find the smallest traced graph which reproduces this error.\n" + ) if not config.suppress_errors: msg += ( diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 7a1dd8579166..47a032cedfd9 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -11,13 +11,7 @@ import torch.nn from torch import fx -from torch._guards import ( - Checkpointable, - Guard, - GuardsCheckpointState, - tracing, - TracingContext, -) +from torch._guards import Checkpointable, Guard, GuardsCheckpointState, TracingContext from torch.fx.experimental.symbolic_shapes import ShapeEnv from . import config, logging as torchdynamo_logging, variables @@ -356,17 +350,13 @@ def new_var(self, name="tmp"): for i in itertools.count(): var = f"___{name}_{i}" if var not in existing: - self.code_options["co_varnames"] = self.code_options["co_varnames"] + ( - var, - ) + self.code_options["co_varnames"] += (var,) return var def update_co_names(self, name): """Ensure self.code_options.co_names contains name""" if name not in self.code_options["co_names"]: - self.code_options["co_names"] = tuple(self.code_options["co_names"]) + ( - name, - ) + self.code_options["co_names"] += (name,) @staticmethod def module_has_hooks(mod, only_check_unsupported=False): @@ -622,8 +612,7 @@ def compile_and_call_fx_graph(self, tx, rv, root): name = unique_id("__compiled_fn") assert_no_fake_params_or_buffers(gm) - with tracing(self.tracing_context): - compiled_fn = self.call_user_compiler(gm) + compiled_fn = self.call_user_compiler(gm) compiled_fn = disable(compiled_fn) counters["stats"]["unique_graphs"] += 1 @@ -709,8 +698,9 @@ def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: _step_logger()(logging.INFO, f"done compiler function {name}") assert callable(compiled_fn), "compiler_fn did not return callable" except Exception as e: - compiled_fn = gm.forward - raise BackendCompilerFailed(self.compiler_fn, e) from e + raise BackendCompilerFailed(self.compiler_fn, e).with_traceback( + e.__traceback__ + ) from None return compiled_fn def fake_example_inputs(self) -> List[torch.Tensor]: @@ -813,6 +803,9 @@ def create_proxy( if kind in {"call_function", "call_method"}: rv.node.meta["source_fn"] = target + elif kind == "call_module": + # For modules we store the class + rv.node.meta["source_fn"] = rv.node.meta["nn_module_stack"][target][1] frame_summaries: List[traceback.FrameSummary] = [] while tx: diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index a4d06b81c9f5..526ce28272f2 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -10,6 +10,7 @@ create_jump_absolute, Instruction, transform_code_object, + unique_id, ) from .codegen import PyCodegen from .utils import ExactWeakKeyDictionary @@ -32,6 +33,79 @@ class ReenterWith: stack_index: int = None target_values: Optional[Tuple] = None + # If we do not want to destroy the stack, we can do the same thing as a + # `SETUP_WITH` block, only that we store the context manager in a local_symbol + def try_except(self, code_options, cleanup: List[Instruction]): + load_args = [] + if self.target_values: + load_args = [ + create_instruction( + "LOAD_CONST", + PyCodegen.get_const_index(code_options, val), + val, + ) + for val in self.target_values + ] + ctx_name = unique_id(f"___context_manager_{self.stack_index}") + if ctx_name not in code_options["co_varnames"]: + code_options["co_varnames"] += (ctx_name,) + for name in ["__enter__", "__exit__"]: + if name not in code_options["co_names"]: + code_options["co_names"] += (name,) + + except_jump_target = create_instruction("NOP") + cleanup_complete_jump_target = create_instruction("NOP") + + setup_finally = [ + *load_args, + create_instruction("CALL_FUNCTION", len(load_args)), + create_instruction( + "STORE_FAST", code_options["co_varnames"].index(ctx_name), ctx_name + ), + create_instruction( + "LOAD_FAST", code_options["co_varnames"].index(ctx_name), ctx_name + ), + create_instruction("LOAD_METHOD", "__enter__"), + create_instruction("CALL_METHOD", 0), + create_instruction("POP_TOP"), + create_instruction("SETUP_FINALLY", target=except_jump_target), + ] + + reset = [ + create_instruction( + "LOAD_FAST", code_options["co_varnames"].index(ctx_name), ctx_name + ), + create_instruction("LOAD_METHOD", "__exit__"), + create_instruction( + "LOAD_CONST", PyCodegen.get_const_index(code_options, None), None + ), + create_instruction("DUP_TOP"), + create_instruction("DUP_TOP"), + create_instruction("CALL_METHOD", 3), + create_instruction("POP_TOP"), + ] + if sys.version_info < (3, 9): + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("BEGIN_FINALLY"), + except_jump_target, + *reset, + create_instruction("END_FINALLY"), + ] + else: + epilogue = [ + create_instruction("POP_BLOCK"), + *reset, + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, + *reset, + create_instruction("RERAISE"), + cleanup_complete_jump_target, + ] + + cleanup[:] = epilogue + cleanup + return setup_finally + def __call__(self, code_options, cleanup): load_args = [] if self.target_values: @@ -95,7 +169,6 @@ def __call__(self, code_options, cleanup): create_instruction("SETUP_WITH", target=with_except_start), create_instruction("POP_TOP"), ] - else: pop_top_after_with_except_start = create_instruction("POP_TOP") cleanup_complete_jump_target = create_instruction("NOP") @@ -193,7 +266,7 @@ def update(instructions: List[Instruction], code_options: Dict[str, Any]): freevars = tuple(code_options["co_cellvars"] or []) + tuple( code_options["co_freevars"] or [] ) - code_options["co_name"] = f"" + code_options["co_name"] = f"" if sys.version_info >= (3, 11): code_options[ "co_qualname" @@ -250,13 +323,11 @@ def update(instructions: List[Instruction], code_options: Dict[str, Any]): @staticmethod def unreachable_codes(code_options): """Codegen a `raise None` to make analysis work for unreachable code""" - if None not in code_options["co_consts"]: - code_options["co_consts"] = tuple(code_options["co_consts"]) + (None,) return [ create_instruction( "LOAD_CONST", argval=None, - arg=code_options["co_consts"].index(None), + arg=PyCodegen.get_const_index(code_options, None), ), create_instruction("RAISE_VARARGS", 1), ] diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 036da4bbe741..baa56c467f3a 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -136,7 +136,7 @@ def guard_source(self): return self.base.guard_source() def name(self): - if self.member.isnumeric(): + if not self.member.isidentifier(): return f"getattr({self.base.name()}, {self.member!r})" return f"{self.base.name()}.{self.member}" diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 19f198237fd2..1c60c6f00e5f 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -17,7 +17,7 @@ from unittest.mock import patch import torch -from torch._guards import Checkpointable +from torch._guards import Checkpointable, TracingContext from . import ( allowed_functions, @@ -76,7 +76,6 @@ ClosureVariable, ContextWrappingVariable, GetAttrVariable, - GradModeVariable, NullVariable, PythonModuleVariable, UnknownVariable, @@ -268,7 +267,7 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): msg = ( "Skipping frame because there is a graph break in a for/while loop" ) - log.debug(msg) + log.info(msg) raise exc.SkipFrame(msg) self.push(value) @@ -286,9 +285,7 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): if_jump = self.create_call_resume_at(inst.target) self.output.add_output_instructions( - [(create_instruction(inst.opname, target=if_jump[0]))] - + if_next - + if_jump + [create_instruction(inst.opname, target=if_jump[0])] + if_next + if_jump ) elif isinstance(value, NNModuleVariable): # Equivalent of "self.nn_module is not None" @@ -350,7 +347,7 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): except Unsupported as excp: if self.has_backedge() and self.should_compile_partial_graph(): msg = "Skipping frame because there is a graph break in a for/while loop" - log.debug(msg) + log.info(msg) raise exc.SkipFrame(msg) from excp if not self.should_compile_partial_graph(): @@ -376,35 +373,6 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): excp.add_to_stats("graph_break") reason = GraphCompileReason(excp.msg, user_stack) self.restore_graphstate(state) - self.output.compile_subgraph(self, reason=reason) - if sys.version_info >= (3, 11) and inst.opname == "CALL": - # stack effect for PRECALL + CALL is split between the two instructions - stack_effect = dis.stack_effect( - dis.opmap["PRECALL"], inst.arg - ) + dis.stack_effect(dis.opmap["CALL"], inst.arg) - else: - stack_effect = dis.stack_effect(inst.opcode, inst.arg) - self.popn(push - stack_effect) - - for _ in range(push): - self.push(UnknownVariable()) - - resume_call_insts = self.create_call_resume_at(self.next_instruction) - # Check if there is a block stack entry with GradModeVariable. And - # wrap the instruction causing the graph break inside a try..finally - # block. See more details at - # https://github.com/pytorch/torchdynamo/issues/207 - cleanup = [] - if len(self.block_stack) == 1 and isinstance( - self.block_stack[0].with_context, GradModeVariable - ): - ctx_variable = self.block_stack[0].with_context - - cg = PyCodegen(self) - setup_finally, cleanup = ctx_variable.reconstruct( - cg, resume_call_insts[0] - ) - self.output.add_output_instructions(setup_finally) if sys.version_info >= (3, 11) and inst.opname == "CALL": kw_names = self.kw_names.value if self.kw_names is not None else () @@ -417,17 +385,26 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): ), ] ) + self.output.compile_subgraph(self, reason=reason) + cg = PyCodegen(self) + cleanup: List[Instruction] = [] + # Reconstruct the context variables in the block stack + for b in self.block_stack: self.output.add_output_instructions( - create_call_function(inst.arg, False) + [ + *b.with_context.reconstruct(cg), + *b.resume_fn().try_except(cg.code_options, cleanup), + ] ) - # no need to reset self.kw_names since self should not continue to run - else: - self.output.add_output_instructions([inst]) - - # Add the cleanup instructions from try..finally block + self.output.add_output_instructions([inst]) self.output.add_output_instructions(cleanup) + + self.popn(push - dis.stack_effect(inst.opcode, inst.arg)) + + for _ in range(push): + self.push(UnknownVariable()) self.output.add_output_instructions( - resume_call_insts, + self.create_call_resume_at(self.next_instruction) ) return wrapper @@ -611,29 +588,30 @@ def step(self): ) def run(self): - try: - self.output.push_tx(self) - while ( - self.instruction_pointer is not None - and not self.output.should_exit - and self.step() - ): - pass - except BackendCompilerFailed: - raise - except Exception as e: - if config.replay_record_enabled: - e.exec_record = self.exec_recorder.get_record() # type: ignore[attr-defined] - raise - finally: - self.output.pop_tx() - # Cleanup the outputGraph to delete the held tensors. We perform the - # cleanup only for InstructionTranslator and not - # InliningInstructionTranslator. The InliningInstructionTranslator - # mutates the output object and is restored to original state if - # there was an exception. - if isinstance(self, InstructionTranslator): - self.output.cleanup() + with TracingContext.current_frame(self.frame_summary()): + try: + self.output.push_tx(self) + while ( + self.instruction_pointer is not None + and not self.output.should_exit + and self.step() + ): + pass + except BackendCompilerFailed: + raise + except Exception as e: + if config.replay_record_enabled: + e.exec_record = self.exec_recorder.get_record() # type: ignore[attr-defined] + raise + finally: + self.output.pop_tx() + # Cleanup the outputGraph to delete the held tensors. We perform the + # cleanup only for InstructionTranslator and not + # InliningInstructionTranslator. The InliningInstructionTranslator + # mutates the output object and is restored to original state if + # there was an exception. + if isinstance(self, InstructionTranslator): + self.output.cleanup() def push(self, val: Optional[VariableTracker]): assert val is None or isinstance( @@ -1955,11 +1933,7 @@ def inline_call(cls, parent, func, args, kwargs): return cls.inline_call_(parent, func, args, kwargs) @staticmethod - def inline_call_(parent, func, args, kwargs): - assert isinstance( - func, - (UserFunctionVariable, NestedUserFunctionVariable), - ) + def check_inlineable(func): if func.has_self(): unimplemented("inline with __self__") @@ -1979,6 +1953,15 @@ def inline_call_(parent, func, args, kwargs): f"inline in skipfiles: {func.fn.__qualname__} | {func.get_name()} {func.get_filename()}" ) + @staticmethod + def inline_call_( + parent, func: VariableTracker, args: List[VariableTracker], kwargs + ): + assert isinstance( + func, + (UserFunctionVariable, NestedUserFunctionVariable), + ) + InliningInstructionTranslator.check_inlineable(func) try: sub_locals, closure_cells = func.bind_args(parent, args, kwargs) except TypeError as e: @@ -1995,7 +1978,10 @@ def inline_call_(parent, func, args, kwargs): if code.co_name in ("__setitem__", "__setattr__"): unimplemented(f"inline {code.co_name}") - log.debug(f"INLINING {code} \n {dis.Bytecode(code).dis()} \n") + suffix = "" + if config.output_code: + suffix = f"\n{dis.Bytecode(code).dis()}" + log.debug(f"INLINING {code}{suffix}") tracer: InliningInstructionTranslator if is_generator(code): diff --git a/torch/_dynamo/types.py b/torch/_dynamo/types.py index 4ef9af8625ea..bc15f75a5f3b 100644 --- a/torch/_dynamo/types.py +++ b/torch/_dynamo/types.py @@ -2,6 +2,7 @@ import sys import types from typing import ( + Any, Callable, Dict, List, @@ -68,3 +69,17 @@ def __call__( last: bool, ) -> None: ... + + +class ProfilerStartHook(Protocol): + def __call__( + self, + name: str, + # TODO(whc) how do I annotate a _RecordFunction here? + ) -> Any: + ... + + +class ProfilerEndHook(Protocol): + def __call__(self, record: Any) -> None: + ... diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 3205bae8bf08..8fd70d3aa04f 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -835,7 +835,7 @@ def wrap_unspecialized_primitive(self, value): GraphArg( self.get_source(), wrapped_value, - True, + isinstance(wrapped_value, torch.Tensor), fake_tensor_value, is_tensor=False, ) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 25a9eae8989d..6813a9a5faf9 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -770,11 +770,15 @@ def is_supported_call_dict_arg(tx, arg): ) @staticmethod - def call_dict_helper(tx, user_cls, arg): + def call_dict_helper(tx, user_cls, arg, **options): if arg is None: - return ConstDictVariable({}, user_cls, mutable_local=MutableLocal()) + return ConstDictVariable( + {}, user_cls, mutable_local=MutableLocal() + ).add_options(options) elif isinstance(arg, variables.ConstDictVariable): - return arg.clone(user_cls=user_cls, mutable_local=MutableLocal()) + return arg.clone( + user_cls=user_cls, mutable_local=MutableLocal() + ).add_options(options) elif isinstance( arg, ( @@ -788,7 +792,9 @@ def call_dict_helper(tx, user_cls, arg): k = x.unpack_var_sequence(tx)[0].as_python_constant() v = x.unpack_var_sequence(tx)[1] items.update({k: v}) - return ConstDictVariable(items, user_cls, mutable_local=MutableLocal()) + return ConstDictVariable( + items, user_cls, mutable_local=MutableLocal() + ).add_options(options) else: raise AssertionError("call_dict_helper with illegal arg") diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index dd4a41f3078f..584b275a3dd9 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -1,6 +1,6 @@ import collections +import functools import inspect -import sys import types from typing import Dict, List @@ -12,8 +12,8 @@ from ..exc import unimplemented from ..guards import GuardBuilder from ..source import AttrSource -from ..utils import identity, proxy_args_kwargs -from .base import VariableTracker +from ..utils import check_constant_args, identity, proxy_args_kwargs +from .base import MutableLocal, VariableTracker from .functions import ( NestedUserFunctionVariable, UserFunctionVariable, @@ -181,122 +181,11 @@ def exit(self, tx, *args): self._call_func(tx, self.initial_values) return variables.ConstantVariable(None, **VariableTracker.propagate(self)) - def reconstruct(self, codegen, target_inst=None): - """ - Generate following Python Bytecode, with a `torch._C._set_grad_enable` call - Python 3.8 - 0 LOAD_GLOBAL 0 (torch) - 2 LOAD_ATTR 1 (_C) - 4 LOAD_METHOD 2 (_set_grad_enable) - 6 LOAD_CONST 1 (False) - 8 CALL_METHOD 1 - 10 POP_TOP - - 12 SETUP_FINALLY 10 (to 24) - - 14 LOAD_GLOBAL 3 (user_inst) - 16 CALL_FUNCTION 0 - 18 POP_TOP - 20 POP_BLOCK - 22 BEGIN_FINALLY - - 24 LOAD_GLOBAL 0 (torch) - 26 LOAD_ATTR 1 (_C) - 28 LOAD_METHOD 2 (_set_grad_enable) - 30 LOAD_CONST 2 (True) - 32 CALL_METHOD 1 - 34 POP_TOP - 36 END_FINALLY - 38 LOAD_CONST 0 (None) - 40 RETURN_VALUE - - Instructions 0-10 and 24-34 call torch._C.set_grad_enable(True/False) - - Python 3.9, 3.10 - 0 LOAD_GLOBAL 0 (torch) - 2 LOAD_ATTR 1 (_C) - 4 LOAD_METHOD 2 (_set_grad_enable) - 6 LOAD_CONST 1 (False) - 8 CALL_METHOD 1 - 10 POP_TOP - - 12 SETUP_FINALLY 22 (to 36) - - 14 LOAD_GLOBAL 3 (user_inst) - 16 CALL_FUNCTION 0 - 18 POP_TOP - 20 POP_BLOCK - - 22 LOAD_GLOBAL 0 (torch) - 24 LOAD_ATTR 1 (_C) - 26 LOAD_METHOD 2 (_set_grad_enable) - 28 LOAD_CONST 2 (True) - 30 CALL_METHOD 1 - 32 POP_TOP - - 34 JUMP_FORWARD 14 (to 50) - - 36 LOAD_GLOBAL 0 (torch) - 38 LOAD_ATTR 1 (_C) - 40 LOAD_METHOD 2 (_set_grad_enable) - 42 LOAD_CONST 2 (True) - 44 CALL_METHOD 1 - 46 POP_TOP - 48 RERAISE - - 50 LOAD_CONST 0 (None) - 52 RETURN_VALUE - - """ - if self.target_values == self.initial_values: - return ([], []) - - def set_context_insts(values): - attr_source = AttrSource( - codegen.tx.import_source(self.module_name()), self.fn_name() - ) - load_set_context_enabling_insts = attr_source.reconstruct(codegen) - - if values: - loads = [codegen.create_load_const(val) for val in values] - else: - loads = [] - - return [ - *load_set_context_enabling_insts, - *loads, - *create_call_function(len(loads), True), - create_instruction("POP_TOP"), - ] - - init_block = set_context_insts(self.target_values) - finally_block = set_context_insts(self.initial_values) - setup_final_inst = create_instruction("SETUP_FINALLY", target=finally_block[0]) - prologue = init_block + [setup_final_inst] - - # Generate the epilogue - starts with 20 POP_BLOCK and ends at 34 POP_TOP - if sys.version_info < (3, 9): - # Generate the prologue that ends with setup_finally - epilogue = [ - create_instruction("POP_BLOCK"), - codegen.create_begin_finally(), - *finally_block, - create_instruction("END_FINALLY"), - ] - else: - except_block = set_context_insts(self.initial_values) - epilogue = [ - create_instruction("POP_BLOCK"), - *except_block, - create_instruction("JUMP_FORWARD", target=target_inst), - *finally_block, - create_instruction("RERAISE"), - ] - - return (prologue, epilogue) - - def _call_func(self, tx, initial_values): - raise NotImplementedError("_call_func called on base") + def reconstruct(self, codegen): + attr_source = AttrSource( + codegen.tx.import_source(self.module_name()), self.fn_name() + ) + return attr_source.reconstruct(codegen) def module_name(self): raise NotImplementedError("module_name called on base") @@ -797,11 +686,20 @@ def python_type(self): def as_python_constant(self): return self.value + @staticmethod + @functools.lru_cache(None) + def fold_through_function_to_wrapper(): + return { + collections.namedtuple: variables.UserDefinedClassVariable, + } + def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": from .builtin import BuiltinVariable + options = VariableTracker.propagate(self, args, kwargs.values()) + if inspect.getattr_static(self.value, "_torchdynamo_disable", False): unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}") # Allowlist a few popular classes(e.g, collections.OrderedDict) calls in skip files. @@ -811,7 +709,23 @@ def call_function( and BuiltinVariable.is_supported_call_dict_arg(tx, args[0]) ): return BuiltinVariable.call_dict_helper( - tx, collections.OrderedDict, None if len(args) == 0 else args[0] + tx, + collections.OrderedDict, + None if len(args) == 0 else args[0], + **options, + ) + # Fold through the functions(e.g, collections.namedtuple) + # that inputs & outputs are all python constants + elif ( + self.value in self.fold_through_function_to_wrapper().keys() + and check_constant_args(args, kwargs) + ): + value = self.value( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + return self.fold_through_function_to_wrapper().get(self.value)( + value, mutable_local=MutableLocal(), **options ) else: try: diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index c7bafcff7fec..c3cf408c6522 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -116,19 +116,9 @@ def set_state_proxies(state_args): num_params_buffers=params_len, aot_id=-1, keep_inference_input_mutations=False, + dynamic_shapes=True, ) - @contextlib.contextmanager - def setup_dynamic_shape(): - prev, torch._functorch.config.use_dynamic_shapes = ( - torch._functorch.config.use_dynamic_shapes, - True, - ) - try: - yield - finally: - torch._functorch.config.use_dynamic_shapes = prev - def exported_call(*args): state_args = args[:params_len] unwrapped_state_args = _unwrap_all_tensors_from_functional( @@ -141,7 +131,7 @@ def exported_call(*args): outputs, out_spec = pytree.tree_flatten(outputs) return outputs - with torch.enable_grad(), setup_dynamic_shape(): + with torch.enable_grad(): create_aot_dispatcher_function( exported_call, full_args, diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index a465d4aa7a09..c7390eb98936 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -390,6 +390,8 @@ class OutputAliasInfo: # as a graph output # (6) an alias of an intermediate, where that intermediate is also a user output output_type: OutputType + # The raw type of the output (torch.Tensor, SymInt, etc) + raw_type: type # If (1) above, then # - base_idx is None # If (2) or (3) above, then @@ -712,6 +714,7 @@ def inner(*flat_args): out_info = OutputAliasInfo( output_type=output_type, + raw_type=type(o), base_idx=base_idx, ) output_info.append(out_info) @@ -737,7 +740,7 @@ def inner(*flat_args): f_output_tangents = [ o for o, info in zip(flat_f_outs, output_info) - if info.output_type == OutputType.non_alias + if info.output_type == OutputType.non_alias and issubclass(info.raw_type, torch.Tensor) ] # intermediate bases are also included in the backward graph f_tangents = f_input_tangents + f_output_tangents + intermediate_bases @@ -956,6 +959,9 @@ def forward_or_joint( x for (i, x) in enumerate(outs) if meta.fw_metadata.output_info[i].output_type == OutputType.non_alias + # Also, only tensor outputs should participate in the backward + # (in particular, Symint outputs in the forward graph shouldn't get tangents) + and issubclass(meta.fw_metadata.output_info[i].raw_type, torch.Tensor) ] # Pass any (non-aliased) mutated inputs in as tangents, since they'll be returned as outputs in the fw # Important: the traced joint fw/bw will return updated inputs with data mutations, @@ -1272,6 +1278,7 @@ class AOTConfig: num_params_buffers: int aot_id: int keep_inference_input_mutations: bool + dynamic_shapes: bool = False def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig): with enable_python_dispatcher(): @@ -2262,6 +2269,7 @@ def backward(ctx, *flat_args): ) assert len(flat_args) == expected_grad_outs + out_info = CompiledFunction.metadata.fw_metadata.output_info if ( CompiledFunction.metadata.num_mutated_metadata_only_inputs > 0 or CompiledFunction.metadata.num_outputs_aliased > 0 @@ -2283,11 +2291,10 @@ def backward(ctx, *flat_args): if input_info[info_idx].mutates_data ] # We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates - out_info = CompiledFunction.metadata.fw_metadata.output_info out_tangents_filtered = [ x for x, info in zip(out_tangents, out_info) - if info.output_type == OutputType.non_alias + if info.output_type == OutputType.non_alias and issubclass(info.raw_type, torch.Tensor) ] # intermediate bases always require gradients, and always participate in the backward graph. flat_bw_args = itertools.chain(inp_tangents_filtered, out_tangents_filtered, intermediate_base_tangents) @@ -2302,7 +2309,13 @@ def backward(ctx, *flat_args): # assert all(x is None for x in metadata_only_inps) # assert all(x is None for x in aliased_outputs) else: - flat_bw_args = flat_args + # filter out non-tensor grad_outputs (aka due to ints being returned as outputs in the forward) + num_mutated_inps = CompiledFunction.metadata.num_mutated_inputs + mutated_inp_args = flat_args[:num_mutated_inps] if num_mutated_inps > 0 else [] + user_tangents = flat_args[num_mutated_inps:] + assert len(user_tangents) == len(out_info) + filtered_user_tangents = [x for x, info in zip(user_tangents, out_info) if issubclass(info.raw_type, torch.Tensor)] + flat_bw_args = tuple(mutated_inp_args) + tuple(filtered_user_tangents) contiguous_args = [ t.contiguous() if torch.is_tensor(t) else t for t in flat_bw_args @@ -2315,12 +2328,14 @@ def backward(ctx, *flat_args): def call_compiled_backward(): if CompiledFunction.compiled_bw is None: - if config.use_dynamic_shapes: + if aot_config.dynamic_shapes: all_args_list = list(all_args) CompiledFunction.compiled_bw = create_aot_dispatcher_function( bw_module, all_args_list, AOTConfig( aot_config.bw_compiler, None, None, - aot_config.decompositions, 0, aot_config.aot_id, aot_config.keep_inference_input_mutations + aot_config.decompositions, 0, aot_config.aot_id, + aot_config.keep_inference_input_mutations, + aot_config.dynamic_shapes ) ) else: @@ -2453,7 +2468,7 @@ def create_aot_dispatcher_function( shape_env = fake_mode.shape_env break else: - shape_env = ShapeEnv() if config.use_dynamic_shapes else None + shape_env = ShapeEnv() if aot_config.dynamic_shapes else None fake_mode = ( FakeTensorMode(shape_env=shape_env) if config.use_fake_tensor @@ -2552,7 +2567,10 @@ def aot_function( num_params_buffers: int = 0, hasher_type=None, # deprecated static_argnums: Optional[Tuple[int]] = None, # deprecated - keep_inference_input_mutations: bool = False + keep_inference_input_mutations: bool = False, + *, + # Whether or not to trace with dynamic shapes + dynamic=False, ) -> Callable: """ Traces the forward and backward graph of :attr:`fn` using torch dispatch @@ -2618,7 +2636,8 @@ def aot_function( decompositions=decompositions, num_params_buffers=num_params_buffers, aot_id=next(AOT_COUNTER), - keep_inference_input_mutations=keep_inference_input_mutations + keep_inference_input_mutations=keep_inference_input_mutations, + dynamic_shapes=dynamic, ) cached_res = None @@ -2801,6 +2820,17 @@ def functional_call(*args, **kwargs): assert static_argnums is None if bw_compiler is None: bw_compiler = fw_compiler + + full_args = [] + full_args.extend(params_flat) + full_args.extend(args) + + dynamic_shapes = False + for x in full_args: + if isinstance(x, FakeTensor): + dynamic_shapes = x.fake_mode.shape_env is not None + break + aot_config = AOTConfig( fw_compiler=fw_compiler, bw_compiler=bw_compiler, @@ -2809,12 +2839,9 @@ def functional_call(*args, **kwargs): num_params_buffers=params_len, aot_id=next(AOT_COUNTER), keep_inference_input_mutations=keep_inference_input_mutations, + dynamic_shapes=dynamic_shapes ) - full_args = [] - full_args.extend(params_flat) - full_args.extend(args) - compiled_fn = create_aot_dispatcher_function( functional_call, full_args, diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 40703ba653d7..f4e87dac5fdd 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -31,8 +31,6 @@ # Prints out joint graph traced, before partitioning debug_joint = os.environ.get("AOT_FX_GRAPHS_JOINT", False) -use_dynamic_shapes = os.getenv("AOT_DYNAMIC_SHAPES", False) - static_weight_shapes = True # Applies CSE to the graph before partitioning diff --git a/torch/_guards.py b/torch/_guards.py index 5e2fb89b904e..decf3983c541 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -1,3 +1,4 @@ +import contextlib import dataclasses import enum import logging @@ -325,6 +326,20 @@ def get() -> Optional["TracingContext"]: def __init__(self, fake_mode): self.guards_context = GuardsContext() self.fake_mode = fake_mode + self.frame_summary_stack = [] + + @staticmethod + @contextlib.contextmanager + def current_frame(frame_summary): + tc = TracingContext.get() + assert ( + tc is not None + ), "Frame context manager must be called within an ongoing trace." + tc.frame_summary_stack.append(frame_summary) + try: + yield + finally: + tc.frame_summary_stack.pop() """ diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index c2316529ffd2..34206b9cfee3 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -596,7 +596,12 @@ def load(cls, source_code, extra=""): key, path = write(source_code, "py", extra) if key not in cls.cache: with open(path) as f: - code = compile(f.read(), path, "exec") + try: + code = compile(f.read(), path, "exec") + except Exception as e: + raise RuntimeError( + f"Failed to import {path}\n{type(e).__name__}: {e}" + ) mod = types.ModuleType(f"{__name__}.{key}") mod.__file__ = path mod.key = key diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index f9b929026696..0cf929ed1790 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -69,6 +69,10 @@ def reduction_init(reduction_type, dtype): + if dtype in (torch.float16, torch.bfloat16): + # Since load promotes all half-precision inputs to float, the initial + # constant for reduction must be promoted as well + dtype = torch.float32 if reduction_type in ("sum", "any"): return 0 if reduction_type in {"max", "argmax"}: @@ -139,20 +143,6 @@ def argmax_argmin_prefix(reduction_type, src_dtype, tmpvar): return prefix -def float16_reduction_prefix(rtype): - # TODO: This user-defined reduction uses float16 accumulation for sum. To reduce numerical - # errors, float32 accumulation should be used instead. - assert rtype in ( - "sum", - "any", - ), f"float16 user-defined reduction only supports 'sum' and 'any' but got {rtype}" - prefix = [ - f"#pragma omp declare reduction({RTYPE_TO_CPP[rtype]}:{DTYPE_TO_CPP[torch.float16]}:" - + f"omp_out = omp_out {RTYPE_TO_CPP[rtype]} omp_in)" - ] - return prefix - - def parallel_num_threads(): threads = config.cpp.threads if threads < 1: @@ -415,6 +405,14 @@ def acos(x): def asin(x): return f"{x}.asin()" + @staticmethod + def cosh(x): + return f"{x}.cosh()" + + @staticmethod + def sinh(x): + return f"{x}.sinh()" + @staticmethod def log10(x): return f"{x}.log10()" @@ -514,7 +512,7 @@ def maximum(a, b): @staticmethod def square(a): - return f"{a}.pow(2)" + return f"{a} * {a}" @staticmethod def where(a, b, c): @@ -712,6 +710,14 @@ def acos(x): def acosh(x): return f"std::acosh({x})" + @staticmethod + def cosh(x): + return f"std::cosh({x})" + + @staticmethod + def sinh(x): + return f"std::sinh({x})" + @staticmethod def asin(x): return f"std::asin({x})" @@ -926,13 +932,14 @@ def reduction(self, name, dtype, src_dtype, reduction_type, index, value): ], ) else: - if dtype == torch.float16: - self.reduction_prefix.writelines( - float16_reduction_prefix(reduction_type) + if dtype in (torch.float16, torch.bfloat16): + self.reduction_prefix.writeline( + f"float {tmpvar} = {reduction_init(reduction_type, dtype)};" + ) + else: + self.reduction_prefix.writeline( + f"{DTYPE_TO_CPP[dtype]} {tmpvar} = {reduction_init(reduction_type, dtype)};" ) - self.reduction_prefix.writeline( - f"{DTYPE_TO_CPP[dtype]} {tmpvar} = {reduction_init(reduction_type, dtype)};" - ) self.stores.writeline( None, f"{reduction_combine(reduction_type, tmpvar, value)};" ) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 38965930d12d..eefdd4f7a46c 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1224,13 +1224,7 @@ def codegen_kernel_benchmark(self): result.writelines(["\n", "\n", "if __name__ == '__main__':"]) with result.indent(): - result.writeline( - "from torch._C import _cuda_getCurrentRawStream as get_cuda_stream" - ) - result.writeline("from torch._dynamo.testing import rand_strided") result.writeline("from torch._inductor.utils import get_num_bytes") - result.writeline("import torch") - result.writeline("from torch._inductor.triton_ops.autotune import grid") result.writeline("from triton.testing import do_bench") result.writeline("") @@ -1273,6 +1267,15 @@ def codegen_kernel(self, name=None): from torch._inductor.utils import instance_descriptor """ ) + if config.benchmark_kernel: + code.splice( + """ + from torch._dynamo.testing import rand_strided + from torch._C import _cuda_getCurrentRawStream as get_cuda_stream + import torch + from torch._inductor.triton_ops.autotune import grid + """ + ) argdefs, _, signature = self.args.python_argdefs() # maps actual expression to SizeArg if its in sizevars replacements @@ -1635,17 +1638,12 @@ def define_kernel(self, src_code, node_schedule): else: fused_name = ( get_fused_kernel_name(node_schedule) - if config.triton.descriptive_kernel_names + if config.triton.descriptive_names else "" ) kernel_name = "_".join(["triton", fused_name, wrapper.next_kernel_suffix()]) wrapper.kernels[src_code] = kernel_name - subs_name = ( - kernel_name - if config.triton.ordered_kernel_names - or config.triton.descriptive_kernel_names - else "triton_" - ) + subs_name = kernel_name if config.triton.unique_kernel_names else "triton_" src_code = src_code.replace("KERNEL_NAME", subs_name) # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 60c698cdd02b..80351b2016e1 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -12,7 +12,13 @@ from .. import codecache, config, ir from ..codecache import code_hash, cpp_compile_command, get_code_path -from ..utils import cache_on_self, has_triton, sympy_dot, sympy_product +from ..utils import ( + cache_on_self, + get_benchmark_name, + has_triton, + sympy_dot, + sympy_product, +) from ..virtualized import V from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel, PythonPrinter @@ -549,13 +555,7 @@ def generate(self): return result.getvalue() - def add_benchmark_harness(self, output): - """ - Append a benchmark harness to generated code for debugging - """ - if not config.benchmark_harness: - return - + def benchmark_compiled_module(self, output): def add_fake_input(name, shape, stride, device, dtype): output.writeline( f"{name} = rand_strided(" @@ -567,7 +567,7 @@ def add_fake_input(name, shape, stride, device, dtype): def add_expr_input(name, val): output.writeline(f"{name} = {val}") - output.writelines(["", "", 'if __name__ == "__main__":']) + output.writelines(["", "", "def benchmark_compiled_module():"]) with output.indent(): output.splice( """ @@ -596,6 +596,35 @@ def add_expr_input(name, val): f"print_performance(lambda: call([{', '.join(V.graph.graph_inputs.keys())}]))" ) + def add_benchmark_harness(self, output): + """ + Append a benchmark harness to generated code for debugging + """ + if not config.benchmark_harness: + return + + self.benchmark_compiled_module(output) + + output.writelines(["", "", 'if __name__ == "__main__":']) + with output.indent(): + output.writelines( + [ + "import argparse", + "from torch._inductor.utils import benchmark_all_kernels", + "", + "parser = argparse.ArgumentParser()", + 'parser.add_argument("--benchmark-kernels", "-k", action="store_true", help="Whether to benchmark each individual kernels")', # noqa: B950, line too long + "args = parser.parse_args()", + "", + "if args.benchmark_kernels:", + ] + ) + with output.indent(): + output.writeline(f"benchmark_all_kernels('{get_benchmark_name()}')") + output.writeline("else:") + with output.indent(): + output.writeline("benchmark_compiled_module()") + def define_kernel(self, name: str, kernel: str, kernel_path: str = None): kernel_path_comment = f"# kernel path: {kernel_path}\n" if kernel_path else "" self.header.splice(f"\n\n{kernel_path_comment}{name} = {kernel}") diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index cf91496017c9..238f15005a2a 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -186,10 +186,16 @@ class triton: tiling_prevents_reduction_fusion = True # should we give different names to kernels - ordered_kernel_names = False + # Note: This is orthogonal to descriptive_names - this is deciding whether + # our triton kernel names should all be `triton_` (to maximize caching) or + # whether they should be unique. + unique_kernel_names = False # should we put op names in kernel names - descriptive_kernel_names = False + # False: No special names (just triton__1, triton__2, etc.) + # "torch": Maps to the fx node in the Dynamo graph (module name, method name, etc.) + # "aten": Maps to the highest-level aten op (i.e. pre-decompositions) + descriptive_names = "aten" # use alternate codegen for smaller reductions persistent_reductions = True diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 9ede1d6dfcbd..4b62fdd6793d 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -14,6 +14,7 @@ log = logging.getLogger(__name__) aten = torch.ops.aten +prims = torch.ops.prims inductor_decompositions = get_decompositions( [ @@ -363,6 +364,8 @@ def baddbmm(self, batch1, batch2, beta=1, alpha=1): result = torch.bmm(batch1, batch2) if not isinstance(alpha, numbers.Number) or alpha != 1: result = result * alpha + if beta == 0: + return result if not isinstance(beta, numbers.Number) or beta != 1: self = self * beta return self + result @@ -385,6 +388,36 @@ def bernoulli(self, *, generator=None): return torch.rand_like(self, dtype=torch.float32) < self +@register_decomposition([aten.fmin, prims.fmin]) +def fmin(self, other): + return torch.where(torch.isnan(other) | (other > self), self, other) + + +@register_decomposition([aten.fmax, prims.fmax]) +def fmax(self, other): + return torch.where(torch.isnan(other) | (other < self), self, other) + + +@register_decomposition([aten.narrow_copy]) +def narrow_copy(self, dim, start, length): + return torch.narrow(self, dim, start, length).clone() + + +@register_decomposition([aten.expand_copy]) +def expand_copy(self, size, *, implicit=False): + return aten.expand(self, size, implicit=implicit).clone() + + +@register_decomposition([aten.view_copy.default]) +def view_copy_default(self, size): + return aten.view(self, size).clone() + + +@register_decomposition([aten.view_copy.dtype]) +def view_copy_dtype(self, dtype): + return self.to(dtype).clone() + + """ Some decomps result in differences from eager related to randomness. We put these decomps in a separate table `extra_random_decomps` to allow diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 7ae6fee46cde..9a324bedc9b2 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -116,6 +116,7 @@ def __init__( graph_id=None, ): super().__init__(gm) + self.extra_traceback = False # we do our own error wrapping if shape_env is None: shape_env = ShapeEnv() self.reuse_shape_env = False @@ -285,7 +286,7 @@ def constant_name(self, name: str, device_override: torch.device): return alt_name def placeholder(self, target: str, args, kwargs): - example: torch.Tensor = super().placeholder(target, args, kwargs) + example = super().placeholder(target, args, kwargs) if isinstance(example, SymTypes): expr = example.node.expr self.graph_inputs[target] = expr @@ -319,43 +320,43 @@ def placeholder(self, target: str, args, kwargs): return tensor def call_function(self, target, args, kwargs): - with ir.IRNode.current_origins(gather_origins(args, kwargs)): - if target is operator.getitem and isinstance(args[0], (list, tuple)): - return super().call_function(target, args, kwargs) - - if hasattr(target, "_inductor_lowering_function"): - # passthrough lowerings from .pattern_matcher - return target(*args, **kwargs) - - if target not in lowerings: - base_name = target.name().split(".")[0] - if base_name in FALLBACK_ALLOW_LIST: - make_fallback(target) - elif config.implicit_fallbacks: - error = ( - MissingOperatorWithDecomp - if get_decompositions([target]) - else MissingOperatorWithoutDecomp - ) - log.info( - "Creating implicit fallback for:\n%s", - error.operator_str(target, args, kwargs), - ) - make_fallback(target) - elif get_decompositions([target]): - # There isn't a good way to dynamically patch this in - # since AOT Autograd already ran. The error message tells - # the user how to fix it. - raise MissingOperatorWithDecomp(target, args, kwargs) - else: - raise MissingOperatorWithoutDecomp(target, args, kwargs) + if target is operator.getitem and isinstance(args[0], (list, tuple)): + return super().call_function(target, args, kwargs) + + if hasattr(target, "_inductor_lowering_function"): + # passthrough lowerings from .pattern_matcher + return target(*args, **kwargs) + + if target not in lowerings: + base_name = target.name().split(".")[0] + if base_name in FALLBACK_ALLOW_LIST: + make_fallback(target) + elif config.implicit_fallbacks: + error = ( + MissingOperatorWithDecomp + if get_decompositions([target]) + else MissingOperatorWithoutDecomp + ) + log.info( + "Creating implicit fallback for:\n%s", + error.operator_str(target, args, kwargs), + ) + make_fallback(target) + elif get_decompositions([target]): + # There isn't a good way to dynamically patch this in + # since AOT Autograd already ran. The error message tells + # the user how to fix it. + raise MissingOperatorWithDecomp(target, args, kwargs) + else: + raise MissingOperatorWithoutDecomp(target, args, kwargs) - try: - out = lowerings[target](*args, **kwargs) - return out - except Exception as e: - log.exception("Error from lowering") - raise LoweringException(e, target, args, kwargs) from e + try: + out = lowerings[target](*args, **kwargs) + return out + except Exception as e: + raise LoweringException(e, target, args, kwargs).with_traceback( + e.__traceback__ + ) from None def get_attr(self, target, args, kwargs): # this is a constant @@ -422,7 +423,11 @@ def finalize(self): buf.decide_layout() def run_node(self, n: torch.fx.Node): - with ir.IRNode.current_origins({n}): + origins = {n} + if n.op == "call_function": + args, kwargs = self.fetch_args_kwargs_from_env(n) + origins |= gather_origins(args, kwargs) + with ir.IRNode.current_origins(origins): if n.op == "call_function" and n.target in layout_constraints: args, kwargs = self.fetch_args_kwargs_from_env(n) args, kwargs = layout_constraints[n.target](n, *args, **kwargs) @@ -476,12 +481,24 @@ def run_node(self, n: torch.fx.Node): # # When we do a better job selecting layout, we should # revisit this. - if user.target in ( + need_fixed_layout = [ torch.ops.aten.convolution.default, torch.ops.aten.convolution_backward.default, torch.ops.aten.mm.default, torch.ops.aten._int_mm.default, - ): + ] + if torch._C.has_mkldnn: + need_fixed_layout += [ + torch.ops.mkldnn._convolution_pointwise.default, + torch.ops.mkldnn._convolution_pointwise.binary, + torch.ops.mkldnn._convolution_pointwise_.binary, + torch.ops.mkldnn._convolution_transpose_pointwise.default, + torch.ops.mkldnn._linear_pointwise.default, + torch.ops.mkldnn._linear_pointwise.binary, + ] + if torch._C.has_mkl: + need_fixed_layout += [torch.ops.mkl._mkl_linear.default] + if user.target in need_fixed_layout: result = ir.ExternKernel.require_stride_order( result, ir.get_stride_order(n.meta["val"].stride()) ) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 02ea9260152b..bcedd6661025 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -325,8 +325,10 @@ class IRNode: def current_origins(origins: Set[torch.fx.Node]): old = IRNode._current_origins IRNode._current_origins = old | origins - yield - IRNode._current_origins = old + try: + yield + finally: + IRNode._current_origins = old def __post_init__(self): self.origins = set(self._current_origins) @@ -394,12 +396,8 @@ def _index(ranges, prefix="i"): @cache_on_self def inner_fn_str(self): - formatter = V.KernelFormatterHandler(V.MockHandler()) - with V.set_ops_handler(formatter), patch.object( - FlexibleLayout, "allow_indexing", True - ): - result = self.inner_fn(self._index(self.ranges)) - return formatter.getvalue(result) + index = self._index(self.ranges) + return V.KernelFormatterHandler.ir_to_string(self.inner_fn, index) def is_zero_elements(self): return any(r == 0 for r in self.ranges) @@ -515,15 +513,13 @@ def index_length(self): @cache_on_self def inner_fn_str(self): - formatter = V.KernelFormatterHandler(V.MockHandler()) - with V.set_ops_handler(formatter), patch.object( - FlexibleLayout, "allow_indexing", True - ): - result = self.inner_fn( - self._index(self.ranges), - self._index(self.reduction_ranges, "r"), - ) - return formatter.getvalue(result) + index = self._index(self.ranges) + rindex = self._index(self.reduction_ranges, "r") + return V.KernelFormatterHandler.ir_to_string( + self.inner_fn, + index, + rindex, + ) def constant_to_device(self, device): """Move this to a given device. Requires that all reads are to constants.""" diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 026f0cae3728..38d1d3542018 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -509,6 +509,11 @@ def squeeze(x, dim=None): return view(x, new_shape) if new_shape != x.get_size() else x +@register_lowering(aten.squeeze_copy, type_promotion_kind=None) +def squeeze_copy(x, dim=None): + return clone(squeeze(x, dim)) + + @register_lowering([aten.squeeze_]) def squeeze_(x, dim=None): val = squeeze(x, dim) @@ -842,6 +847,12 @@ def glu(x, dim=-1): def register_onednn_fusion_ops(): if torch._C.has_mkldnn: + cpu_needs_realized_inputs = [ + torch.ops.mkldnn._convolution_pointwise, + torch.ops.mkldnn._convolution_pointwise_, + torch.ops.mkldnn._convolution_transpose_pointwise, + torch.ops.mkldnn._linear_pointwise, + ] @register_lowering(torch.ops.mkldnn._convolution_pointwise) def convolution_unary( @@ -982,6 +993,7 @@ def convolution_transpose_unary( ) if torch._C.has_mkl: + cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear) @register_lowering(torch.ops.mkl._mkl_linear) def mkl_packed_linear( @@ -998,6 +1010,7 @@ def mkl_packed_linear( result = add(result, b) return result + add_needs_realized_inputs(cpu_needs_realized_inputs) else: pass @@ -1250,7 +1263,6 @@ def apply_constraint(arg, fx_arg): make_fallback(aten.adaptive_max_pool3d) make_fallback(aten.addbmm) make_fallback(aten.addmv) -make_fallback(aten.aminmax) make_fallback(aten.avg_pool3d) make_fallback(aten.block_diag) make_fallback(aten._cdist_forward) @@ -1262,13 +1274,10 @@ def apply_constraint(arg, fx_arg): make_fallback(aten.diagonal_copy, warn=False) make_fallback(aten.diagonal_scatter, warn=False) make_fallback(aten.digamma, warn=False) -make_fallback(aten.dist) make_fallback(aten._efficientzerotensor) make_fallback(aten._embedding_bag_per_sample_weights_backward) make_fallback(aten.erfc, warn=False) make_fallback(aten.erfinv, warn=False) -make_fallback(aten.fmax, warn=False) -make_fallback(aten.fmin, warn=False) make_fallback(aten.dist) make_fallback(aten._efficientzerotensor) make_fallback(aten._embedding_bag_per_sample_weights_backward) @@ -1281,8 +1290,6 @@ def apply_constraint(arg, fx_arg): make_fallback(aten.igamma, warn=False) make_fallback(aten.igammac, warn=False) make_fallback(aten.isin) -make_fallback(aten.isneginf, warn=False) -make_fallback(aten.isposinf, warn=False) make_fallback(aten.kthvalue) make_fallback(aten.linalg_cholesky_ex) make_fallback(aten.linalg_cross) @@ -1302,8 +1309,6 @@ def apply_constraint(arg, fx_arg): make_fallback(aten._linalg_svd) make_fallback(aten.logaddexp2) make_fallback(aten.logcumsumexp) -make_fallback(aten.log_sigmoid_forward, warn=False) -make_fallback(aten.logspace, warn=False) make_fallback(aten.lu_unpack) make_fallback(aten.max_pool3d_with_indices) make_fallback(aten.max_unpool2d) @@ -1313,8 +1318,6 @@ def apply_constraint(arg, fx_arg): make_fallback(aten.multilabel_margin_loss_forward) make_fallback(aten.multi_margin_loss) make_fallback(aten.nanmedian) -make_fallback(aten.nansum) -make_fallback(aten.narrow_copy, warn=False) make_fallback(aten.ormqr) make_fallback(aten._pdist_forward) make_fallback(aten.pixel_shuffle) @@ -1356,15 +1359,11 @@ def apply_constraint(arg, fx_arg): make_fallback(aten.special_spherical_bessel_j0, warn=False) make_fallback(aten.special_zeta, warn=False) make_fallback(aten.take) -make_fallback(aten.threshold, warn=False) -make_fallback(aten.trace, warn=False) make_fallback(aten._trilinear) -make_fallback(aten.unfold_copy, warn=False) make_fallback(aten.uniform, warn=False) make_fallback(aten.unsafe_split, warn=False) make_fallback(aten.vdot) make_fallback(aten.view_as_complex) -make_fallback(aten.view_copy) make_fallback(aten._adaptive_avg_pool3d_backward) make_fallback(aten.adaptive_max_pool2d_backward) make_fallback(aten.adaptive_max_pool3d_backward) @@ -1385,7 +1384,6 @@ def apply_constraint(arg, fx_arg): make_fallback(aten.smooth_l1_loss_backward) make_fallback(aten.soft_margin_loss_backward, warn=False) make_fallback(aten.softshrink_backward, warn=False) -make_fallback(aten.squeeze_copy) make_fallback(aten.linalg_pinv.atol_rtol_tensor) make_fallback(aten.segment_reduce.default) make_fallback(aten._segment_reduce_backward.default) @@ -1400,7 +1398,6 @@ def apply_constraint(arg, fx_arg): make_fallback(aten.masked_scatter) make_fallback(aten.to_sparse) make_fallback(aten.triangular_solve) -make_fallback(aten.expand_copy) make_fallback(aten.gcd.default, warn=False) make_fallback(aten._linalg_eigh) make_fallback(aten.zeros.names) @@ -1504,30 +1501,6 @@ def fn(index): ) -@register_lowering(aten.triu) -def triu(x, diagonal=0): - x_loader = x.make_loader() - dtype = x.get_dtype() - - def inner_fn(index): - *_, i, j = index - return ops.where( - ops.ge( - ops.index_expr(j - i - diagonal, torch.int32), - ops.constant(0, torch.int32), - ), - x_loader(index), - ops.constant(0, dtype), - ) - - return Pointwise.create( - device=x.get_device(), - dtype=dtype, - inner_fn=inner_fn, - ranges=list(x.get_size()), - ) - - @register_lowering(aten.select_scatter, type_promotion_kind=None) def select_scatter(x, src, dim: int, index: int): assert x.get_dtype() == src.get_dtype() diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py index f82c77e3f3a1..eb53d73cc960 100644 --- a/torch/_inductor/overrides.py +++ b/torch/_inductor/overrides.py @@ -461,7 +461,7 @@ def transpose_matmul(A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: boo philox_rand_like = _prims._make_prim( - schema="philox_rand_like(Tensor input, Tensor seed, int offset) -> Tensor", + schema="philox_rand_like(Tensor input, Tensor seed, SymInt offset) -> Tensor", return_type=_prims.RETURN_TYPE.NEW, meta=_philox_rand_like_meta, impl_aten=_philox_rand_like, diff --git a/torch/_inductor/triton_ops/autotune.py b/torch/_inductor/triton_ops/autotune.py index a38a3fabb14d..61945858232b 100644 --- a/torch/_inductor/triton_ops/autotune.py +++ b/torch/_inductor/triton_ops/autotune.py @@ -8,6 +8,7 @@ import operator import os import os.path +import re import threading from typing import List diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 3920f225e354..e4d28fbc99e0 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1,13 +1,13 @@ import collections import contextlib import functools -import glob import itertools import logging import math import operator import os import shutil +import sys import tempfile import textwrap import time @@ -234,32 +234,44 @@ def wrapper(self): def get_fused_kernel_name(node_schedule): - return "_".join( - ["fused"] - + sorted( - [ - str(origin.name) - for origin in functools.reduce( - operator.or_, - [ - node.node.origins - for node in node_schedule - if hasattr(node, "node") - ], - ) - if origin.op == "call_function" - ] - )[0 : config.kernel_name_max_ops] + all_origins = functools.reduce( + operator.or_, + [node.node.origins for node in node_schedule if hasattr(node, "node")], ) + if config.triton.descriptive_names == "aten": + # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions) + sources = [ + origin.meta["original_aten"]._overloadpacket.__name__ + for origin in all_origins + if origin.op == "call_function" and "original_aten" in origin.meta + ] + elif config.triton.descriptive_names == "torch": + # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph) + sources = [] + for origin in all_origins: + if origin.op == "call_function" and "source_fn" in origin.meta: + if isinstance(origin.meta["source_fn"], str): + sources.append(origin.meta["source_fn"]) + else: + sources.append(origin.meta["source_fn"].__name__) + else: + raise NotImplementedError + sources = set(sources) + sources = sorted(sources)[: config.kernel_name_max_ops] + return "_".join(["fused"] + sources) def gather_origins(args, kwargs): import itertools - from .ir import ComputedBuffer, IRNode + from . import ir def is_unrealized_node(n): - return isinstance(n, IRNode) and not isinstance(n, ComputedBuffer) + if isinstance(n, ir.TensorBox): + return is_unrealized_node(n.data) + if isinstance(n, ir.StorageBox): + return is_unrealized_node(n.data) + return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise) kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)] arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)] @@ -525,33 +537,32 @@ def __exit__(self, *args): torch._dynamo.config.debug_dir_root = self.prev_debug_name -def run_and_get_triton_code(fn, *args, **kwargs): - from torch._inductor.debug import DebugContext - from torch._inductor.virtualized import V - - torch._dynamo.reset() - - context = DebugContext() +def run_and_get_code(fn, *args, **kwargs): + from .graph import GraphLowering - with DebugDirManager(), mock.patch.object( - config.trace, "enabled", True - ), context, V.set_debug_handler(context): + compile_to_module = GraphLowering.compile_to_module + source_codes = [] - dir_name = "/".join(context._path.split("/")[:-1]) + "/" - fil = dir_name + "*inference*" - existing_dirs = glob.glob(fil) + def patched_compile_to_module(self): + mod = compile_to_module(self) + with open(mod.__file__, "r") as f: + source_codes.append(f.read()) + return mod + with mock.patch.object( + GraphLowering, "compile_to_module", patched_compile_to_module + ): + torch._dynamo.reset() fn(*args, **kwargs) + return source_codes - assert context._path is not None - - dir_dbg = [x for x in glob.glob(fil) if x not in existing_dirs] - - assert len(dir_dbg) == 1, f"{dir_dbg}, {context._path}" - full_name = os.path.join(dir_dbg[0], "output_code.py") - with open(full_name, "r") as f: - return f.read() +def run_and_get_triton_code(fn, *args, **kwargs): + source_codes = run_and_get_code(fn, *args, **kwargs) + assert ( + len(source_codes) == 1 + ), f"expected exactly one code output got {len(source_codes)}" + return source_codes[0] def developer_warning(msg): @@ -575,3 +586,71 @@ def get_num_bytes(*args): for arg in args if isinstance(arg, torch.Tensor) ) + + +def get_benchmark_name(): + """ + An experimental API used only when config.benchmark_kernel is true. + + The benchmark name is only available at codegen time. So we can not + directly call it in benchmark_all_kernels which is run after codegen. + + The function assumes the argument after --only is the benchmark name. + It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc + scripts, this function may return None. + + There are 2 flavors of --only argument we need handle: + 1. --only model_name + 2. --only=model_name + """ + try: + idx = sys.argv.index("--only") + if ( + idx + 1 < len(sys.argv) + and len(sys.argv[idx + 1]) > 0 + and sys.argv[idx + 1][0] != "-" + ): + return sys.argv[idx + 1] + except ValueError: + pass + + for arg in sys.argv: + if arg.startswith("--only="): + return arg[len("--only=") :] + + +def benchmark_all_kernels(benchmark_name): + """ + An experimental API used only when config.benchmark_kernel is true. + + Run the kernel benchmarks for all the kernels cached in PyCodeCache. + Used in the compiled modules. + + Put this method here rather than codegen it for convenience since its implementation + does not change based on different graph modules being compiled. + """ + from torch._inductor.codecache import PyCodeCache + + nfound = 0 + for kernel_key, kernel_mod in PyCodeCache.cache.items(): + if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"): + continue + args = kernel_mod.get_args() + ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True)[0] + num_gb = get_num_bytes(*args) / 1e9 + gb_per_s = num_gb / (ms / 1e3) + + # follow what we do in DebugAutotuner + info_str = f"{benchmark_name:20} {kernel_key[:10]} {ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s" + import colorama + + if ms > 0.012 and gb_per_s < 650: + print(colorama.Fore.RED + info_str + colorama.Fore.RESET) + else: + print(info_str) + + nfound += 1 + if nfound == 0: + print( + "No kernel with benchmark functionality found. Make sure you run inductor with config.benchmark_kernel being True" + ) diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 4aec976561f7..006d6c85fa47 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -2,6 +2,7 @@ from contextlib import contextmanager from itertools import chain from threading import local +from unittest.mock import patch import sympy @@ -66,13 +67,13 @@ def __getattr__(self, name): def inner(*args, **kwargs): fargs = [_arg_str(a) for a in args] fargs.extend(f"{k}={v}" for k, v in kwargs.items()) - return f"{name}({', '.join(fargs)})" + return f"ops.{name}({', '.join(fargs)})" return inner @staticmethod def masked(mask, body, other): - return f"masked({mask}, {body()}, {other})" + return f"ops.masked({mask}, {body()}, {other})" @staticmethod def indirect_indexing(index_var): @@ -96,9 +97,35 @@ def inner(*args): class KernelFormatterHandler: def __init__(self, parent_handler): self.parent_handler = parent_handler - self.output = IndentedBuffer() + self.output = IndentedBuffer(1) self.var_counter = itertools.count() + @staticmethod + def ir_to_string(ir_fn, index, rindex=None): + from .ir import FlexibleLayout + + args = [index, rindex] if rindex is not None else [index] + names = ["index", "rindex"] if rindex is not None else ["index"] + formatter = KernelFormatterHandler(MockHandler()) + + with formatter.output.indent(-1): + formatter.output.writeline(f"def inner_fn({', '.join(names)}):") + for name, arg in zip(names, args): + if arg: + lhs = ", ".join( + [ + str("_" if isinstance(v, (int, sympy.Integer)) else v) + for v in arg + ] + ) + formatter.output.writeline(f"{lhs} = {name}") + + with V.set_ops_handler(formatter), patch.object( + FlexibleLayout, "allow_indexing", True + ): + result = ir_fn(*args) + return formatter.getvalue(result) + def __getattr__(self, name): def inner(*args, **kwargs): line = getattr(self.parent_handler, name)(*args, **kwargs) diff --git a/torch/_ops.py b/torch/_ops.py index afba4d38d4a2..9cb5815c160b 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -22,11 +22,14 @@ def dl_open_guard(): Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a shared library to load custom operators. """ - if _SET_GLOBAL_FLAGS: - old_flags = sys.getdlopenflags() - sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL) - yield - if _SET_GLOBAL_FLAGS: + if not _SET_GLOBAL_FLAGS: + yield + return + old_flags = sys.getdlopenflags() + sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL) + try: + yield + finally: sys.setdlopenflags(old_flags) diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index 92e9b699519e..5f8ee013d7be 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -22,6 +22,7 @@ from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper __all__ = [ + "diagonal", "svd", "vector_norm", "matrix_norm", @@ -57,6 +58,16 @@ def check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name from torch._decomp import register_decomposition +def diagonal( + input: TensorLikeType, + *, + offset: int = 0, + dim1: int = -2, + dim2: int = -1, +) -> TensorLikeType: + return torch.diagonal(input, offset=offset, dim1=dim1, dim2=dim2) + + @register_decomposition(torch._ops.ops.aten.linalg_vector_norm) @out_wrapper(exact_dtype=True) def vector_norm( diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 359885a9f2f0..42d002c8df89 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -426,6 +426,8 @@ def nonzero(fake_mode, func, arg): raise DynamicOutputShapeException(func) if arg.nonzero_memo is None: + import sys + from torch.fx.experimental.symbolic_shapes import constrain_range nnz = fake_mode.shape_env.create_unbacked_symint() @@ -438,9 +440,7 @@ def nonzero(fake_mode, func, arg): # disjoint with what can actually occur. But this is fine: # remember, the hypothesis is that if your later code works # with N >= 2, it will work with N = 1 and N = 0. - lower = 2 - upper = None - constrain_range(nnz, min=lower, max=upper) + constrain_range(nnz, min=2, max=sys.maxsize - 1) arg._nonzero_memo = nnz arg._nonzero_memo_vc = arg._version diff --git a/torch/ao/ns/fx/mappings.py b/torch/ao/ns/fx/mappings.py index 84944e1e8658..69a102da4073 100644 --- a/torch/ao/ns/fx/mappings.py +++ b/torch/ao/ns/fx/mappings.py @@ -328,6 +328,14 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: { F.pixel_shuffle, }, + # pixel unshuffle + { + F.pixel_unshuffle, + }, + # narrow + { + torch.narrow, + }, ] # for each floating point op, add versions of the op added by @@ -529,6 +537,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: F.max_pool3d, F.relu6, F.pixel_shuffle, + F.pixel_unshuffle, torch.avg_pool1d, torch._C._nn.avg_pool2d, torch._C._nn.avg_pool3d, @@ -540,6 +549,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: torch.max, torch.mean, torch.min, + torch.narrow, torch.repeat_interleave, torch.sort, torch.squeeze, diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py index 4872d418d559..d6eca86d78b5 100644 --- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py +++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -504,6 +504,7 @@ def _get_share_qprams_op_backend_config(op): torch.nn.functional.max_pool2d, torch.nn.functional.max_pool3d, torch.nn.functional.pixel_shuffle, + torch.nn.functional.pixel_unshuffle, torch.nn.functional.relu, torch.nn.functional.relu6, torch.avg_pool1d, @@ -512,6 +513,7 @@ def _get_share_qprams_op_backend_config(op): torch.clamp, torch.flatten, torch.mean, + torch.narrow, torch.repeat_interleave, torch.transpose, torch.squeeze, diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py index e1be7d7ec2ce..afa4ddd10a18 100644 --- a/torch/ao/quantization/fx/_lower_to_native_backend.py +++ b/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -140,12 +140,14 @@ def is_copy_node(node, modules): def is_general_tensor_shape_node(node, modules): func_list = [ + torch.narrow, torch.transpose, torch.repeat_interleave, torch.squeeze, torch.stack, torch.unsqueeze, torch.nn.functional.pixel_shuffle, + torch.nn.functional.pixel_unshuffle, ] method_list = [ "contiguous", diff --git a/torch/backends/xeon/run_cpu.py b/torch/backends/xeon/run_cpu.py index 0a5774ff2319..bdeb7a9f5e55 100644 --- a/torch/backends/xeon/run_cpu.py +++ b/torch/backends/xeon/run_cpu.py @@ -282,6 +282,18 @@ def add_lib_preload(self, lib_type): return lib_set or lib_find + def is_numactl_available(self): + numactl_available = False + try: + cmd = ["numactl", "-C", "0", "-m", "0", "hostname"] + r = subprocess.run(cmd, env=os.environ, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + if r.returncode == 0: + numactl_available = True + except Exception: + pass + return numactl_available + + def set_memory_allocator(self, enable_tcmalloc=True, enable_jemalloc=False, use_default_allocator=False): """ Enable TCMalloc/JeMalloc with LD_PRELOAD and set configuration for JeMalloc. @@ -373,6 +385,7 @@ def set_multi_thread_and_allocator(self, ncores_per_instance, def launch(self, args): cores = [] set_kmp_affinity = True + enable_taskset = False if args.core_list: # user specify what cores will be used by params cores = [int(x) for x in args.core_list.split(",")] if args.ncores_per_instance == -1: @@ -458,6 +471,22 @@ def launch(self, args): if args.ninstances > 1 and args.rank != -1: logger.info(f"assigning {args.ncores_per_instance} cores for instance {args.rank}") + if not args.disable_numactl: + numactl_available = self.is_numactl_available() + if not numactl_available: + if not args.disable_taskset: + logger.warning("Core binding with numactl is not available. Disabling numactl and using taskset instead. \ + This may affect performance in multi-socket system; please use numactl if memory binding is needed.") + args.disable_numactl = True + enable_taskset = True + else: + logger.warning("Core binding with numactl is not available, and --disable_taskset is set. \ + Please unset --disable_taskset to use taskset insetad of numactl.") + exit(-1) + + if not args.disable_taskset: + enable_taskset = True + self.set_multi_thread_and_allocator(args.ncores_per_instance, args.disable_iomp, set_kmp_affinity, @@ -471,8 +500,11 @@ def launch(self, args): for i in range(args.ninstances): cmd = [] cur_process_cores = "" - if not args.disable_numactl: - cmd = ["numactl"] + if not args.disable_numactl or enable_taskset: + if not args.disable_numactl: + cmd = ["numactl"] + elif enable_taskset: + cmd = ["taskset"] cores = sorted(cores) if args.rank == -1: # sequentially assign ncores_per_instance to ninstances core_list = cores[i * args.ncores_per_instance : (i + 1) * args.ncores_per_instance] @@ -494,10 +526,14 @@ def launch(self, args): for r in core_ranges: cur_process_cores = f"{cur_process_cores}{r['start']}-{r['end']}," cur_process_cores = cur_process_cores[:-1] - numa_params = f"-C {cur_process_cores} " - numa_ids = ",".join([str(numa_id) for numa_id in self.cpuinfo.numa_aware_check(core_list)]) - numa_params += f"-m {numa_ids}" - cmd.extend(numa_params.split()) + if not args.disable_numactl: + numa_params = f"-C {cur_process_cores} " + numa_ids = ",".join([str(numa_id) for numa_id in self.cpuinfo.numa_aware_check(core_list)]) + numa_params += f"-m {numa_ids}" + cmd.extend(numa_params.split()) + elif enable_taskset: + taskset_params = f"-c {cur_process_cores} " + cmd.extend(taskset_params.split()) with_python = not args.no_python if with_python: cmd.append(sys.executable) @@ -562,6 +598,8 @@ def _add_multi_instance_params(parser): help="Whether only use physical cores") group.add_argument("--disable-numactl", "--disable_numactl", action="store_true", default=False, help="Disable numactl") + group.add_argument("--disable-taskset", "--disable_taskset", action="store_true", default=False, + help="Disable taskset") group.add_argument("--core-list", "--core_list", metavar="\b", default=None, type=str, help="Specify the core list as \"core_id, core_id, ....\", otherwise, all the cores will be used.") group.add_argument("--log-path", "--log_path", metavar="\b", default="", type=str, diff --git a/torch/csrc/Generator.cpp b/torch/csrc/Generator.cpp index 44a0d0242b93..3e3b2819fc5b 100644 --- a/torch/csrc/Generator.cpp +++ b/torch/csrc/Generator.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -68,7 +69,9 @@ static PyObject* THPGenerator_pynew( self->cdata = make_generator(); } #endif - else { + else if (device.type() == at::kPrivateUse1) { + self->cdata = at::GetGeneratorForPrivateuse1(device.index()); + } else { AT_ERROR( "Device type ", c10::DeviceTypeName(device.type()), diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index a30b7f519e77..3a5d1bdc8623 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -77,15 +77,15 @@ Tensor toNonOptPrimal(const c10::optional& t) { } void copy_range(variable_list& out, IndexRange range, const Tensor& t) { - AT_ASSERT(range.second <= out.size()); - AT_ASSERTM( + TORCH_CHECK(range.second <= out.size()); + TORCH_CHECK( range.second - range.first == 1, "inconsistent range for Tensor output"); out[range.first] = t; } void copy_range(variable_list& out, IndexRange range, at::ArrayRef t) { - AT_ASSERT(range.second <= out.size()); - AT_ASSERTM( + TORCH_CHECK(range.second <= out.size()); + TORCH_CHECK( range.second - range.first == t.size(), "inconsistent range for TensorList output"); std::copy(t.begin(), t.end(), out.begin() + range.first); @@ -1684,7 +1684,7 @@ Tensor cholesky_jvp(const Tensor& dA, const Tensor& L, bool upper) { // L^{-1}dA(L^{-H}) = L^{-1}dL + (L^{-1}dL)^H // = sym(L^{-1}dL) // where sym(X) = X + X^H - // A short computaiton gives that the inverse of sym is given by + // A short computation gives that the inverse of sym is given by // \pi(X) = X.tril() - 0.5*diag(X) // so // dL = L\pi(L^{-1}dA(L^{-H})) @@ -1787,8 +1787,8 @@ Tensor cholesky_inverse_jvp( // of Ap^i, A^j, dA^k with i, j, k in {1, H}, where X^H = X.mH(). To prove that, // note (A Ap)^H = A Ap and (Ap A)^H = Ap A, which could be shown by taking the // product between the SVD decompositions of A and Ap. Consider the -// conjugate-tranposed [2]: (A Ap A)^H = A^H (A Ap) = A^H. By differentiating it -// we get: dA^H A Ap + A^H dA Ap + A^H A dAp = dA^H. By multiplying from the +// conjugate-transposed [2]: (A Ap A)^H = A^H (A Ap) = A^H. By differentiating +// it we get: dA^H A Ap + A^H dA Ap + A^H A dAp = dA^H. By multiplying from the // left by Ap^H and using Ap^H A^H = (A Ap)^H = A Ap: Ap^H dA^H A Ap + A Ap dA // Ap + A Ap A dAp = Ap^H dA^H. By multiplying from the left by Ap and by // applying [1] and [2] repeatedly until impossible we get: Ap Ap^H dA^H A Ap + @@ -2368,7 +2368,7 @@ Tensor softplus_double_backward( // this later) // 4. Return the as_strided view of the storage tensor using input geometry. // -// In step (2), if the output tensor does't have overlapping memory, we can +// In step (2), if the output tensor doesn't have overlapping memory, we can // safely scatter (`storage.as_strided(output_geometry).copy_(grad)`); // otherwise, we must use `index_add` as gradients at different indices may need // to be summed to a single location. @@ -2501,12 +2501,12 @@ Tensor softplus_double_backward( // // Note that all values in `S(n)` are the same (they point to the same // memory location anyways, so this step doesn't change anything, but -// effectively avoids having the denpendency on the layout of `input`. +// effectively avoids having the dependency on the layout of `input`. // I.e., the result holds fixed regardless of the layout of `input`, as // long as `input_stride` is fixed. // -// NOTE: for forward pass, we can equivalently simply selet any one of -// `S(n)` as `storage[n]`. However, cosnidering this as an average +// NOTE: for forward pass, we can equivalently simply select any one of +// `S(n)` as `storage[n]`. However, considering this as an average // operation makes backward easier (so all values in set // `{ grad_input[i] : i in S(n) }` are the same, and it can use the // same geometry as input). @@ -2645,7 +2645,7 @@ Tensor softplus_double_backward( // stride[B[j]] // // Then the invariant is obviously satisfied at every dimension -// in this block if it is satisfied at dimnesion B[-1]. It only +// in this block if it is satisfied at dimension B[-1]. It only // remains to show that it is satisfied at the last dimension in // each block. // @@ -3212,7 +3212,7 @@ Tensor svd_backward( // where CP(n-1) is the complex projective space of dimension n-1. // In other words, M is just the complex projective space, and pi is (pretty // similar to) the usual principal bundle from S^{2n-1} to CP(n-1). The case k - // > 1 is the same, but requiring a linear inependence condition between the + // > 1 is the same, but requiring a linear independence condition between the // vectors from the different S^{2n-1} or CP(n-1). // // Note that this is a U(1)^k-bundle. In plain words, this means that the @@ -3672,14 +3672,14 @@ Tensor linalg_qr_backward( const Tensor& Q, const Tensor& R, const c10::string_view mode) { - // Nb. We won't be too formal below, as writing this proof formaly is a pain + // Nb. We won't be too formal below, as writing this proof formally is a pain // We'll link here a formal writing of all this at some point in the future // // Case m >= n // dQ = dAR^{-1} - Qsyminv(sym(Q^H dA R^{-1})) // dR = syminv(sym(Q^H dA R^{-1}))R // - // With the notation from the JVP formla, the only two computations that we + // With the notation from the JVP formula, the only two computations that we // need are syminv*(R) = 0.5 * (R.triu() + R.triu()^H - Re diag(R)) sym*(X) = // 2 * X Using these, after a few simplifications we get that gA = (gQ + // syminvadj(triu(gR R^H - Q^H gQ)))R^{-H} @@ -4712,14 +4712,14 @@ std::tuple _trilinear_backward( } Tensor log1p_backward(const Tensor& grad, const Tensor& self) { - // We must conditionally initalize this using to_dense if sparse, sparse + // We must conditionally initialize this using to_dense if sparse, sparse // addition is not supported without exact shape match Tensor self_p1_conj; if (self.layout() == c10::kSparse || self.layout() == c10::kSparseCsr || self.layout() == c10::kSparseCsc || self.layout() == c10::kSparseBsr || self.layout() == c10::kSparseBsc) { // The warning only applies to the sparsity of self, dense grad is never - // materialized so if self is strided and grad is sparse nothing unepected + // materialized so if self is strided and grad is sparse nothing unexpected // happens memory wise TORCH_WARN( "log1p_backward: received self with sparse layout, but backward requires materialization of a dense tensor with this shape"); @@ -4959,7 +4959,7 @@ std::tuple householder_product_backward( // better performance bool modify_K_in_place = !at::GradMode::is_enabled(); - // This method exploites that at k-th iteration vector v_k has only elements + // This method exploits that at k-th iteration vector v_k has only elements // v_k[k:] which are non-zero. auto update_grad = [&m]( int64_t k, @@ -5217,7 +5217,7 @@ std::tuple ormqr_backward( if (self_requires_grad || tau_requires_grad) { if (left ^ transpose) { // Assume left = true and transpose = false. The case with - // left = false and tranpose = true is very much similar with just + // left = false and transpose = true is very much similar with just // transposed arguments passed into householder_product_backward. // Ormqr computes B = H_1 * ... * H_k * A. // The sensivity wrt H_i is given by (see notes in @@ -6068,7 +6068,7 @@ Tensor gather_with_keepdimed_indices( // P^T dA1 = dL U1 + L dU1 => [left-multiply by L^{-1}, right-multiply by // U1^{-1}] L^{-1} P^T dA1 U1^{-1} = L^{-1} dL + dU1 U1^{-1} (**). Note, L is // lower-triangular, and so is its inverse, hence L^{-1} dL is lower-triangular. -// Also, since the diagonal of L (all ones) is never exposed explicity (packed +// Also, since the diagonal of L (all ones) is never exposed explicitly (packed // representation), the diagonal of dL is zero, and hence diag(L^{-1} dL) = 0. // Assuming that U1 is full-rank, similarly, dU1 U1^{-1} is upper-triangular. // Combining these observations we conclude: @@ -6351,7 +6351,7 @@ Tensor logsumexp_jvp( const Tensor& self_t, IntArrayRef dim, bool keepdim) { - // NB: for simplicitly, we recompute some values that can be reused from + // NB: for simplicity, we recompute some values that can be reused from // forward auto self_p_exp = [&self_p, &dim]() { if (self_p.sym_numel() > 0) { diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index 34eda5378721..a308aea3e0a1 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -330,7 +330,7 @@ inline std::vector as_view( "Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT"); } if (is_fw_differentiable) { - // Check if base is a forward differentiabble view + // Check if base is a forward differentiable view auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base); if (diff_view_meta && diff_view_meta->has_fw_view()) { const auto& base_fw_info = diff_view_meta->get_forward_view(); diff --git a/torch/csrc/autograd/autograd_meta.cpp b/torch/csrc/autograd/autograd_meta.cpp index 968bc5139141..900a5e69944b 100644 --- a/torch/csrc/autograd/autograd_meta.cpp +++ b/torch/csrc/autograd/autograd_meta.cpp @@ -276,7 +276,7 @@ const Variable& AutogradMeta::fw_grad( return ForwardGrad::undef_grad(); } - // Ensure that concurent fw_grad() "reads" are thread safe + // Ensure that concurrent fw_grad() "reads" are thread safe std::lock_guard lock(mutex_); const auto& direct_fw_grad = diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp index 890a7fa3e6e9..0c0e07c84ccc 100644 --- a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp @@ -187,7 +187,7 @@ void autogradNotImplementedFallbackImpl( t.use_count() <= 1, op_name); // Okay to return undefined tensor // note(crcrpar): `_foreach_norm` returns a list of scalar Tensors and // each Tensor shares a storage of a hidden, intermediate 1D Tensor - // created inside the CUDA implemenetation. This is because the + // created inside the CUDA implementation. This is because the // reference implementation of nvidia/apex repo returns this 1D Tensor // where each element represents the norm of corresponding input Tensor, // here I want to return the same number of Tensors as the input @@ -357,7 +357,7 @@ void autogradNotImplementedInplaceOrViewFallbackImpl( ? CreationMeta::INFERENCE_MODE : (at::GradMode::is_enabled() ? CreationMeta::MULTI_OUTPUT_NODE : CreationMeta::NO_GRAD_MODE)); - // ^ pass in creation meta unecessarily even if not isDifferentiableType, + // ^ pass in creation meta unnecessarily even if not isDifferentiableType, // but we don't have that // information here anyway. stack->at(stack->size() - num_returns + aliased_output_idx) = result; diff --git a/torch/csrc/autograd/custom_function.cpp b/torch/csrc/autograd/custom_function.cpp index 05b3642c1572..527a87a87aa3 100644 --- a/torch/csrc/autograd/custom_function.cpp +++ b/torch/csrc/autograd/custom_function.cpp @@ -29,7 +29,7 @@ Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const { } // This function has two main goals: -// 1) Use the user-provided jvp function to populate the the outputs' forward +// 1) Use the user-provided jvp function to populate the outputs' forward // gradient 2) Perform error checking to ensure that view and inplace ops are // properly handled // diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 965c2dc109ae..61078b22d0c4 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -411,7 +411,7 @@ std::vector get_current_graph_task_execution_order() { } // We could potentially check if there is only a single device here - // but explicitly require this context doens't seem bad either + // but explicitly require this context doesn't seem bad either TORCH_CHECK( !c10::AutogradState::get_tls_state().get_multithreading_enabled(), "get_current_graph_task_execution_order expects the current backward to be " @@ -849,7 +849,7 @@ void validate_outputs( if (grad.layout() != metadata.layout()) { // TODO: Currently we only support (*, Sparse) combination for // (tensor.layout(), tensor.grad.layout()) In future, there will be an - // oppportunity to support more combinations of layouts if they are + // opportunity to support more combinations of layouts if they are // composable (example., operations like addition etc., are well defined // between tensors of different layouts.), as well as all parts of // autograd like AccumulateGrad correctly handle this. We allow grad to be @@ -1501,7 +1501,7 @@ void GraphTask::init_to_execute( // recursion, but the actual code does this iteratively. Refer to the // numbering to see how the actual code corresponds. A difference to note is // that in the iterative version, when you are working with the current Node, - // you are reponsible to update your parent's is_needed after all your + // you are responsible to update your parent's is_needed after all your // children have been updated. // // is_needed = {fn: True for fn in outputs} # (0) diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 05ba3edecf07..8fbf4104a1fb 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -560,7 +560,7 @@ struct TORCH_API Node : std::enable_shared_from_this { variable_list traced_apply(variable_list inputs); // Sequence number used to correlate backward nodes with forward ops in the - // profiler and provide determinisim in the engine. + // profiler and provide determinism in the engine. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) const uint64_t sequence_nr_; diff --git a/torch/csrc/autograd/functions/accumulate_grad.h b/torch/csrc/autograd/functions/accumulate_grad.h index 15c266faed54..4e3768d33492 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.h +++ b/torch/csrc/autograd/functions/accumulate_grad.h @@ -138,7 +138,7 @@ struct TORCH_API AccumulateGrad : public Node { // shallow copy. We need a shallow copy so that modifying the original // grad tensor doesn't modify the grad we accumulate. // We only skip clone if indices and values themselves are contiguous - // for backward compatiblity reasons. Since without this optimization, + // for backward compatibility reasons. Since without this optimization, // earlier we would clone the entire SparseTensor which cloned indices // and values. // For details see https://github.com/pytorch/pytorch/issues/34375. diff --git a/torch/csrc/autograd/graph_task.h b/torch/csrc/autograd/graph_task.h index 4d0a7ea84fe7..6256e44b3987 100644 --- a/torch/csrc/autograd/graph_task.h +++ b/torch/csrc/autograd/graph_task.h @@ -143,7 +143,7 @@ struct GraphTask : std::enable_shared_from_this { // The value of worker_device in the thread that created this task. // See Note [Reentrant backwards] - // Safe to read owner_ and reentrant_depth_ without synchronizaton + // Safe to read owner_ and reentrant_depth_ without synchronization int owner_; // The number of parent graph tasks for this graph task const int reentrant_depth_; diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp index a8d8b9880faa..f1e3b6981719 100644 --- a/torch/csrc/autograd/input_buffer.cpp +++ b/torch/csrc/autograd/input_buffer.cpp @@ -148,7 +148,7 @@ void InputBuffer::add( // (4) var is a CUDA variable and it shares a device with the producer but // not the consumer: // (4a) Uses the producer device's default stream as the accumulation - // stream (4b) Syncs the accumulation stream with the the producer's + // stream (4b) Syncs the accumulation stream with the producer's // stream (4c) Accumulates. // (5) var is a CUDA variable and it does not share a device with the // consumer or producer. diff --git a/torch/csrc/autograd/profiler_kineto.h b/torch/csrc/autograd/profiler_kineto.h index 37764c480e8a..69277c90d186 100644 --- a/torch/csrc/autograd/profiler_kineto.h +++ b/torch/csrc/autograd/profiler_kineto.h @@ -109,7 +109,7 @@ struct TORCH_API ProfilerResult { * For example, if part of the model is lowered to a dsp backend, then * the execution of that part of the model is delegated to the backend. * When backend finishes execution it has an option to provide profiling - * information (latency only at th emoment) corresponding to different operators + * information (latency only at the moment) corresponding to different operators * that were executed in the backend. * When such events are recorded by backend using this API, the event * records will be collected by active kineto profiler. If no kineto profiler diff --git a/torch/csrc/autograd/profiler_legacy.cpp b/torch/csrc/autograd/profiler_legacy.cpp index 35b8fac7e876..388695957e45 100644 --- a/torch/csrc/autograd/profiler_legacy.cpp +++ b/torch/csrc/autograd/profiler_legacy.cpp @@ -44,13 +44,13 @@ namespace profiler { // mapping. A corresponding entry is removed when the guard is destroyed, // potentially revealing the previously set value for the same slot. // -// For the async tasks, slots previuosly set in the main thread before +// For the async tasks, slots previously set in the main thread before // launching of an async task are shared and visible in the async task. // // On the other hand, any adding or overwriting of the mapping by the // async task is not visible to the main thread and any modification // (including removal of the entries) in the main thread is not visible -// to the async task if it happends after launching the task. +// to the async task if it happens after launching the task. // // We use ThreadLocalDebugInfo (slot PROFILER_STATE) to store profiler config, // as well as a list of events that happen during profiling. diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp index d438205e8947..52fab04c336f 100644 --- a/torch/csrc/autograd/saved_variable.cpp +++ b/torch/csrc/autograd/saved_variable.cpp @@ -196,7 +196,7 @@ Variable SavedVariable::unpack(std::shared_ptr saved_for) const { } // The version counter is correct. - // Additionnally, if we deal with a non-leaf variable, we have its correct + // Additionally, if we deal with a non-leaf variable, we have its correct // grad_fn. // If we have the original variable, we simply return it diff --git a/torch/csrc/autograd/saved_variable.h b/torch/csrc/autograd/saved_variable.h index 6861f2f2f690..8100e6e2bb4f 100644 --- a/torch/csrc/autograd/saved_variable.h +++ b/torch/csrc/autograd/saved_variable.h @@ -56,7 +56,7 @@ class TORCH_API SavedVariable { // we fall into the second case and its metadata is also saved separately. // In that case, the grad_fn must be passed in to the unpack function when // reconstructing the Variable (except when we are doing an inplace operation - // on a view, see below). The field saved_orignal_ below reflects the two + // on a view, see below). The field saved_original_ below reflects the two // cases: its value is true in the first case and false in the second case. // The value data_.defined() can be false in three cases: // 1. SavedVariable was constructed without a Tensor (the value to save is diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index f6fcb1083d6e..dacbe90d13be 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -664,14 +664,14 @@ const std::shared_ptr& VariableHooks::grad_fn( // self = inplace_op(self) // // For CPU/CUDA backends, we employ one AsStridedBackward0 Node to - // represent the chain of view backward ops for effienciency. + // represent the chain of view backward ops for efficiency. // // However in XLA backend we don't have full support of // AsStridedBackward0, we instead run a full forward pass with a tensor // that requires gradient to get proper grad_fn setup, then save it to // DifferentiableViewMeta for future use. This is fairly cheap for XLA // lazy tensor approach (but would be really expensive for CPU/CUDA). XLA - // Tensor only run thorugh VariableType dispatch and lower the forward + // Tensor only run through VariableType dispatch and lower the forward // pass to a XLA HLO graph, then we take grad_fn and never materialize the // tensor content. So we only construct the graph but not execute it, // which is a fairly cheap operation to do. diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index 6ca5c3d148b9..81d7c01f6d0f 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -31,7 +31,7 @@ static constexpr const char* kNumAutogradContexts = "num_autograd_contexts"; // This hook does 3 things: // 1. Call pre hooks of the original AccumulateGrad to modify the input grad. -// 2. Accumuate the guard to RPC context. +// 2. Accumurate the guard to RPC context. // 3. Call post hooks of the original AccumulateGrad. class DistAccumulateGradCaptureHook : public GraphTask::ExecInfo::Capture::GradCaptureHook { diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.h b/torch/csrc/distributed/autograd/engine/dist_engine.h index f8102a796595..9124fe9fe17d 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.h +++ b/torch/csrc/distributed/autograd/engine/dist_engine.h @@ -96,7 +96,7 @@ class TORCH_API DistEngine { // traverse the GraphTask instead of using the GraphTask embedded // cpu_ready_queue, this is because dist engine might run the same GraphTask // from different SendFunctions concurrently in different threads. The method - // will only mark the GraphTask as completed when it needes to, which means it + // will only mark the GraphTask as completed when it needs to, which means it // might not mark as completed for every call as dist engine would like to // keep the GraphTask alive when it not receives all gradients. // diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h index f3d9dd2362d8..fef0055e04be 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h +++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h @@ -18,7 +18,7 @@ class TORCH_API RpcWithProfilingResp : public rpc::RpcCommandBase { std::vector profiledEvents, rpc::ProfilingId profilingId); - // For receving RPCs. Used in from message when converting a message received + // For receiving RPCs. Used in from message when converting a message received // over the wire. RpcWithProfilingResp( rpc::MessageType messageType, diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index 70452b32287c..c7d49cf6acf8 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -113,7 +113,7 @@ class TORCH_API Backend : public torch::CustomClassHolder { } // Gathers a single tensor inputBuffer into a single buffer outputBuffer that - // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE. + // is interpreted as a contiguous collection of size inputBuffer * WORLD_SIZE. // For implementers of ProcessGroup API and advanced users only. // Note: this function will be deprecated in near future. virtual c10::intrusive_ptr _allgather_base( diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 6966e640aa91..008bb84827ac 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -226,7 +226,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { } // Gathers a single tensor inputBuffer into a single buffer outputBuffer that - // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE. + // is interpreted as a contiguous collection of size inputBuffer * WORLD_SIZE. // For implementers of ProcessGroup API and advanced users only. // Note: this function will be deprecated in near future. virtual c10::intrusive_ptr _allgather_base( diff --git a/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp b/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp index fe81192f448c..429b33af9854 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp @@ -71,8 +71,8 @@ struct WorkEntry { // MPI_THREAD_SERIALIZED, ProcessGroupMPI will only support a singe process // group. In other words, no more than 1 process group can be created globally. // -// If you would like to use multiple ProcessGroupMPI, it requres your MPI -// implemenation to have a thread support value of MPI_THREAD_MULTIPLE, that is, +// If you would like to use multiple ProcessGroupMPI, it requires your MPI +// implementation to have a thread support value of MPI_THREAD_MULTIPLE, that is, // multiple threads may call MPI, with no restriction. // // Also note that ProcessGroupMPI only supports a single Tensor operation. In @@ -229,7 +229,7 @@ class TORCH_API ProcessGroupMPI : public Backend { c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; - // Creating a new ProcessGroupMPI, will initiialize MPI if not initialized + // Creating a new ProcessGroupMPI, will initialize MPI if not initialized static c10::intrusive_ptr createProcessGroupMPI( std::vector ranks = {}); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 6a11bacb376a..7786815c0df5 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -499,7 +499,7 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal( // So explicitly abort ncclComms here before throwing this timed out // exception to users, after this, ncclCommWatchdog can detect nccl // communicators are aborted and clean up devNCCLCommMap_ accordingly. - // if throwing timed out excepiton without aborting nccl communicators + // if throwing timed out exception without aborting nccl communicators // here, it was observed that CUDA GPU will have 100% utilization and // can not run new events successfully. @@ -776,6 +776,17 @@ uint64_t ProcessGroupNCCL::getSequenceNumberForGroup() { return seq_; } +void ProcessGroupNCCL::abort(c10::optional abortReason) { + std::lock_guard lock(mutex_); + for (auto& it : devNCCLCommMap_) { + auto& ncclComms = it.second; + + for (const auto& ncclComm : ncclComms) { + ncclComm->ncclCommAbort(abortReason); + } + } +} + ProcessGroupNCCL::~ProcessGroupNCCL() { terminateProcessGroup_.store(true); @@ -789,19 +800,9 @@ ProcessGroupNCCL::~ProcessGroupNCCL() { workCleanupThread_.join(); } - { - // Abort all NCCL Communicators on Process Group Destruction - std::lock_guard lock(mutex_); - for (auto& it : devNCCLCommMap_) { - auto& ncclComms = it.second; - - for (const auto& ncclComm : ncclComms) { - std::string abortReason = - c10::str("Process Group destroyed on rank ", rank_); - ncclComm->ncclCommAbort(abortReason); - } - } - } + // Abort all NCCL Communicators on Process Group Destruction + std::string abortReason = c10::str("Process Group destroyed on rank ", rank_); + abort(abortReason); } void ProcessGroupNCCL::abortTimedOutCollectives( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index e9a0e5585832..f9c845ea3cd5 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -136,7 +136,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { void synchronizeStreams(); // Helper function used in CUDA Stream callbacks to complete WorkNCCL - // objects and throw exceptions when neeeded. + // objects and throw exceptions when needed. void handleNCCLGuard(ErrorHandlingMode asyncErrorHandling); // Helper function that checks if the NCCL kernels have finished @@ -437,6 +437,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Tests if the UCC fallback path is available bool isUCCAvailable() const; + // Provides an API to abort the ProcessGroup (similar to ncclCommAbort) + // instead of relying on ProcessGroupNCCL destructor. + void abort(c10::optional abortReason = c10::nullopt); + protected: // Helper that broadcasts nccl unique ID to all ranks through the store void broadcastUniqueNCCLID( @@ -497,7 +501,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Helper that encapsulates work shared across point-to-point communication // primitives. It is the same structure as the helper used for collective - // communicaiton primitives. + // communication primitives. template c10::intrusive_ptr pointToPoint( std::vector& tensor, @@ -644,7 +648,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Add Work Pointer to workVector void workEnqueue(c10::intrusive_ptr); - // The CUDA steams used by NCCL kernels + // The CUDA streams used by NCCL kernels std::unordered_map> ncclStreams_; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index abc4359e7dda..acf8a98940ed 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1494,7 +1494,7 @@ that adds a prefix to each key inserted to the store. processGroup, "Options", R"( -Base class for all processs group options implementations, such as the nccl +Base class for all processes group options implementations, such as the nccl options :class:`~torch.distributed.ProcessGroupNCCL.Options`). )") .def( @@ -1944,6 +1944,14 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). py::arg("size"), py::arg("timeout") = kProcessGroupDefaultTimeout, py::call_guard()) + .def( + "_abort", + [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self, + const c10::optional& abortReason) { + return self->abort(abortReason); + }, + py::arg("abort_reason") = py::none(), + py::call_guard()) .def_property_readonly( "options", &::c10d::ProcessGroupNCCL::getOptions) .def_property_readonly( @@ -2096,7 +2104,7 @@ Example:: ``fut.then()`` will return another ``CUDAFuture`` that holds the return value of the callback and a ``CUDAEvent`` that recorded the callback stream. - 1. For CPU work, ``fut.done()`` returns true when work has been complted and value() + 1. For CPU work, ``fut.done()`` returns true when work has been completed and value() tensors are ready. 2. For GPU work, ``fut.done()`` returns true only whether the operation has been enqueued. 3. For mixed CPU-GPU work (e.g. sending GPU tensors with GLOO), ``fut.done()`` returns diff --git a/torch/csrc/distributed/c10d/logger.hpp b/torch/csrc/distributed/c10d/logger.hpp index 82acd5d202d4..827c1a1dd574 100644 --- a/torch/csrc/distributed/c10d/logger.hpp +++ b/torch/csrc/distributed/c10d/logger.hpp @@ -69,7 +69,7 @@ class TORCH_API Logger { ); // Set stats that can be collected only during // training loop. It is called at the beginning of forward call - // to record the run time stats of sampled iterations that previouly ran. + // to record the run time stats of sampled iterations that previously ran. // GPU performance stats are collected only for single process // single device program and single device module right now. // TODO to support single process multiple devices and multi device modules, diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index df11c6444f3c..bb7754c47284 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -1178,7 +1178,7 @@ void Reducer::initialize_bucket_views(Reducer::Bucket& bucket) { if (grad.defined() && !grad.is_alias_of(bucket_view)) { bucket_view.copy_(grad); grad = bucket_view; - // The grad is modefied and needs to be written back. + // The grad is modified and needs to be written back. return true; } // The grad is not modified and does not need to be written back. diff --git a/torch/csrc/distributed/c10d/reducer.hpp b/torch/csrc/distributed/c10d/reducer.hpp index eafc2c826f37..f1bc2557d338 100644 --- a/torch/csrc/distributed/c10d/reducer.hpp +++ b/torch/csrc/distributed/c10d/reducer.hpp @@ -73,8 +73,8 @@ class TORCH_API Reducer { // a call to this function can simply be omitted. void prepare_for_backward(const std::vector& outputs); - // Called at the begginning of forward() inside DistributedDataParallel, - // right now it caputures the starting time of forward in each iteration. + // Called at the beginning of forward() inside DistributedDataParallel, + // right now it captures the starting time of forward in each iteration. void prepare_for_forward(); // Returns the relative time in nanoseconds when gradients were ready, @@ -153,7 +153,7 @@ class TORCH_API Reducer { // An function for users to set sample_rate of collecting // runtime stats. The time stats will be recorded for the - // first 10 iterations, after 10 iteratons time stats will be + // first 10 iterations, after 10 iterations time stats will be // recorded once every "sample_rate" training iterations. void set_ddp_runtime_logging_sample_rate(int sample_rate); @@ -504,7 +504,7 @@ class TORCH_API Reducer { // Retrieves parameter names that have not been marked as ready as part of // previous iteration. std::vector getUnmarkedParamsForIteration(); - // Retrives parameter indices that have not been marked as ready as part of + // Retrieves parameter indices that have not been marked as ready as part of // previous iteration. std::vector getUnmarkedParamIndicesForIteration(); // Raises appropriate error if mark_variable_ready is called on the same diff --git a/torch/csrc/distributed/rpc/message.h b/torch/csrc/distributed/rpc/message.h index 17a7808912b1..6ef573cf14ff 100644 --- a/torch/csrc/distributed/rpc/message.h +++ b/torch/csrc/distributed/rpc/message.h @@ -98,7 +98,7 @@ enum MessageType { // to determine how to serialize them. This design is helpful for // communicating super large tensors where serializing all the data at // once leads to excessively large memory footprint. An implementation -// can then serialize and send tensors chunck-by-chunk, in the streaming +// can then serialize and send tensors chunk-by-chunk, in the streaming // fashion. // type (MessageType): type of the message. // id (int64_t): message id, this is used to match request and response. diff --git a/torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h b/torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h index 917c5dcf0a0a..61418d9bf278 100644 --- a/torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h +++ b/torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h @@ -76,7 +76,7 @@ TORCH_API extern mutexType currentStateStackEntryMutex; // This class is used to implement a stack of ``State``s. // It has 2 members. -// One is `prevPtr`, a shared_ptr poiniting to previous elememnt in the +// One is `prevPtr`, a shared_ptr pointing to previous element in the // stack. // The other is ``statePtr``, a shared_ptr pointing to ``State``. class StateStackEntry { diff --git a/torch/csrc/distributed/rpc/request_callback.cpp b/torch/csrc/distributed/rpc/request_callback.cpp index cc9da4e97724..a98b74a0a799 100644 --- a/torch/csrc/distributed/rpc/request_callback.cpp +++ b/torch/csrc/distributed/rpc/request_callback.cpp @@ -14,7 +14,7 @@ c10::intrusive_ptr RequestCallback::operator()( std::vector streams) const { // NB: cannot clear autograd context id here because the processMessage method // might pause waiting for all RRefs in the arguments to be confirmed by their - // owners and resumne processing in a different thread. Hence, the + // owners and resume processing in a different thread. Hence, the // thread_local context id needs to be set and cleared in the thread that // indeed carries out the processing logic. return processMessage(request, std::move(streams)); diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp index 22f020e67a92..80c9944e2bbf 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.cpp +++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp @@ -125,7 +125,7 @@ c10::intrusive_ptr RequestCallbackImpl::runPythonFunction( return asFuture(std::current_exception()); } - // After sync exection or failed async execution return the value as-is. + // After sync execution or failed async execution return the value as-is. if (pythonRpcHandler.isRemoteException(result) || !isAsyncExecution) { return asFuture( c10::ivalue::ConcretePyObjectHolder::create(result), diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp index 9e16061e0ad4..10930ffa4134 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp +++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp @@ -78,7 +78,7 @@ c10::intrusive_ptr RequestCallbackNoPython::processMessage( // of 10us. auto serverProcessGlobalProfilerStateStackEntryPtr = profiler::processglobal::StateStackEntry::current(); - // If server global profiler is enabled, we futher pay the + // If server global profiler is enabled, we further pay the // cost of thread local profiler state initialization. if (serverProcessGlobalProfilerStateStackEntryPtr) { // Initialize thread-local profiler state from process-global diff --git a/torch/csrc/distributed/rpc/rpc_agent.cpp b/torch/csrc/distributed/rpc/rpc_agent.cpp index 451ec4b7598c..e49c8f0b12d6 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.cpp +++ b/torch/csrc/distributed/rpc/rpc_agent.cpp @@ -178,7 +178,7 @@ void RpcAgent::retryExpiredRpcs() { } // If there are no more RPC's set to be retried at the current timepoint, - // we can remove the corresponsing unordered_set from the retry map. + // we can remove the corresponding unordered_set from the retry map. if (earliestRpcList.empty()) { rpcRetryMap_.erase(earliestTimeout); } diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h index df5563e4136d..12fb1314b103 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.h +++ b/torch/csrc/distributed/rpc/rpc_agent.h @@ -32,7 +32,7 @@ using steady_clock_time_point = std::chrono::time_point; // Input is qualified name string, output is JIT StrongTypePtr // Same as jit::TypeResolver, did not import jit::TypeResolver to here -// because it could instroduce cyclic dependencies. +// because it could introduce cyclic dependencies. using TypeResolver = std::function; @@ -153,7 +153,7 @@ class TORCH_API RpcAgent { const DeviceMap& deviceMap = {}) = 0; // Retries sending the message up to maxRetries times until an ACK is - // receieved. The duration between consecutive sends is increased over + // received. The duration between consecutive sends is increased over // time using an exponential backoff algorithm. // // Sends ``message`` to the ``RpcAgent`` of id ``to`` and returns a @@ -232,7 +232,7 @@ class TORCH_API RpcAgent { // Retrieve metrics as KV map virtual std::unordered_map getMetrics() = 0; - // Retrive debug info in addition to metrics as KV map + // Retrieve debug info in addition to metrics as KV map virtual std::unordered_map getDebugInfo(); // Flag to control whether GIL wait times diff --git a/torch/csrc/distributed/rpc/rref_context.h b/torch/csrc/distributed/rpc/rref_context.h index 70a2b31f6897..87ffd4f868e3 100644 --- a/torch/csrc/distributed/rpc/rref_context.h +++ b/torch/csrc/distributed/rpc/rref_context.h @@ -180,7 +180,7 @@ class TORCH_API RRefContext { // been confirmed (i.e. is no longer in the pendingUsers_ map). c10::intrusive_ptr getPendingUser(const ForkId& forkId); - // Start recroding new pending UserRRefs. All pending UserRRefs introduced + // Start recording new pending UserRRefs. All pending UserRRefs introduced // after this point will be put into the thread_local userTable_, which will // then be consumed and cleared in waitForThreadLocalPendingRRefs(). void recordThreadLocalPendingRRefs(); @@ -264,7 +264,7 @@ class TORCH_API RRefContext { RRefId::Hash> forks_; - // This cond var is used by deleteAllUsers(), a event notificaton is sent if + // This cond var is used by deleteAllUsers(), a event notification is sent if // number of pending UserRRef or UserRRef children is reduced, or // number of owned OwnerRRef is reduced. std::condition_variable deleteAllUsersCV_; diff --git a/torch/csrc/distributed/rpc/rref_proto.h b/torch/csrc/distributed/rpc/rref_proto.h index 4ce8066dfe1f..4c00fb7f235b 100644 --- a/torch/csrc/distributed/rpc/rref_proto.h +++ b/torch/csrc/distributed/rpc/rref_proto.h @@ -111,7 +111,7 @@ class TORCH_API PythonRRefFetchRet final : public RRefFetchRet { const Message& message); }; -// UserRRef (regardless it's the creator or not) uses this message to notiify +// UserRRef (regardless it's the creator or not) uses this message to notify // OwnerRRef on delete. class TORCH_API RRefUserDelete final : public ForkMessageBase { public: diff --git a/torch/csrc/distributed/rpc/script_remote_call.h b/torch/csrc/distributed/rpc/script_remote_call.h index 0ec1779702f9..460bc7352bd1 100644 --- a/torch/csrc/distributed/rpc/script_remote_call.h +++ b/torch/csrc/distributed/rpc/script_remote_call.h @@ -15,7 +15,7 @@ using torch::jit::Operator; // A ScriptRemoteCall instance represents an invocation of `dist.remote` on a // builtin operator. Currently, it does not support using RRef as arguments yet. // Besides the operator and a vector of arguments, ScriptRemoteCall also -// caontains the RRefId and the ForkId of the return value RRef. +// contains the RRefId and the ForkId of the return value RRef. class TORCH_API ScriptRemoteCall final : public ScriptCall { public: // Constructor for builitin operator call. diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index 2db60ed59c6e..4a5ca5b6abdb 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -304,6 +304,9 @@ static PyObject* noargs = NULL; /* cached empty tuple */ static PyObject* dotzerokey = NULL; /* ".0" */ static PyObject* guard_fail_hook = NULL; static PyObject* guard_error_hook = NULL; +static PyObject* profiler_start_hook = NULL; +static PyObject* profiler_end_hook = NULL; +static PyObject* guard_profiler_name_str = NULL; /* cached py str */ size_t extra_index = -1; @@ -476,6 +479,27 @@ static PyObject* call_guard_fail_hook( return result; } +static PyObject* call_profiler_start_hook(PyObject* name_str) { + if (profiler_start_hook == NULL) return NULL; + if (name_str == NULL) return NULL; + PyObject* args = PyTuple_Pack(1, name_str); + if (args == NULL) return NULL; + PyObject* result = PyObject_CallObject(profiler_start_hook, args); + Py_DECREF(args); + return result; +} + +static void call_profiler_end_hook(PyObject* record) { + // 'record' obj is the return value of calling _start_hook() + if (profiler_end_hook == NULL) return; + if (record == NULL) return; + PyObject* args = PyTuple_Pack(1, record); + if (args == NULL) return; + PyObject* result = PyObject_CallObject(profiler_end_hook, args); + Py_XDECREF(result); + Py_DECREF(args); +} + // Return value: borrowed reference // Is either Py_None or a PyCodeObject static PyObject* lookup(CacheEntry* e, THP_EVAL_API_FRAME_OBJECT *frame, CacheEntry* prev) { @@ -640,7 +664,11 @@ static PyObject* _custom_eval_frame( // we never compile. if (callback == Py_False) { DEBUG_TRACE("In run only mode %s", name(frame)); + PyObject* hook_record = call_profiler_start_hook(guard_profiler_name_str); PyObject* maybe_cached_code = lookup(extra, frame, NULL); + call_profiler_end_hook(hook_record); + Py_XDECREF(hook_record); + if (maybe_cached_code == NULL) { // guard eval failed, keep propagating return NULL; @@ -662,7 +690,10 @@ static PyObject* _custom_eval_frame( // in the shim. eval_frame_callback_set(Py_None); + PyObject* hook_record = call_profiler_start_hook(guard_profiler_name_str); PyObject* maybe_cached_code = lookup(extra, frame, NULL); + call_profiler_end_hook(hook_record); + Py_XDECREF(hook_record); if (maybe_cached_code == NULL) { // Python error return NULL; @@ -840,6 +871,37 @@ static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* args) { Py_RETURN_NONE; } +static PyObject* clear_profiler_hooks(PyObject* dummy, PyObject* args) { + Py_XDECREF(profiler_start_hook); + profiler_start_hook = NULL; + Py_XDECREF(profiler_end_hook); + profiler_end_hook = NULL; + Py_XDECREF(guard_profiler_name_str); + guard_profiler_name_str = NULL; + Py_RETURN_NONE; +} + +static PyObject* set_profiler_hooks(PyObject* dummy, PyObject* args) { + PyObject* start = NULL; + PyObject* end = NULL; + if (!PyArg_ParseTuple(args, "OO", &start, &end)) { + return NULL; + } + Py_XDECREF(profiler_start_hook); + Py_XDECREF(profiler_end_hook); + if (start == Py_None || end == Py_None) { + clear_profiler_hooks(NULL, NULL); + } else { + profiler_start_hook = start; + profiler_end_hook = end; + Py_INCREF(profiler_start_hook); + Py_INCREF(profiler_end_hook); + } + Py_XDECREF(guard_profiler_name_str); + guard_profiler_name_str = Py_BuildValue("s", "TorchDynamo Cache Lookup"); + Py_RETURN_NONE; +} + static PyMethodDef _methods[] = { {"set_eval_frame", set_eval_frame_py, METH_VARARGS, NULL}, {"reset_code", reset_code, METH_VARARGS, NULL}, @@ -847,6 +909,8 @@ static PyMethodDef _methods[] = { {"skip_code", skip_code, METH_VARARGS, NULL}, {"set_guard_fail_hook", set_guard_fail_hook, METH_VARARGS, NULL}, {"set_guard_error_hook", set_guard_error_hook, METH_VARARGS, NULL}, + {"set_profiler_hooks", set_profiler_hooks, METH_VARARGS, NULL}, + {"clear_profiler_hooks", clear_profiler_hooks, METH_VARARGS, NULL}, {NULL, NULL, 0, NULL}}; static struct PyModuleDef _module = { diff --git a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp index 031946494945..e2146ed31ed1 100644 --- a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp +++ b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp @@ -2000,6 +2000,154 @@ def conv_forwards(input: List[int], _11 = result_size return _11 +)=====") ++ std::string(R"=====(def stack(tensors: List[List[int]], + dim: int) -> List[int]: + _0 = "AssertionError: Tensors must have same number of dimensions" + _1 = "AssertionError: Sizes of tensors must match except in dimension" + unsqueezed_tensors = annotate(List[List[int]], []) + for _2 in range(torch.len(tensors)): + tensor = tensors[_2] + _3 = torch.add(torch.len(tensor), 1) + if torch.le(_3, 0): + dim_post_expr = 1 + else: + dim_post_expr = _3 + min = torch.neg(dim_post_expr) + max = torch.sub(dim_post_expr, 1) + if torch.lt(dim, min): + _4 = True + else: + _4 = torch.gt(dim, max) + if torch.__not__(_4): + pass + else: + ops.prim.RaiseException("AssertionError: ") + if torch.lt(dim, 0): + dim0 = torch.add(dim, dim_post_expr) + else: + dim0 = dim + unsqueezed = annotate(List[int], []) + for _5 in range(torch.len(tensor)): + elem = tensor[_5] + _6 = torch.append(unsqueezed, elem) + torch.insert(unsqueezed, dim0, 1) + _7 = torch.append(unsqueezed_tensors, unsqueezed) + for _8 in range(torch.len(unsqueezed_tensors)): + tensor0 = unsqueezed_tensors[_8] + if torch.gt(torch.len(tensor0), 0): + pass + else: + ops.prim.RaiseException("AssertionError: ") + out_dim: Optional[int] = None + for _9 in range(torch.len(unsqueezed_tensors)): + size = unsqueezed_tensors[_9] + if torch.eq(torch.len(size), 1): + _10 = torch.eq(size[0], 0) + else: + _10 = False + if torch.__not__(_10): + if torch.__is__(out_dim, None): + _11 = torch.len(size) + if torch.le(_11, 0): + dim_post_expr0 = 1 + else: + dim_post_expr0 = _11 + min0 = torch.neg(dim_post_expr0) + max0 = torch.sub(dim_post_expr0, 1) + if torch.lt(dim, min0): + _12 = True + else: + _12 = torch.gt(dim, max0) + if torch.__not__(_12): + pass + else: + ops.prim.RaiseException("AssertionError: ") + if torch.lt(dim, 0): + dim1 = torch.add(dim, dim_post_expr0) + out_dim2 = dim1 + else: + out_dim2 = dim + out_dim1 = out_dim2 + else: + out_dim1 = unchecked_cast(int, out_dim) + out_dim0 : Optional[int] = out_dim1 + else: + out_dim0 = out_dim + out_dim = out_dim0 + if torch.__is__(out_dim, None): + dim2 = dim + else: + dim2 = unchecked_cast(int, out_dim) + _13 = torch.gt(torch.len(unsqueezed_tensors), 0) + if _13: + pass + else: + ops.prim.RaiseException("AssertionError: ") + not_skipped_tensor: Optional[List[int]] = None + for _14 in range(torch.len(unsqueezed_tensors)): + tensor1 = unsqueezed_tensors[_14] + numel = 1 + for _15 in range(torch.len(tensor1)): + elem0 = tensor1[_15] + numel = torch.mul(numel, elem0) + if torch.eq(numel, 0): + _16 = torch.eq(torch.len(tensor1), 1) + else: + _16 = False + if torch.__not__(_16): + not_skipped_tensor0 : Optional[List[int]] = tensor1 + else: + not_skipped_tensor0 = not_skipped_tensor + not_skipped_tensor = not_skipped_tensor0 + _17 = torch.__is__(not_skipped_tensor, None) + if _17: + _18 = [0] + else: + not_skipped_tensor1 = unchecked_cast(List[int], not_skipped_tensor) + cat_dim_size = 0 + for i in range(torch.len(unsqueezed_tensors)): + tensor2 = unsqueezed_tensors[i] + numel0 = 1 + for _19 in range(torch.len(tensor2)): + elem1 = tensor2[_19] + numel0 = torch.mul(numel0, elem1) + if torch.eq(numel0, 0): + _20 = torch.eq(torch.len(tensor2), 1) + else: + _20 = False + if torch.__not__(_20): + first_dims = torch.len(not_skipped_tensor1) + second_dims = torch.len(tensor2) + _21 = torch.eq(first_dims, second_dims) + if _21: + pass + else: + ops.prim.RaiseException(_0) + _22 = torch.__range_length(0, first_dims, 1) + for _23 in range(_22): + dim3 = torch.__derive_index(_23, 0, 1) + if torch.ne(dim3, dim2): + _24 = torch.eq(not_skipped_tensor1[dim3], tensor2[dim3]) + if _24: + pass + else: + ops.prim.RaiseException(_1) + else: + pass + cat_dim_size1 = torch.add(cat_dim_size, tensor2[dim2]) + cat_dim_size0 = cat_dim_size1 + else: + cat_dim_size0 = cat_dim_size + cat_dim_size = cat_dim_size0 + result_size = annotate(List[int], []) + for _25 in range(torch.len(not_skipped_tensor1)): + elem2 = not_skipped_tensor1[_25] + _26 = torch.append(result_size, elem2) + _27 = torch._set_item(result_size, dim2, cat_dim_size) + _18 = result_size + return _18 + )=====") + std::string(R"=====(def permute(input: List[int], dims: List[int]) -> List[int]: @@ -2955,6 +3103,7 @@ const OperatorMap& GetShapeFunctionMappings() { {"aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor", "conv_transpose2d_input"}, {"aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", "flatten"}, {"aten::cat(Tensor[] tensors, int dim=0) -> Tensor", "cat"}, + {"aten::stack(Tensor[] tensors, int dim=0) -> Tensor", "stack"}, {"aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", "permute"}, {"aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", "movedim"}, {"aten::view(Tensor(a) self, int[] size) -> Tensor(a)", "view"}, diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index e70574108218..1530ca5eebcc 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -679,7 +679,12 @@ void StaticModule::prepareFunctionsAndConstants( if (node->kind() == prim::Constant) { auto* v = node->output(); - TORCH_CHECK(v->type()->kind() != FunctionType::Kind); + TORCH_CHECK( + v->type()->kind() != FunctionType::Kind, + "got ", + typeKindToString(v->type()->kind()), + " instead of ", + typeKindToString(FunctionType::Kind)); value_to_index.emplace(v, constants_.size()); constants_.emplace_back(toIValue(v).value()); continue; @@ -948,10 +953,24 @@ void BlockRunner::set_inputs( if (!is_root_block_ || C10_UNLIKELY(!schema)) { TORCH_CHECK( - kwargs.empty(), "Schema is not available, but BlockRunner got kwargs."); + kwargs.empty(), + "BlockRunner got kwargs; is_root_block: ", + std::to_string(is_root_block_), + "schema: ", + schema ? schema->name() : "(not available)"); const auto total_num_inputs = args.size() + first_input_is_self_; - TORCH_CHECK(total_num_inputs == block_info_.num_inputs()); + TORCH_CHECK( + total_num_inputs == block_info_.num_inputs(), + "Block runner got ", + std::to_string(total_num_inputs), + " inputs; ", + " first_input_is_self: ", + std::to_string(first_input_is_self_), + "; SR block expects ", + std::to_string(block_info_.num_inputs()), + " inputs for schema ", + schema ? schema->name() : "(not available)"); for (size_t i = 0; i < args.size(); ++i) { set_arg(i, std::forward(args)); @@ -964,7 +983,12 @@ void BlockRunner::set_inputs( DCHECK(!schema_args.empty()); TORCH_CHECK( args.size() < schema_args.size(), - "Static runtime got too many arguments"); + "Static runtime got ", + std::to_string(args.size()), + " arguments, expects ", + std::to_string(schema_args.size() - 1), + " for schema ", + schema->name()); for (size_t i = 0; i < schema_args.size() - 1; ++i) { // Start at 1 since the schema always contains `self`. const auto& schema_arg = schema_args[i + 1]; @@ -990,9 +1014,20 @@ void BlockRunner::set_inputs( } TORCH_CHECK( - false, "Static runtime is missing required kwarg ", schema_arg.name()); + false, + "Static runtime is missing required kwarg ", + schema_arg.name(), + " for schema ", + schema->name()); } - TORCH_CHECK(consumed_kwargs == kwargs.size()); + TORCH_CHECK( + consumed_kwargs == kwargs.size(), + "kwargs size mismatch (consumed ", + std::to_string(consumed_kwargs), + ", expected ", + std::to_string(kwargs.size()), + " for schema ", + schema->name()); } void BlockRunner::create_memory_planner() { @@ -1277,7 +1312,7 @@ c10::intrusive_ptr BlockRunner::run_impl_async( const KeywordArgs& kwargs) { // run the graph inline in the caller thread. Async ops will be // executed on taskLauncher attached to the metadata of ProcessedNodes - c10::IValue output = run_impl(args, kwargs); + c10::IValue output = run_impl(std::forward(args), kwargs); // If the output is of type future, return it if (output.isFuture()) { @@ -1958,7 +1993,14 @@ StaticNodeInfo::StaticNodeInfo( fn_(fn), inputs_(std::move(inputs)), outputs_offset_(outputs_offset) { - TORCH_CHECK(num_outputs() == node->outputs().size()); + TORCH_CHECK( + num_outputs() == node->outputs().size(), + "Node ", + node->kind().toQualString(), + " has ", + std::to_string(num_outputs()), + " outputs, expected ", + std::to_string(node->outputs().size())); } std::vector ProcessedNode::inputs_ivalue_vec() const { diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 1ecdaf2a7d77..1277d66ce5fe 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -600,9 +600,9 @@ void Pickler::endTypeTag(const IValue& ivalue) { TORCH_INTERNAL_ASSERT(ivalue.isGenericDict() || ivalue.isList()); // Push the dict type - TORCH_INTERNAL_ASSERT(ivalue.type()); - auto type = ivalue.type(); + TORCH_INTERNAL_ASSERT(type); + auto annot_str = type->annotation_str(type_printer); pushString(annot_str); diff --git a/torch/csrc/quantized/quantized_backward.cpp b/torch/csrc/quantized/quantized_backward.cpp new file mode 100644 index 000000000000..a4d94def9ad8 --- /dev/null +++ b/torch/csrc/quantized/quantized_backward.cpp @@ -0,0 +1,77 @@ +#include +#include +#include + +namespace { +using namespace torch::autograd; +using namespace at; +// This class is a custom gradient function that enables quantized tensor to +// pass input gradient back to the previous layers This function can be used +// when the user is adapting mixed precision for traninig after quantization +// From torch layer, we have no access to linear_dynamic operator which needs to +// access via redispatching mechanism TO-DO : currently we are supporting per +// tensor quantization only, will expand to per channel later on +class PackedLinearWeightDynamicBackward + : public Function { + public: + static torch::Tensor forward( + AutogradContext* ctx, + at::Tensor input, + const c10::intrusive_ptr& packed_weight, + bool reduce_range) { + static auto op = + at::Dispatcher::singleton() + .findSchemaOrThrow("quantized::linear_dynamic", "") + .typed> const&, + bool)>(); + auto output = op.redispatch( + DispatchKeySet({DispatchKey::CPU}), input, packed_weight, reduce_range); + // TO-DO: passing packed_weight as saved_data requires more work in adding + // LinearPackedParamsBase in ivalue For now, we can simply pass a weight + // itself. Referenced : + // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/ivalue.h + auto unpacked_parameters = packed_weight->unpack(); + ctx->saved_data["weight"] = std::get<0>(unpacked_parameters); + return output; + } + + static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) { + auto original_weight = ctx->saved_data["weight"].toTensor(); + original_weight = at::permute(original_weight, {1, 0}); + auto grad_output = grad_outputs[0]; + static auto op = at::Dispatcher::singleton() + .findSchemaOrThrow("quantized::linear_prepack", "") + .typed( + at::Tensor, c10::optional)>(); + auto prepacked_weight = op.call(original_weight, nullopt); + auto grad_input = prepacked_weight->apply_dynamic(grad_output); + return {grad_input, torch::Tensor(), torch::Tensor()}; + } +}; + +at::Tensor packed_linear_weight_grad( + c10::DispatchKeySet ks, + at::Tensor input, + const c10::intrusive_ptr& packed_weight, + bool reduce_range) { + return PackedLinearWeightDynamicBackward::apply( + input, packed_weight, reduce_range); +} +} // namespace + +namespace at { +namespace native { +namespace { +TORCH_LIBRARY_IMPL(quantized, Autograd, m) { + m.impl( + TORCH_SELECTIVE_NAME("quantized::linear_dynamic"), + TORCH_FN(packed_linear_weight_grad)); +} +} // namespace +} // namespace native +} // namespace at diff --git a/torch/cuda/_memory_viz.py b/torch/cuda/_memory_viz.py index a6e19d851930..dfb33da300bf 100644 --- a/torch/cuda/_memory_viz.py +++ b/torch/cuda/_memory_viz.py @@ -5,7 +5,7 @@ import subprocess import json from functools import lru_cache -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Any, Dict cache = lru_cache(None) @@ -124,8 +124,10 @@ def calc_active(seg): def _report_free(free_external, free_internal): total = free_external + free_internal - pct = (free_internal / total) * 100 - suffix = f' ({pct:.1f}% internal)' + suffix = '' + if total != 0: + pct = (free_internal / total) * 100 + suffix = f' ({pct:.1f}% internal)' return f'{Bytes(total)}{suffix}' PAGE_SIZE = 1024 * 1024 * 20 @@ -385,6 +387,15 @@ def to_html(): self.to_html = to_html self.categories = categories +def _choose_device(data, device): + if device is None: + for i, t in enumerate(data['device_traces']): + if len(t) > 0: + if device is not None: + raise ValueError(f'Both device {device} and {i} have traces, use --device to specify which trace.') + device = i + return device + def trace_plot(data, device=None, plot_segments=False): """Generate a visualization over time of the memory usage recorded by the trace as an html file. @@ -399,38 +410,39 @@ def trace_plot(data, device=None, plot_segments=False): """ w = PlotWriter() addr_to_alloc = {} - + device = _choose_device(data, device) if device is None: - for i, t in enumerate(data['device_traces']): - if len(t) > 0: - if device is not None: - raise ValueError(f'Both device {device} and {i} have traces, use --device to specify which trace.') - device = i - if device is None: - raise ValueError('No trace information was recorded.') + raise ValueError('No trace information was recorded.') trace = data['device_traces'][device] if plot_segments: + addr_prefix = 's' alloc = 'segment_alloc' free = 'segment_free' else: + addr_prefix = 'b' alloc = 'alloc' free = 'free_completed' - def add_element(size, frames, extra=()): - frames = [f"{_format_size(size)} allocation", *extra, *(_frame_fmt(f, full_filename=True) for f in frames)] + addr_versions: Dict[int, int] = {} + + def add_element(addr, size, frames, extra=()): + next_version = addr_versions[addr] = addr_versions.get(addr, 0) + 1 + frames = [f"{addr_prefix}{addr:x}_{next_version - 1} {_format_size(size)} allocation ({size} bytes)", + *extra, + *(_frame_fmt(f, full_filename=True) for f in frames)] return w.add_element(size, frames) for i, e in enumerate(trace): if e['action'] == alloc: - elemid = add_element(e['size'], e.get('frames', [])) + elemid = add_element(e['addr'], e['size'], e.get('frames', [])) addr_to_alloc[e['addr']] = elemid w.allocate(elemid) elif e['action'] == free: idx = addr_to_alloc.pop(e['addr'], None) if idx is None: - idx = add_element(e['size'], e.get('frames', []), extra=('alloc not recorded, stack trace for free:',)) + idx = add_element(e['addr'], e['size'], e.get('frames', []), extra=('alloc not recorded, stack trace for free:',)) w.initially_allocated(idx) w.free(idx) return w.to_html() @@ -481,7 +493,8 @@ def add_element(size, tensor_key, version): stack = allocation_stacks.get(tensor_key, ()) assert w.categories is not None return w.add_element(size, - [f"{_format_size(size)} allocation ({w.categories[category]})", *(p.name for p in stack)], + [f"{_format_size(size)} ({size} bytes) allocation ({w.categories[category]})", + *(p.name for p in stack)], category) kv_to_elem = {} @@ -824,6 +837,586 @@ def add_element(size, tensor_key, version): """ +def segment_plot(data: Any, device=None): + device = _choose_device(data, device) + if device is None: + trace = [] + else: + trace = data['device_traces'][device] + + string_table: List[str] = [] + suffix_table: List[Tuple[int, Optional[int]]] = [] + + @cache + def intern_str(s): + string_table.append(s) + return len(string_table) - 1 + + @cache + def intern_suffix(sid, restid): + suffix_table.append((sid, restid)) + return len(suffix_table) - 1 + + def intern_stack(frames): + next_id = None + for f in reversed(frames): + next_id = intern_suffix(intern_str(f), next_id) + return next_id + + def format_frames(frames): + return intern_stack([_frame_fmt(f, full_filename=True) for f in frames]) + + result: Any = { + 'string_table': string_table, + 'suffix_table': suffix_table, + 'events': { + 'action': [], # reference to string table + 'addr': [], # for OOM, this will hold device_free value + 'size': [], + 'stream': [], + 'frames': [] # reference to suffix_table + }, + 'segments': { + 'addr': [], + 'size': [], + 'stream': [] + }, + 'blocks': { + 'addr': [], + 'size': [], + 'real_size': [], + 'frames': [], # reference to string table + 'pending_free': [], + } + } + + def fold_free(ts): + # turn a free_requested/free_completed pair into a single free event + i = 0 + while i < len(ts): + t = ts[i] + if i + 1 < len(ts): + tnext = ts[i + 1] + if t['action'] == 'free_requested' and tnext['action'] == 'free_completed' and t['addr'] == tnext['addr']: + yield {**t, 'action': 'free'} + i += 2 + continue + if t['action'] == 'oom': + yield {**t, 'addr': t['device_free']} + else: + yield t + i += 1 + + preproc: Any = { + 'action': intern_str, + 'frames': format_frames, + } + + events: Any = result['events'] + for event in fold_free(trace): + for k in events.keys(): + # stack frames not recorded on event + # happens for snapshot even when + # frames are recorded for other things. + if k == 'frames' and k not in event: + events[k].append(None) + continue + events[k].append(preproc.get(k, lambda x: x)(event[k])) + + segments = result['segments'] + blocks = result['blocks'] + + segment_names = { + 'addr': 'address', + 'size': 'total_size', + } + + for seg in data['segments']: + for k in segments.keys(): + sk = segment_names.get(k, k) + segments[k].append(preproc.get(k, lambda x: x)(seg[sk])) + addr = seg['address'] + for b in seg['blocks']: + if b['state'] in ('active_pending_free', 'active_allocated'): + if 'history' in b: + frames = b['history'][0].get('frames', []) + real_size = b['history'][0]['real_size'] + else: + real_size = b['size'] + frames = [] + blocks['addr'].append(addr) + blocks['size'].append(b['size']) + blocks['real_size'].append(real_size) + blocks['frames'].append(format_frames(frames)) + blocks['pending_free'].append(1 if b['state'] == 'active_pending_free' else 0) + addr += b['size'] + + plot_data = json.dumps(result) + return _events_template.replace('$PLOT_DATA', plot_data) + +_events_template = r""" + + + + + + + + + +""" + if __name__ == "__main__": import os.path thedir = os.path.realpath(os.path.dirname(__file__)) @@ -866,15 +1459,20 @@ def _output(p): compare_a.add_argument('after', help=pickled) _output(compare_a) - description = "Generate a visualization over time of the memory usage recorded by the trace as an html file." - trace_plot_a = subparsers.add_parser('trace_plot', description=description) - trace_plot_a.add_argument('input', help=pickled) - help = 'visualize trace from this device (default: chooses the only device with trace info or errors)' - trace_plot_a.add_argument('-d', '--device', type=int, default=None, help=help) - help = 'path to save the visualization(default: output.html)' - trace_plot_a.add_argument('-o', '--output', default='output.html', help=help) - help = 'visualize change to segments rather than individual allocations' - trace_plot_a.add_argument('-s', '--segments', action='store_true', help=help) + plots = ( + ("trace_plot", "Generate a visualization over time of the memory usage recorded by the trace as an html file."), + ("segment_plot", "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.") + ) + for cmd, description in plots: + trace_plot_a = subparsers.add_parser(cmd, description=description) + trace_plot_a.add_argument('input', help=pickled) + help = 'visualize trace from this device (default: chooses the only device with trace info or errors)' + trace_plot_a.add_argument('-d', '--device', type=int, default=None, help=help) + help = 'path to save the visualization(default: output.html)' + trace_plot_a.add_argument('-o', '--output', default='output.html', help=help) + if cmd == "trace_plot": + help = 'visualize change to segments rather than individual allocations' + trace_plot_a.add_argument('-s', '--segments', action='store_true', help=help) args = parser.parse_args() @@ -912,3 +1510,6 @@ def _write(name, data): elif args.action == 'trace_plot': data = _read(args.input) _write(args.output, trace_plot(data, device=args.device, plot_segments=args.segments)) + elif args.action == 'segment_plot': + data = _read(args.input) + _write(args.output, segment_plot(data, device=args.device)) diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index a85b129d6d07..668b292085e4 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -72,11 +72,12 @@ def _register_tensor_work(tensor, work): data_ptr_to_work[tensor.data_ptr()] = (work_version, work) work_version += 1 -def _clear_tensor(data_ptr, version): +def _wait_and_clear_tensor(data_ptr, version): global data_ptr_to_work version_and_work = data_ptr_to_work.get(data_ptr) if version_and_work is not None and version_and_work[0] == version: + version_and_work[1].wait() del data_ptr_to_work[data_ptr] def _register_wrapper_tensor(tensor_wrapper, tensor): @@ -87,15 +88,15 @@ def _register_wrapper_tensor(tensor_wrapper, tensor): "Trying to register finalizers to AsyncCollectiveTensor but the inner tensor is already gone" ) else: - weakref.finalize(tensor_wrapper, _clear_tensor, tensor.data_ptr(), version) + # We force the collective to be waited in the case this tensor goes away to reduce the change of deadlocks. + weakref.finalize(tensor_wrapper, _wait_and_clear_tensor, tensor.data_ptr(), version) def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor: global data_ptr_to_work data_ptr = tensor.data_ptr() version_and_work = data_ptr_to_work.get(data_ptr) if version_and_work is not None: - version_and_work[1].wait() - _clear_tensor(data_ptr, version_and_work[0]) + _wait_and_clear_tensor(data_ptr, version_and_work[0]) return tensor diff --git a/torch/distributed/_shard/api.py b/torch/distributed/_shard/api.py index cd318103550f..c3f8cd2aa64c 100644 --- a/torch/distributed/_shard/api.py +++ b/torch/distributed/_shard/api.py @@ -5,7 +5,6 @@ from torch.distributed import distributed_c10d from torch.distributed._shard.sharded_tensor import ( ShardedTensor, - _PartialTensor ) from .sharding_spec import ( ShardingSpec, @@ -126,7 +125,7 @@ def shard_parameter( @contextmanager def load_with_process_group(process_group): """ - Context manager to set the process group with which to load a ShardedTensor/ReplicatedTensor. + Context manager to set the process group with which to load a ShardedTensor. """ global _CURRENT_PROCESS_GROUP if _CURRENT_PROCESS_GROUP is not None: @@ -166,7 +165,7 @@ def _reshard_output( A :class:`torch.nn.Module` object with reshard API hooked. """ def hook_func(_module, _input, output): - if isinstance(output, (ShardedTensor, _PartialTensor)): + if isinstance(output, ShardedTensor): return output.reshard(resharding_spec) return output module.register_forward_hook(hook_func) diff --git a/torch/distributed/_shard/common_op_utils.py b/torch/distributed/_shard/common_op_utils.py index 7ef88965eecb..dd23f0a766e3 100644 --- a/torch/distributed/_shard/common_op_utils.py +++ b/torch/distributed/_shard/common_op_utils.py @@ -6,7 +6,6 @@ def _basic_validation(op, args=(), kwargs=None): """ Common validation across all ops go in here. """ - from torch.distributed._shard.partial_tensor import _PartialTensor from torch.distributed._shard.sharded_tensor import ShardedTensor if len(args) == 0 and (kwargs is None or len(kwargs) == 0): @@ -17,7 +16,7 @@ def _basic_validation(op, args=(), kwargs=None): def is_distributed_tensor(e): nonlocal has_distributed_tensor - if isinstance(e, (_PartialTensor, ShardedTensor)): + if isinstance(e, ShardedTensor): has_distributed_tensor = True tree_map(is_distributed_tensor, args) @@ -34,7 +33,7 @@ def is_distributed_tensor(e): def validate_pg(e): nonlocal cur_pg - if isinstance(e, (_PartialTensor, ShardedTensor)): + if isinstance(e, ShardedTensor): if cur_pg is not None and e._process_group is not cur_pg: raise RuntimeError( 'All distributed tensors should use the ' diff --git a/torch/distributed/_shard/partial_tensor.py b/torch/distributed/_shard/partial_tensor.py deleted file mode 100644 index 9c1aefbf2d3f..000000000000 --- a/torch/distributed/_shard/partial_tensor.py +++ /dev/null @@ -1,321 +0,0 @@ -import functools -import warnings -from typing import Callable, Dict, TYPE_CHECKING - -import torch -import torch.distributed as dist -import torch.distributed._shard.sharding_spec as shard_spec -from torch.distributed._shard._utils import ( - DEPRECATE_MSG, -) -from torch.distributed import distributed_c10d -from torch.distributed.nn.functional import ( - reduce_scatter, -) -from torch.distributed._shard.common_op_utils import _register_default_op -from torch.distributed._shard.op_registry_utils import _decorator_func -from torch.utils._pytree import tree_map - -if TYPE_CHECKING: - # Only include ShardedTensor when do type checking, exclude it - # from run-time to resolve circular dependency. - from torch.distributed._shard.sharded_tensor import ShardedTensor - -# Custom PartialTensor ops -_PARTIAL_TENSOR_OPS: Dict[Callable, Callable] = {} - -def _custom_partial_tensor_op(func): - """ - Decorate for custom partial tensor op - Args: - func(Callable): Torch function for which we want to provide a PartialTensor - implementation (ex: torch.nn.functional.linear) - """ - return functools.partial( - _decorator_func, - op=func, - op_table=_PARTIAL_TENSOR_OPS - ) - -class _PartialTensor(torch.Tensor): - """ - PartialTensor is an abstraction to represent Tensors that need - aggregation across multiple devices and multiple processes. - - PartialTensor is initialized in an SPMD like fashion where each rank - initializes the PartialTensor. The PartialTensor object on each rank - then only stores the local partial shard, process group and the - aggregation way to get a full tensor. - - PartialTensor doesn't provide any Tensor like operations but is a - wrapper providing the Tensor representing the local partial shard. - - We assume the size of each local tensor to be exactly the same. - - Users can apply custom distributed sharded computations on top of - this primitive. - - Args: - local_partial_shard (Tensor): Partial result stored across ranks. - process_group (ProcessGroup): The process group to aggregate on. - reduce_op (distributed_c10d.ReduceOp): Way to aggregate the partial result. - Default: ``distributed_c10d.ReduceOp.SUM`` - - Examples: - >>> # All tensors below are of torch.int64 type. - >>> # We have 2 process groups, 2 ranks. - >>> # xdoctest: +SKIP - >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank - >>> tensor = torch.cat([tensor, tensor + 2]) - >>> tensor - tensor([1, 2, 3, 4]) # Rank 0 - tensor([3, 4, 5, 6]) # Rank 1 - >>> partial_tensor = _PartialTensor(tensor, distributed_c10d.ReduceOp.MAX) - >>> sharding_dim = 0 - >>> collect_spec = shard_spec.ChunkShardingSpec( - dim=sharding_dim, - placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - ], - ) - >>> complete_tensor = partial_tensor.reshard(collect_spec) - >>> complete_tensor - ShardedTensor( - ShardedTensorMetadata( - shards_metadata=[ - ShardMetadata(shard_offsets=[0], shard_sizes=[2], placement=rank:0/cuda:0), - ShardMetadata(shard_offsets=[2], shard_sizes=[2], placement=rank:1/cuda:1)], - size=torch.Size([4]) - ) - >>> complete_tensor.local_tensor() - tensor([3, 4]) # Rank 0 - tensor([5, 6]) # Rank 1 - - >>> # All tensors below are of torch.cfloat type. - >>> # We have 2 process groups, 2 ranks. - >>> tensor = torch.tensor([1, 2]) + 2 * rank - >>> tensor = torch.cat([tensor, tensor + 2]) - >>> tensor - tensor([1, 2, 3, 4]) # Rank 0 - tensor([3, 4, 5, 6]) # Rank 1 - >>> partial_tensor = _PartialTensor(tensor) - >>> complete_tensor = partial_tensor.reshard(collect_spec) - >>> complete_tensor - ShardedTensor( - ShardedTensorMetadata( - shards_metadata=[ - ShardMetadata(shard_offsets=[0], shard_sizes=[2], placement=rank:0/cuda:0), - ShardMetadata(shard_offsets=[2], shard_sizes=[2], placement=rank:1/cuda:1)], - size=torch.Size([4]) - ) - >>> complete_tensor.local_tensor() - tensor([4, 6]) # Rank 0 - tensor([8, 10]) # Rank 1 - """ - - _process_group: distributed_c10d.ProcessGroup - _local_shard: torch.Tensor - _reduce_op: distributed_c10d.ReduceOp - - __slots__ = ["_process_group", "_local_shard", "_reduce_op"] - - def __new__(cls, local_shard, process_group=None, reduce_op=distributed_c10d.ReduceOp.SUM): - warnings.warn(DEPRECATE_MSG) - r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] - cls, - local_shard.size(), - dtype=local_shard.dtype, - layout=local_shard.layout, - pin_memory=local_shard.is_pinned(), - requires_grad=local_shard.requires_grad) # type: ignore[arg-type] - r._process_group = ( # type: ignore[attr-defined] - process_group - if process_group is not None - else distributed_c10d._get_default_group() - ) - r._reduce_op = reduce_op - r._local_shard = local_shard - return r - - def __post_init__(self): - if not isinstance(self._reduce_op, distributed_c10d.ReduceOp): - raise ValueError( - "reduce_op needs to be a member of distributed_c10d.ReduceOp." - ) - - def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> "ShardedTensor": - """ - The reshard happens in two steps logically: - - 1. Aggregate all the shards of the partial tensor. - 2. Shard this tensor according to the provided spec. - - In reality, for the sake of performance, we consolidate all partial tensors - across multiple ranks and covert to a sharded tensor in one step. - - Args: - resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): - The specification describing how we reshard the aggregated local result. - - Returns: - A :class:`ShardedTensor` filled with local aggregated result. - """ - from torch.distributed._shard.sharded_tensor.api import ShardedTensor - - warnings.warn(DEPRECATE_MSG) - if not isinstance(resharding_spec, shard_spec.ChunkShardingSpec): - raise NotImplementedError("Only ChunkShardingSpec supported for reshard.") - if self._local_shard.is_complex(): - raise NotImplementedError("Only real partial tensor supported for reshard.") - sharding_dim = int(resharding_spec.dim) # type: ignore[attr-defined] - chunk_mode_res = self._local_shard.size(sharding_dim) % self._process_group.size() - local_shard = self._local_shard - # Add padding when the size is not divisible by the world size. - if chunk_mode_res != 0: - padding = [0] * (local_shard.dim() * 2) - padding[-1] = self._process_group.size() - chunk_mode_res - local_shard = torch.nn.functional.pad( - local_shard, - tuple(padding), - "constant", - 0, - ) - current_rank = dist.get_rank(self._process_group) # type: ignore[attr-defined] - rank_idx = None - rearrange_local_shards = False - indices = [0] * self._process_group.size() - for idx, placement in enumerate(resharding_spec.placements): # type: ignore[attr-defined] - if placement.rank() == current_rank: # type: ignore[index, union-attr] - rank_idx = idx # type: ignore[attr-defined] - if placement.rank() != idx: # type: ignore[index, union-attr] - rearrange_local_shards = True - indices[placement.rank()] = idx # type: ignore[index, union-attr] - - local_shards = local_shard.chunk(self._process_group.size(), dim=sharding_dim) - if rearrange_local_shards: - # Need to re-arrange original shard_dim of output_tensor_list. - local_shards = [local_shards[idx] for idx in indices] # type: ignore[call-overload] - local_result = reduce_scatter( - torch.empty_like(local_shards[0]), - list(local_shards), - op=self._reduce_op, - group=self._process_group, - ) - - sharded_tensor_size = self._local_shard.size() - # Remove padding when the size is not divisible by the world size. - if chunk_mode_res != 0: - uneven_local_shards = self._local_shard.chunk( - self._process_group.size(), dim=sharding_dim - ) - expected_size = uneven_local_shards[rank_idx].size() # type: ignore[index] - if local_result.size() != expected_size: - local_result = local_result.narrow( - sharding_dim, - 0, - expected_size[sharding_dim], - ) - return ShardedTensor._init_from_local_tensor( - local_result, - resharding_spec, - sharded_tensor_size, - process_group=self._process_group, - ) - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - warnings.warn(DEPRECATE_MSG) - # Find process_group - process_group = None - - def find_process_group(e): - nonlocal process_group - if process_group is None and isinstance(e, _PartialTensor): - process_group = e._process_group - - tree_map(find_process_group, args) - tree_map(find_process_group, kwargs) - - if func in _PARTIAL_TENSOR_OPS: - return _PARTIAL_TENSOR_OPS[func](types, args, kwargs, process_group) - - # Need to disable all dispatch to print args and kwargs appropriately. - guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] - try: - with torch._C.DisableTorchFunctionSubclass(): - raise RuntimeError( - f"torch function '{func.__name__}', with args: {args} and " - f"kwargs: {kwargs} not supported for PartialTensor!") - finally: - del guard - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - raise RuntimeError( - f"A {cls.__name__} object is being used from c++ " - f"while calling {func.__module__}.{func.__name__} " - "but the there is no custom __torch_dispatch__ implementation for it." - ) - - def __repr__(self): - return f"PartialTensor({super().__repr__()})" - -def _transpose_impl(types, args=(), kwargs=None, process_group=None): - partial_tensor = args[0] - input = partial_tensor._local_shard - dim0 = args[1] - dim1 = args[2] - return _PartialTensor( - torch.transpose(input, dim0, dim1), - process_group, - partial_tensor._reduce_op - ) - -@_custom_partial_tensor_op(torch.Tensor.transpose) -def partial_transpose(types, args=(), kwargs=None, process_group=None): - return _transpose_impl(types, args, kwargs, process_group) - -@_custom_partial_tensor_op(torch.transpose) -def partial_torch_transpose(types, args=(), kwargs=None, process_group=None): - return _transpose_impl(types, args, kwargs, process_group) - -@_custom_partial_tensor_op(torch.cat) -def partial_cat(types, args=(), kwargs=None, process_group=None): - input_list = args[0] - if len(input_list) == 0: - raise RuntimeError('Empty list of tensors to torch.cat!') - - local_shards = [] - for idx, input in enumerate(input_list): - if not isinstance(input, _PartialTensor): - raise RuntimeError('All inputs need to be an instance of _PartialTensor') - if idx == 0: - reduce_op = input._reduce_op - elif reduce_op != input._reduce_op: - raise RuntimeError( - 'All _PartialTensor reduce_ops need to be the same, found: ' - '{reduce_op} and {input._reduce_op}' - ) - - local_shards.append(input._local_shard) - - if kwargs is None: - dim = 0 - else: - if 'out' in kwargs: - raise RuntimeError('"out" kwarg is not supported!') - dim = kwargs['dim'] if 'dim' in kwargs else 0 - - return _PartialTensor(torch.cat(local_shards, dim), process_group, input._reduce_op) - -# Tensor properties access -_register_default_op(torch.Tensor.requires_grad.__get__, _custom_partial_tensor_op) # type: ignore[attr-defined] -_register_default_op(torch.Tensor.shape.__get__, _custom_partial_tensor_op) # type: ignore[attr-defined] -_register_default_op(torch.Tensor.dtype.__get__, _custom_partial_tensor_op) # type: ignore[attr-defined] -_register_default_op(torch.Tensor.layout.__get__, _custom_partial_tensor_op) # type: ignore[attr-defined] -_register_default_op(torch.Tensor.size, _custom_partial_tensor_op) -_register_default_op(torch.Tensor.dim, _custom_partial_tensor_op) -_register_default_op(torch.Tensor.ndim.__get__, _custom_partial_tensor_op) # type: ignore[attr-defined] -_register_default_op(torch.Tensor.is_contiguous, _custom_partial_tensor_op) -_register_default_op(torch.Tensor.contiguous, _custom_partial_tensor_op) diff --git a/torch/distributed/_shard/replicated_tensor.py b/torch/distributed/_shard/replicated_tensor.py deleted file mode 100644 index 6a4217940d82..000000000000 --- a/torch/distributed/_shard/replicated_tensor.py +++ /dev/null @@ -1,173 +0,0 @@ -import warnings -import torch -import torch.distributed as dist - -from torch.distributed._shard.sharded_tensor.api import ShardedTensor -from torch.distributed._shard._utils import ( - DEPRECATE_MSG, -) -from torch.distributed import distributed_c10d -from torch.overrides import get_default_nowrap_functions - -_REPLICATED_WITH_NON_TENSOR_ALLOWLIST = [ - # List of ops where if parameters are a combination of ReplicatedTensors - # and non-tensors, we can still return a ReplicatedTensor as the result. - torch.unsqueeze, - torch.Tensor.unsqueeze, - torch.Tensor.__getitem__, -] - -warnings.warn(DEPRECATE_MSG) - -class ReplicatedTensor(torch.Tensor): - """ - ReplicatedTensor represents a tensor which is replicated across the `world_size` and - has the same value on each rank. - - ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together - with ShardedTensor/Tensor together to express different types of computation. The - inter-op rules defined as (using torch.add as an example op): - ReplicatedTensor + ReplicatedTensor = ReplicatedTensor - ReplicatedTensor + torch.Tensor = torch.Tensor - ReplicatedTensor + ShardedTensor = ShardedTensor - ReplicatedTensor + other type (i.e. Scalar) = other type - - NOTE: We do not gurantee equal content of ReplicatedTensor across nodes after its - construction. Although we defined proper inter-op rules to make sure ReplicatedTensor - stays the same, there's no enforcement on it (i.e. if you manually modify content on - some ranks, the modified value will not automatically get synced to other nodes). If - you wish to manually validate tensors are the same across ranks, use `validate()`. - - """ - _process_group: distributed_c10d.ProcessGroup - - __slots__ = ["_process_group"] - - def __new__(cls, data=None, process_group=None): - if data is None: - data = torch.empty(0) - r = torch.Tensor._make_subclass(cls, data, data.requires_grad) # type: ignore[arg-type] - r._process_group = ( # type: ignore[attr-defined] - process_group - if process_group is not None - else distributed_c10d._get_default_group() - ) - return r - - def __deepcopy__(self, memo): - if id(self) in memo: - return memo[id(self)] - else: - result = type(self)(self.data.clone(memory_format=torch.preserve_format), self._process_group) - memo[id(self)] = result - return result - - def __repr__(self): - return f"ReplicatedTensor({super().__repr__()})" - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - # We will re-dispatch the execution to ShardedTensor __torch_function__ - # if we find there're ShardedTensor operands. We will also check if args/kwargs - # are all replicated tensor operands, we have to do this to ensure we do not - # converting results back to ReplicatedTensor if not all operands are replicated. - all_replicated = True - replicated_with_non_tensor = True - replicated_pg = None - - def dispatch_arg(arg): - # This function returns a tuple, first element represents whether the op been - # executed, the second element represents the result of the execution - nonlocal replicated_pg, all_replicated, replicated_with_non_tensor - if isinstance(arg, ShardedTensor): - # redispatch to ShardedTensor - # TODO: handle ShardedTensor/PartialTensor inter-op with ReplicatedTensor - return True, arg.__torch_function__(func, types, args, kwargs) - if isinstance(arg, ReplicatedTensor): - if replicated_pg is None: - replicated_pg = arg._process_group - elif replicated_pg != arg._process_group: - raise RuntimeError( - f"ReplicatedTensor operands must be in the same process group " - f"in torch function '{func.__name__}', but found at least two " - f"ReplicatedTensor operands in different process groups! ") - elif isinstance(arg, torch.Tensor): - replicated_with_non_tensor = False - all_replicated = False - else: - all_replicated = False - - return False, None - - for arg in args: - redispatched, res = dispatch_arg(arg) - if redispatched: - return res - - if kwargs is not None: - for k, v in kwargs.items(): - redispatched, res = dispatch_arg(v) - if redispatched: - return res - - # We cann't do super().__torch_function__() as it implicitly convert the result - # back to tensor subclasses, where in our case, we need to control the output type - # base on the inter-op rules we defined. - with torch._C.DisableTorchFunctionSubclass(): - rs = func(*args, **kwargs) - if func in get_default_nowrap_functions(): - return rs - - result_not_replicated = isinstance(rs, torch.Tensor) and not isinstance(rs, ReplicatedTensor) - should_convert_to_replicated = all_replicated or ( - replicated_with_non_tensor and func in _REPLICATED_WITH_NON_TENSOR_ALLOWLIST - ) - if result_not_replicated and should_convert_to_replicated: - # if all operands are ReplicatedTensors and does not get dispatched to ShardedTensor - # __torch_function__, result is a torch.Tensor, then we convert and return a - # ReplicatedTensor according to our inter-op rule - rs = rs.as_subclass(ReplicatedTensor) # type: ignore[arg-type] - # propagate the process_group field to result - rs._process_group = replicated_pg # type: ignore[attr-defined] - - return rs - - def validate(self) -> bool: - """ - Validate the ReplicatedTensor is legit by all gathering tensors on all ranks - and check to make sure they are the same. - - If there's some ranks with different values, a ValueError will be raised. - - Keyword args: - process_group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. - - Returns: - True if validation succeed. - """ - world_size = dist.get_world_size(self._process_group) - current_rank = dist.get_rank(self._process_group) - - tensors_on_rank = [torch.empty_like(self) for _ in range(world_size)] - - dist.all_gather(tensors_on_rank, self, group=self._process_group) - # validate and check if all tensors are equal - for rank, tensor in enumerate(tensors_on_rank): - if not torch.allclose(self, tensor): - raise ValueError( - f"ReplicatedTensor have different values on rank {current_rank} and {rank}") - - return True - - def __setstate__(self, state): - with torch._C.DisableTorchFunctionSubclass(): - self.data = state - self.requires_grad = state.requires_grad - from torch.distributed._shard.api import _get_current_process_group - self._process_group = _get_current_process_group() - - def __getstate__(self): - return self.data diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py index cba094076f6b..386e53885a9f 100644 --- a/torch/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/__init__.py @@ -1,12 +1,10 @@ # coding=utf-8 -import copy import functools from typing import List import torch import torch.distributed._shard.sharding_spec as shard_spec -from torch.distributed._shard.partial_tensor import _PartialTensor from .api import ( _CUSTOM_SHARDED_OPS, diff --git a/torch/distributed/_shard/sharded_tensor/_ops/__init__.py b/torch/distributed/_shard/sharded_tensor/_ops/__init__.py index 1243a1d2396d..7fdb56048a11 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/__init__.py @@ -9,7 +9,6 @@ from .init import kaiming_uniform_, normal_, uniform_, constant_ # Import all ChunkShardingSpec ops -from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.linear import sharded_linear from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import sharded_embedding from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import sharded_embedding_bag from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.softmax import sharded_softmax diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py deleted file mode 100644 index e38f1dc15e7c..000000000000 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py +++ /dev/null @@ -1,394 +0,0 @@ -from typing import List - -import torch -import torch.distributed as dist -from torch.autograd import Function -from torch.distributed.nn.functional import ( - _all_gather_base, - all_to_all_single, -) -from torch.distributed._shard.partial_tensor import _PartialTensor -from torch.distributed._shard.sharded_tensor import ( - ShardedTensor, -) -from torch.distributed._shard.sharding_spec import ChunkShardingSpec -from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op -from torch.distributed._shard.sharding_spec._internals import ( - get_split_size, - get_chunked_dim_size, - get_chunk_sharding_params, -) - -from ._common import ( - _result_distribute_with_col_rearrange, -) - - -@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.linear) -def sharded_linear(types, args, kwargs, pg): - """ - Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. - This method computes a sharded linear and has the following limitations: - - 1. Supports only sharding of ``weight``. - 2. Supports only ``ChunkShardingSpec``. - 3. Supports only a single local shard per rank. - 4. Tailored for Megatron-LM style model(tensor) parallelism. Further API - calls are needed if a fully synced local tensor is needed. - Megatron-LM paper link: https://arxiv.org/abs/1909.08053 - - Based on the dimension that the weight is sharded on, there are two - algorithms: - - ROWWISE SHARDING - ================ - For row-wise sharding the weight is sharded on dimension 1, but this is - row-wise since the actual computation for the linear layer involves - transposing the weight: :math:`y = xA^T + b` - - The overall algorithm can be best explained with an example. Let's assume - the dims for x are (13 x 16) and A are (17 x 16) and A is sharded across - 4 GPUs creating shards of (17 x 4). The algorithm is as follows: - - 1. First the input is split on the column dimension to create shards of - (13 x 4) and communicated to all other ranks. Since we are running in - an SPMD mode with each rank having distinct input, this is done via - an all2all run on all ranks. - 2. Now each (13 x 4) shard on each GPU is multiplied with the local shard - (4 x 17) (transposed) resulting in a (13 x 17) matrix which is the same - size that we need for the global result which would be (13 x 16) - multiplied by (16 x 17). But the final result needs to be aggregated - across the rest of the ranks. - 3. Here we just return the partial result here. One can call API - aggregate_partial_tensor_list to get the aggregated final result. - The API uses a reduce_scatter operation ensuring each rank - aggregates its own result. This is essentially a sum operation across - all the (13 x 17) local computations we did for each rank. - 4. For partial result, we only add 1 / n of the bias term to the partial - result. n is # of all GPUs. - - COLWISE SHARDING - ================ - For col-wise sharding the weight is sharded on dimension 0, but this is - col-wise since the actual computation for the linear layer involves - transposing the weight: :math:`y = xA^T + b` - - The overall algorithm can be best explained with an example. Let's assume - the dims for x are (13 x 17) and A are (16 x 17) and A is sharded across - 4 GPUs creating shards of (4 x 17). The algorithm is as follows: - - 1. First the input is broadcasted to all ranks, since this is SPMD we - actually do an all_gather for all the inputs resulting in 4 (13 x 17) - inputs on each rank. - 2. Next we perform local matmuls by multiplying each input (13 x 17) - with the local shard (17 x 4) (transposed). This results in 4 (13 x 4) - matrices on each rank. - 3. Next, we stack them into a (4 x 13 x 4) tensor and build a sharded - tensor across 4 ranks. - 4. To merge them into a fully-sync local tensor, one can call API - merge_sharded_local_results. - This API concat these 4 matrices and perform an all2all to share the - appropriate (13 x 4) matrices to each rank. Specifically, each rank - receives a (13 x 16) matrix which is basically the size of the result. - 5. If placements are not in order any appropriate rearrangement of rows - are done for the (13 x 16) matrix and finally the bias term is added. - """ - # Validate input params - _validate_linear_op_param(args, kwargs) - input = args[0] - weight = args[1] - bias = args[2] - - local_shard = weight.local_tensor() - local_shard_t = local_shard.t() - sharding_dim = weight._sharding_spec.dim - world_size = dist.get_world_size(pg) - rank = dist.get_rank(pg) - - if sharding_dim == 1 and isinstance(input, ShardedTensor): - return _handle_row_wise_sharding_sharded_tensor( - input, world_size, weight, local_shard_t, bias, pg - ) - elif sharding_dim == 1 and isinstance(input, torch.Tensor): - return _handle_row_wise_sharding_tensor( - input, world_size, weight, rank, local_shard_t, bias, pg - ) - elif sharding_dim == 0: - return _handle_col_wise_sharding( - input, world_size, weight, rank, local_shard_t, bias, pg - ) - else: - raise RuntimeError( - f"nn.Linear weight sharded on dim {sharding_dim} not supported!" - ) - - -def _validate_linear_op_param(args, kwargs): - """ - Validate input params of sharded linear op. - - Args: - input: input of the linear layer. - weight: sharded weight tensor. - kwargs: same as normal Linear. - - Return: None. - """ - input = args[0] - weight = args[1] - bias = args[2] - - # Validate types - if not isinstance(input, torch.Tensor) and not isinstance(input, ShardedTensor): - raise TypeError("input needs to be either torch.Tensor or ShardedTensor") - if type(bias) != torch.Tensor and type(bias) != torch.nn.Parameter: - raise TypeError("bias needs to be torch.Tensor") - if not isinstance(weight, ShardedTensor): - raise TypeError("weight needs to be ShardedTensor") - if len(input.size()) < 1: # type: ignore[arg-type] - raise ValueError("Input needs to have at least 1 dim") - weight_size = weight.size() - if len(weight_size) != 2: - raise ValueError("Weight needs to have exactly 2 dims") - if len(bias.size()) != 1: - raise ValueError("Bias needs to have exactly 1 dim") - if input.size()[-1] != weight_size[1]: # type: ignore[index] - raise ValueError( - f"Input dim: {input.size()[-1]} does not match " # type: ignore[index] - f"appropriate weight dim: {weight_size[1]}" - ) - if not isinstance(weight._sharding_spec, ChunkShardingSpec): - raise ValueError("Only ChunkShardingSpec supported for ShardedTensor ops!") - if len(weight.local_shards()) != 1: - raise ValueError("Only one local shard supported!") - - -def _handle_col_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg): - """ - Entry-point function to handle the logic of col-wise sharding of weight - for Linear. (Detailed explanations of the logic can be found in the - comment for sharded_linear.) - - When the local tensor only has one dimension, we increase one more dimension - for reshard. We need to do squeeze manually to reduce the dimension later-on. - - For example, if we have: - input: size[15] - weight: size[15, 16] - world_size: 4 - - In each rank, we will have 4 * [4] tensors. We then stack them into a [4, 4] - tensor and generate a sharded tenor sharded by dim 1. - - For the rest situations, we just simply concatenate local tensors. No more actions - are needed afterward. - - Args: - input: matrix to be multiplied with the sharded weight. - world_size: number of ranks. - weight: sharded weight tensor. - rank: # of cuda process. - local_shard_t: row-wise shared local weight used for lookup. - bias: bias term of linear op. - pg: process group. - - Returns: - A :class:`ShardedTensor` object which filled with local intermediate results. - """ - # allgather the inputs first. - out_size = list(input.size()) - out_size[0] = input.size(0) * dist.get_world_size(pg) - output = torch.empty(out_size, device=input.device, dtype=input.dtype) - output = _all_gather_base(output, input, group=pg) - - # Adjust bias and perform local matmul. - (start_pos, chunk_size) = get_chunk_sharding_params( - bias.size(0), world_size, weight._sharding_spec, rank - ) - local_bias = _BiasTensorNarrow.apply( - world_size, start_pos, chunk_size, weight, pg, bias - ) - - if output.dim() == 1: - output = output.view(dist.get_world_size(pg), -1) - - if output.dim() <= 2: - # Use fused version if possible. - result = torch.addmm(local_bias, output, local_shard_t) - else: - result = output.matmul(local_shard_t) + local_bias - - # Build ShardedTensor as result. - st_size = list(result.size()) - st_size[-1] = weight.size(0) - new_sharding_spec = ChunkShardingSpec( - dim=-1, - placements=weight.sharding_spec().placements - ) - return ShardedTensor._init_from_local_tensor( - result, - new_sharding_spec, - *st_size, # type: ignore[arg-type] - process_group=pg, - ) - - -def _handle_row_wise_sharding_tensor( - input, world_size, weight, rank, local_shard_t, bias, pg -): - """ - Entry-point function to handle the logic of row-wise sharding of weight - for Linear. (Detailed explanations of the logic can be found in the - comment for sharded_linear.) - - Args: - input: matrix to be multiplied with the sharded weight. - world_size: number of ranks. - weight: sharded weight tensor. - rank: # of cuda process. - local_shard_t: row-wise shared local weight used for lookup. - bias: bias term of linear op. - pg: process group. - - Returns: - A :class:`_PartialTensor` object which stores the partial local result. - """ - # alltoall to gather all the appropriate inputs. - input_t = input.transpose(0, -1).contiguous() - input_t_size = input_t.size() - - # Compute expected size - split_size = get_split_size(input_t_size[0], world_size) - input_split_sizes = [0] * world_size - rearrange_rows = False - - for idx, placement in enumerate(weight._sharding_spec.placements): - sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx) - input_split_sizes[placement.rank()] = sharded_dim_size - if placement.rank() != idx: - rearrange_rows = True - - if rearrange_rows: - # Need to re-arrange rows of input_t for all2all. - indices: List[List[int]] = [[0]] * world_size - # When we do the chunk split, we always ensure the first N - 1 chunks get max out - # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] - # are not possible. The expected split size will be [4, 4, 4, 1]. - sharded_dim_size_max = max(input_split_sizes) - for idx, placement in enumerate(weight._sharding_spec.placements): - split_size = input_split_sizes[placement.rank()] - offset_start_idx = idx * sharded_dim_size_max - indices[placement.rank()] = list( - range(offset_start_idx, offset_start_idx + split_size) - ) - indices_flatten = [idx for indice in indices for idx in indice] - - input_t = input_t.index_select( - 0, torch.tensor(indices_flatten, device=input_t.device) - ) - - gathered_input_size = [input_split_sizes[rank] * world_size] + list( - input_t_size[1:] - ) - gathered_input = torch.empty(gathered_input_size, device=input_t.device, dtype=input_t.dtype) - - # Perform autograd enabled alltoall - all_to_all_single( - gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg - ) - - # Reshape gathered_input appropriately for matmul - shard_size = local_shard_t.size()[0] - reshaped_inputs = [ - torch.narrow(gathered_input, 0, r * shard_size, shard_size).transpose(0, -1) - for r in range(world_size) - ] - reshaped_input = torch.cat(reshaped_inputs) - if reshaped_input.dim() == 1: - reshaped_input = reshaped_input.view(-1, local_shard_t.size(0)) - - # Perform appropriate local matmul - if reshaped_input.dim() <= 2: - result = torch.addmm(_BiasTensorPartial.apply(world_size, bias), reshaped_input, local_shard_t) - else: - result = reshaped_input.matmul(local_shard_t) + _BiasTensorPartial.apply(world_size, bias) - - # Return the partial local result. - return _PartialTensor(result, pg) - - -def _handle_row_wise_sharding_sharded_tensor( - input, world_size, weight, local_shard_t, bias, pg -): - """ - Entry-point function to handle the logic of row-wise sharding of weight - for Linear when the input is a sharded tensor. (Detailed explanations - of the logic can be found in the comment for sharded_linear.) - - Args: - input: matrix to be multiplied with the sharded weight. - world_size: number of ranks. - weight: sharded weight tensor. - local_shard_t: row-wise shared local weight used for lookup. - bias: bias term of linear op. - pg: process group. - - Returns: - A :class:`_PartialTensor` object which stores the partial local result. - """ - local_input = input.local_tensor() - if input.sharding_spec().dim not in (-1, len(input.size()) - 1): - raise NotImplementedError( - "The case when the input does not come from col-wise sharded " - "linear is not supported for row-wise sharded linear." - ) - - # Use fused version if possible. - if local_input.dim() <= 2: - result = torch.addmm(_BiasTensorPartial.apply(world_size, bias), local_input, local_shard_t) - else: - result = local_input.matmul(local_shard_t) + _BiasTensorPartial.apply(world_size, bias) - - # Return the partial local result. - return _PartialTensor(result, pg) - - -class _BiasTensorNarrow(Function): - """ - Since we now return the intermediate results in a col-wise sharding. We - need to narrow the bias term in the forward while doing backward, we need - to gather all gradients of narrowed bias across all ranks. - """ - - @staticmethod - def forward(ctx, world_size, start_pos, chunk_size, weight, pg, bias): - ctx.weight = weight - ctx.pg = pg - ctx.world_size = world_size - return torch.narrow(bias, 0, start_pos, chunk_size) - - @staticmethod - def backward(ctx, grad_output): - results = [grad_output.clone()] * ctx.world_size - return (None, None, None, None, None) + ( - _result_distribute_with_col_rearrange( - results, grad_output, ctx.world_size, ctx.weight, ctx.pg - ), - ) - - -class _BiasTensorPartial(Function): - """ - Since we now only return partial results in a row-wise sharding. We need to - divide the bias term by the world size in the forward while doing backward, - we need to skip this division op. - """ - - @staticmethod - def forward(ctx, world_size, bias): - ctx.world_size = world_size - return torch.div(bias, world_size) - - @staticmethod - def backward(ctx, grad_output): - return (None, grad_output) diff --git a/torch/distributed/_tensor/__init__.py b/torch/distributed/_tensor/__init__.py index 5f7ec96497b9..a6788b037bf0 100644 --- a/torch/distributed/_tensor/__init__.py +++ b/torch/distributed/_tensor/__init__.py @@ -7,7 +7,7 @@ from torch.distributed._tensor.api import distribute_module, distribute_tensor, DTensor from torch.distributed._tensor.device_mesh import DeviceMesh, get_global_device_mesh from torch.distributed._tensor.placement_types import Placement, Replicate, Shard -from torch.distributed._tensor.utils import compute_local_tensor_size +from torch.distributed._tensor._utils import compute_local_shape # All public APIs from dtensor package __all__ = [ @@ -65,12 +65,12 @@ def zeros( assert layout == torch.strided, "layout value not supported!" torch_stride = torch._prims_common.make_contiguous_strides_for(torch_size) - local_size = compute_local_tensor_size(torch_size, device_mesh, placements) - if local_size is None: + local_shape = compute_local_shape(torch_size, device_mesh, placements) + if len(local_shape) == 0: local_tensor = torch.tensor([], dtype=dtype, requires_grad=requires_grad) else: local_tensor = torch.zeros( - local_size, + local_shape, device=device_mesh.device_type, dtype=dtype, layout=layout, diff --git a/torch/distributed/_tensor/_utils.py b/torch/distributed/_tensor/_utils.py new file mode 100644 index 000000000000..710fae0e630a --- /dev/null +++ b/torch/distributed/_tensor/_utils.py @@ -0,0 +1,74 @@ +from typing import Tuple, Sequence +from torch.distributed._tensor.device_mesh import DeviceMesh +from torch.distributed._tensor.placement_types import Shard, Placement +from torch._prims_common import ShapeType + + +def compute_local_shape( + global_shape: ShapeType, + mesh: DeviceMesh, + placements: Sequence[Placement] +) -> Tuple[int, ...]: + """ + Compute the shape of a local shard of the given DTensor on its current + coordinate of the mesh. + """ + if mesh.get_coordinate() is None: + # if rank not in the mesh, return empty shape + return () + else: + local_shape = list(global_shape) # start with global shape + ndim = len(global_shape) + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + my_coordinate = mesh.get_coordinate() + assert my_coordinate is not None, "Rank not part of mesh!" + if isinstance(placement, Shard): + shard_dim = placement.dim + assert ( + shard_dim < ndim + ), f"Sharding dim {shard_dim} greater than tensor ndim {ndim}" + local_shard_size, _ = placement._local_shard_size_on_dim( + local_shape[shard_dim], mesh_dim_size, my_coordinate[idx] + ) + assert isinstance(local_shard_size, int) + local_shape[shard_dim] = local_shard_size + + return tuple(local_shape) + + +def compute_local_offset( + global_shape: ShapeType, + mesh: DeviceMesh, + placements: Sequence[Placement] +) -> Tuple[int, ...]: + """ + Compute the offsets of a local shard of the given DTensor on its current + global rank. This is mostly used by distributed checkpointing to know the + exact offsets of the local shard. + """ + if mesh.get_coordinate() is None: + # if rank not in the mesh, return empty offset + return () + else: + local_offsets = [0] * len(global_shape) + local_shape = list(global_shape) + + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + my_coordinate = mesh.get_coordinate() + assert my_coordinate is not None, "Rank not part of mesh!" + if isinstance(placement, Shard): + shard_dim = placement.dim + assert ( + shard_dim < len(local_shape) + ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + shard_size, shard_offset = placement._local_shard_size_on_dim( + local_shape[shard_dim], + mesh_dim_size, + my_coordinate[idx], + return_offset=True, + ) + local_shape[shard_dim] = shard_size + local_offsets[shard_dim] = shard_offset + return tuple(local_offsets) diff --git a/torch/distributed/_tensor/ops/common_rules.py b/torch/distributed/_tensor/ops/common_rules.py index caf96dcf9320..55e843938768 100644 --- a/torch/distributed/_tensor/ops/common_rules.py +++ b/torch/distributed/_tensor/ops/common_rules.py @@ -5,6 +5,7 @@ from torch.fx.passes.shape_prop import TensorMetadata from torch.distributed._tensor.op_schema import OpSchema, OutputSharding from torch.distributed._tensor.ops.utils import prod +from torch.distributed._tensor._utils import compute_local_shape from torch.distributed._tensor.placement_types import DTensorSpec @@ -181,9 +182,14 @@ def merge_sharding(dim: str, a: int, b: int) -> int: d in input_dim and input_spec.dim_map[input_dim.index(d)] == mesh_dim ): - cost += prod(input_spec.local_shape) * input_spec.mesh.size( - mesh_dim + assert input_spec.tensor_meta is not None + global_shape = input_spec.tensor_meta.shape + local_shape = compute_local_shape( + global_shape, + input_spec.mesh, + input_spec.placements ) + cost += prod(local_shape) * input_spec.mesh.size(mesh_dim) costs.append(cost) d_to_keep_sharding = dims[costs.index(max(costs))] for d in dims: diff --git a/torch/distributed/_tensor/ops/view_ops.py b/torch/distributed/_tensor/ops/view_ops.py index ea04dfdef4c5..4edc7a69a928 100644 --- a/torch/distributed/_tensor/ops/view_ops.py +++ b/torch/distributed/_tensor/ops/view_ops.py @@ -13,6 +13,7 @@ prod, register_prop_rule, ) +from torch.distributed._tensor._utils import compute_local_shape from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate @@ -594,7 +595,8 @@ def register_prop_rule_map( @register_prop_rule(aten_op_overload) def reshape_prop(op_schema: OpSchema) -> OutputSharding: rules = spec.dim_map(*op_schema.args_schema, **op_schema.kwargs_schema) - input_dtensor_spec = op_schema.args_schema[0] + input_dtensor_spec = cast(DTensorSpec, op_schema.args_schema[0]) + mesh = input_dtensor_spec.mesh assert isinstance( input_dtensor_spec, DTensorSpec @@ -606,16 +608,13 @@ def reshape_prop(op_schema: OpSchema) -> OutputSharding: input_dtensor_spec.placements, tuple(global_in_shape), rules, - tuple(input_dtensor_spec.mesh.mesh.shape), + tuple(mesh.mesh.shape), ) if shard_out is not None: # no reshard needed - output_dtensor_spec = DTensorSpec( - mesh=input_dtensor_spec.mesh, - placements=shard_out, - ) - local_out_shape = output_dtensor_spec._local_shape_from_global_shape(list(global_out_shape)) + output_dtensor_spec = DTensorSpec(mesh=mesh, placements=shard_out) + local_out_shape = compute_local_shape(list(global_out_shape), mesh, shard_out) # We only need the local shape to lower the call into the local op args = op_schema.args_schema diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py index 129bdc65e786..25c9ba18175c 100644 --- a/torch/distributed/_tensor/placement_types.py +++ b/torch/distributed/_tensor/placement_types.py @@ -304,11 +304,14 @@ class DTensorSpec: mesh: DeviceMesh placements: Sequence[Placement] + # tensor meta will only be set during sharding propagation tensor_meta: Optional[TensorMetadata] = None def __hash__(self) -> int: - # TODO: tensor meta should all be part of the hash function, but we only - # use shape for now, need to fix this later + # hashing and equality check for DTensorSpec are used to cache the sharding + # propagation results. We only need to consider the mesh, placements and shape + # Caveat: we need to keep this in mind and sync hash and eq if we add more + # fields to them, if self.tensor_meta is not None: return hash((self.mesh, tuple(self.placements), self.tensor_meta.shape)) else: @@ -384,67 +387,6 @@ def sums(self) -> List[int]: if placement.is_partial() ] - def _local_shape_from_global_shape( - self, global_shape: List[int] - ) -> Tuple[int, ...]: - local_shape = global_shape # start with global shape - ndim = len(global_shape) - for idx, placement in enumerate(self.placements): - mesh_dim_size = self.mesh.size(idx) - my_coordinate = self.mesh.get_coordinate() - assert my_coordinate is not None, "Rank not part of mesh!" - if isinstance(placement, Shard): - shard_dim = placement.dim - assert ( - shard_dim < ndim - ), f"Sharding dim {shard_dim} greater than tensor ndim {ndim}" - local_shard_size, _ = placement._local_shard_size_on_dim( - local_shape[shard_dim], mesh_dim_size, my_coordinate[idx] - ) - assert isinstance(local_shard_size, int) - local_shape[shard_dim] = local_shard_size - - return tuple(local_shape) - - @property - def local_shape(self) -> Tuple[int, ...]: - """ - Compute the shape of a local shard of the given DTensor on its current - coordinate of the mesh. - """ - assert self.tensor_meta is not None, "DTensorSpec does not contain tensor meta." - return self._local_shape_from_global_shape(list(self.tensor_meta.shape)) - - @property - def local_offsets(self) -> Tuple[int, ...]: - """ - Compute the offsets of a local shard of the given DTensor on its current - global rank. This is mostly used by distributed checkpointing to know the - exact offsets of the local shard. - """ - assert self.tensor_meta is not None, "DTensorSpec does not contain tensor meta." - local_offsets = [0] * len(self.tensor_meta.shape) - local_shape = list(self.tensor_meta.shape) - - for idx, placement in enumerate(self.placements): - mesh_dim_size = self.mesh.size(idx) - my_coordinate = self.mesh.get_coordinate() - assert my_coordinate is not None, "Rank not part of mesh!" - if isinstance(placement, Shard): - shard_dim = placement.dim - assert ( - shard_dim < len(local_shape) - ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" - shard_size, shard_offset = placement._local_shard_size_on_dim( - local_shape[shard_dim], - mesh_dim_size, - my_coordinate[idx], - return_offset=True, - ) - local_shape[shard_dim] = shard_size - local_offsets[shard_dim] = shard_offset - return tuple(local_offsets) - @classmethod def from_dim_map( cls, diff --git a/torch/distributed/_tensor/utils.py b/torch/distributed/_tensor/utils.py deleted file mode 100644 index 7a7808e8c3b2..000000000000 --- a/torch/distributed/_tensor/utils.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Optional, Sequence - -import torch -from torch.distributed._tensor.device_mesh import DeviceMesh -from torch.distributed._tensor.placement_types import Placement, Replicate, Shard - - -def compute_local_tensor_size( - size: torch.Size, device_mesh: DeviceMesh, placements: Sequence[Placement] -) -> Optional[torch.Size]: - """ - Args: - size(torch.Size): define the shape of the whole Dtensor. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks - placement: a sequence of :class:`Placement` type: Shard, Replicate - - Returns: - A :class:`torch.Size` for the local tensor on the device_mesh - """ - if device_mesh.get_coordinate() is None: - return None - else: - local_size = list(size) - rank_coordinates = device_mesh.get_coordinate() - if rank_coordinates is None: - return None - for idx, placement in enumerate(placements): - if isinstance(placement, Replicate): - continue - elif isinstance(placement, Shard): - curr_coordinate = rank_coordinates[idx] - shard_dim = placement.dim - len_before_shard = local_size[shard_dim] - num_chucks = device_mesh.size(idx) - - len_after_shard, _ = placement._local_shard_size_on_dim( - len_before_shard, num_chucks, curr_coordinate - ) - local_size[shard_dim] = len_after_shard - else: - raise RuntimeError(f"placement type {type(placement)} not supported!") - - return torch.Size(local_size) diff --git a/torch/distributed/checkpoint/planner_helpers.py b/torch/distributed/checkpoint/planner_helpers.py index d154bd1f5877..16e0536432f0 100644 --- a/torch/distributed/checkpoint/planner_helpers.py +++ b/torch/distributed/checkpoint/planner_helpers.py @@ -8,6 +8,7 @@ from torch.distributed._shard.sharded_tensor.metadata import TensorProperties from torch.distributed._shard.sharded_tensor.shard import Shard from torch.distributed._tensor import DTensor +from torch.distributed._tensor._utils import compute_local_shape, compute_local_offset from torch.distributed._shard.sharding_spec._internals import ( _check_shard_metadata_pair_overlap, @@ -70,8 +71,8 @@ def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem: device_mesh.ndim == 1 ), "Only 1D DeviceMeshes can currently be handled." - sizes = torch.Size(tensor._spec.local_shape) - offsets = torch.Size(tensor._spec.local_offsets) + sizes = torch.Size(compute_local_shape(tensor.shape, device_mesh, tensor.placements)) + offsets = torch.Size(compute_local_offset(tensor.shape, device_mesh, tensor.placements)) return WriteItem( index=MetadataIndex(fqn, offsets), @@ -233,8 +234,8 @@ def _create_shard_from_dtensor(tensor: DTensor) -> Shard: device_mesh.ndim == 1 ), "Only 1D DeviceMeshes can currently be handled." - sizes = tensor._spec.local_shape - offsets = tensor._spec.local_offsets + sizes = torch.Size(compute_local_shape(tensor.shape, device_mesh, tensor.placements)) + offsets = torch.Size(compute_local_offset(tensor.shape, device_mesh, tensor.placements)) return Shard( tensor=tensor.to_local(), metadata=ShardMetadata( diff --git a/torch/distributed/fsdp/_exec_order_utils.py b/torch/distributed/fsdp/_exec_order_utils.py index 8c136d49507e..1ad2025e3d47 100644 --- a/torch/distributed/fsdp/_exec_order_utils.py +++ b/torch/distributed/fsdp/_exec_order_utils.py @@ -227,6 +227,8 @@ def _check_order(self, handles_key: _HandlesKey, is_training: bool) -> None: local_num_valid_indices, group=self.process_group, ) + # Copy entire tensor from D2H once to avoid per element D2H copies + world_num_valid_indices = world_num_valid_indices.cpu() # Check that all ranks plan to all-gather the same number of # parameters # TODO (awgu): Since every module has at most one handle in the @@ -251,6 +253,8 @@ def _check_order(self, handles_key: _HandlesKey, is_training: bool) -> None: dist.all_gather_into_tensor( world_indices, local_indices, group=self.process_group ) + # Copy entire tensor from D2H once to avoid per element D2H copies + world_indices = world_indices.cpu() # Check that all ranks plan to all-gather the same index parameters for (r1, i1), (r2, i2) in itertools.combinations( ( diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 6cb4055cf30a..0bf25aca8604 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -1,7 +1,7 @@ import copy import functools import warnings -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import ( Any, cast, @@ -50,6 +50,7 @@ def sorted_items(dictionary: Dict[str, Any]) -> Iterator[Tuple[str, Any]]: yield k, dictionary[k] +@dataclass class _ConsolidatedOptimState: """ This holds the consolidated optimizer state on the target rank. Positive- @@ -70,9 +71,9 @@ class _ConsolidatedOptimState: name to its value. """ - tensor_state: Dict[str, torch.Tensor] = {} - zero_dim_tensor_state: Dict[str, torch.Tensor] = {} - non_tensor_state: Dict[str, Any] = {} + tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict) + zero_dim_tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict) + non_tensor_state: Dict[str, Any] = field(default_factory=dict) class _PosDimTensorInfo(NamedTuple): diff --git a/torch/distributed/fsdp/api.py b/torch/distributed/fsdp/api.py index 6e222cd42b52..b9deb7aa6065 100644 --- a/torch/distributed/fsdp/api.py +++ b/torch/distributed/fsdp/api.py @@ -73,10 +73,11 @@ class BackwardPrefetch(Enum): This configures explicit backward prefetching, which can improve throughput but may slightly increase peak memory usage. - For NCCL backend, any collectives, even if issued in different streams, - contend for the same per-device NCCL stream, which is why the relative - order in which the collectives are issued matters for overlapping. The - different backward prefetching settings correspond to different orderings. + For a single process group using NCCL backend, any collectives, even if + issued in different streams, contend for the same per-device NCCL stream, + which is why the relative order in which the collectives are issued matters + for overlapping. The different backward prefetching settings correspond to + different orderings. - ``BACKWARD_PRE``: This prefetches the next set of parameters before the current set of parameter's gradient computation. This improves backward diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index ae272c54159d..eeb89b15e6f2 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -31,10 +31,11 @@ def __init__(self, loc, scale, validate_args=None): self.loc, self.scale = broadcast_all(loc, scale) finfo = torch.finfo(self.loc.dtype) if isinstance(loc, Number) and isinstance(scale, Number): - base_dist = Uniform(finfo.tiny, 1 - finfo.eps) + base_dist = Uniform(finfo.tiny, 1 - finfo.eps, validate_args=validate_args) else: base_dist = Uniform(torch.full_like(self.loc, finfo.tiny), - torch.full_like(self.loc, 1 - finfo.eps)) + torch.full_like(self.loc, 1 - finfo.eps), + validate_args=validate_args) transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)), ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)] super().__init__(base_dist, transforms, validate_args=validate_args) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 48696c1a086a..9f7598ce5b2e 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -24,6 +24,7 @@ from torch._subclasses import FakeTensor from .symbolic_shapes import ShapeEnv, SymDispatchMode, SymNode from torch.fx import Proxy +import torch.fx.traceback as fx_traceback from torch import SymInt, SymFloat, SymBool from torch.utils.weak import WeakTensorKeyDictionary @@ -471,6 +472,23 @@ def wrapped(*proxies): return wrapped +ORIGINAL_ATEN = None +@contextmanager +def set_original_aten_op(func): + global ORIGINAL_ATEN + if ORIGINAL_ATEN is None and fx_traceback.has_preserved_node_meta(): + ORIGINAL_ATEN = func + fx_traceback.current_meta['original_aten'] = func + try: + yield + finally: + ORIGINAL_ATEN = None + fx_traceback.current_meta['original_aten'] = None + else: + yield + + + class ProxyTorchDispatchMode(TorchDispatchMode): def __init__(self, tracer, tracing_mode): @@ -483,7 +501,7 @@ def __init__(self, tracer, tracing_mode): @count def __torch_dispatch__(self, func, types, args=(), kwargs=None): - with self.sym_mode.enable(False): + with self.sym_mode.enable(False), set_original_aten_op(func): return self.inner_torch_dispatch(func, types, args, kwargs) def __enter__(self): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 0145dc341c85..6ddc3b7d5087 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1334,7 +1334,7 @@ def create_unbacked_symfloat(self): def create_unbacked_symint(self): symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True) self.var_to_stack[symbol] = ''.join(traceback.format_list(traceback.extract_stack()[:-1])) - self.var_to_range[symbol] = ValueRanges.unknown() + self.var_to_range[symbol] = ValueRanges(-sys.maxsize - 1, sys.maxsize) return SymInt(SymNode(symbol, self, int, None)) # This is guaranteed to return a symbol or its negation is a sympy.Symbol, @@ -1361,7 +1361,10 @@ def create_symbol(self, val: int, source: Source, dyn=False) -> "sympy.Expr": # We also infer that it must be not 0/1 lower = 2 if self.specialize_zero_one else 0 - self.var_to_range[sympy_expr] = ValueRanges(lower, sympy.oo) + # NB: sys.maxsize is NOT allowed for sizes, because we use MAX_INT + # as a sentinel sometimes. Your sizevar isn't going to be + # anywhere near the max 64-bit integer anyway. + self.var_to_range[sympy_expr] = ValueRanges(lower, sys.maxsize - 1) if not dyn and self.duck_shape: # This implements duck-shaping: input sizes that match are assigned @@ -1577,12 +1580,20 @@ def _verify(expr, potential_expr): if not _simplified: for symbol, sources in symbol_to_source.items(): assert sources + assert symbol.is_integer r = self.var_to_range[symbol] bounds = [] if r.lower != -sympy.oo: bounds.append(str(r.lower)) bounds.append(source_ref(sources[0])) - if r.upper != sympy.oo: + # NB: This looks like an off-by-one error but it's not: the + # upper bound may be sys.maxsize - 1 because we intentionally + # exclude sys.maxsize from our bounds to deal with direct + # == INT_MAX guards, but it's still dumb to actually test it. + # Note that you can be off by a pretty large constant and it + # won't matter because sizes in practice will be no where near + # the 64-bit limit. + if r.upper != sympy.oo and r.upper < sys.maxsize - 1: bounds.append(str(r.upper)) if len(bounds) > 1: exprs.append(" <= ".join(bounds)) @@ -1781,14 +1792,10 @@ def size_hint(self, expr: "sympy.Expr"): def _make_data_dependent_error(self, expr, unhinted_expr): # TODO: in a Dynamo context, having user code, and having the # name of the local, will be much better - accesses = '\n\n'.join( - f"Data dependent variable '{s}' allocated at:\n{self.var_to_stack[s]}" - for s in expr.free_symbols - ) + for s in expr.free_symbols: + log.debug(f"Data dependent variable '{s}' allocated at:\n{self.var_to_stack[s]}") return GuardOnDataDependentSymNode( - f"\n\n{accesses}\n" - "GuardOnDataDependentSymNode: It appears that you're trying to get " - "a value out of symbolic int/float " + "It appears that you're trying to get a value out of symbolic int/float " "whose value is data-dependent (and thus we do not know the true value.) " f"The expression we were trying to evaluate is {expr} (unhinted: {unhinted_expr}). " "Scroll up to see where each of these data-dependent accesses originally occurred." diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index 586dd3bf75a5..6cffc4ce51f3 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -76,6 +76,7 @@ def __init__(self, module : GraphModule, garbage_collect_values : bool = True): self.env : Dict[Node, Any] = {} self.name = "Interpreter" self.garbage_collect_values = garbage_collect_values + self.extra_traceback = True if self.garbage_collect_values: # Run through reverse nodes and record the first instance of a use @@ -135,12 +136,13 @@ def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_p try: self.env[node] = self.run_node(node) except Exception as e: - msg = f"While executing {node.format_node()}" - msg = '{}\n\n{}'.format(e.args[0], msg) if e.args else str(msg) - msg += f"\nOriginal traceback:\n{node.stack_trace}" - e.args = (msg,) + e.args[1:] - if isinstance(e, KeyError): - raise RuntimeError(*e.args) from e + if self.extra_traceback: + msg = f"While executing {node.format_node()}" + msg = '{}\n\n{}'.format(e.args[0], msg) if e.args else str(msg) + msg += f"\nOriginal traceback:\n{node.stack_trace}" + e.args = (msg,) + e.args[1:] + if isinstance(e, KeyError): + raise RuntimeError(*e.args) from e raise if self.garbage_collect_values: diff --git a/torch/fx/passes/graph_drawer.py b/torch/fx/passes/graph_drawer.py index cbce8f24cd04..bcc601d0cea6 100644 --- a/torch/fx/passes/graph_drawer.py +++ b/torch/fx/passes/graph_drawer.py @@ -92,6 +92,28 @@ def __init__( ) def get_dot_graph(self, submod_name=None) -> pydot.Dot: + """ + Visualize a torch.fx.Graph with graphviz + Example: + >>> # xdoctest: +REQUIRES(module:pydot) + >>> # define module + >>> class MyModule(torch.nn.Module): + >>> def __init__(self): + >>> super().__init__() + >>> self.linear = torch.nn.Linear(4, 5) + >>> def forward(self, x): + >>> return self.linear(x).clamp(min=0.0, max=1.0) + >>> module = MyModule() + >>> # trace the module + >>> symbolic_traced = torch.fx.symbolic_trace(module) + >>> # setup output file + >>> import ubelt as ub + >>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir() + >>> fpath = dpath / 'linear.svg' + >>> # draw the graph + >>> g = FxGraphDrawer(symbolic_traced, "linear") + >>> g.get_dot_graph().write_svg(fpath) + """ if submod_name is None: return self.get_main_dot_graph() else: diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 4c933a15a326..dc574452dbae 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -130,19 +130,15 @@ def create_node(self, kind : str, target : Target, if fx_traceback.has_preserved_node_meta(): current_meta: Dict[str, Any] = fx_traceback.get_current_meta() - # Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta - # If other meta fields are needed, they can be added here stack_trace = current_meta.get("stack_trace") if stack_trace: node.stack_trace = stack_trace - - nn_module_stack = current_meta.get("nn_module_stack") - if nn_module_stack: - node.meta["nn_module_stack"] = nn_module_stack - - source_fn = current_meta.get("source_fn") - if source_fn: - node.meta["source_fn"] = source_fn + # Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta + # If other meta fields are needed, they can be added here + copy_meta_fields = ["nn_module_stack", "source_fn", "original_aten"] + for field in copy_meta_fields: + if field in current_meta: + node.meta[field] = current_meta[field] elif self.module_stack: node.meta['nn_module_stack'] = copy.copy(self.module_stack) return node diff --git a/torch/fx/verifier.py b/torch/fx/verifier.py new file mode 100644 index 000000000000..20f1393f7fa8 --- /dev/null +++ b/torch/fx/verifier.py @@ -0,0 +1,164 @@ +import itertools +import operator +from collections.abc import Iterable + +import torch +from torch._ops import OpOverload +from torch._subclasses.fake_tensor import FakeTensor +from torch.fx import GraphModule + + +ALLOWED_META_KEYS = {"spec", "stack_trace"} + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +class SpecViolationError(Exception): + pass + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def is_functional(op: OpOverload) -> bool: + return not op._schema.is_mutable + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def _check_has_fake_tensor(node: torch.fx.Node) -> None: + def _check_is_fake_tensor(val): + if isinstance(val, FakeTensor): + return True + if isinstance(val, Iterable): + return all(_check_is_fake_tensor(x) for x in val) + return False + + val = node.meta.get("val") + if not _check_is_fake_tensor(val): + raise SpecViolationError("Node.meta {} is missing val field.".format(node.name)) + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def check_valid(gm: GraphModule) -> None: # noqa: C901 + + for node in gm.graph.nodes: + # TODO(T140410192): should have fake tensor for all dialects + if node.op == "call_method": + # what is delegates in ATen dialect? + raise SpecViolationError( + "call_module can only be used for delegates, got a object of class '{}.{}' instead".format( + type(node.args[0]).__module__, type(node.args[0]).__name__ + ), + ) + + if node.op == "call_module": + raise SpecViolationError( + "call_module is not valid: got a class '{}' ".format(node.target), + ) + + if node.op == "call_function": + _check_has_fake_tensor(node) + op_name = ( + node.target.name + if hasattr(node.target, "name") + else node.target.__name__ + ) + is_builtin_func = (node.target == operator.getitem or node.target.__name__ in [ + 'while_loop', + 'cond', + ]) + if not isinstance(node.target, OpOverload) and not is_builtin_func: + raise SpecViolationError( + "Operator '{}' is not a registered Op".format(op_name), + ) + # All ops functional + # TODO(qihan): use node.target.is_functional: when PR/83134 lands + if not is_builtin_func and not is_functional(node.target): + raise SpecViolationError( + "operator '{}' is not functional".format(op_name), + ) + + if isinstance(node.target, OpOverload): + stacktrace = node.meta.get("stack_trace") + + if stacktrace is None: + raise SpecViolationError( + "node of name '{}' for operator '{}' is missing stackstrace".format( + node.name, op_name + ), + ) + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def is_valid(gm: GraphModule) -> bool: + try: + check_valid(gm) + return True + except SpecViolationError: + return False + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def check_valid_aten_dialect(gm: GraphModule) -> None: + """Raises exception if gm is not in aten dialect. + + Args: + gm: GraphModule + """ + # need to be first valid + check_valid(gm) + # Operators be aten cannonical + for n in gm.graph.nodes: + if n.op == "call_function" and isinstance(n.target, OpOverload): + if ( + torch.Tag.core not in n.target.tags # type: ignore[attr-defined] + and torch.Tag.view_copy not in n.target.tags # type: ignore[attr-defined] + ): + # NOTE(qihan): whether view_copy operators are marked as canonical is still under + # discussion. + raise SpecViolationError( + "Operator {}.{} is not Aten Canonical.".format( + n.target.__module__, n.target.__name__ + ) + ) + + # Tensors be of contiguous format + for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()): + if isinstance(param, torch.Tensor): + if not param.is_contiguous(): + raise SpecViolationError( + f"Tensors in Aten dialect must be contiguous, {name} is not contiguous" + ) + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def is_valid_aten_dialect(gm: GraphModule) -> bool: + try: + check_valid_aten_dialect(gm) + return True + except SpecViolationError: + return False + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def check_valid_edge_dialect(gm: GraphModule) -> None: + check_valid_aten_dialect(gm) + + # Additionally, edge dialect's operator must have same input dtype + for n in gm.graph.nodes: + if n.op == "call_function" and isinstance(n.target, OpOverload): + _check_has_fake_tensor(n) + dtypes = set() + for arg in n.args: + if isinstance(arg, torch.Tensor): + dtypes.add(arg.dtype) + if isinstance(arg, torch.fx.Node): + dtypes.add(arg.meta["val"].dtype) + if len(dtypes) > 1: + raise SpecViolationError( + "Operators of Edge dialect in should work on tensors of same dtype" + ) + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def is_valid_edge_dialect(gm: GraphModule) -> bool: + try: + check_valid_edge_dialect(gm) + return True + except SpecViolationError: + return False diff --git a/torch/jit/_script.py b/torch/jit/_script.py index fd0fa1f22a05..75af39edd241 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -15,7 +15,6 @@ import warnings from typing import Any, Dict, List, Set, Tuple, Union, Callable - import torch import torch._jit_internal as _jit_internal from torch.utils import set_module @@ -1249,6 +1248,10 @@ def forward(self, a) -> MyModule: return obj if isinstance(obj, ScriptFunction): return obj + from torch._dynamo.eval_frame import OptimizedModule + if isinstance(obj, OptimizedModule): + raise AttributeError("it is not possible to torch.jit.script() a torch.compile() model") + if example_inputs: # If MonkeyType is installed, enable profile directed type annotation diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py index 12d8326baec4..91ff4a39d60f 100644 --- a/torch/jit/_shape_functions.py +++ b/torch/jit/_shape_functions.py @@ -12,7 +12,7 @@ # After regenerating files, compile PyTorch. # Then run: ./build/bin/test_jit --gtest_filter=TestShapeGraphLinting.Basic # If you have enabled opinfo testing for the op, also run: -# python test/test_ops_jit.py TestJitCPU::test_variant_consistency_jit_[FAILING_OP]_cpu_float32 +# python test/test_ops_jit.py TestJitCPU.test_variant_consistency_jit_[FAILING_OP]_cpu_float32 # to reproduce errors from opinfo tests. # Example PR: https://github.com/pytorch/pytorch/pull/80860/files @@ -545,6 +545,14 @@ def cat(tensors: List[List[int]], dim: int): return result_size +def stack(tensors: List[List[int]], dim: int): + unsqueezed_tensors: List[List[int]] = [] + for tensor in tensors: + unsqueezed = unsqueeze(tensor, dim) + unsqueezed_tensors.append(unsqueezed) + return cat(unsqueezed_tensors, dim) + + def select(self: List[int], dim: int, index: int): ndim = len(self) assert ndim != 0 @@ -1100,6 +1108,7 @@ def add_bounded_compute_mapping(operator_schema: str, lower_bound_func: Callable add_shape_compute_mapping("aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor", conv_transpose2d_input) add_shape_compute_mapping("aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", flatten) add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat) +add_shape_compute_mapping("aten::stack(Tensor[] tensors, int dim=0) -> Tensor", stack) add_shape_compute_mapping("aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute) add_shape_compute_mapping("aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", movedim) add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", view) diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 4afe73496900..c9b9302cfc9b 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -1003,6 +1003,10 @@ def weighted_kernel_sum(self, weight): if not isinstance(mod, torch.nn.Module): raise AttributeError("expected torch.nn.Module as the first argument") + from torch._dynamo.eval_frame import OptimizedModule + if isinstance(mod, OptimizedModule): + raise AttributeError("it is not possible to torch.jit.trace() a torch.compile() model") + if not isinstance(inputs, dict): raise AttributeError("expected a dictionary of (method_name, input) pairs") diff --git a/torch/nn/functional.py b/torch/nn/functional.py index d7b31fd54d80..4d0ee34090da 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -5206,8 +5206,8 @@ def multi_head_attention_forward( attn_mask = _canonical_mask( mask=attn_mask, mask_name="attn_mask", - other_type=_none_or_dtype(key_padding_mask), - other_name="key_padding_mask", + other_type=None, + other_name="", target_type=q.dtype, check_other=False, ) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 1e92dc0852e2..ba1fc1c2a8d3 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1244,8 +1244,8 @@ def merge_masks(self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Te attn_mask = F._canonical_mask( mask=attn_mask, mask_name="attn_mask", - other_type=F._none_or_dtype(key_padding_mask), - other_name="key_padding_mask", + other_type=None, + other_name="", target_type=query.dtype, check_other=False, ) diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index 560028ad53c7..8a26c6ddd13a 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -219,6 +219,15 @@ def forward( target_type=src.dtype ) + mask = F._canonical_mask( + mask=mask, + mask_name="mask", + other_type=None, + other_name="", + target_type=src.dtype, + check_other=False, + ) + output = src convert_to_nested = False first_layer = self.layers[0] @@ -492,6 +501,15 @@ def forward( target_type=src.dtype ) + src_mask = F._canonical_mask( + mask=src_mask, + mask_name="src_mask", + other_type=None, + other_name="", + target_type=src.dtype, + check_other=False, + ) + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf why_not_sparsity_fast_path = '' if not src.dim() == 3: diff --git a/torch/onnx/_internal/fx/__init__.py b/torch/onnx/_internal/fx/__init__.py index 57fbf56c5284..a9f379e2fd2d 100644 --- a/torch/onnx/_internal/fx/__init__.py +++ b/torch/onnx/_internal/fx/__init__.py @@ -1,10 +1,7 @@ from .context import FxToOnnxContext -from .exporter import ( - export, - export_after_normalizing_args_and_kwargs, - export_without_parameters_and_buffers, - save_model_with_external_data, -) +from .exporter import export, export_after_normalizing_args_and_kwargs +from .serialization import save_model_with_external_data +from .symbolic_exporter import export_without_parameters_and_buffers __all__ = [ diff --git a/torch/onnx/_internal/fx/exporter.py b/torch/onnx/_internal/fx/exporter.py index 1d18cb8ab07b..e6193cdf501a 100644 --- a/torch/onnx/_internal/fx/exporter.py +++ b/torch/onnx/_internal/fx/exporter.py @@ -1,11 +1,9 @@ from __future__ import annotations import copy -import functools import inspect import itertools import operator -import os import re import warnings from types import FunctionType @@ -75,117 +73,6 @@ def _onnx_function_diagnose_call_append_symbolic_source_location( ) -class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer): - """Tracer to create ONNX-exporting friendly FX graph. - - This tracer traces models into operators. That is, - the traced graph mostly contains call_function nodes and - has no call_module nodes. The call_module nodes - are problematic to the use of make_fx(...) in ONNX - exporter. - """ - - @_beartype.beartype - def is_leaf_module( - self, module: torch.nn.Module, module_qualified_name: str - ) -> bool: - # This returns False so that all sub-modules are considered as not leaves - # and therefore expanded into operators in - # torch.fx._symbolic_trace.Tracer.call_module. - return False - - @_beartype.beartype - def to_bool(self, obj: "torch.fx.Proxy") -> bool: - # This is a hack to tracing through if-else Python blocks. - # It may generate incorrect ONNX graphs if the if-else block - return False - - -# Functions directly wrapped to produce torch.fx.Proxy so that symbolic -# data can flow through those functions. Python functions (e.g., `torch.arange`) -# not defined by pybind11 in C++ do not go though Python dispatcher, so -# they are not automatically patched by FX's Python dispatcher. -# The list below means `torch.arange`, `torch.tensor`, and so on will be -# patched. -_TORCH_METHODS_TO_PATCH: Tuple[str, ...] = ( - "arange", - "tensor", - "finfo", - "full", - "empty", -) - - -def _wrap_for_symbolic_trace(target: Callable) -> Tuple[Callable, Callable]: - """This function wraps ```target`` for symbolic tracing. - - This function wraps ```target``` so that its wrapper produces - torch.fx.Proxy in symbolic computation. The returned values are - the wrapper and then the original function. Per `_TORCH_METHODS_TO_PATCH`, - this function shall receive `torch.arange`, `torch.tensor`, etc. as inputs. - """ - - @functools.wraps(target) - def wrapper(*args, **kwargs): - proxy = None - - def check_has_proxy(v): - if isinstance(v, torch.fx.Proxy): - nonlocal proxy - proxy = v - - torch.fx.node.map_aggregate(args, check_has_proxy) - torch.fx.node.map_aggregate(kwargs, check_has_proxy) - - if proxy is not None: - return proxy.tracer.create_proxy("call_function", target, args, kwargs) - else: - return target(*args, **kwargs) - - return wrapper, target - - -@_beartype.beartype -def _module_expansion_symbolic_trace( - root: Union[torch.nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = None, -) -> "torch.fx.GraphModule": - """Trace a callable into FX graph. - - When "root" is torch.nn.Module, calls to its submodule (type: torch.nn.Module) will be - expanded into operators (e.g., torch.matmul, torch.add, +, and -) to simplify graph - structure. - """ - # For functions doesn't support symbolic tracing, create wrappers - # which produce symbolic results during tracing. - patched_torch_methods = { - target_name: _wrap_for_symbolic_trace(getattr(torch, target_name)) - for target_name in _TORCH_METHODS_TO_PATCH - } - - # Set the symbolic-tracing friendly functions so that `tracer.trace` below - # can work. - for name, (wrapper, _) in patched_torch_methods.items(): - setattr(torch, name, wrapper) - - try: - # Set up a tracer. - tracer = ModuleExpansionTracer() - # Trace the model. - graph = tracer.trace(root, concrete_args) - name = ( - root.__class__.__name__ - if isinstance(root, torch.nn.Module) - else root.__name__ - ) - return torch.fx.GraphModule(tracer.root, graph, name) - finally: - # Revert the patches for symbolic tracing. - for name, (_, wrapped) in patched_torch_methods.items(): - # wrapped is the original version of `torch.name`. - setattr(torch, name, wrapped) - - def _retrieve_or_adapt_input_to_graph_set(fx_node_arg, fx_name_to_onnxscipt_value): """Map FX value to TorchScript value. @@ -780,309 +667,6 @@ def compile(self, graph_module: "torch.fx.GraphModule", _): ) -@_beartype.beartype -def _move_placeholder_to_front(graph_module: "torch.fx.GraphModule") -> None: - """ - This function move all placeholder nodes to the front of the graph node list. - In torch.fx.Graph, placeholder is a special assignment node. If it's not - executed in the beginning, it could overwrite values computed by upstream - nodes. - """ - - graph = graph_module.graph - placeholders = [] - first_not_placeholder = None - for node in graph.nodes: - if node.op == "placeholder": - placeholders.append(node) - if first_not_placeholder is None and node.op != "placeholder": - first_not_placeholder = node - if first_not_placeholder is None: - return - for placeholder in placeholders: - first_not_placeholder.prepend(placeholder) - - -@_beartype.beartype -def _replace_get_attr_with_placeholder( - graph_module: "torch.fx.GraphModule", -) -> Tuple[torch.Tensor, ...]: - """ - Replace get_attr with placeholder. - The parameters and buffers accessed by the original get_attr are returned; - they are useful when creating random inputs for the modified graph_module. - """ - graph = graph_module.graph - replaced_attrs: List[torch.Tensor] = [] - for node in graph.nodes: - if node.op == "get_attr": - replaced_attr: Optional[torch.Tensor] = None - # get_attr could retrieve either parameter or buffer, so - # we need to try both. - try: - replaced_attr = graph_module.get_parameter(node.target) - except AttributeError: - # It's possible that model author use buffer instead of - # parameter to store trainable weights. In this case, - # 1. get_parameter will throw something like - # AttributeError: `bias` is not an nn.Parameter. - # 2. get_buffer should work. - replaced_attr = graph_module.get_buffer(node.target) - - # Reassign op type so that get_attr node becomes placeholder node. - node.op = "placeholder" - # The target name in placeholder must be a valid Python identifier. - # Thus, we replace, e.g., "module.submodule.weight" with - # "module_submodule_weight". - node.target = node.target.replace(".", "_") - # Default value is None. This is needed as long as the "graph_module" - # has optional inputs. Assume the original forward signature is - # def forward(self, x, y=None) - # and the replaced get_attr node has target "z". Then, the modified - # signature should be - # def forward(self, x, y=None, z=None) - # Without the following line, the signature will be - # def forward(self, x, y=None, z) - # , which is not valid Python code. - node.args = (None,) - - replaced_attrs.append(replaced_attr) - - return tuple(replaced_attrs) - - -@_beartype.beartype -def _trace_into_fx_graph_via_fx_symbolic_trace( - module: torch.nn.Module, - *args, - # kwargs are the keyword arguments to call "module"; that is, - # module(*args, **kwargs) must run. - **kwargs, -) -> Tuple["torch.fx.GraphModule", Tuple[Any, ...]]: - signature = inspect.signature(module.forward) - - # We hope the input kwargs will be mapped to bound.args after binding. - # If not, we will raise an error. - bound = signature.bind(*args, **kwargs) - bound.apply_defaults() - # After apply_defaults, all non keyword-only arguments are in bound.args. - # Because below code do not support keyword-word arguments, bound.kwargs - # must be empty. - assert len(bound.kwargs) == 0, bound.kwargs - - # Create inputs to call symbolic trace (torch.fx.symbolic_trace) - # Example content of concrete_args: - # concrete_args["x"] = torch.fx._symbolic_trace.PH - # concrete_args["b"] = 1 - # where "x" and "b" are argument names in "signature". - concrete_args = {} - for param_name, param_value in bound.arguments.items(): - if isinstance(param_value, torch.Tensor): - # param_value can be, e.g., a real tensor or a fake tensor. - # param_value is treated as substitutable tensor symbol (aka placeholder). - concrete_args[param_name] = torch.fx._symbolic_trace.PH - else: - concrete_args[param_name] = param_value - - return ( - _module_expansion_symbolic_trace(module, concrete_args=concrete_args), - bound.args, - ) - - -@_beartype.beartype -def export_without_parameters_and_buffers( - module: torch.nn.Module, - *args, - decomposition_table: Optional[Dict[torch._ops.OpOverload, Callable]] = None, - use_binary_format: bool = True, - opset_version: int = _constants.ONNX_DEFAULT_OPSET, - op_level_debug: bool = False, - # kwargs are the keyword arguments to call "module"; that is, - # module(*args, **kwargs) must run. - **kwargs, -) -> Tuple[ - Union["onnx.ModelProto", bytes], - "torch.fx.GraphModule", - Tuple[Any, ...], - Tuple[Any, ...], -]: - - graph_module, bound_args = _trace_into_fx_graph_via_fx_symbolic_trace( - module, *args, **kwargs - ) - - # Make sure all placeholder nodes are executed before get_attr nodes. - # Otherwise, inputs can interleave with initializers in the final ModeoProto.graph.input. - # Basically, we want - # ModeoProto.graph.input = - # [input_0, input_1, ..., input_n, weight_0, weight_1, ..., weight_m] - # and we don't want - # ModeoProto.graph.input = - # [input_0, weight_0, input_1, weight_1, ..., input_n, weight_0, weight_1, ..., weight_m] - _move_placeholder_to_front(graph_module) - # To save memory, move get_attr to input so that the generated model doesn't - # have weigh tensors. "replaced_attrs" are the list of replaced weight tensors. - replaced_attrs = _replace_get_attr_with_placeholder(graph_module) - # Move all newly created placeholder nodes to the front of the graph. - _move_placeholder_to_front(graph_module) - # Finalize the graph editing. - graph_module.recompile() - return ( - _export( - graph_module, - (*bound_args, *replaced_attrs), - opset_version=opset_version, - decomposition_table=decomposition_table, - use_binary_format=use_binary_format, - op_level_debug=op_level_debug, - ), - graph_module, - bound_args, - replaced_attrs, - ) - - -@_beartype.beartype -def _create_tensor_proto_with_external_data( - tensor: torch.Tensor, name: str, location: str, basepath: str -) -> "onnx.TensorProto": - """Create a TensorProto with external data from a PyTorch tensor. - The external data is saved to os.path.join(basepath, location). - - Args: - tensor: Tensor to be saved. - name: Name of the tensor (i.e., initializer name in ONNX graph). - location: Relative location of the external data file - (e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx"). - basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp"). - - - Reference for ONNX's external data format: - How to load? - https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187 - How to save? - https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43 - How to set ONNX fields? - https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88 - """ - tensor_proto = onnx.TensorProto() - tensor_proto.name = name - tensor_proto.data_type = torch.onnx._type_utils._SCALAR_TYPE_TO_ONNX[ # type: ignore[assignment] - torch.onnx._type_utils._DTYPE_TO_SCALAR_TYPE[tensor.dtype] - ] - tensor_proto.dims.extend(tensor.shape) - tensor_proto.data_location = onnx.TensorProto.EXTERNAL - - # Settings for saving one tensor per file. - # Offset is zero because there is no other tensor in the same file. - key_value_pairs = { - "location": location, - "offset": 0, - "length": tensor.untyped_storage().nbytes(), - } - for k, v in key_value_pairs.items(): - entry = tensor_proto.external_data.add() - entry.key = k - entry.value = str(v) - - # Actual path to write content of tensor. - external_data_file_path = os.path.join(basepath, location) - if os.path.exists(external_data_file_path): - os.remove(external_data_file_path) - - # Create external data's folder if not exists. - external_data_dir_path = os.path.dirname(external_data_file_path) - if not os.path.exists(external_data_dir_path): - # if the demo_folder directory is not present - # then create it. - os.makedirs(external_data_dir_path) - - # Create a fresh file. - with open(external_data_file_path, "xb") as data_file: - # No need to call "seek" because offset is 0. - # data_file.seek(0) - # Write tensor content to the file. - data_file.write(tensor.numpy().tobytes()) - - return tensor_proto - - -@_beartype.beartype -def save_model_with_external_data( - basepath: str, - model_location: str, - initializer_location: str, - torch_load_paths: Tuple[str, ...], - onnx_model: "onnx.ModelProto", -) -> None: - """Load PyTorch tensors from files and add to "onnx_model" as external initializers. - - Output files: - ONNX model file path: - ONNX initializer folder: os.path.join(basepath, initializer_location) - - After running this function, you can do - ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location)) - to execute the model. - - Arguments: - basepath: Base path of the external data file (e.g., "/tmp/large-onnx-model"). - model_location: Relative location of the ONNX model file. - E.g., "model.onnx" so that the model file is saved to - "/tmp/large-onnx-model/model.onnx". - initializer_location: Relative location of the ONNX initializer folder. - E.g., "initializers" so that the initializers are saved to - "/tmp/large-onnx-model/initializers". - torch_load_paths: Files which containing serialized PyTorch tensors to be saved - as ONNX initializers. They are loaded by torch.load. - onnx_model: ONNX model to be saved with external initializers. - If an input name matches a tensor loaded from "torch_load_paths", - the tensor will be saved as that input's external initializer. - """ - onnx_model_with_initializers = onnx.ModelProto() - onnx_model_with_initializers.CopyFrom(onnx_model) - onnx_input_names = [input.name for input in onnx_model.graph.input] - - for path in torch_load_paths: - state_ditc = torch.load(path) - for name, tensor in state_ditc.items(): - # Basically, "transformer.attention.self.query.weight" is mapped - # to "transformer_attention_self_query_weight" for mimicking the - # name-modifying code in FX-to-ONNX exporter. - # See function _replace_get_attr_with_placeholder for details. - refined_name = name.replace(".", "_") - - # For each refined PyTorch tensor name loaded by torch.load, - # 1. Search its best match in ONNX model. E.g., the match of - # "transformer_attention_weight" could be "attention_weight". - # 2. Set "tensor" as the initializer of the matched ONNX input. - # E.g., "tensor" is stored as the initializer of "attention_weight". - # Step 1 is required because sometimes, tensor names are stored with prefix the dictionary - # loaded by torch.load. - for onnx_input_name in onnx_input_names: - if onnx_input_name.endswith(refined_name) or refined_name.endswith( - onnx_input_name - ): - # Find a match. Change refined_name to the matched ONNX input name, so that we - # create initializer with the right ONNX name. - refined_name = onnx_input_name - break - - relative_tensor_file_path = os.path.join(initializer_location, refined_name) - # Create one file per tensor. - # tensor_proto.raw_data is stored to external file at - # os.path.join(basepath, relative_tensor_file_path). - tensor_proto = _create_tensor_proto_with_external_data( - tensor, refined_name, relative_tensor_file_path, basepath - ) - # Add the tensor_proto to the ONNX model as an initializer with external data. - onnx_model_with_initializers.graph.initializer.append(tensor_proto) - - # model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx". - onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location)) - - @_beartype.beartype def _validate_op_between_ort_torch( node: torch.fx.Node, diff --git a/torch/onnx/_internal/fx/serialization.py b/torch/onnx/_internal/fx/serialization.py new file mode 100644 index 000000000000..75aba61edbab --- /dev/null +++ b/torch/onnx/_internal/fx/serialization.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import os +from typing import Tuple + +import onnx + +import torch +from torch.onnx._internal import _beartype + + +@_beartype.beartype +def _create_tensor_proto_with_external_data( + tensor: torch.Tensor, name: str, location: str, basepath: str +) -> "onnx.TensorProto": + """Create a TensorProto with external data from a PyTorch tensor. + The external data is saved to os.path.join(basepath, location). + + Args: + tensor: Tensor to be saved. + name: Name of the tensor (i.e., initializer name in ONNX graph). + location: Relative location of the external data file + (e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx"). + basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp"). + + + Reference for ONNX's external data format: + How to load? + https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187 + How to save? + https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43 + How to set ONNX fields? + https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88 + """ + tensor_proto = onnx.TensorProto() + tensor_proto.name = name + tensor_proto.data_type = torch.onnx._type_utils._SCALAR_TYPE_TO_ONNX[ # type: ignore[assignment] + torch.onnx._type_utils._DTYPE_TO_SCALAR_TYPE[tensor.dtype] + ] + tensor_proto.dims.extend(tensor.shape) + tensor_proto.data_location = onnx.TensorProto.EXTERNAL + + # Settings for saving one tensor per file. + # Offset is zero because there is no other tensor in the same file. + key_value_pairs = { + "location": location, + "offset": 0, + "length": tensor.untyped_storage().nbytes(), + } + for k, v in key_value_pairs.items(): + entry = tensor_proto.external_data.add() + entry.key = k + entry.value = str(v) + + # Actual path to write content of tensor. + external_data_file_path = os.path.join(basepath, location) + if os.path.exists(external_data_file_path): + os.remove(external_data_file_path) + + # Create external data's folder if not exists. + external_data_dir_path = os.path.dirname(external_data_file_path) + if not os.path.exists(external_data_dir_path): + # if the demo_folder directory is not present + # then create it. + os.makedirs(external_data_dir_path) + + # Create a fresh file. + with open(external_data_file_path, "xb") as data_file: + # No need to call "seek" because offset is 0. + # data_file.seek(0) + # Write tensor content to the file. + data_file.write(tensor.numpy().tobytes()) + + return tensor_proto + + +@_beartype.beartype +def save_model_with_external_data( + basepath: str, + model_location: str, + initializer_location: str, + torch_load_paths: Tuple[str, ...], + onnx_model: onnx.ModelProto, +) -> None: + """Load PyTorch tensors from files and add to "onnx_model" as external initializers. + + Output files: + ONNX model file path: + ONNX initializer folder: os.path.join(basepath, initializer_location) + + After running this function, you can do + ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location)) + to execute the model. + + Arguments: + basepath: Base path of the external data file (e.g., "/tmp/large-onnx-model"). + model_location: Relative location of the ONNX model file. + E.g., "model.onnx" so that the model file is saved to + "/tmp/large-onnx-model/model.onnx". + initializer_location: Relative location of the ONNX initializer folder. + E.g., "initializers" so that the initializers are saved to + "/tmp/large-onnx-model/initializers". + torch_load_paths: Files which containing serialized PyTorch tensors to be saved + as ONNX initializers. They are loaded by torch.load. + onnx_model: ONNX model to be saved with external initializers. + If an input name matches a tensor loaded from "torch_load_paths", + the tensor will be saved as that input's external initializer. + """ + onnx_model_with_initializers = onnx.ModelProto() + onnx_model_with_initializers.CopyFrom(onnx_model) + onnx_input_names = [input.name for input in onnx_model.graph.input] + + for path in torch_load_paths: + state_ditc = torch.load(path) + for name, tensor in state_ditc.items(): + # Basically, "transformer.attention.self.query.weight" is mapped + # to "transformer_attention_self_query_weight" for mimicking the + # name-modifying code in FX-to-ONNX exporter. + # See function _replace_get_attr_with_placeholder for details. + refined_name = name.replace(".", "_") + + # For each refined PyTorch tensor name loaded by torch.load, + # 1. Search its best match in ONNX model. E.g., the match of + # "transformer_attention_weight" could be "attention_weight". + # 2. Set "tensor" as the initializer of the matched ONNX input. + # E.g., "tensor" is stored as the initializer of "attention_weight". + # Step 1 is required because sometimes, tensor names are stored with prefix the dictionary + # loaded by torch.load. + for onnx_input_name in onnx_input_names: + if onnx_input_name.endswith(refined_name) or refined_name.endswith( + onnx_input_name + ): + # Find a match. Change refined_name to the matched ONNX input name, so that we + # create initializer with the right ONNX name. + refined_name = onnx_input_name + break + + relative_tensor_file_path = os.path.join(initializer_location, refined_name) + # Create one file per tensor. + # tensor_proto.raw_data is stored to external file at + # os.path.join(basepath, relative_tensor_file_path). + tensor_proto = _create_tensor_proto_with_external_data( + tensor, refined_name, relative_tensor_file_path, basepath + ) + # Add the tensor_proto to the ONNX model as an initializer with external data. + onnx_model_with_initializers.graph.initializer.append(tensor_proto) + + # model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx". + onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location)) diff --git a/torch/onnx/_internal/fx/symbolic_exporter.py b/torch/onnx/_internal/fx/symbolic_exporter.py new file mode 100644 index 000000000000..a6e05253a53a --- /dev/null +++ b/torch/onnx/_internal/fx/symbolic_exporter.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import functools + +import inspect + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import onnx + +import torch +import torch.fx + +from torch.onnx import _constants +from torch.onnx._internal import _beartype +from torch.onnx._internal.fx import exporter + +# Functions directly wrapped to produce torch.fx.Proxy so that symbolic +# data can flow through those functions. Python functions (e.g., `torch.arange`) +# not defined by pybind11 in C++ do not go though Python dispatcher, so +# they are not automatically patched by FX's Python dispatcher. +# The list below means `torch.arange`, `torch.tensor`, and so on will be +# patched. +_TORCH_METHODS_TO_PATCH: Tuple[str, ...] = ( + "arange", + "tensor", + "finfo", + "full", + "empty", +) + + +class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer): + """Tracer to create ONNX-exporting friendly FX graph. + + This tracer traces models into operators. That is, + the traced graph mostly contains call_function nodes and + has no call_module nodes. The call_module nodes + are problematic to the use of make_fx(...) in ONNX + exporter. + """ + + @_beartype.beartype + def is_leaf_module( + self, module: torch.nn.Module, module_qualified_name: str + ) -> bool: + # This returns False so that all sub-modules are considered as not leaves + # and therefore expanded into operators in + # torch.fx._symbolic_trace.Tracer.call_module. + return False + + @_beartype.beartype + def to_bool(self, obj: "torch.fx.Proxy") -> bool: + # FIXME: This is a hack to tracing through if-else Python blocks. + # It may generate incorrect ONNX graphs if the if-else block + return False + + +@_beartype.beartype +def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None: + """ + This function move all placeholder nodes to the front of the graph node list. + In torch.fx.Graph, placeholder is a special assignment node. If it's not + executed in the beginning, it could overwrite values computed by upstream + nodes. + """ + + graph = graph_module.graph + placeholders = [] + first_not_placeholder = None + for node in graph.nodes: + if node.op == "placeholder": + placeholders.append(node) + if first_not_placeholder is None and node.op != "placeholder": + first_not_placeholder = node + if first_not_placeholder is None: + return + for placeholder in placeholders: + first_not_placeholder.prepend(placeholder) + + +@_beartype.beartype +def _replace_get_attr_with_placeholder( + graph_module: torch.fx.GraphModule, +) -> Tuple[torch.Tensor, ...]: + """ + Replace get_attr with placeholder. + The parameters and buffers accessed by the original get_attr are returned; + they are useful when creating random inputs for the modified graph_module. + """ + graph = graph_module.graph + replaced_attrs: List[torch.Tensor] = [] + for node in graph.nodes: + if node.op == "get_attr": + replaced_attr: Optional[torch.Tensor] = None + # get_attr could retrieve either parameter or buffer, so + # we need to try both. + try: + replaced_attr = graph_module.get_parameter(node.target) + except AttributeError: + # It's possible that model author use buffer instead of + # parameter to store trainable weights. In this case, + # 1. get_parameter will throw something like + # AttributeError: `bias` is not an nn.Parameter. + # 2. get_buffer should work. + replaced_attr = graph_module.get_buffer(node.target) + + # Reassign op type so that get_attr node becomes placeholder node. + node.op = "placeholder" + # The target name in placeholder must be a valid Python identifier. + # Thus, we replace, e.g., "module.submodule.weight" with + # "module_submodule_weight". + node.target = node.target.replace(".", "_") + # Default value is None. This is needed as long as the "graph_module" + # has optional inputs. Assume the original forward signature is + # def forward(self, x, y=None) + # and the replaced get_attr node has target "z". Then, the modified + # signature should be + # def forward(self, x, y=None, z=None) + # Without the following line, the signature will be + # def forward(self, x, y=None, z) + # , which is not valid Python code. + node.args = (None,) + + replaced_attrs.append(replaced_attr) + + return tuple(replaced_attrs) + + +@_beartype.beartype +def _trace_into_fx_graph_via_fx_symbolic_trace( + module: torch.nn.Module, + *args, + # kwargs are the keyword arguments to call "module"; that is, + # module(*args, **kwargs) must run. + **kwargs, +) -> Tuple["torch.fx.GraphModule", Tuple[Any, ...]]: + signature = inspect.signature(module.forward) + + # We hope the input kwargs will be mapped to bound.args after binding. + # If not, we will raise an error. + bound = signature.bind(*args, **kwargs) + bound.apply_defaults() + # After apply_defaults, all non keyword-only arguments are in bound.args. + # Because below code do not support keyword-word arguments, bound.kwargs + # must be empty. + assert len(bound.kwargs) == 0, bound.kwargs + + # Create inputs to call symbolic trace (torch.fx.symbolic_trace) + # Example content of concrete_args: + # concrete_args["x"] = torch.fx._symbolic_trace.PH + # concrete_args["b"] = 1 + # where "x" and "b" are argument names in "signature". + concrete_args = {} + for param_name, param_value in bound.arguments.items(): + if isinstance(param_value, torch.Tensor): + # param_value can be, e.g., a real tensor or a fake tensor. + # param_value is treated as substitutable tensor symbol (aka placeholder). + concrete_args[param_name] = torch.fx._symbolic_trace.PH + else: + concrete_args[param_name] = param_value + + return ( + _module_expansion_symbolic_trace(module, concrete_args=concrete_args), + bound.args, + ) + + +def _wrap_for_symbolic_trace(target: Callable) -> Tuple[Callable, Callable]: + """This function wraps ```target`` for symbolic tracing. + + This function wraps ```target``` so that its wrapper produces + torch.fx.Proxy in symbolic computation. The returned values are + the wrapper and then the original function. Per `_TORCH_METHODS_TO_PATCH`, + this function shall receive `torch.arange`, `torch.tensor`, etc. as inputs. + """ + + @functools.wraps(target) + def wrapper(*args, **kwargs): + proxy = None + + def check_has_proxy(v): + if isinstance(v, torch.fx.Proxy): + nonlocal proxy + proxy = v + + torch.fx.node.map_aggregate(args, check_has_proxy) + torch.fx.node.map_aggregate(kwargs, check_has_proxy) + + if proxy is not None: + return proxy.tracer.create_proxy("call_function", target, args, kwargs) + else: + return target(*args, **kwargs) + + return wrapper, target + + +@_beartype.beartype +def _module_expansion_symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, +) -> torch.fx.GraphModule: + """Trace a callable into FX graph. + + When "root" is torch.nn.Module, calls to its submodule (type: torch.nn.Module) will be + expanded into operators (e.g., torch.matmul, torch.add, +, and -) to simplify graph + structure. + """ + # For functions doesn't support symbolic tracing, create wrappers + # which produce symbolic results during tracing. + patched_torch_methods = { + target_name: _wrap_for_symbolic_trace(getattr(torch, target_name)) + for target_name in _TORCH_METHODS_TO_PATCH + } + + # Set the symbolic-tracing friendly functions so that `tracer.trace` below + # can work. + for name, (wrapper, _) in patched_torch_methods.items(): + setattr(torch, name, wrapper) + + try: + # Set up a tracer. + tracer = ModuleExpansionTracer() + # Trace the model. + graph = tracer.trace(root, concrete_args) + name = ( + root.__class__.__name__ + if isinstance(root, torch.nn.Module) + else root.__name__ + ) + return torch.fx.GraphModule(tracer.root, graph, name) + finally: + # Revert the patches for symbolic tracing. + for name, (_, wrapped) in patched_torch_methods.items(): + # wrapped is the original version of `torch.name`. + setattr(torch, name, wrapped) + + +@_beartype.beartype +def export_without_parameters_and_buffers( + module: torch.nn.Module, + *args, + decomposition_table: Optional[Dict[torch._ops.OpOverload, Callable]] = None, + use_binary_format: bool = True, + opset_version: int = _constants.ONNX_DEFAULT_OPSET, + op_level_debug: bool = False, + # kwargs are the keyword arguments to call "module"; that is, + # module(*args, **kwargs) must run. + **kwargs, +) -> Tuple[ + Union[onnx.ModelProto, bytes], + torch.fx.GraphModule, + Tuple[Any, ...], + Tuple[Any, ...], +]: + + graph_module, bound_args = _trace_into_fx_graph_via_fx_symbolic_trace( + module, *args, **kwargs + ) + + # Make sure all placeholder nodes are executed before get_attr nodes. + # Otherwise, inputs can interleave with initializers in the final ModeoProto.graph.input. + # Basically, we want + # ModeoProto.graph.input = + # [input_0, input_1, ..., input_n, weight_0, weight_1, ..., weight_m] + # and we don't want + # ModeoProto.graph.input = + # [input_0, weight_0, input_1, weight_1, ..., input_n, weight_0, weight_1, ..., weight_m] + _move_placeholder_to_front(graph_module) + # To save memory, move get_attr to input so that the generated model doesn't + # have weigh tensors. "replaced_attrs" are the list of replaced weight tensors. + replaced_attrs = _replace_get_attr_with_placeholder(graph_module) + # Move all newly created placeholder nodes to the front of the graph. + _move_placeholder_to_front(graph_module) + # Finalize the graph editing. + graph_module.recompile() + return ( + exporter._export( + graph_module, + (*bound_args, *replaced_attrs), + opset_version=opset_version, + decomposition_table=decomposition_table, + use_binary_format=use_binary_format, + op_level_debug=op_level_debug, + ), + graph_module, + bound_args, + replaced_attrs, + ) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 2b62021833c1..8a857373d0d1 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -940,7 +940,9 @@ def expand_as(g: jit_utils.GraphContext, self, other): for d in range(self_t.dim()): if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t): dims.append(d) - self = g.op("Constant", value_t=self_t.mean(dims).to(orig_type)) + self = g.op( + "Constant", value_t=self_t.mean(dims, keepdim=True).to(orig_type) + ) shape = g.op("Shape", other) return g.op("Expand", self, shape) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index f36ae38a4a46..d596480926ea 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -122,6 +122,8 @@ sample_inputs_svd, sample_inputs_linalg_det_logdet_slogdet, sample_inputs_linalg_lu, + sample_inputs_diagonal_diag_embed, + error_inputs_diagonal_diag_embed, ) from torch.testing._internal.opinfo.definitions.special import ( sample_inputs_i0_i1, @@ -4495,7 +4497,14 @@ def make_idx(n): else: alphas = (None,) - for shape, alpha in product(shapes, alphas): + if fill: + # A weird number to catch errors. + # The former one tests `index_fill.int_Scalar`, and the latter one tests `index_fill.int_Tensor`. + values = (make_arg((1,)).item(), make_arg(())) + else: + values = (None,) + + for shape, alpha, value in product(shapes, alphas, values): t = make_arg(shape) args = [] @@ -4510,8 +4519,7 @@ def make_idx(n): if copy or add: args.append(make_arg(shape)) elif fill: - # A weird number to catch errors - args.append(make_arg((1,)).item()) + args.append(value) args = tuple(args) kwargs = {} if alpha is None else {"alpha": alpha} @@ -5716,23 +5724,6 @@ def sample_inputs_diag(op_info, device, dtype, requires_grad, **kwargs): for tensor, arg in product(tensors, args): yield SampleInput(tensor.clone().requires_grad_(requires_grad), *arg) -def sample_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs): - make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) - - # Shapes for 2D Tensors - shapes_2d = ((S, S), (3, 5), (5, 3)) - - # Shapes for 3D Tensors - shapes_3d = ((S, S, S),) - - kwargs_2d = (dict(), dict(offset=2), dict(offset=2), dict(offset=1)) - kwargs_3d = (dict(offset=1, dim1=1, dim2=2), - dict(offset=2, dim1=0, dim2=1), - dict(offset=-2, dim1=0, dim2=1)) - - for shape, kwarg in chain(product(shapes_2d, kwargs_2d), product(shapes_3d, kwargs_3d)): - yield SampleInput(make_arg(shape), kwargs=kwarg) - def reference_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs): yield from sample_inputs_diagonal_diag_embed( op_info, device, dtype, requires_grad, **kwargs) @@ -5772,62 +5763,6 @@ def reference_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, continue yield SampleInput(input=make_arg(shape), kwargs=kwargs) -def error_inputs_diagonal_diag_embed(op_info, device, **kwargs): - make_arg = partial(make_tensor, device=device, dtype=torch.float32) - - shapes1d = (0, 1, (0,), (1,)) - shapes2d = ((M, L),) - shapes3d = ((M, S, L),) - - kwargs1d = {} - - kwargs2d = ( - # dim1 == dim2 is not allowed - dict(dim1=1, dim2=1), - # out of bounds dims are not allowed - dict(dim1=10000), - dict(dim2=10000), - ) - - kwargs3d = kwargs2d - - samples1d = product(shapes1d, kwargs1d) - samples2d = product(shapes2d, kwargs2d) - samples3d = product(shapes3d, kwargs3d) - - for shape, kwargs in chain(samples1d, samples2d, samples3d): - arg = make_arg(shape) - sample = SampleInput(input=arg, kwargs=kwargs) - - dim1 = kwargs.get('dim1') - dim2 = kwargs.get('dim2') - - if 'diagonal' in op_info.name: - num_dim = arg.dim() - elif op_info.name in ('diag_embed', '_refs.diag_embed'): - # these are valid inputs for diag_embed - if shape in ((0,), (1,)): - continue - num_dim = arg.dim() + 1 - else: - raise RuntimeError("should be unreachable") - - bound1 = -num_dim - bound2 = num_dim - 1 - dim_range = range(bound1, bound2 + 1) - dim1_cond = dim1 and dim1 not in dim_range - dim2_cond = dim2 and dim2 not in dim_range - - if dim1 == dim2: - err = f"diagonal dimensions cannot be identical {dim1}, {dim2}" - yield ErrorInput(sample, error_regex=err, error_type=RuntimeError) - elif dim1_cond or dim2_cond: - err_dim = dim1 if dim1_cond else dim2 - err = (r"Dimension out of range \(expected to be in range of " - rf"\[{bound1}, {bound2}\], but got {err_dim}\)") - yield ErrorInput(sample, error_regex=err, error_type=IndexError) - else: - raise RuntimeError("should be unreachable") def sample_inputs_diagonal_scatter(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) @@ -5925,10 +5860,9 @@ def sample_inputs_logit(op_info, device, dtype, requires_grad, **kwargs): low = low + domain_eps high = high - domain_eps - make_arg = partial( - make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=low, high=high) make_arg = partial(make_tensor, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + yield SampleInput(make_arg((S, S, S))) yield SampleInput(make_arg((S, S, S)), 0.2) yield SampleInput(make_arg(())) @@ -10331,9 +10265,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): reference_inputs_func=reference_inputs_diagonal_diag_embed, error_inputs_func=error_inputs_diagonal_diag_embed), OpInfo('diagonal', - # They are not strictly aliases as they have diverging defaults, but we can see them as aliases for testing purposes - # If we add tests that test the function against the alias, make linalg.diagonal into its own OpInfo - aliases=('linalg.diagonal',), aten_backward_name='diagonal_backward', dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), supports_out=False, @@ -15078,6 +15009,12 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_fwgrad_bwgrad=True, # https://github.com/pytorch/pytorch/issues/66357 check_batched_forward_grad=False, + skips=( + # RuntimeError: Mismatch on aten._unique.default: Shapes torch.Size([2]) and torch.Size([1]) are not equal! + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'), + # RuntimeError: Mismatch on aten._unique.default: Shapes torch.Size([2]) and torch.Size([1]) are not equal! + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_amp'), + ), sample_inputs_func=sample_inputs_index, reference_inputs_func=partial(sample_inputs_index, reference=True)), OpInfo('index_copy', @@ -20155,7 +20092,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_nvfuser=False, skips=( # no _refs support for Tensor.__setitem__ - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),) + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + ), ), PythonRefInfo( "_refs.index_add", @@ -20164,7 +20102,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_nvfuser=False, skips=( # no _refs support for Tensor.__setitem__ - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),) + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + ), ), PythonRefInfo( "_refs.index_fill", diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 0a8b49960ec5..569c2cb4c88a 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -16,7 +16,7 @@ from torch.testing._internal.common_methods_invocations import DecorateInfo from torch.testing._internal.common_nn import nllloss_reference, get_reduction from torch.testing._internal.common_utils import ( - freeze_rng_state, set_single_threaded_if_parallel_tbb, skipIfMps, GRADCHECK_NONDET_TOL, TEST_WITH_ROCM, IS_WINDOWS) + freeze_rng_state, set_single_threaded_if_parallel_tbb, skipIfMps, GRADCHECK_NONDET_TOL, TEST_WITH_ROCM) from types import ModuleType from typing import List, Tuple, Type, Set, Dict @@ -1470,11 +1470,6 @@ def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, train ModuleInfo(torch.nn.TransformerEncoderLayer, train_and_eval_differ=True, module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer, - decorators=[ - DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), - 'TestModule', 'test_non_contiguous_tensors', - device_type='cpu', active_if=IS_WINDOWS), - ], skips=( # No channels_last support for TransformerEncoderLayer currently. DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 55d6200eb1d3..78569f6b9b39 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1686,8 +1686,6 @@ def remove_device_and_dtype_suffixes(test_name: str) -> str: def check_if_enable(test: unittest.TestCase): test_suite = str(test.__class__).split('\'')[1] - if "USING_PYTEST" in os.environ: - test_suite = f"__main__.{test_suite.split('.')[1]}" raw_test_name = f'{test._testMethodName} ({test_suite})' if raw_test_name in slow_tests_dict: getattr(test, test._testMethodName).__dict__['slow_test'] = True @@ -1976,6 +1974,11 @@ def set_warn_always_context(new_val: bool): torch.set_warn_always(old_val) +class NoTest(): + # causes pytest to not recognize this class as a test + __test__ = False + + class TestCase(expecttest.TestCase): # NOTE: "precision" lets classes and generated tests set minimum # atol values when comparing tensors. Used by @precisionOverride and @toleranceOverride, for diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index bbed3e70f1e3..75e9cfb5078b 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -120,7 +120,8 @@ def init_pg(self, backend: str = "nccl") -> None: def destroy_pg(self) -> None: # Wait for all ranks to reach here before starting shutdown. - dist.barrier() + # FIXME dist.barrier deadlocks with multiple threads and NCCL: https://github.com/pytorch/pytorch/issues/95895 + dist.all_reduce(torch.zeros((1,), device="cuda" if torch.cuda.is_available() else "cpu")) dist.destroy_process_group() def setUp(self) -> None: diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py index 616c8cf42f4b..a3fe21a212da 100644 --- a/torch/testing/_internal/opinfo/definitions/linalg.py +++ b/torch/testing/_internal/opinfo/definitions/linalg.py @@ -1,7 +1,7 @@ import itertools import unittest from functools import partial -from itertools import product +from itertools import chain, product from typing import Iterable, List import numpy as np @@ -44,6 +44,8 @@ DecorateInfo, ErrorInput, gradcheck_wrapper_hermitian_input, + L, + M, OpInfo, ReductionOpInfo, S, @@ -790,6 +792,90 @@ def error_inputs_lstsq_grad_oriented(op_info, device, **kwargs): ) +def sample_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + + # Shapes for 2D Tensors + shapes_2d = ((S, S), (3, 5), (5, 3)) + + # Shapes for 3D Tensors + shapes_3d = ((S, S, S),) + + kwargs_2d = (dict(), dict(offset=2), dict(offset=2), dict(offset=1)) + kwargs_3d = ( + dict(offset=1, dim1=1, dim2=2), + dict(offset=2, dim1=0, dim2=1), + dict(offset=-2, dim1=0, dim2=1), + ) + + for shape, kwarg in chain( + product(shapes_2d, kwargs_2d), product(shapes_3d, kwargs_3d) + ): + yield SampleInput(make_arg(shape), kwargs=kwarg) + + +def error_inputs_diagonal_diag_embed(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + shapes1d = (0, 1, (0,), (1,)) + shapes2d = ((M, L),) + shapes3d = ((M, S, L),) + + kwargs1d = {} + + kwargs2d = ( + # dim1 == dim2 is not allowed + dict(dim1=1, dim2=1), + # out of bounds dims are not allowed + dict(dim1=10000), + dict(dim2=10000), + ) + + kwargs3d = kwargs2d + + samples1d = product(shapes1d, kwargs1d) + samples2d = product(shapes2d, kwargs2d) + samples3d = product(shapes3d, kwargs3d) + + for shape, kwargs in chain(samples1d, samples2d, samples3d): + arg = make_arg(shape) + sample = SampleInput(input=arg, kwargs=kwargs) + + dim1 = kwargs.get("dim1") + dim2 = kwargs.get("dim2") + + if "diagonal" in op_info.name: + num_dim = arg.dim() + elif op_info.name in ("diag_embed", "_refs.diag_embed"): + # these are valid inputs for diag_embed + if shape in ((0,), (1,)): + continue + num_dim = arg.dim() + 1 + else: + raise RuntimeError("should be unreachable") + + bound1 = -num_dim + bound2 = num_dim - 1 + dim_range = range(bound1, bound2 + 1) + dim1_cond = dim1 and dim1 not in dim_range + dim2_cond = dim2 and dim2 not in dim_range + + if dim1 == dim2: + err = f"diagonal dimensions cannot be identical {dim1}, {dim2}" + yield ErrorInput(sample, error_regex=err, error_type=RuntimeError) + elif dim1_cond or dim2_cond: + err_dim = dim1 if dim1_cond else dim2 + err = ( + r"Dimension out of range \(expected to be in range of " + rf"\[{bound1}, {bound2}\], but got {err_dim}\)" + ) + yield ErrorInput(sample, error_regex=err, error_type=IndexError) + else: + raise RuntimeError("should be unreachable") + + def sample_inputs_linalg_cholesky( op_info, device, dtype, requires_grad=False, **kwargs ): @@ -1172,6 +1258,19 @@ def make_input(): ), ), ), + OpInfo( + "linalg.diagonal", + aten_name="linalg_diagonal", + aten_backward_name="diagonal_backward", + dtypes=all_types_and_complex_and( + torch.bool, torch.bfloat16, torch.float16, torch.chalf + ), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diagonal_diag_embed, + error_inputs_func=error_inputs_diagonal_diag_embed, + ), OpInfo( "linalg.cholesky", aten_name="linalg_cholesky", @@ -2180,6 +2279,13 @@ def make_input(): # # torch.linalg # + PythonRefInfo( + "_refs.linalg.diagonal", + torch_opinfo_name="linalg.diagonal", + supports_out=False, + supports_nvfuser=False, + op_db=op_db, + ), ReductionPythonRefInfo( "_refs.linalg.vector_norm", torch_opinfo_name="linalg.vector_norm", diff --git a/torch/utils/data/datapipes/utils/common.py b/torch/utils/data/datapipes/utils/common.py index 0599da7860d5..e39d67ee6c81 100644 --- a/torch/utils/data/datapipes/utils/common.py +++ b/torch/utils/data/datapipes/utils/common.py @@ -78,9 +78,9 @@ def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list] continue if isinstance(fn, functools.partial): - fn_name = fn.func.__name__ + fn_name = getattr(fn.func, "__name__", repr(fn.func)) else: - fn_name = fn.__name__ + fn_name = getattr(fn, "__name__", repr(fn)) if len(non_default_kw_only) > 0: raise ValueError( diff --git a/torchgen/api/autograd.py b/torchgen/api/autograd.py index 5ff9e1ad7a55..94f3532e4f75 100644 --- a/torchgen/api/autograd.py +++ b/torchgen/api/autograd.py @@ -16,6 +16,7 @@ ) from torchgen.utils import IDENT_REGEX + # Represents a saved attribute involved in backward calculation. # Note that it can be a derived property of an input argument, e.g.: # we could save `other.scalar_type()` instead of the entire `other` tensor. @@ -305,6 +306,161 @@ def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str: return "use_type" +def is_foreach_func(f: NativeFunction) -> bool: + base_op_name = f.func.name.name + return base_op_name.base.startswith("_foreach_") and not base_op_name.inplace + + +# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function +# reference to generate derivatives. +def is_reference_for_foreach( + f: NativeFunction, + function_schema: FunctionSchema, +) -> bool: + return ( + f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base + and not function_schema.name.name.inplace + and all( + ref_arg.type in (arg.type, getattr(arg.type, "elem", None)) + for arg, ref_arg in zip( + f.func.arguments.flat_non_out, + function_schema.arguments.flat_non_out, + ) + ) + ) + + +def gen_foreach_derivativeinfo( + foreach_function: NativeFunction, + differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], + functional_info_by_signature: Dict[ + FunctionSchema, Dict[str, DifferentiabilityInfo] + ], +) -> Optional[DifferentiabilityInfo]: + ref_diff_info: Optional[DifferentiabilityInfo] = None + for function_schema in functional_info_by_signature: + if not is_reference_for_foreach(foreach_function, function_schema): + continue + if function_schema in differentiability_infos: + ref_diff_info = differentiability_infos[function_schema]["Default"] + elif ( + function_schema.signature(strip_default=True) + in functional_info_by_signature + ): + ref_diff_info = functional_info_by_signature[ + function_schema.signature(strip_default=True) + ]["Default"] + else: + raise RuntimeError( + "Reference `DifferentiabilityInfo` for {} not found".format( + foreach_function.func + ) + ) + if ref_diff_info is not None: + break + if ref_diff_info is None: + return None + + map_refarg2foreacharg, map_name2arg = {}, {} + for i, (arg, ref_arg) in enumerate( + zip( + foreach_function.func.arguments.flat_non_out, + function_schema.arguments.flat_non_out, + ) + ): + map_refarg2foreacharg[ref_arg.name] = arg.name + map_name2arg[arg.name] = arg + + all_saved_inputs, all_saved_outputs, all_var_names = [], [], [] + modified_derivative_formulas = [] + for i, derivative in enumerate(ref_diff_info.derivatives): + modified_formula = derivative.formula.replace("grad", "grads[i]").replace( + "result", "result[i]" + ) + saved_inputs, saved_outputs = [], [] + # note(crcrpar): This context seems necessary to call `cpp.argument_type` + with local.parametrize( + use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors, + use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group, + ): + for ref_input in derivative.saved_inputs: + ref_input_jit_name = ref_input.expr.split(".")[0] + mapped_name = map_refarg2foreacharg[ref_input_jit_name] + if isinstance(map_name2arg[mapped_name].type, ListType): + mapped_expr = mapped_name + "[i]" + else: + mapped_expr = mapped_name + new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr) + modified_formula = modified_formula.replace( + cast(str, ref_input.nctype.name), new_expr + ) + + nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name) + canonical_nctype = NamedCType( + nctype.name, nctype.type.remove_const_ref() + ) + saved_inputs.append( + SavedAttribute(nctype=canonical_nctype, expr=mapped_name) + ) + for ref_output in derivative.saved_outputs: + if ref_output.nctype.name == "result": + saved_outputs.append( + SavedAttribute( + nctype=NamedCType( + name="result", type=BaseCType(tensorListT) + ), + expr="result", + ) + ) + else: + raise RuntimeError("") + var_names = [map_refarg2foreacharg[var] for var in derivative.var_names] + all_var_names.extend(var_names) + all_saved_inputs.extend(saved_inputs) + all_saved_outputs.extend(saved_outputs) + modified_derivative = Derivative( + formula=modified_formula, + original_formula=derivative.formula, + var_names=tuple(var_names), + saved_inputs=tuple(saved_inputs), + saved_outputs=tuple(saved_outputs), + named_gradients=set(), + ) + modified_derivative_formulas.append(modified_derivative) + + with local.parametrize( + use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors, + use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group, + ): + args_with_derivatives = [ + Binding( + name=arg.name, + nctype=cpp.argument_type(arg, binds=arg.name), + argument=arg, + default=None, + ) + for arg in foreach_function.func.arguments.flat_non_out + if arg.name in all_var_names + ] + return DifferentiabilityInfo( + name=foreach_function.func.name.name.base, + func=foreach_function, + op="Foreach{}{}".format( + ref_diff_info.op, foreach_function.func.name.overload_name + ), + derivatives=modified_derivative_formulas, + forward_derivatives=[], + all_saved_inputs=tuple(set(all_saved_inputs)), + all_saved_outputs=tuple(set(all_saved_outputs)), + available_named_gradients=(), + used_named_gradients=set(), + args_with_derivatives=args_with_derivatives, + non_differentiable_arg_names=[], + output_differentiability=None, + output_differentiability_conditions=None, + ) + + def match_differentiability_info( native_functions: List[NativeFunction], differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], @@ -325,31 +481,6 @@ def match_differentiability_info( if schema.kind() != SchemaKind.functional } - def is_foreach_func(f: NativeFunction) -> bool: - base_op_name = f.func.name.name - return base_op_name.base.startswith("_foreach_") and not base_op_name.inplace - - def is_reference_for_foreach( - f: NativeFunction, - function_schema: FunctionSchema, - ) -> bool: - return ( - f.func.name.name.base.split("_foreach_")[-1] - == function_schema.name.name.base - and not function_schema.name.name.inplace - and ( - True - if len(f.func.arguments.post_self_positional) == 0 - else all( - ref_arg.type in (arg.type, getattr(arg.type, "elem", None)) - for arg, ref_arg in zip( - f.func.arguments.flat_non_out, - function_schema.arguments.flat_non_out, - ) - ) - ) - ) - def find_info( f: NativeFunction, ) -> Tuple[Optional[Dict[str, DifferentiabilityInfo]], bool]: @@ -386,136 +517,17 @@ def find_info( return info_dict, False # (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml` - base_op_name = f.func.name.name if is_foreach_func(f): - for function_schema in functional_info_by_signature: - if not is_reference_for_foreach(f, function_schema): - continue - if function_schema in differentiability_infos: - ref_diff_info = differentiability_infos[function_schema]["Default"] - elif ( - function_schema.signature(strip_default=True) - in functional_info_by_signature - ): - ref_diff_info = functional_info_by_signature[ - function_schema.signature(strip_default=True) - ]["Default"] - else: - raise RuntimeError( - f"Reference `DifferentiabilityInfo` for {f.func} not found: query: {function_schema}" - ) - - map_refarg2foreacharg = {} - map_name2arg = {} - for arg, ref_arg in zip( - f.func.arguments.flat_non_out, - function_schema.arguments.flat_non_out, - ): - map_refarg2foreacharg[ref_arg.name] = arg.name - map_name2arg[arg.name] = arg - - all_saved_inputs: List[SavedAttribute] = [] - all_saved_outputs: List[SavedAttribute] = [] - modified_derivative_formulas: List[Derivative] = [] - all_var_names: List[str] = [] - for derivative in ref_diff_info.derivatives: - # note(crcrpar): Assumption: `grads` and `result` always are a sequence of Tensors. - modified_formula = derivative.formula.replace( - "grad", "grads[i]" - ).replace("result", "result[i]") - - saved_inputs, saved_outputs = [], [] - with local.parametrize( - use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors, - use_ilistref_for_tensor_lists=f.part_of_structured_group, - ): - for ref_input in derivative.saved_inputs: - ref_input_jit_name = ref_input.expr.split(".")[0] - mapped_name = map_refarg2foreacharg[ref_input_jit_name] - if isinstance(map_name2arg[mapped_name].type, ListType): - mapped_expr = mapped_name + "[i]" - else: - mapped_expr = mapped_name - new_expr = ref_input.expr.replace( - ref_input_jit_name, mapped_expr - ) - modified_formula = modified_formula.replace( - cast(str, ref_input.nctype.name), new_expr - ) - - nctype = cpp.argument_type( - map_name2arg[mapped_name], binds=mapped_name - ) - canonical_nctype = NamedCType( - nctype.name, nctype.type.remove_const_ref() - ) - saved_inputs.append( - SavedAttribute( - nctype=canonical_nctype, expr=mapped_name - ) - ) - for ref_output in derivative.saved_outputs: - if ref_output.nctype.name == "result": - saved_outputs.append( - SavedAttribute( - nctype=NamedCType( - name="result", type=BaseCType(tensorListT) - ), - expr="result", - ) - ) - else: - raise RuntimeError( - f"Counterpart of {ref_output} not found" - ) - var_names = [ - map_refarg2foreacharg[var] for var in derivative.var_names - ] - all_var_names.extend(var_names) - all_saved_inputs.extend(saved_inputs) - all_saved_outputs.extend(saved_outputs) - modified_derivative = Derivative( - formula=modified_formula, - original_formula=derivative.formula, - var_names=tuple(var_names), - saved_inputs=tuple(saved_inputs), - saved_outputs=tuple(saved_outputs), - named_gradients=set(), - ) - modified_derivative_formulas.append(modified_derivative) - with local.parametrize( - use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors, - use_ilistref_for_tensor_lists=f.part_of_structured_group, - ): - args_with_derivatives = [ - Binding( - name=var, - nctype=cpp.argument_type(map_name2arg[var], binds=var), - argument=map_name2arg[var], - default=None, - ) - for var in all_var_names - ] - diff_info = DifferentiabilityInfo( - name=base_op_name.base, - func=f, - op=f"Foreach{ref_diff_info.op}{f.func.name.overload_name}", - derivatives=modified_derivative_formulas, - forward_derivatives=[], - all_saved_inputs=tuple(set(all_saved_inputs)), - all_saved_outputs=tuple(set(all_saved_outputs)), - available_named_gradients=(), - used_named_gradients=set(), - args_with_derivatives=args_with_derivatives, - non_differentiable_arg_names=[], - output_differentiability=None, - output_differentiability_conditions=None, - ) - diff_info_dict = {"Default": diff_info} - if f.func not in differentiability_infos: - differentiability_infos[f.func] = diff_info_dict - functional_info_by_signature[f.func] = diff_info_dict - return diff_info_dict, True + diff_info = gen_foreach_derivativeinfo( + f, differentiability_infos, functional_info_by_signature + ) + if diff_info is None: + return None, False + diff_info_dict = {"Default": diff_info} + if f.func not in differentiability_infos: + differentiability_infos[f.func] = diff_info_dict + functional_info_by_signature[f.func] = diff_info_dict + return diff_info_dict, True return None, False diff --git a/torchgen/gen_executorch.py b/torchgen/gen_executorch.py index e10b07742dbb..e6c99780059b 100644 --- a/torchgen/gen_executorch.py +++ b/torchgen/gen_executorch.py @@ -48,6 +48,22 @@ ) +def _sig_decl_wrapper(sig: Union[CppSignature, ExecutorchCppSignature]) -> str: + """ + A wrapper function to basically get `sig.decl(include_context=True)`. + For ATen kernel, the codegen has no idea about ET contextArg, so we + use this wrapper to add it. + """ + if isinstance(sig, ExecutorchCppSignature): + return sig.decl() + + returns_type = aten_cpp.returns_type(sig.func.returns).cpp_type() + cpp_args = [a.decl() for a in sig.arguments()] + cpp_args_str = ", ".join([contextArg.decl()] + cpp_args) + sig_decl = f"{returns_type} {sig.name()}({cpp_args_str})" + return sig_decl + + def static_dispatch( sig: Union[CppSignature, ExecutorchCppSignature], f: NativeFunction, @@ -80,7 +96,7 @@ def static_dispatch( """ return f""" // {f.namespace}::{f.func} -TORCH_API inline {sig.decl()} {{ +TORCH_API inline {_sig_decl_wrapper(sig)} {{ {static_block} }} """ @@ -116,10 +132,10 @@ def __call__(self, f: NativeFunction) -> Optional[str]: return f""" // {f.namespace}::{f.func} -TORCH_API inline {sig.decl()} {{ +TORCH_API inline {_sig_decl_wrapper(sig)} {{ return at::{sig.name()}({comma.join(e.name for e in sig.arguments())}); }} - """ +""" else: return static_dispatch( @@ -188,11 +204,10 @@ def __call__(self, f: NativeFunction) -> str: Operator( "{f.namespace}::{f.func.name}", []({contextArg.defn()}, EValue** stack) {{ - {"(void)context;" if self.use_aten_lib else ""} {code_connector.join(code_list)} EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}"); - {ret_prefix}torch::executor::{f.namespace}::{sig.name()}({"" if self.use_aten_lib else "context, "}{args_str}); + {ret_prefix}torch::executor::{f.namespace}::{sig.name()}({"context, "}{args_str}); {return_assignment} }}