Skip to content

Commit

Permalink
Update on "recompile fx.GraphModule lazily"
Browse files Browse the repository at this point in the history
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule.recompile` to 37 times and reduces its total execution time to 0.045s.



[ghstack-poisoned]
  • Loading branch information
shunting314 committed Jul 17, 2023
2 parents 6f2f418 + 322cf80 commit 01a5e32
Show file tree
Hide file tree
Showing 339 changed files with 34,081 additions and 2,479 deletions.
18 changes: 16 additions & 2 deletions .ci/docker/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,20 @@ case "$image" in
TRITON=yes
INDUCTOR_BENCHMARKS=yes
;;
pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9)
CUDA_VERSION=12.1.0
CUDNN_VERSION=8
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=9
PROTOBUF=yes
DB=yes
VISION=yes
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
CONDA_CMAKE=yes
TRITON=yes
;;
pytorch-linux-focal-py3-clang7-asan)
ANACONDA_PYTHON_VERSION=3.9
CLANG_VERSION=7
Expand Down Expand Up @@ -225,7 +239,7 @@ case "$image" in
PROTOBUF=yes
DB=yes
VISION=yes
ROCM_VERSION=5.3
ROCM_VERSION=5.4.2
NINJA_VERSION=1.9.0
CONDA_CMAKE=yes
TRITON=yes
Expand All @@ -236,7 +250,7 @@ case "$image" in
PROTOBUF=yes
DB=yes
VISION=yes
ROCM_VERSION=5.4.2
ROCM_VERSION=5.6
NINJA_VERSION=1.9.0
CONDA_CMAKE=yes
TRITON=yes
Expand Down
8 changes: 8 additions & 0 deletions .ci/docker/common/install_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,13 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then
pip_install -r /opt/conda/requirements-docs.txt
fi

# HACK HACK HACK
# gcc-9 for ubuntu-18.04 from http://ppa.launchpad.net/ubuntu-toolchain-r/test/ubuntu
# Pulls llibstdc++6 13.1.0-8ubuntu1~18.04 which is too new for conda
# So remove libstdc++6.so.3.29 installed by https://anaconda.org/anaconda/libstdcxx-ng/files?version=11.2.0
if grep 18.04.6 /etc/issue >/dev/null; then
rm /opt/conda/envs/py_$ANACONDA_PYTHON_VERSION/lib/libstdc++.so.6
fi

popd
fi
38 changes: 33 additions & 5 deletions .ci/docker/common/install_rocm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,23 @@ install_ubuntu() {
rocprofiler-dev \
roctracer-dev

# precompiled miopen kernels added in ROCm 3.5; search for all unversioned packages
# precompiled miopen kernels added in ROCm 3.5, renamed in ROCm 5.5
# search for all unversioned packages
# if search fails it will abort this script; use true to avoid case where search fails
MIOPENKERNELS=$(apt-cache search --names-only miopenkernels | awk '{print $1}' | grep -F -v . || true)
if [[ "x${MIOPENKERNELS}" = x ]]; then
echo "miopenkernels package not available"
if [[ $(ver $ROCM_VERSION) -ge $(ver 5.5) ]]; then
MIOPENHIPGFX=$(apt-cache search --names-only miopen-hip-gfx | awk '{print $1}' | grep -F -v . || true)
if [[ "x${MIOPENHIPGFX}" = x ]]; then
echo "miopen-hip-gfx package not available" && exit 1
else
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ${MIOPENHIPGFX}
fi
else
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ${MIOPENKERNELS}
MIOPENKERNELS=$(apt-cache search --names-only miopenkernels | awk '{print $1}' | grep -F -v . || true)
if [[ "x${MIOPENKERNELS}" = x ]]; then
echo "miopenkernels package not available" && exit 1
else
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ${MIOPENKERNELS}
fi
fi

# Cleanup
Expand Down Expand Up @@ -123,6 +133,24 @@ install_centos() {
rocprofiler-dev \
roctracer-dev

# precompiled miopen kernels; search for all unversioned packages
# if search fails it will abort this script; use true to avoid case where search fails
if [[ $(ver $ROCM_VERSION) -ge $(ver 5.5) ]]; then
MIOPENHIPGFX=$(yum -q search miopen-hip-gfx | grep miopen-hip-gfx | awk '{print $1}'| grep -F kdb. || true)
if [[ "x${MIOPENHIPGFX}" = x ]]; then
echo "miopen-hip-gfx package not available" && exit 1
else
yum install -y ${MIOPENHIPGFX}
fi
else
MIOPENKERNELS=$(yum -q search miopenkernels | grep miopenkernels- | awk '{print $1}'| grep -F kdb. || true)
if [[ "x${MIOPENKERNELS}" = x ]]; then
echo "miopenkernels package not available" && exit 1
else
yum install -y ${MIOPENKERNELS}
fi
fi

# Cleanup
yum clean all
rm -rf /var/cache/yum
Expand Down
4 changes: 2 additions & 2 deletions .ci/docker/requirements-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ librosa>=0.6.2 ; python_version < "3.11"
#Pinned versions:
#test that import:

mypy==0.960
mypy==1.4.1
# Pin MyPy version because new errors are likely to appear with each release
#Description: linter
#Pinned versions: 0.960
#Pinned versions: 1.4.1
#test that import: test_typing.py, test_type_hints.py

networkx==2.8.8
Expand Down
3 changes: 2 additions & 1 deletion .ci/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ test_dynamo_shard() {
test_package \
test_legacy_vmap \
test_custom_op_testing \
test_content_store \
export/test_db \
functorch/test_dims \
functorch/test_aotdispatch \
Expand Down Expand Up @@ -369,7 +370,7 @@ test_perf_for_dashboard() {
if [[ "$DASHBOARD_TAG" == *dynamic-true* ]]; then
python "benchmarks/dynamo/$suite.py" \
"${target_flag[@]}" --"$mode" --"$dtype" --backend "$backend" --dynamic-shapes \
--dynamic-batch-only --disable-cudagraphs "$@" \
--dynamic-batch-only "$@" \
--output "$TEST_REPORTS_DIR/${backend}_dynamic_${suite}_${dtype}_${mode}_cuda_${target}.csv"
fi
if [[ "$DASHBOARD_TAG" == *cppwrapper-true* ]] && [[ "$mode" == "inference" ]]; then
Expand Down
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/torchbench.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3df1cc93cccda38506f96d039d80b1888a7d6fb0
745644f391b4d11da107b2c82fe2d7a3eacf561d
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/vision.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
08c9938f3e69b4fb2d710a038479264f3c23b133
29418e34a94e2c43f861a321265f7f21035e7b19
2 changes: 2 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
- torch/_functorch/aot_autograd.py
- .ci/docker/ci_commit_pins/**
- .github/ci_commit_pins/**
- c10/core/Sym*
- torch/fx/experimental/symbolic_shapes.py

"module: cpu":
- aten/src/ATen/cpu/**
Expand Down
24 changes: 24 additions & 0 deletions .github/scripts/github_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import os
import warnings

from dataclasses import dataclass
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
Expand Down Expand Up @@ -143,3 +144,26 @@ def gh_post_commit_comment(
def gh_delete_comment(org: str, repo: str, comment_id: int) -> None:
url = f"https://api.github.com/repos/{org}/{repo}/issues/comments/{comment_id}"
gh_fetch_url(url, method="DELETE")


def gh_fetch_merge_base(org: str, repo: str, base: str, head: str) -> str:
merge_base = ""
# Get the merge base using the GitHub REST API. This is the same as using
# git merge-base without the need to have git. The API doc can be found at
# https://docs.github.com/en/rest/commits/commits?apiVersion=2022-11-28#compare-two-commits
try:
json_data = gh_fetch_url(
f"https://api.github.com/repos/{org}/{repo}/compare/{base}...{head}",
headers={"Accept": "application/vnd.github.v3+json"},
reader=json.load,
)
if json_data:
merge_base = json_data.get("merge_base_commit", {}).get("sha", "")
else:
warnings.warn(
f"Failed to get merge base for {base}...{head}: Empty response"
)
except Exception as error:
warnings.warn(f"Failed to get merge base for {base}...{head}: {error}")

return merge_base

0 comments on commit 01a5e32

Please sign in to comment.