diff --git a/src/stacky/stacky.py b/src/stacky/stacky.py index 2e5a1f1..d3404d2 100755 --- a/src/stacky/stacky.py +++ b/src/stacky/stacky.py @@ -29,6 +29,7 @@ import shlex import subprocess import sys +import time from argparse import ArgumentParser from typing import List, Optional @@ -38,6 +39,8 @@ _LOGGING_FORMAT = "%(asctime)s %(module)s %(levelname)s: %(message)s" +# 2 minutes ought to be enough for anybody ;-) +MAX_SSH_MUX_LIFETIME = 120 COLOR_STDOUT = os.isatty(1) COLOR_STDERR = os.isatty(2) IS_TERMINAL = os.isatty(1) and os.isatty(2) @@ -61,6 +64,7 @@ class StackyConfig: skip_confirm: bool = False change_to_main: bool = False change_to_adopted: bool = False + share_ssh_session: bool = False def read_one_config(self, config_path: str): rawconfig = configparser.ConfigParser() @@ -75,6 +79,9 @@ def read_one_config(self, config_path: str): self.change_to_adopted = rawconfig.get( "UI", "change_to_adopted", fallback=self.change_to_adopted ) + self.share_ssh_session = rawconfig.get( + "UI", "share_ssh_session", fallback=self.share_ssh_session + ) def read_config() -> StackyConfig: @@ -123,7 +130,21 @@ def __init__(self, fmt, *args, **kwargs): super().__init__(fmt.format(*args, **kwargs)) +def stop_muxed_ssh(remote: str = "origin"): + if CONFIG.share_ssh_session: + hostish = get_remote_type(remote) + if hostish is not None: + cmd = gen_ssh_mux_cmd() + cmd.append("-O") + cmd.append("exit") + cmd.append(hostish) + subprocess.Popen(cmd, stderr=subprocess.DEVNULL) + + def die(*args, **kwargs): + # We are taking a wild guess at what is the remote ... + # TODO (mpatou) fix the assumption about the remote + stop_muxed_ssh() raise ExitException(*args, **kwargs) @@ -715,6 +736,8 @@ def create_gh_pr(b): def do_push(forest, *, force=False, pr=False): + remote = "origin" + start_muxed_ssh(remote) if pr: load_pr_info_for_forest(forest) print_forest(forest) @@ -810,6 +833,8 @@ def do_push(forest, *, force=False, pr=False): elif pr_action == PR_CREATE: create_gh_pr(b) + stop_muxed_ssh(remote) + def cmd_stack_push(stack, args): do_push(get_current_stack_as_forest(stack), force=args.force, pr=args.pr) @@ -1015,34 +1040,57 @@ def get_bottom_level_branches_as_forest(stack): ] -def cmd_update(stack, args): - remote = "origin" - info("Fetching from {}", remote) - run(["git", "fetch", remote]) - - # TODO(tudor): We should rebase instead of silently dropping - # everything you have on local master. Oh well. - global CURRENT_BRANCH - for b in stack.bottoms: - run( - [ - "git", - "update-ref", - "refs/heads/{}".format(b.name), - "refs/remotes/{}/{}".format(remote, b.remote_branch), - ] +def get_remote_type(remote: str = "origin") -> Optional[str]: + out = run(["git", "remote", "-v"]) + for l in out.split("\n"): + match = re.match( + r"^{}\s+(?:ssh://)?([^/]*):(?!//).*\s+\(push\)$".format(remote), l ) - if b.name == CURRENT_BRANCH: - run(["git", "reset", "--hard", "HEAD"]) + if match: + sshish_host = match.group(1) + return sshish_host + + +def gen_ssh_mux_cmd() -> List[str]: + args = [ + "ssh", + "-o", + "ControlMaster=auto", + "-o", + f"ControlPersist={MAX_SSH_MUX_LIFETIME}", + "-o", + "ControlPath=~/.ssh/stacky-%C", + ] - # We treat origin as the source of truth for bottom branches (master), and - # the local repo as the source of truth for everything else. So we can only - # track PR closure for branches that are direct descendants of master. + return args + + +def start_muxed_ssh(remote: str = "origin"): + if not CONFIG.share_ssh_session: + return + hostish = get_remote_type(remote) + if hostish is not None: + info("Creating a muxed ssh connection") + cmd = gen_ssh_mux_cmd() + os.environ["GIT_SSH_COMMAND"] = " ".join(cmd) + cmd.append("-MNf") + cmd.append(hostish) + # We don't want to use the run() wrapper because + # we don't want to wait for the process to finish + + p = subprocess.Popen(cmd, stderr=subprocess.PIPE) + # Wait a little bit for the connection to establish + # before carrying on + while p.poll() is None: + time.sleep(1) + if p.returncode != 0: + error = p.stderr.read() + die( + f"Failed to start ssh muxed connection, error was: {error.decode('utf-8').strip()}" + ) - info("Checking if any PRs have been merged and can be deleted") - forest = get_bottom_level_branches_as_forest(stack) - load_pr_info_for_forest(forest) +def get_branches_to_delete(forest): deletes = [] for b in depth_first(forest): if not b.parent or b.open_pr_info: @@ -1064,10 +1112,11 @@ def cmd_update(stack, args): b.parent.name, ) break + return deletes - if deletes and not args.force: - confirm() +def delete_branches(stack, deletes): + global CURRENT_BRANCH # Make sure we're not trying to delete the current branch for b in deletes: for c in b.children: @@ -1083,6 +1132,43 @@ def cmd_update(stack, args): run(["git", "branch", "-D", b.name]) +def cmd_update(stack, args): + remote = "origin" + start_muxed_ssh(remote) + info("Fetching from {}", remote) + run(["git", "fetch", remote]) + + # TODO(tudor): We should rebase instead of silently dropping + # everything you have on local master. Oh well. + global CURRENT_BRANCH + for b in stack.bottoms: + run( + [ + "git", + "update-ref", + "refs/heads/{}".format(b.name), + "refs/remotes/{}/{}".format(remote, b.remote_branch), + ] + ) + if b.name == CURRENT_BRANCH: + run(["git", "reset", "--hard", "HEAD"]) + + # We treat origin as the source of truth for bottom branches (master), and + # the local repo as the source of truth for everything else. So we can only + # track PR closure for branches that are direct descendants of master. + + info("Checking if any PRs have been merged and can be deleted") + forest = get_bottom_level_branches_as_forest(stack) + load_pr_info_for_forest(forest) + + deletes = get_branches_to_delete(forest) + if deletes and not args.force: + confirm() + + delete_branches(stack, deletes) + stop_muxed_ssh(remote) + + def cmd_import(stack, args): # Importing has to happen based on PR info, rather than local branch # relationships, as that's the only place Graphite populates.