diff --git a/utils/update_checkout/update_checkout/update_checkout.py b/utils/update_checkout/update_checkout/update_checkout.py index ea324e6886701..8d104a2a941f9 100755 --- a/utils/update_checkout/update_checkout/update_checkout.py +++ b/utils/update_checkout/update_checkout/update_checkout.py @@ -75,6 +75,20 @@ def check_parallel_results(results, op): def confirm_tag_in_repo(tag, repo_name): + # type: (str, str) -> str | None + """Confirm that a given tag exists in a git repository. This function + assumes that the repository is already a current working directory before + it's called. + + Args: + tag (str): tag to look up in the repository + repo_name (str): name the repository for the look up, used for logging + + Returns: + str | None: returns `tag` argument value or `None` if the tag doesn't + exist. + """ + tag_exists = shell.capture(['git', 'ls-remote', '--tags', 'origin', tag], echo=False) if not tag_exists: @@ -97,6 +111,21 @@ def find_rev_by_timestamp(timestamp, repo_name, refspec): def get_branch_for_repo(config, repo_name, scheme_name, scheme_map, cross_repos_pr): + """Infer, fetch, and return a branch corresponding to a given PR, otherwise + return a branch found in the config for this repository name. + + Args: + config (Dict[str, Any]): deserialized `update-checkout-config.json` + repo_name (str): name of the repository for checking out the branch + scheme_name (str): name of the scheme to look up in the config + scheme_map (Dict[str, str]): map of repo names to branches to check out + cross_repos_pr (Dict[str, str]): map of repo ids to PRs to check out + + Returns: + Tuple[str, bool]: a pair of a checked out branch and a boolean + indicating whether this repo matched any `cross_repos_pr`. + """ + cross_repo = False repo_branch = scheme_name if scheme_map: @@ -240,17 +269,39 @@ def update_single_repository(pool_args): return value -def get_timestamp_to_match(args): - if not args.match_timestamp: +def get_timestamp_to_match(match_timestamp, source_root): + # type: (str | None, str) -> str | None + """Computes a timestamp of the last commit on the current branch in + the `swift` repository. + + Args: + match_timestamp (str | None): value of `--match-timestamp` to check. + source_root (str): directory that contains sources of the Swift project. + + Returns: + str | None: a timestamp of the last commit of `swift` repository if + `match_timestamp` argument has a value, `None` if `match_timestamp` is + falsy. + """ + if not match_timestamp: return None - with shell.pushd(os.path.join(args.source_root, "swift"), + with shell.pushd(os.path.join(source_root, "swift"), dry_run=False, echo=False): return shell.capture(["git", "log", "-1", "--format=%cI"], echo=False).strip() -def update_all_repositories(args, config, scheme_name, cross_repos_pr): - scheme_map = None +def get_scheme_map(config, scheme_name): + """Find a mapping from repository IDs to branches in the config. + + Args: + config (Dict[str, Any]): deserialized `update-checkout-config.json` + scheme_name (str): name of the scheme to look up in `config` + + Returns: + Dict[str, str]: a mapping from repos to branches for the given scheme. + """ + if scheme_name: # This loop is only correct, since we know that each alias set has # unique contents. This is checked by validate_config. Thus the first @@ -258,10 +309,14 @@ def update_all_repositories(args, config, scheme_name, cross_repos_pr): # the only possible correct answer. for v in config['branch-schemes'].values(): if scheme_name in v['aliases']: - scheme_map = v['repos'] - break + return v['repos'] + + return None + + +def update_all_repositories(args, config, scheme_name, scheme_map, cross_repos_pr): pool_args = [] - timestamp = get_timestamp_to_match(args) + timestamp = get_timestamp_to_match(args.match_timestamp, args.source_root) for repo_name in config['repos'].keys(): if repo_name in args.skip_repository_list: print("Skipping update of '" + repo_name + "', requested by user") @@ -471,6 +526,18 @@ def full_target_name(repository, target): def skip_list_for_platform(config, all_repos): + """Computes a list of repositories to skip when updating or cloning, if not + overriden by `--all-repositories` CLI argument. + + Args: + config (Dict[str, Any]): deserialized `update-checkout-config.json` + all_repos (List[str]): repositories not required for current platform. + + Returns: + List[str]: a resulting list of repositories to skip or empty list if + `all_repos` is not empty. + """ + if all_repos: return [] # Do not skip any platform-specific repositories @@ -598,7 +665,7 @@ def main(): clone_with_ssh = args.clone_with_ssh skip_history = args.skip_history skip_tags = args.skip_tags - scheme = args.scheme + scheme_name = args.scheme github_comment = args.github_comment all_repos = args.all_repositories @@ -606,14 +673,6 @@ def main(): config = json.load(f) validate_config(config) - if args.dump_hashes: - dump_repo_hashes(args, config) - return (None, None) - - if args.dump_hashes_config: - dump_repo_hashes(args, config, args.dump_hashes_config) - return (None, None) - cross_repos_pr = {} if github_comment: regex_pr = r'(apple/[-a-zA-Z0-9_]+/pull/\d+|apple/[-a-zA-Z0-9_]+#\d+)' @@ -622,22 +681,51 @@ def main(): repos_with_pr = [pr.replace('/pull/', '#') for pr in repos_with_pr] cross_repos_pr = dict(pr.split('#') for pr in repos_with_pr) + # If branch is None, default to using the default branch alias + # specified by our configuration file. + if scheme_name is None: + scheme_name = config['default-branch-scheme'] + + scheme_map = get_scheme_map(config, scheme_name) + clone_results = None + skip_repo_list = [] if clone or clone_with_ssh: - # If branch is None, default to using the default branch alias - # specified by our configuration file. - if scheme is None: - scheme = config['default-branch-scheme'] - skip_repo_list = skip_list_for_platform(config, all_repos) skip_repo_list.extend(args.skip_repository_list) clone_results = obtain_all_additional_swift_sources(args, config, clone_with_ssh, - scheme, + scheme_name, skip_history, skip_tags, skip_repo_list) + swift_repo_path = os.path.join(args.source_root, 'swift') + if 'swift' not in skip_repo_list and os.path.exists(swift_repo_path): + with shell.pushd(swift_repo_path, dry_run=False, echo=True): + # Check if `swift` repo itself needs to switch to a cross-repo branch. + branch_name, cross_repo = get_branch_for_repo(config, 'swift', + scheme_name, + scheme_map, + cross_repos_pr) + + if cross_repo: + shell.run(['git', 'checkout', branch_name], echo=True, + prefix="[swift] ") + + # Re-read the config after checkout. + with open(args.config) as f: + config = json.load(f) + validate_config(config) + + if args.dump_hashes: + dump_repo_hashes(args, config) + return (None, None) + + if args.dump_hashes_config: + dump_repo_hashes(args, config, args.dump_hashes_config) + return (None, None) + # Quick check whether somebody is calling update in an empty directory directory_contents = os.listdir(args.source_root) if not ('cmark' in directory_contents or @@ -646,8 +734,8 @@ def main(): print("You don't have all swift sources. " "Call this script with --clone to get them.") - update_results = update_all_repositories(args, config, scheme, - cross_repos_pr) + update_results = update_all_repositories(args, config, scheme_name, + scheme_map, cross_repos_pr) fail_count = 0 fail_count += check_parallel_results(clone_results, "CLONE") fail_count += check_parallel_results(update_results, "UPDATE")