Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 129 additions & 29 deletions src/stacky/stacky.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import shlex
import subprocess
import sys
import time
from argparse import ArgumentParser
from typing import List, Optional

Expand All @@ -38,6 +39,8 @@

_LOGGING_FORMAT = "%(asctime)s %(module)s %(levelname)s: %(message)s"

# 2 minutes ought to be enough for anybody ;-)
MAX_SSH_MUX_LIFETIME = 120
COLOR_STDOUT = os.isatty(1)
COLOR_STDERR = os.isatty(2)
IS_TERMINAL = os.isatty(1) and os.isatty(2)
Expand All @@ -61,6 +64,7 @@ class StackyConfig:
skip_confirm: bool = False
change_to_main: bool = False
change_to_adopted: bool = False
share_ssh_session: bool = False

def read_one_config(self, config_path: str):
rawconfig = configparser.ConfigParser()
Expand All @@ -75,6 +79,9 @@ def read_one_config(self, config_path: str):
self.change_to_adopted = rawconfig.get(
"UI", "change_to_adopted", fallback=self.change_to_adopted
)
self.share_ssh_session = rawconfig.get(
"UI", "share_ssh_session", fallback=self.share_ssh_session
)


def read_config() -> StackyConfig:
Expand Down Expand Up @@ -123,7 +130,21 @@ def __init__(self, fmt, *args, **kwargs):
super().__init__(fmt.format(*args, **kwargs))


def stop_muxed_ssh(remote: str = "origin"):
if CONFIG.share_ssh_session:
hostish = get_remote_type(remote)
if hostish is not None:
cmd = gen_ssh_mux_cmd()
cmd.append("-O")
cmd.append("exit")
cmd.append(hostish)
subprocess.Popen(cmd, stderr=subprocess.DEVNULL)


def die(*args, **kwargs):
# We are taking a wild guess at what is the remote ...
# TODO (mpatou) fix the assumption about the remote
stop_muxed_ssh()
raise ExitException(*args, **kwargs)


Expand Down Expand Up @@ -726,6 +747,7 @@ def create_gh_pr(b, prefix):


def do_push(forest, *, force=False, pr=False, remote_name="origin"):
start_muxed_ssh(remote_name)
if pr:
load_pr_info_for_forest(forest)
print_forest(forest)
Expand Down Expand Up @@ -833,9 +855,16 @@ def do_push(forest, *, force=False, pr=False, remote_name="origin"):
elif pr_action == PR_CREATE:
create_gh_pr(b, prefix)

stop_muxed_ssh(remote_name)


def cmd_stack_push(stack, args):
do_push(get_current_stack_as_forest(stack), force=args.force, pr=args.pr)
do_push(
get_current_stack_as_forest(stack),
force=args.force,
pr=args.pr,
remote_name=args.remote_name,
)


def do_sync(forest):
Expand Down Expand Up @@ -980,7 +1009,12 @@ def cmd_upstack_info(stack, args):


def cmd_upstack_push(stack, args):
do_push(get_current_upstack_as_forest(stack), force=args.force, pr=args.pr)
do_push(
get_current_upstack_as_forest(stack),
force=args.force,
pr=args.pr,
remote_name=args.remote_name,
)


def cmd_upstack_sync(stack, args):
Expand Down Expand Up @@ -1024,7 +1058,12 @@ def cmd_downstack_info(stack, args):


def cmd_downstack_push(stack, args):
do_push(get_current_downstack_as_forest(stack), force=args.force, pr=args.pr)
do_push(
get_current_downstack_as_forest(stack),
force=args.force,
pr=args.pr,
remote_name=args.remote_name,
)


def cmd_downstack_sync(stack, args):
Expand All @@ -1038,34 +1077,57 @@ def get_bottom_level_branches_as_forest(stack):
]


def cmd_update(stack, args):
remote = "origin"
info("Fetching from {}", remote)
run(["git", "fetch", remote])

