Skip to content

Commit

Permalink
Update on "Fused attention patterns"
Browse files Browse the repository at this point in the history
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
  • Loading branch information
jansel committed Apr 9, 2023
2 parents edd46af + 02b1bff commit 75c89eb
Show file tree
Hide file tree
Showing 162 changed files with 4,583 additions and 1,868 deletions.
10 changes: 9 additions & 1 deletion .ci/onnx/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,19 @@
# shellcheck source=./common.sh
source "$(dirname "${BASH_SOURCE[0]}")/common.sh"

# Use to retry ONNX test, only retry it twice
retry () {
"$@" || (sleep 60 && "$@")
}

if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
pip -q install --user "file:///var/lib/jenkins/workspace/third_party/onnx#egg=onnx"
# TODO: This can be removed later once vision is also part of the Docker image
pip install -q --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)"
# JIT C++ extensions require ninja, so put it into PATH.
export PATH="/var/lib/jenkins/.local/bin:$PATH"
"$ROOT_DIR/scripts/onnx/test.sh"
# NB: ONNX test is fast (~15m) so it's ok to retry it few more times to avoid any flaky issue, we
# need to bring this to the standard PyTorch run_test eventually. The issue will be tracked in
# https://github.com/pytorch/pytorch/issues/98626
retry "$ROOT_DIR/scripts/onnx/test.sh"
fi
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/vision.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
781f512b01bc2324d7fdd11f0901f60571fc476f
27b8491640aac89a08624f3f70a270ee88542984
3 changes: 2 additions & 1 deletion .github/scripts/github_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def gh_fetch_json(
url: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
method: Optional[str] = None,
) -> List[Dict[str, Any]]:
headers = {"Accept": "application/vnd.github.v3+json"}
if params is not None and len(params) > 0:
Expand All @@ -63,7 +64,7 @@ def gh_fetch_json(
)
return cast(
List[Dict[str, Any]],
gh_fetch_url(url, headers=headers, data=data, reader=json.load),
gh_fetch_url(url, headers=headers, data=data, reader=json.load, method=method),
)


Expand Down
13 changes: 10 additions & 3 deletions .github/scripts/label_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, List, Tuple, TYPE_CHECKING, Union
from urllib.request import Request, urlopen

from github_utils import gh_fetch_json, GitHubComment
from github_utils import gh_fetch_url, GitHubComment

