diff --git a/.circleci/build.sh b/.circleci/build.sh index 8891e2201ac..04a093109d7 100755 --- a/.circleci/build.sh +++ b/.circleci/build.sh @@ -34,7 +34,7 @@ pip install ninja pip install lark-parser # Install Pytorch -patch -p1 < xla/pytorch.patch +xla/scripts/apply_patches.sh python setup.py build develop # Bazel doesn't work with sccache gcc. https://github.com/bazelbuild/bazel/issues/3642 diff --git a/README.md b/README.md index 12ebb5960bf..1720b6e428e 100644 --- a/README.md +++ b/README.md @@ -21,10 +21,10 @@ To build: git checkout $(cat xla/.torch_commit_id) ``` -* Apply the `pytorch.patch` to the current `xla` code. From within the _pytorch_ source folder: +* Apply PyTorch patches: ``` - patch -p1 < xla/pytorch.patch + xla/scripts/apply_patches.sh ``` * Install the Lark parser used for automatic code generation: diff --git a/kokoro/ubuntu/common.sh b/kokoro/ubuntu/common.sh index 627d1623250..dc1d8e48990 100755 --- a/kokoro/ubuntu/common.sh +++ b/kokoro/ubuntu/common.sh @@ -62,7 +62,7 @@ cd pytorch # TODO(jysohn): remove following patching once pytorch JIT bug is fixed git checkout $(cat xla/.torch_commit_id) -git apply xla/pytorch.patch +xla/scripts/apply_patches.sh # Build and install torch wheel and collect artifact export NO_CUDA=1 python setup.py bdist_wheel diff --git a/scripts/apply_patches.sh b/scripts/apply_patches.sh new file mode 100755 index 00000000000..f3cdbe18b26 --- /dev/null +++ b/scripts/apply_patches.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +CDIR=$(dirname $0) +XDIR=$CDIR/.. +PTDIR=$XDIR/.. + +python $CDIR/cond_patch.py \ + $XDIR/torch_patches \ + $PTDIR diff --git a/scripts/cond_patch.py b/scripts/cond_patch.py new file mode 100755 index 00000000000..deb297e28a9 --- /dev/null +++ b/scripts/cond_patch.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python + +from __future__ import print_function + +import argparse +import glob +import os +import re +import subprocess +import sys + + +def get_log(repo_folder, depth): + return subprocess.check_output( + ['git', '-C', repo_folder, 'log', '-{}'.format(depth)]).decode('utf-8') + + +def is_applied(log, revno): + revrx = 'Pull Request resolved: .*[/#]{}'.format(revno) + return re.search(revrx, log) + + +def select_patches(patch_folder, repo_folder, depth): + log = get_log(repo_folder, depth) + files = sorted(glob.glob(os.path.join(patch_folder, '*.diff'))) + selected = [] + for ppath in files: + revno = os.path.splitext(os.path.basename(ppath))[0] + # Patches which are not all digits (PR numbers) are always applied. + if not re.match(r'\d+$', revno) or not is_applied(log, revno): + selected.append(ppath) + return selected + + +def apply_patch(ppath, repo_folder, level): + return subprocess.call([ + 'patch', '-d', repo_folder, '-p{}'.format(level), '-i', ppath, '-E', '-l', + '-r', '-', '-s', '--no-backup-if-mismatch' + ]) + + +def patch_repo(args): + patches = select_patches( + os.path.normpath(args.patch_folder), os.path.normpath(args.repo_folder), + args.log_depth) + for ppath in patches: + print('Applying patch file: {}'.format(ppath), file=sys.stderr) + apply_patch(ppath, os.path.normpath(args.repo_folder), args.level) + + +if __name__ == '__main__': + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument('--level', type=int, default=1) + arg_parser.add_argument('--log_depth', type=int, default=1000) + arg_parser.add_argument( + 'patch_folder', + type=str, + metavar='PATCH_FOLDER', + help='The path to the folder containing the patches') + arg_parser.add_argument( + 'repo_folder', + type=str, + metavar='REPO_FOLDER', + help='The path to the root folder of the repo to be patched') + args, files = arg_parser.parse_known_args() + patch_repo(args) diff --git a/torch_patches/.gitignore b/torch_patches/.gitignore new file mode 100644 index 00000000000..e69de29bb2d diff --git a/torch_patches/README.md b/torch_patches/README.md new file mode 100644 index 00000000000..694716d6b1e --- /dev/null +++ b/torch_patches/README.md @@ -0,0 +1,17 @@ +# Guidelines For Patch File Names + +The only files which are considered by the apply script are the ones with extension '.diff'. + +A file for PyTorch PR _N_ needs to be named 'N.diff'. + +Patch files which are not related to PyTorch PRs, should begin with an 'X' character, +followed by a two digit number, followed by a dash ('-'), a name, and '.diff'. +Example: + +``` +X10-optimizer.diff +``` + +Patch file are alphabetically ordered, so PyTorch PR patches are always applied +before the non PyTorch ones. + diff --git a/pytorch.patch b/torch_patches/X10-optimizer.diff similarity index 100% rename from pytorch.patch rename to torch_patches/X10-optimizer.diff