From a83dfc841fa06b27a8f095121e6b7bb670e3ec9b Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 26 Mar 2020 17:06:56 -0400 Subject: [PATCH 1/3] Switch XLA to new operator registration API. Signed-off-by: Edward Z. Yang --- scripts/gen.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/scripts/gen.py b/scripts/gen.py index 8e8eaa298e7d..99992c1cd123 100755 --- a/scripts/gen.py +++ b/scripts/gen.py @@ -949,7 +949,7 @@ def parse_local_overrides(path): def generate_registrations(fgens, overrides): code = 'void RegisterAtenTypeFunctions() {\n' - code += ' static auto dispatch = torch::RegisterOperators()\n' + code += ' static auto dispatch = torch::import()\n' overridden = set() for fgen in fgens: if not is_overrideable(fgen): @@ -962,11 +962,9 @@ def generate_registrations(fgens, overrides): override_fn = fgen.xfunc if fgen.code else None if override_fn: code += ( - ' .op(torch::RegisterOperators::options().schema("{}")\n ' - '.impl_unboxedOnlyKernel<{}, &{}>(at::DispatchKey::XLATensorId)\n' - ' .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))\n'.format( - fgen.aten_sig, fgen.funsig, override_fn, override_fn, - fgen.aten_sig)) + ' .impl("{}", torch::dispatch(at::DispatchKey::XLATensorId, ' + 'CppFunction::makeUnboxedOnly(static_cast<{}>({})))\n'.format( + fgen.aten_sig.split("(")[0], fgen.funsig, override_fn)) return code + ';\n}\n', overridden From 1fe0feb47918e799127a1cea9f2e70fe02769bbf Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Thu, 26 Mar 2020 16:20:56 -0700 Subject: [PATCH 2/3] Fix overload address issue. --- scripts/gen.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/gen.py b/scripts/gen.py index 99992c1cd123..3c085f286fed 100755 --- a/scripts/gen.py +++ b/scripts/gen.py @@ -961,10 +961,12 @@ def generate_registrations(fgens, overrides): else: override_fn = fgen.xfunc if fgen.code else None if override_fn: + pos = fgen.funsig.find('(') + overload = fgen.funsig[:pos] + ' (*)' + fgen.funsig[pos:] code += ( ' .impl("{}", torch::dispatch(at::DispatchKey::XLATensorId, ' - 'CppFunction::makeUnboxedOnly(static_cast<{}>({})))\n'.format( - fgen.aten_sig.split("(")[0], fgen.funsig, override_fn)) + 'at::CppFunction::makeUnboxedOnly(static_cast<{}>(&{}))))\n'.format( + fgen.aten_sig.split("(")[0], overload, override_fn)) return code + ';\n}\n', overridden From 08a755e4b0921b60013e8b953713c6d9eb273197 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Fri, 27 Mar 2020 00:48:19 +0000 Subject: [PATCH 3/3] Provide an easy way to test against PT PRs. --- scripts/apply_patches.sh | 27 ++++++++++++++++++++++++++- scripts/gen.py | 12 ++++++------ torch_patches/README.md | 16 +++++++++++++++- 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/scripts/apply_patches.sh b/scripts/apply_patches.sh index 8e92a2121b84..177b452319a9 100755 --- a/scripts/apply_patches.sh +++ b/scripts/apply_patches.sh @@ -6,6 +6,31 @@ CDIR="$(cd "$(dirname "$0")" ; pwd -P)" XDIR=$CDIR/.. PTDIR=$XDIR/.. +TORCH_PIN="$XDIR/torch_patches/.torch_pin" +if [ -f "$TORCH_PIN" ]; then + CID=$(cat "$TORCH_PIN") + # If starts with # and it's not merged into master, fetch from origin + if [[ $CID = \#* ]]; then + PRNUM="${CID//[!0-9]/}" + set +x + MCHECK=$(git -C $PTDIR log -1000) + if [[ $MCHECK != *"Pull Request resolved: https://github.com/pytorch/pytorch/pull/$PRNUM"* ]]; then + echo "Fetching PyTorch PR #$PRNUM" + pushd "$PTDIR" + git fetch origin "pull/$PRNUM/head:$PRNUM" + git checkout "$PRNUM" + popd + fi + set -x + elif [[ "$CID" != "" ]]; then + echo 'Checking out branch $CID' + pushd "$PTDIR" + git fetch origin "$CID" + git checkout "$CID" + popd + fi +fi + python $CDIR/cond_patch.py \ $XDIR/torch_patches \ - $PTDIR + $PTDIR \ No newline at end of file diff --git a/scripts/gen.py b/scripts/gen.py index 3c085f286fed..8e8eaa298e7d 100755 --- a/scripts/gen.py +++ b/scripts/gen.py @@ -949,7 +949,7 @@ def parse_local_overrides(path): def generate_registrations(fgens, overrides): code = 'void RegisterAtenTypeFunctions() {\n' - code += ' static auto dispatch = torch::import()\n' + code += ' static auto dispatch = torch::RegisterOperators()\n' overridden = set() for fgen in fgens: if not is_overrideable(fgen): @@ -961,12 +961,12 @@ def generate_registrations(fgens, overrides): else: override_fn = fgen.xfunc if fgen.code else None if override_fn: - pos = fgen.funsig.find('(') - overload = fgen.funsig[:pos] + ' (*)' + fgen.funsig[pos:] code += ( - ' .impl("{}", torch::dispatch(at::DispatchKey::XLATensorId, ' - 'at::CppFunction::makeUnboxedOnly(static_cast<{}>(&{}))))\n'.format( - fgen.aten_sig.split("(")[0], overload, override_fn)) + ' .op(torch::RegisterOperators::options().schema("{}")\n ' + '.impl_unboxedOnlyKernel<{}, &{}>(at::DispatchKey::XLATensorId)\n' + ' .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))\n'.format( + fgen.aten_sig, fgen.funsig, override_fn, override_fn, + fgen.aten_sig)) return code + ';\n}\n', overridden diff --git a/torch_patches/README.md b/torch_patches/README.md index 694716d6b1ee..78697bc23cd6 100644 --- a/torch_patches/README.md +++ b/torch_patches/README.md @@ -1,6 +1,6 @@ # Guidelines For Patch File Names -The only files which are considered by the apply script are the ones with extension '.diff'. +Files with extension '.diff' are consider as git patches by apply script. A file for PyTorch PR _N_ needs to be named 'N.diff'. @@ -15,3 +15,17 @@ X10-optimizer.diff Patch file are alphabetically ordered, so PyTorch PR patches are always applied before the non PyTorch ones. + +There's a special file `torch_patches/.torch_pin`, which is used to coordinate landing PRs in +`pytorch/pytorch` and `pytorch/xla`. + +To test a `pytorch/xla` PR against a `pytorch/pytorch` PR or branch, +put the PR number or branch name in this file. +Example: + +``` +#32451 +# or +my_awesome_branch # (must live in `pytorch/pytorch`) +``` +