# TODO(tudor): We should rebase instead of silently dropping
# everything you have on local master. Oh well.
global CURRENT_BRANCH
for b in stack.bottoms:
run(
[
"git",
"update-ref",
"refs/heads/{}".format(b.name),
"refs/remotes/{}/{}".format(remote, b.remote_branch),
]
def get_remote_type(remote: str = "origin") -> Optional[str]:
out = run(["git", "remote", "-v"])
for l in out.split("\n"):
match = re.match(
r"^{}\s+(?:ssh://)?([^/]*):(?!//).*\s+\(push\)$".format(remote), l
)
if b.name == CURRENT_BRANCH:
run(["git", "reset", "--hard", "HEAD"])
if match:
sshish_host = match.group(1)
return sshish_host


def gen_ssh_mux_cmd() -> List[str]:
args = [
"ssh",
"-o",
"ControlMaster=auto",
"-o",
f"ControlPersist={MAX_SSH_MUX_LIFETIME}",
"-o",
"ControlPath=~/.ssh/stacky-%C",
]

# We treat origin as the source of truth for bottom branches (master), and
# the local repo as the source of truth for everything else. So we can only
# track PR closure for branches that are direct descendants of master.
return args

info("Checking if any PRs have been merged and can be deleted")
forest = get_bottom_level_branches_as_forest(stack)
load_pr_info_for_forest(forest)

def start_muxed_ssh(remote: str = "origin"):
if not CONFIG.share_ssh_session:
return
hostish = get_remote_type(remote)
if hostish is not None:
info("Creating a muxed ssh connection")
cmd = gen_ssh_mux_cmd()
os.environ["GIT_SSH_COMMAND"] = " ".join(cmd)
cmd.append("-MNf")
cmd.append(hostish)
# We don't want to use the run() wrapper because
# we don't want to wait for the process to finish

p = subprocess.Popen(cmd, stderr=subprocess.PIPE)
# Wait a little bit for the connection to establish
# before carrying on
while p.poll() is None:
time.sleep(1)
if p.returncode != 0:
error = p.stderr.read()
die(
f"Failed to start ssh muxed connection, error was: {error.decode('utf-8').strip()}"
)


def get_branches_to_delete(forest):
deletes = []
for b in depth_first(forest):
if not b.parent or b.open_pr_info:
Expand All @@ -1087,10 +1149,11 @@ def cmd_update(stack, args):
b.parent.name,
)
break
return deletes

if deletes and not args.force:
confirm()

def delete_branches(stack, deletes):
global CURRENT_BRANCH
# Make sure we're not trying to delete the current branch
for b in deletes:
for c in b.children:
Expand All @@ -1106,6 +1169,43 @@ def cmd_update(stack, args):
run(["git", "branch", "-D", b.name])


def cmd_update(stack, args):
remote = "origin"
start_muxed_ssh(remote)
info("Fetching from {}", remote)
run(["git", "fetch", remote])

# TODO(tudor): We should rebase instead of silently dropping
# everything you have on local master. Oh well.
global CURRENT_BRANCH
for b in stack.bottoms:
run(
[
"git",
"update-ref",
"refs/heads/{}".format(b.name),
"refs/remotes/{}/{}".format(remote, b.remote_branch),
]
)
if b.name == CURRENT_BRANCH:
run(["git", "reset", "--hard", "HEAD"])

# We treat origin as the source of truth for bottom branches (master), and
# the local repo as the source of truth for everything else. So we can only
# track PR closure for branches that are direct descendants of master.

info("Checking if any PRs have been merged and can be deleted")
forest = get_bottom_level_branches_as_forest(stack)
load_pr_info_for_forest(forest)

deletes = get_branches_to_delete(forest)
if deletes and not args.force:
confirm()

delete_branches(stack, deletes)
stop_muxed_ssh(remote)


def cmd_import(stack, args):
# Importing has to happen based on PR info, rather than local branch
# relationships, as that's the only place Graphite populates.
Expand Down