# TODO: this is a temp workaround to avoid circular dependencies,
# and should be removed once GitHubPR is refactored out of trymerge script.
Expand Down Expand Up @@ -77,12 +77,19 @@ def gh_get_labels(org: str, repo: str) -> List[str]:
def gh_add_labels(
org: str, repo: str, pr_num: int, labels: Union[str, List[str]]
) -> None:
gh_fetch_json(
f"https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/labels",
gh_fetch_url(
url=f"https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/labels",
data={"labels": labels},
)


def gh_remove_label(org: str, repo: str, pr_num: int, label: str) -> None:
gh_fetch_url(
url=f"https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/labels/{label}",
method="DELETE",
)


def get_release_notes_labels(org: str, repo: str) -> List[str]:
return [
label
Expand Down
39 changes: 29 additions & 10 deletions .github/scripts/trymerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,18 @@
GitRepo,
patterns_to_regex,
)
from label_utils import gh_add_labels, has_required_labels, LABEL_ERR_MSG
from label_utils import (
gh_add_labels,
gh_remove_label,
has_required_labels,
LABEL_ERR_MSG,
)
from trymerge_explainer import get_revert_message, TryMergeExplainer

# labels
MERGE_IN_PROGRESS_LABEL = "merging"
MERGE_COMPLETE_LABEL = "merged"


class JobCheckState(NamedTuple):
name: str
Expand Down Expand Up @@ -1031,12 +1040,18 @@ def gen_commit_message(self, filter_ghstack: bool = False) -> str:
return msg

def add_numbered_label(self, label_base: str) -> None:
labels = self.get_labels()
label = label_base
for i in range(len(labels) if labels is not None else 0):
if label in labels:
label = f"{label_base}X{i+2}"
gh_add_labels(self.org, self.project, self.pr_num, [label])
labels = self.get_labels() if self.labels is not None else []
full_label = label_base
count = 0
for label in labels:
if label_base in label:
count += 1
full_label = f"{label_base}X{count}"
gh_add_labels(self.org, self.project, self.pr_num, [full_label])

def remove_label(self, label: str) -> None:
if self.get_labels() is not None and label in self.get_labels():
gh_remove_label(self.org, self.project, self.pr_num, label)

def merge_into(
self,
Expand All @@ -1061,9 +1076,9 @@ def merge_into(

repo.push(self.default_branch(), dry_run)
if not dry_run:
self.add_numbered_label("merged")
self.add_numbered_label(MERGE_COMPLETE_LABEL)
for pr in additional_merged_prs:
pr.add_numbered_label("merged")
pr.add_numbered_label(MERGE_COMPLETE_LABEL)

if comment_id and self.pr_num:
# When the merge process reaches this part, we can assume that the commit
Expand Down Expand Up @@ -1751,6 +1766,9 @@ def merge(
initial_commit_sha = pr.last_commit()["oid"]
print(f"Attempting merge of {initial_commit_sha}")

if MERGE_IN_PROGRESS_LABEL not in pr.get_labels():
gh_add_labels(org, project, pr_num, [MERGE_IN_PROGRESS_LABEL])

explainer = TryMergeExplainer(
skip_mandatory_checks, pr.get_labels(), pr.pr_num, org, project, ignore_current
)
Expand Down Expand Up @@ -1983,7 +2001,6 @@ def handle_exception(e: Exception, title: str = "Merge failed") -> None:
message += '\nIf those updates are intentional, please add "submodule" keyword to PR title/description.'
gh_post_pr_comment(org, project, args.pr_num, message, dry_run=args.dry_run)
return

try:
merge(
args.pr_num,
Expand Down Expand Up @@ -2020,6 +2037,8 @@ def handle_exception(e: Exception, title: str = "Merge failed") -> None:
)
else:
print("Missing comment ID or PR number, couldn't upload to Rockset")
finally:
pr.remove_label(MERGE_IN_PROGRESS_LABEL)


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions .github/scripts/update_commit_hashes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def parse_args() -> Any:
parser = ArgumentParser("Rebase PR into branch")
parser.add_argument("--repo-name", type=str)
parser.add_argument("--branch", type=str)
parser.add_argument("--pin-folder", type=str)
return parser.parse_args()


Expand Down Expand Up @@ -135,7 +136,7 @@ def main() -> None:
.stdout.decode("utf-8")
.strip()
)
with open(f".github/ci_commit_pins/{args.repo_name}.txt", "r+") as f:
with open(f"{args.pin_folder}/{args.repo_name}.txt", "r+") as f:
old_hash = f.read().strip()
subprocess.run(f"git checkout {old_hash}".split(), cwd=args.repo_name)
f.seek(0)
Expand All @@ -144,7 +145,7 @@ def main() -> None:
if is_newer_hash(hash, old_hash, args.repo_name):
# if there was an update, push to branch
subprocess.run(f"git checkout -b {branch_name}".split())
subprocess.run(f"git add .github/ci_commit_pins/{args.repo_name}.txt".split())
subprocess.run(f"git add {args.pin_folder}/{args.repo_name}.txt".split())
subprocess.run(
"git commit -m".split() + [f"update {args.repo_name} commit hash"]
)
Expand Down
53 changes: 50 additions & 3 deletions .github/workflows/_linux-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ jobs:
steps:
- name: Setup SSH (Click me for login details)
uses: pytorch/test-infra/.github/actions/setup-ssh@main
if: ${{ !contains(matrix.runner, 'gcp.a100') }}
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
instructions: |
Expand Down Expand Up @@ -284,9 +285,7 @@ jobs:
# As both the root cause and recovery path are unclear, let's take the runner out of
# service so that it doesn't get any more jobs
- name: Check NVIDIA driver installation step
if:
failure() &&
((steps.install-nvidia-driver.conclusion && steps.install-nvidia-driver.conclusion == 'failure') || (contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu')))
if: failure() && steps.install-nvidia-driver.outcome && steps.install-nvidia-driver.outcome != 'skipped'
shell: bash
env:
RUNNER_WORKSPACE: ${{ runner.workspace }}
Expand All @@ -295,10 +294,58 @@ jobs:
set -x
nvidia-smi
# NB: Surprisingly, nvidia-smi command returns successfully with return code 0 even in
# the case where the driver has already crashed as it still can get the driver version
# and some basic information like the bus ID. However, the rest of the information
# would be missing (ERR!), for example:
#
# +-----------------------------------------------------------------------------+
# | NVIDIA-SMI 525.89.02 Driver Version: 525.89.02 CUDA Version: 12.0 |
# |-------------------------------+----------------------+----------------------+
# | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
# | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
# | | | MIG M. |
# |===============================+======================+======================|
# | 0 ERR! Off | 00000000:00:1E.0 Off | ERR! |
# |ERR! ERR! ERR! ERR! / ERR! | 4184MiB / 23028MiB | ERR! Default |
# | | | ERR! |
# +-------------------------------+----------------------+----------------------+
#
# +-----------------------------------------------------------------------------+
# | Processes: |
# | GPU GI CI PID Type Process name GPU Memory |
# | ID ID Usage |
# |=============================================================================|
# +-----------------------------------------------------------------------------+
#
# This should be reported as a failure instead as it will guarantee to fail when
# Docker tries to run with --gpus all
#
# So, the correct check here is to query one of the missing piece of info like
# GPU name, so that the command can fail accordingly
nvidia-smi --query-gpu=gpu_name --format=csv,noheader --id=0
NVIDIA_SMI_STATUS=$?
# These are acceptable return code from nvidia-smi as copied from setup-nvidia GitHub action
if [ "$NVIDIA_SMI_STATUS" -ne 0 ] && [ "$NVIDIA_SMI_STATUS" -ne 14 ]; then
echo "NVIDIA driver installation has failed, shutting down the runner..."
.github/scripts/stop_runner_service.sh
fi
# For runner with multiple GPUs, we also want to confirm that the number of GPUs are the
# power of 2, i.e. 1, 2, 4, or 8. This is to avoid flaky test issue when one GPU fails
# https://github.com/pytorch/test-infra/issues/4000
GPU_COUNT=$(nvidia-smi --list-gpus | wc -l)
NVIDIA_SMI_STATUS=$?
# These are acceptable return code from nvidia-smi as copied from setup-nvidia GitHub action
if [ "$NVIDIA_SMI_STATUS" -ne 0 ] && [ "$NVIDIA_SMI_STATUS" -ne 14 ]; then
echo "NVIDIA driver installation has failed, shutting down the runner..."
.github/scripts/stop_runner_service.sh
fi
# Check the GPU count to be a power of 2
if [ "$GPU_COUNT" -le 8 ] && [ "$GPU_COUNT" -ne 1 ] && [ "$GPU_COUNT" -ne 2 ] && [ "$GPU_COUNT" -ne 4 ] && [ "$GPU_COUNT" -ne 8 ]; then
echo "NVIDIA driver detects $GPU_COUNT GPUs. The runner has a broken GPU, shutting it down..."
.github/scripts/stop_runner_service.sh
fi
19 changes: 6 additions & 13 deletions .github/workflows/_mac-test-mps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,15 @@ jobs:
environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }}
pip-requirements-file: .github/requirements/pip-requirements-${{ runner.os }}.txt

- name: Install PyTorch
- name: Install PyTorch and run MPS tests
id: test
env:
ENV_NAME: conda-test-env-${{ github.run_id }}
PY_VERS: 3.9
PR_BODY: ${{ github.event.pull_request.body }}
PYTORCH_RETRY_TEST_CASES: 1
PYTORCH_OVERRIDE_FLAKY_SIGNAL: 1
CONTINUE_THROUGH_ERROR: ${{ needs.filter.outputs.keep-going }}
shell: arch -arch arm64 bash {0}
run: |
# shellcheck disable=SC1090
Expand All @@ -116,18 +121,6 @@ jobs:
ORIG_WHLNAME=$(ls -1 dist/*.whl); ARM_WHLNAME=${ORIG_WHLNAME/x86_64/arm64}; mv ${ORIG_WHLNAME} ${ARM_WHLNAME}
${CONDA_RUN} python3 -mpip install --no-index --no-deps dist/*.whl
- name: Run MPS tests
id: test
env:
ENV_NAME: conda-test-env-${{ github.run_id }}
PR_BODY: ${{ github.event.pull_request.body }}
PYTORCH_RETRY_TEST_CASES: 1
PYTORCH_OVERRIDE_FLAKY_SIGNAL: 1
CONTINUE_THROUGH_ERROR: ${{ needs.filter.outputs.keep-going }}
shell: arch -arch arm64 bash {0}
run: |
# shellcheck disable=SC1090
set -ex
${CONDA_RUN} python3 test/run_test.py --mps --verbose
- name: Print remaining test logs
Expand Down
19 changes: 15 additions & 4 deletions .github/workflows/_update-commit-hash.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,24 @@ name: update-commit-hash
on:
workflow_call:
inputs:
repo-owner:
required: false
type: string
description: Name of repository's owner.
default: pytorch
repo-name:
required: true
type: string
description: Name of the repository we're updating commit hash for. Must be in pytorch org.
description: Name of the repository we're updating commit hash for.
branch:
required: true
type: string
description: Branch to fetch commit of
pin-folder:
type: string
description: Path to folder with commit pin
required: false
default: .github/ci_commit_pins
secrets:
MERGEBOT_TOKEN:
required: true
Expand All @@ -23,7 +33,7 @@ env:
NEW_BRANCH_NAME: update-${{ inputs.repo-name }}-commit-hash/${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}

jobs:
update-xla-commit-hash:
update-commit-hash:
runs-on: ubuntu-latest
steps:
- name: Checkout repo
Expand All @@ -36,18 +46,19 @@ jobs:
- name: Checkout
shell: bash
run: |
git clone https://github.com/pytorch/${{ inputs.repo-name }}.git --quiet
git clone https://github.com/${{ inputs.repo-owner }}/${{ inputs.repo-name }}.git --quiet
- name: Check if there already exists a PR
shell: bash
env:
REPO_NAME: ${{ inputs.repo-name }}
BRANCH: ${{ inputs.branch }}
PIN_FOLDER: ${{ inputs.pin-folder }}
MERGEBOT_TOKEN: ${{ secrets.MERGEBOT_TOKEN }}
PYTORCHBOT_TOKEN: ${{ secrets.PYTORCHBOT_TOKEN }}
run: |
# put this here instead of the script to prevent accidentally changing the config when running the script locally
git config --global user.name "PyTorch MergeBot"
git config --global user.email "pytorchmergebot@users.noreply.github.com"
python .github/scripts/update_commit_hashes.py --repo-name "${REPO_NAME}" --branch "${BRANCH}"
python .github/scripts/update_commit_hashes.py --repo-name "${REPO_NAME}" --branch "${BRANCH}" --pin-folder "${PIN_FOLDER}"
2 changes: 2 additions & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ jobs:
{ config: "docs_test", shard: 1, num_shards: 1, runner: "linux.2xlarge" },
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.2xlarge" },
{ config: "backwards_compat", shard: 1, num_shards: 1, runner: "linux.2xlarge" },
{ config: "distributed", shard: 1, num_shards: 2, runner: "linux.2xlarge" },
{ config: "distributed", shard: 2, num_shards: 2, runner: "linux.2xlarge" },
]}
linux-focal-py3_8-gcc7-test:
Expand Down
21 changes: 0 additions & 21 deletions .github/workflows/unstable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,6 @@ jobs:
echo "Once the jobs are deemed stable enough (% red signal < 5% and TTS < 3h),"
echo " they can graduate and move back to pull or trunk."
linux-focal-py3_8-gcc7-build:
name: linux-focal-py3.8-gcc7
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-focal-py3.8-gcc7
docker-image-name: pytorch-linux-focal-py3.8-gcc7
test-matrix: |
{ include: [
{ config: "distributed", shard: 1, num_shards: 2, runner: "linux.2xlarge" },
{ config: "distributed", shard: 2, num_shards: 2, runner: "linux.2xlarge" },
]}
linux-focal-py3_8-gcc7-test:
name: linux-focal-py3.8-gcc7
uses: ./.github/workflows/_linux-test.yml
needs: linux-focal-py3_8-gcc7-build
with:
build-environment: linux-focal-py3.8-gcc7
docker-image: ${{ needs.linux-focal-py3_8-gcc7-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-py3_8-gcc7-build.outputs.test-matrix }}

linux-vulkan-bionic-py3_11-clang9-build:
name: linux-vulkan-bionic-py3.11-clang9
uses: ./.github/workflows/_linux-build.yml
Expand Down

0 comments on commit 75c89eb

Please sign in to comment.