diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c8b8d08 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,8 @@ +[tool.black] +line-length = 120 +target-version = ['py310'] +include = '\.pyi?$' + +[tool.isort] +profile = "black" +line_length = 120 diff --git a/setup.py b/setup.py index 535bf12..62de83d 100644 --- a/setup.py +++ b/setup.py @@ -1,28 +1,29 @@ -from setuptools import setup, find_packages import pathlib +from setuptools import find_packages, setup + here = pathlib.Path(__file__).parent.resolve() # Get the long description from the README file long_description = (here / "README.md").read_text(encoding="utf-8") setup( - name="rockset-stacky", - version="1.0.10", + name="rockset-stacky", + version="1.0.11", description=""" stacky is a tool to manage stacks of PRs. This allows developers to easily manage many smaller, more targeted PRs that depend on each other. - """, - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/rockset/stacky", - author="Rockset", - author_email="tudor@rockset.com", + """, + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/rockset/stacky", + author="Rockset", + author_email="tudor@rockset.com", keywords="github, stack, pr, pull request", - package_dir={"": "src"}, + package_dir={"": "src"}, packages=find_packages(where="src"), python_requires=">=3.8, <4", - install_requires=["asciitree", "ansicolors", "simple-term-menu"], + install_requires=["asciitree", "ansicolors", "simple-term-menu"], entry_points={ "console_scripts": [ "stacky=stacky:main", @@ -32,4 +33,4 @@ "Bug Reports": "https://github.com/rockset/stacky/issues", "Source": "https://github.com/rockset/stacky", }, -) \ No newline at end of file +) diff --git a/src/stacky/stacky.py b/src/stacky/stacky.py index 33704a0..4502e64 100755 --- a/src/stacky/stacky.py +++ b/src/stacky/stacky.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 + # GitHub helper for stacked diffs. # # Git maintains all metadata locally. Does everything by forking "git" and "gh" @@ -31,21 +32,57 @@ import sys import time from argparse import ArgumentParser -from typing import List, Optional +from typing import Dict, FrozenSet, Generator, List, NewType, Optional, Tuple, TypedDict, Union + +import asciitree # type: ignore +import colors # type: ignore +from simple_term_menu import TerminalMenu # type: ignore + +BranchName = NewType("BranchName", str) +PathName = NewType("PathName", str) +Commit = NewType("Commit", str) +CmdArgs = NewType("CmdArgs", List[str]) +StackSubTree = Tuple["StackBranch", "BranchesTree"] +TreeNode = Tuple[BranchName, StackSubTree] +BranchesTree = NewType("BranchesTree", Dict[BranchName, StackSubTree]) +BranchesTreeForest = NewType("BranchesTreeForest", List[BranchesTree]) + +JSON = Union[Dict[str, "JSON"], List["JSON"], str, int, float, bool, None] + + +class PRInfo(TypedDict): + id: str + number: int + state: str + mergeable: str + url: str + title: str + baseRefName: str + headRefName: str + commits: List[Dict[str, str]] + + +@dataclasses.dataclass +class PRInfos: + all: Dict[str, PRInfo] + open: Optional[PRInfo] + + +@dataclasses.dataclass +class BranchNCommit: + branch: BranchName + parent_commit: Optional[str] -import asciitree -import colors -from simple_term_menu import TerminalMenu _LOGGING_FORMAT = "%(asctime)s %(module)s %(levelname)s: %(message)s" # 2 minutes ought to be enough for anybody ;-) MAX_SSH_MUX_LIFETIME = 120 -COLOR_STDOUT = os.isatty(1) -COLOR_STDERR = os.isatty(2) -IS_TERMINAL = os.isatty(1) and os.isatty(2) -CURRENT_BRANCH = None -STACK_BOTTOMS = frozenset(["master", "main"]) +COLOR_STDOUT: bool = os.isatty(1) +COLOR_STDERR: bool = os.isatty(2) +IS_TERMINAL: bool = os.isatty(1) and os.isatty(2) +CURRENT_BRANCH: BranchName +STACK_BOTTOMS: FrozenSet[BranchName] = frozenset([BranchName("master"), BranchName("main")]) STATE_FILE = os.path.expanduser("~/.stacky.state") TMP_STATE_FILE = STATE_FILE + ".tmp" @@ -70,18 +107,13 @@ def read_one_config(self, config_path: str): rawconfig = configparser.ConfigParser() rawconfig.read(config_path) if rawconfig.has_section("UI"): - self.skip_confirm = rawconfig.get( - "UI", "skip_confirm", fallback=self.skip_confirm - ) - self.change_to_main = rawconfig.get( - "UI", "change_to_main", fallback=self.change_to_main - ) - self.change_to_adopted = rawconfig.get( - "UI", "change_to_adopted", fallback=self.change_to_adopted - ) - self.share_ssh_session = rawconfig.get( - "UI", "share_ssh_session", fallback=self.share_ssh_session - ) + self.skip_confirm = bool(rawconfig.get("UI", "skip_confirm", fallback=self.skip_confirm)) + self.change_to_main = bool(rawconfig.get("UI", "change_to_main", fallback=self.change_to_main)) + self.change_to_adopted = bool(rawconfig.get("UI", "change_to_adopted", fallback=self.change_to_adopted)) + self.share_ssh_session = bool(rawconfig.get("UI", "share_ssh_session", fallback=self.share_ssh_session)) + + +CONFIG: StackyConfig def read_config() -> StackyConfig: @@ -96,7 +128,7 @@ def read_config() -> StackyConfig: return config -def fmt(s, *args, color=False, fg=None, bg=None, style=None, **kwargs): +def fmt(s: str, *args, color: bool = False, fg=None, bg=None, style=None, **kwargs) -> str: s = colors.color(s, fg=fg, bg=bg, style=style) if color else s return s.format(*args, **kwargs) @@ -148,7 +180,7 @@ def die(*args, **kwargs): raise ExitException(*args, **kwargs) -def _check_returncode(sp, cmd): +def _check_returncode(sp: subprocess.CompletedProcess, cmd: CmdArgs): rc = sp.returncode if rc == 0: return @@ -159,7 +191,7 @@ def _check_returncode(sp, cmd): die("Exited with status {}: {}. Stderr was:\n{}", rc, shlex.join(cmd), stderr) -def run_multiline(cmd, *, check=True, null=True, out=False): +def run_multiline(cmd: CmdArgs, *, check: bool = True, null: bool = True, out: bool = False) -> Optional[str]: debug("Running: {}", shlex.join(cmd)) sys.stdout.flush() sys.stderr.flush() @@ -178,29 +210,37 @@ def run_multiline(cmd, *, check=True, null=True, out=False): return sp.stdout.decode("UTF-8") -def run(cmd, **kwargs): +def run_always_return(cmd: CmdArgs, **kwargs) -> str: + out = run(cmd, **kwargs) + assert out is not None + return out + + +def run(cmd: CmdArgs, **kwargs) -> Optional[str]: out = run_multiline(cmd, **kwargs) return None if out is None else out.strip() -def remove_prefix(s, prefix): +def remove_prefix(s: str, prefix: str) -> str: if not s.startswith(prefix): die('Invalid string "{}": expected prefix "{}"', s, prefix) return s[len(prefix) :] # noqa: E203 -def get_current_branch(): - return remove_prefix(run(["git", "symbolic-ref", "-q", "HEAD"]), "refs/heads/") +def get_current_branch() -> Optional[BranchName]: + s = run(CmdArgs(["git", "symbolic-ref", "-q", "HEAD"])) + if s is not None: + return BranchName(remove_prefix(s, "refs/heads/")) + return None -def get_all_branches(): - branches = run_multiline( - ["git", "for-each-ref", "--format", "%(refname:short)", "refs/heads"] - ) - return [b for b in branches.split("\n") if b] +def get_all_branches() -> List[BranchName]: + branches = run_multiline(CmdArgs(["git", "for-each-ref", "--format", "%(refname:short)", "refs/heads"])) + assert branches is not None + return [BranchName(b) for b in branches.split("\n") if b] -def get_real_stack_bottom() -> Optional[str]: +def get_real_stack_bottom() -> Optional[BranchName]: # type: ignore [return] """ return the actual stack bottom for this current repo """ @@ -214,28 +254,36 @@ def get_real_stack_bottom() -> Optional[str]: return candiates.pop() -def get_stack_parent_branch(branch): +def get_stack_parent_branch(branch: BranchName) -> Optional[BranchName]: # type: ignore [return] if branch in STACK_BOTTOMS: return None - p = run(["git", "config", "branch.{}.merge".format(branch)], check=False) + p = run(CmdArgs(["git", "config", "branch.{}.merge".format(branch)]), check=False) if p is not None: p = remove_prefix(p, "refs/heads/") - return p + return BranchName(p) + +def get_top_level_dir() -> PathName: + p = run_always_return(CmdArgs(["git", "rev-parse", "--show-toplevel"])) + return PathName(p) -def get_top_level_dir() -> str: - return run(["git", "rev-parse", "--show-toplevel"]) +def get_stack_parent_commit(branch: BranchName) -> Optional[Commit]: # type: ignore [return] + c = run( + CmdArgs(["git", "rev-parse", "refs/stack-parent/{}".format(branch)]), + check=False, + ) -def get_stack_parent_commit(branch): - return run(["git", "rev-parse", "refs/stack-parent/{}".format(branch)], check=False) + if c is not None: + return Commit(c) -def get_commit(branch): - return run(["git", "rev-parse", "refs/heads/{}".format(branch)], check=False) +def get_commit(branch: BranchName) -> Commit: # type: ignore [return] + c = run_always_return(CmdArgs(["git", "rev-parse", "refs/heads/{}".format(branch)]), check=False) + return Commit(c) -def get_pr_info(branch, *, full=False): +def get_pr_info(branch: BranchName, *, full: bool = False) -> PRInfos: fields = [ "id", "number", @@ -248,37 +296,40 @@ def get_pr_info(branch, *, full=False): ] if full: fields += ["commits"] - fields = ",".join(fields) - infos = json.loads( - run( - [ - "gh", - "pr", - "list", - "--json", - fields, - "--state", - "all", - "--head", - branch, - ] + data = json.loads( + run_always_return( + CmdArgs( + [ + "gh", + "pr", + "list", + "--json", + ",".join(fields), + "--state", + "all", + "--head", + branch, + ] + ) ) ) - infos = {info["id"]: info for info in infos} - open_prs = [info for info in infos.values() if info["state"] == "OPEN"] + raw_infos: List[PRInfo] = data + + infos: Dict[str, PRInfo] = {info["id"]: info for info in raw_infos} + open_prs: List[PRInfo] = [info for info in infos.values() if info["state"] == "OPEN"] if len(open_prs) > 1: die( "Branch {} has more than one open PR: {}", branch, ", ".join([str(pr) for pr in open_prs]), - ) - return infos, open_prs[0] if open_prs else None + ) # type: ignore[arg-type] + return PRInfos(infos, open_prs[0] if open_prs else None) # (remote, remote_branch, remote_branch_commit) -def get_remote_info(branch): +def get_remote_info(branch: BranchName) -> Tuple[str, BranchName, Optional[Commit]]: if branch not in STACK_BOTTOMS: - remote = run(["git", "config", "branch.{}.remote".format(branch)], check=False) + remote = run(CmdArgs(["git", "config", "branch.{}.remote".format(branch)]), check=False) if remote != ".": die("Misconfigured branch {}: remote {}", branch, remote) @@ -287,28 +338,33 @@ def get_remote_info(branch): remote_branch = branch remote_commit = run( - ["git", "rev-parse", "refs/remotes/{}/{}".format(remote, remote_branch)], + CmdArgs(["git", "rev-parse", "refs/remotes/{}/{}".format(remote, remote_branch)]), check=False, ) - return (remote, remote_branch, remote_commit) + # TODO(mpatou): do something when remote_commit is none + commit = None + if remote_commit is not None: + commit = Commit(remote_commit) + + return (remote, BranchName(remote_branch), commit) class StackBranch: def __init__( self, - name, - parent, - parent_commit, + name: BranchName, + parent: "StackBranch", + parent_commit: Commit, ): self.name = name self.parent = parent self.parent_commit = parent_commit - self.children = set() + self.children: set["StackBranch"] = set() self.commit = get_commit(name) self.remote, self.remote_branch, self.remote_commit = get_remote_info(name) - self.pr_info = [] - self.open_pr_info = None + self.pr_info: Dict[str, PRInfo] = {} + self.open_pr_info: Optional[PRInfo] = None self._pr_info_loaded = False def is_synced_with_parent(self): @@ -318,21 +374,26 @@ def is_synced_with_remote(self): return self.commit == self.remote_commit def __repr__(self): - return f"StackBranch: {self.name} {len(self.children)}" + return f"StackBranch: {self.name} {len(self.children)} {self.commit}" def load_pr_info(self): if not self._pr_info_loaded: self._pr_info_loaded = True - self.pr_info, self.open_pr_info = get_pr_info(self.name) + pr_infos = get_pr_info(self.name) + # FIXME maybe store the whole object and use it elsewhere + self.pr_info, self.open_pr_info = ( + pr_infos.all, + pr_infos.open, + ) class StackBranchSet: - def __init__(self): - self.stack = {} - self.tops = set() - self.bottoms = set() + def __init__(self: "StackBranchSet"): + self.stack: Dict[BranchName, StackBranch] = {} + self.tops: set[StackBranch] = set() + self.bottoms: set[StackBranch] = set() - def add(self, name, **kwargs) -> StackBranch: + def add(self, name: BranchName, **kwargs) -> StackBranch: if name in self.stack: s = self.stack[name] assert s.name == name @@ -353,44 +414,56 @@ def add(self, name, **kwargs) -> StackBranch: self.tops.add(s) return s - def add_child(self, s, child): + def __repr__(self) -> str: + out = f"StackBranchSet: {self.stack}" + return out + + def add_child(self, s: StackBranch, child: StackBranch): s.children.add(child) self.tops.discard(s) -def load_current_stack(stack, branch, *, check=True): - branches = [] +def load_stack_for_given_branch( + stack: StackBranchSet, branch: BranchName, *, check: bool = True +) -> Tuple[Optional[StackBranch], List[BranchName]]: + """Given a stack of branch and a branch name, + update the stack with all the parents of the specified branch + if the branch is part of an existing stack. + Return also a list of BranchName of all the branch bellow the specified one + """ + branches: List[BranchNCommit] = [] while branch not in STACK_BOTTOMS: parent = get_stack_parent_branch(branch) parent_commit = get_stack_parent_commit(branch) - branches.append((branch, parent_commit)) + branches.append(BranchNCommit(branch, parent_commit)) if not parent or not parent_commit: if check: die("Branch is not in a stack: {}", branch) - return None, [b for b, _ in branches] + return None, [b.branch for b in branches] branch = parent - branches.append((branch, None)) + branches.append(BranchNCommit(branch, None)) top = None - for name, parent_commit in reversed(branches): + for b in reversed(branches): n = stack.add( - name, + b.branch, parent=top, - parent_commit=parent_commit, + parent_commit=b.parent_commit, ) if top: stack.add_child(top, n) top = n - return top, [b for b, _ in branches] + return top, [b.branch for b in branches] -def load_all_stacks(stack): +def load_all_stacks(stack: StackBranchSet) -> Optional[StackBranch]: + """Given a stack return the top of it, aka the bottom of the tree""" all_branches = set(get_all_branches()) current_branch_top = None while all_branches: b = all_branches.pop() - top, branches = load_current_stack(stack, b, check=False) + top, branches = load_stack_for_given_branch(stack, b, check=False) all_branches -= set(branches) if top is None: if len(branches) > 1: @@ -402,28 +475,29 @@ def load_all_stacks(stack): return current_branch_top -def make_tree_node(b): +def make_tree_node(b: StackBranch) -> TreeNode: return (b.name, (b, make_subtree(b))) -def make_subtree(b): - return dict(make_tree_node(c) for c in sorted(b.children, key=lambda x: x.name)) +def make_subtree(b) -> BranchesTree: + return BranchesTree(dict(make_tree_node(c) for c in sorted(b.children, key=lambda x: x.name))) -def make_tree(b): - return dict([make_tree_node(b)]) +def make_tree(b: StackBranch) -> BranchesTree: + return BranchesTree(dict([make_tree_node(b)])) -def format_name(b, *, color=None): +def format_name(b: StackBranch, *, colorize: bool) -> str: prefix = "" severity = 0 + # TODO: Align things so that we have the same prefix length ? if not b.is_synced_with_parent(): - prefix += fmt("!", color=color, fg="yellow") + prefix += fmt("!", color=colorize, fg="yellow") severity = max(severity, 2) if not b.is_synced_with_remote(): - prefix += fmt("~", color=color, fg="yellow") + prefix += fmt("~", color=colorize, fg="yellow") if b.name == CURRENT_BRANCH: - prefix += fmt("*", color=color, fg="cyan") + prefix += fmt("*", color=colorize, fg="cyan") else: severity = max(severity, 1) if prefix: @@ -432,15 +506,15 @@ def format_name(b, *, color=None): suffix = "" if b.open_pr_info: suffix += " " - suffix += fmt("(#{})", b.open_pr_info["number"], color=color, fg="blue") + suffix += fmt("(#{})", b.open_pr_info["number"], color=colorize, fg="blue") suffix += " " - suffix += fmt("{}", b.open_pr_info["title"], color=color, fg="blue") - return prefix + fmt("{}", b.name, color=color, fg=fg) + suffix + suffix += fmt("{}", b.open_pr_info["title"], color=colorize, fg="blue") + return prefix + fmt("{}", b.name, color=colorize, fg=fg) + suffix -def format_tree(tree, *, color=None): +def format_tree(tree: BranchesTree, *, colorize: bool = False): return { - format_name(branch, color=color): format_tree(children, color=color) + format_name(branch, colorize=colorize): format_tree(children, colorize=colorize) for branch, children in tree.values() } @@ -456,46 +530,46 @@ def format_tree(tree, *, color=None): ASCII_TREE = asciitree.LeftAligned(draw=_ASCII_TREE_STYLE) -def print_tree(tree): +def print_tree(tree: BranchesTree): global ASCII_TREE - s = ASCII_TREE(format_tree(tree, color=COLOR_STDOUT)) + s = ASCII_TREE(format_tree(tree, colorize=COLOR_STDOUT)) lines = s.split("\n") print("\n".join(reversed(lines))) -def print_forest(trees): +def print_forest(trees: List[BranchesTree]): for i, t in enumerate(trees): if i != 0: print() print_tree(t) -def get_all_stacks_as_forest(stack): - return [make_tree(b) for b in stack.bottoms] +def get_all_stacks_as_forest(stack: StackBranchSet) -> BranchesTreeForest: + return BranchesTreeForest([make_tree(b) for b in stack.bottoms]) -def get_current_stack_as_forest(stack): +def get_current_stack_as_forest(stack: StackBranchSet): b = stack.stack[CURRENT_BRANCH] - d = make_tree(b) + d: BranchesTree = make_tree(b) b = b.parent while b: - d = {b.name: (b, d)} + d = BranchesTree({b.name: (b, d)}) b = b.parent return [d] -def get_current_upstack_as_forest(stack): +def get_current_upstack_as_forest(stack: StackBranchSet) -> BranchesTreeForest: b = stack.stack[CURRENT_BRANCH] - return [make_tree(b)] + return BranchesTreeForest([make_tree(b)]) -def get_current_downstack_as_forest(stack): +def get_current_downstack_as_forest(stack: StackBranchSet) -> BranchesTreeForest: b = stack.stack[CURRENT_BRANCH] - d = {} + d: BranchesTree = BranchesTree({}) while b: - d = {b.name: (b, d)} + d = BranchesTree({b.name: (b, d)}) b = b.parent - return [d] + return BranchesTreeForest([d]) def init_git(): @@ -509,19 +583,23 @@ def init_git(): CURRENT_BRANCH = get_current_branch() -def depth_first(forest): - if type(forest) == list: - for tree in forest: - for b in depth_first(tree): - yield b - else: - for _, (branch, children) in forest.items(): - yield branch - for b in depth_first(children): - yield b +def forest_depth_first( + forest: BranchesTreeForest, +) -> Generator[StackBranch, None, None]: + for tree in forest: + for b in depth_first(tree): + yield b -def menu_choose_branch(forest): +def depth_first(tree: BranchesTree) -> Generator[StackBranch, None, None]: + # This is for the regular forest + for _, (branch, children) in tree.items(): + yield branch + for b in depth_first(children): + yield b + + +def menu_choose_branch(forest: BranchesTreeForest): if not IS_TERMINAL: die("May only choose from menu when using a terminal") @@ -544,17 +622,17 @@ def menu_choose_branch(forest): if idx is None: die("Aborted") - branches = list(depth_first(forest)) + branches = list(forest_depth_first(forest)) branches.reverse() return branches[idx] -def load_pr_info_for_forest(forest): - for b in depth_first(forest): +def load_pr_info_for_forest(forest: BranchesTreeForest): + for b in forest_depth_first(forest): b.load_pr_info() -def cmd_info(stack, args): +def cmd_info(stack: StackBranchSet, args): forest = get_all_stacks_as_forest(stack) if args.pr: load_pr_info_for_forest(forest) @@ -566,7 +644,7 @@ def checkout(branch): run(["git", "checkout", branch], out=True) -def cmd_branch_up(stack, args): +def cmd_branch_up(stack: StackBranchSet, args): b = stack.stack[CURRENT_BRANCH] if not b.children: info("Branch {} is already at the top of the stack", CURRENT_BRANCH) @@ -584,14 +662,14 @@ def cmd_branch_up(stack, args): len(b.children), fg="green", ) - forest = [{c.name: (c, {})} for c in b.children] + forest = BranchesTreeForest([BranchesTree({BranchName(c.name): (c, BranchesTree({}))}) for c in b.children]) child = menu_choose_branch(forest).name else: child = next(iter(b.children)).name checkout(child) -def cmd_branch_down(stack, args): +def cmd_branch_down(stack: StackBranchSet, args): b = stack.stack[CURRENT_BRANCH] if not b.parent: info("Branch {} is already at the bottom of the stack", CURRENT_BRANCH) @@ -603,15 +681,15 @@ def create_branch(branch): run(["git", "checkout", "-b", branch, "--track"], out=True) -def cmd_branch_new(stack, args): +def cmd_branch_new(stack: StackBranchSet, args): b = stack.stack[CURRENT_BRANCH] assert b.commit name = args.name create_branch(name) - run(["git", "update-ref", "refs/stack-parent/{}".format(name), b.commit, ""]) + run(CmdArgs(["git", "update-ref", "refs/stack-parent/{}".format(name), b.commit, ""])) -def cmd_branch_checkout(stack, args): +def cmd_branch_checkout(stack: StackBranchSet, args): branch_name = args.name if branch_name is None: forest = get_all_stacks_as_forest(stack) @@ -619,14 +697,14 @@ def cmd_branch_checkout(stack, args): checkout(branch_name) -def cmd_stack_info(stack, args): +def cmd_stack_info(stack: StackBranchSet, args): forest = get_current_stack_as_forest(stack) if args.pr: load_pr_info_for_forest(forest) print_forest(forest) -def cmd_stack_checkout(stack, args): +def cmd_stack_checkout(stack: StackBranchSet, args): forest = get_current_stack_as_forest(stack) branch_name = menu_choose_branch(forest).name checkout(branch_name) @@ -647,7 +725,7 @@ def prompt(message: str, default_value: Optional[str]) -> str: return default_value -def confirm(msg="Proceed?"): +def confirm(msg: str = "Proceed?"): if CONFIG.skip_confirm: return if not os.isatty(0): @@ -664,26 +742,29 @@ def confirm(msg="Proceed?"): cout("Please answer yes or no\n", fg="red") -def find_reviewers(b) -> Optional[List[str]]: +def find_reviewers(b: StackBranch) -> Optional[List[str]]: out = run_multiline( - [ - "git", - "log", - "--pretty=format:%b", - "-1", - f"{b.name}", - ], + CmdArgs( + [ + "git", + "log", + "--pretty=format:%b", + "-1", + f"{b.name}", + ] + ), ) + assert out is not None for l in out.split("\n"): reviewer_match = re.match(r"^reviewers?\s*:\s*(.*)", l, re.I) if reviewer_match: reviewers = reviewer_match.group(1).split(",") logging.debug(f"Found the following reviewers: {', '.join(reviewers)}") return reviewers - return + return None -def create_gh_pr(b, prefix): +def create_gh_pr(b: StackBranch, prefix: str): cout("Creating PR for {}\n", b.name, fg="green") parent_prefix = "" if b.parent.name not in STACK_BOTTOMS: @@ -701,34 +782,34 @@ def create_gh_pr(b, prefix): reviewers = find_reviewers(b) if match: out = run_multiline( - ["git", "log", "--pretty=oneline", f"{b.parent.name}..{b.name}"], + CmdArgs(["git", "log", "--pretty=oneline", f"{b.parent.name}..{b.name}"]), ) title = f"[{match.group(1)}] " # Just one line (hence 2 elements with the last one being an empty string when we # split on "\"n ? # Then use the title of the commit as the title of the PR - - if len(out.split("\n")) == 2: + if out is not None and len(out.split("\n")) == 2: out = run( - [ - "git", - "log", - "--pretty=format:%s", - "-1", - f"{b.name}", - ], + CmdArgs( + [ + "git", + "log", + "--pretty=format:%s", + "-1", + f"{b.name}", + ] + ), out=False, ) + if out is None: + out = "" if b.name not in out: title += out else: title = out title = prompt( - ( - fmt("? ", color=COLOR_STDOUT, fg="green") - + fmt("Title ", color=COLOR_STDOUT, style="bold", fg="white") - ), + (fmt("? ", color=COLOR_STDOUT, fg="green") + fmt("Title ", color=COLOR_STDOUT, style="bold", fg="white")), title, ) cmd.extend(["--title", title.strip()]) @@ -741,17 +822,23 @@ def create_gh_pr(b, prefix): cmd.extend(["--reviewer", r]) run( - cmd, + CmdArgs(cmd), out=True, ) -def do_push(forest, *, force=False, pr=False, remote_name="origin"): +def do_push( + forest: BranchesTreeForest, + *, + force: bool = False, + pr: bool = False, + remote_name: str = "origin", +): start_muxed_ssh(remote_name) if pr: load_pr_info_for_forest(forest) print_forest(forest) - for b in depth_first(forest): + for b in forest_depth_first(forest): if not b.is_synced_with_parent(): die( "Branch {} is not synced with parent {}, sync first", @@ -764,7 +851,7 @@ def do_push(forest, *, force=False, pr=False, remote_name="origin"): PR_FIX_BASE = 1 PR_CREATE = 2 actions = [] - for b in depth_first(forest): + for b in forest_depth_first(forest): if not b.parent: cout("✓ Not pushing base branch {}\n", b.name, fg="green") continue @@ -817,12 +904,12 @@ def do_push(forest, *, force=False, pr=False, remote_name="origin"): # Figure out if we need to add a prefix to the branch # ie. user:foo # We should call gh repo set-default before doing that - val = run(["git", "config", f"remote.{remote_name}.gh-resolved"], check=False) + val = run(CmdArgs(["git", "config", f"remote.{remote_name}.gh-resolved"]), check=False) if val is not None and "/" in val: # If there is a "/" in the gh-resolved it means that the repo where # the should be created is not the same as the one where the push will # be made, we need to add a prefix to the branch in the gh pr command - val = run(["git", "config", f"remote.{remote_name}.url"]) + val = run_always_return(CmdArgs(["git", "config", f"remote.{remote_name}.url"])) prefix = f'{val.split(":")[1].split("/")[0]}:' else: prefix = "" @@ -830,26 +917,31 @@ def do_push(forest, *, force=False, pr=False, remote_name="origin"): if push: cout("Pushing {}\n", b.name, fg="green") run( - [ - "git", - "push", - "-f", - b.remote, - "{}:{}".format(b.name, b.remote_branch), - ], + CmdArgs( + [ + "git", + "push", + "-f", + b.remote, + "{}:{}".format(b.name, b.remote_branch), + ] + ), out=True, ) if pr_action == PR_FIX_BASE: cout("Fixing PR base for {}\n", b.name, fg="green") + assert b.open_pr_info is not None run( - [ - "gh", - "pr", - "edit", - str(b.open_pr_info["number"]), - "--base", - b.parent.name, - ], + CmdArgs( + [ + "gh", + "pr", + "edit", + str(b.open_pr_info["number"]), + "--base", + b.parent.name, + ] + ), out=True, ) elif pr_action == PR_CREATE: @@ -858,7 +950,7 @@ def do_push(forest, *, force=False, pr=False, remote_name="origin"): stop_muxed_ssh(remote_name) -def cmd_stack_push(stack, args): +def cmd_stack_push(stack: StackBranchSet, args): do_push( get_current_stack_as_forest(stack), force=args.force, @@ -867,13 +959,13 @@ def cmd_stack_push(stack, args): ) -def do_sync(forest): +def do_sync(forest: BranchesTreeForest): print_forest(forest) - syncs = [] - sync_names = [] - syncs_set = set() - for b in depth_first(forest): + syncs: List[StackBranch] = [] + sync_names: List[BranchName] = [] + syncs_set: set[StackBranch] = set() + for b in forest_depth_first(forest): if not b.parent: cout("✓ Not syncing base branch {}\n", b.name, fg="green") continue @@ -895,10 +987,11 @@ def do_sync(forest): syncs.reverse() sync_names.reverse() + # TODO: use list(syncs_set).reverse() ? inner_do_sync(syncs, sync_names) -def set_parent_commit(branch, new_commit, prev_commit=None): +def set_parent_commit(branch: BranchName, new_commit: Commit, prev_commit: Optional[str] = None): cmd = [ "git", "update-ref", @@ -907,15 +1000,16 @@ def set_parent_commit(branch, new_commit, prev_commit=None): ] if prev_commit is not None: cmd.append(prev_commit) - run(cmd) + run(CmdArgs(cmd)) -def get_commits_between(a, b): - lines = run_multiline(["git", "rev-list", "{}..{}".format(a, b)]) +def get_commits_between(a: Commit, b: Commit): + lines = run_multiline(CmdArgs(["git", "rev-list", "{}..{}".format(a, b)])) + assert lines is not None return [x.strip() for x in lines.split("\n")] -def inner_do_sync(syncs, sync_names): +def inner_do_sync(syncs: List[StackBranch], sync_names: List[BranchName]): print() while syncs: with open(TMP_STATE_FILE, "w") as f: @@ -937,7 +1031,7 @@ def inner_do_sync(syncs, sync_names): else: cout("Rebasing {} on top of {}\n", b.name, b.parent.name, fg="green") r = run( - ["git", "rebase", "--onto", b.parent.name, b.parent_commit, b.name], + CmdArgs(["git", "rebase", "--onto", b.parent.name, b.parent_commit, b.name]), out=True, check=False, ) @@ -949,14 +1043,14 @@ def inner_do_sync(syncs, sync_names): b.commit = get_commit(b.name) set_parent_commit(b.name, b.parent.commit, b.parent_commit) b.parent_commit = b.parent.commit - run(["git", "checkout", CURRENT_BRANCH]) + run(CmdArgs(["git", "checkout", str(CURRENT_BRANCH)])) -def cmd_stack_sync(stack, args): +def cmd_stack_sync(stack: StackBranchSet, args): do_sync(get_current_stack_as_forest(stack)) -def do_commit(stack, *, message=None, amend=False, allow_empty=False, edit=True): +def do_commit(stack: StackBranchSet, *, message=None, amend=False, allow_empty=False, edit=True): b = stack.stack[CURRENT_BRANCH] if not b.parent: die("Do not commit directly on {}", b.name) @@ -980,14 +1074,14 @@ def do_commit(stack, *, message=None, amend=False, allow_empty=False, edit=True) die("--no-edit is only supported with --amend") if message: cmd += ["-m", message] - run(cmd, out=True) + run(CmdArgs(cmd), out=True) # Sync everything upstack b.commit = get_commit(b.name) do_sync(get_current_upstack_as_forest(stack)) -def cmd_commit(stack, args): +def cmd_commit(stack: StackBranchSet, args): do_commit( stack, message=args.message, @@ -997,18 +1091,18 @@ def cmd_commit(stack, args): ) -def cmd_amend(stack, args): +def cmd_amend(stack: StackBranchSet, args): do_commit(stack, amend=True, edit=False) -def cmd_upstack_info(stack, args): +def cmd_upstack_info(stack: StackBranchSet, args): forest = get_current_upstack_as_forest(stack) if args.pr: load_pr_info_for_forest(forest) print_forest(forest) -def cmd_upstack_push(stack, args): +def cmd_upstack_push(stack: StackBranchSet, args): do_push( get_current_upstack_as_forest(stack), force=args.force, @@ -1017,31 +1111,33 @@ def cmd_upstack_push(stack, args): ) -def cmd_upstack_sync(stack, args): +def cmd_upstack_sync(stack: StackBranchSet, args): do_sync(get_current_upstack_as_forest(stack)) -def set_parent(branch, target, *, set_origin=False): +def set_parent(branch: BranchName, target: BranchName, *, set_origin: bool = False): if set_origin: - run(["git", "config", "branch.{}.remote".format(branch), "."]) + run(CmdArgs(["git", "config", "branch.{}.remote".format(branch), "."])) run( - [ - "git", - "config", - "branch.{}.merge".format(branch), - "refs/heads/{}".format(target), - ] + CmdArgs( + [ + "git", + "config", + "branch.{}.merge".format(branch), + "refs/heads/{}".format(target), + ] + ) ) -def cmd_upstack_onto(stack, args): +def cmd_upstack_onto(stack: StackBranchSet, args): b = stack.stack[CURRENT_BRANCH] if not b.parent: die("May not restack {}", b.name) target = stack.stack[args.target] upstack = get_current_upstack_as_forest(stack) - for ub in depth_first(upstack): + for ub in forest_depth_first(upstack): if ub == target: die("Target branch {} is upstack of {}", target.name, b.name) b.parent = target @@ -1057,7 +1153,7 @@ def cmd_downstack_info(stack, args): print_forest(forest) -def cmd_downstack_push(stack, args): +def cmd_downstack_push(stack: StackBranchSet, args): do_push( get_current_downstack_as_forest(stack), force=args.force, @@ -1066,27 +1162,36 @@ def cmd_downstack_push(stack, args): ) -def cmd_downstack_sync(stack, args): +def cmd_downstack_sync(stack: StackBranchSet, args): do_sync(get_current_downstack_as_forest(stack)) -def get_bottom_level_branches_as_forest(stack): - return [ - {bottom.name: (bottom, {b.name: (b, {}) for b in bottom.children})} - for bottom in stack.bottoms - ] +def get_bottom_level_branches_as_forest(stack: StackBranchSet) -> BranchesTreeForest: + return BranchesTreeForest( + [ + BranchesTree( + { + bottom.name: ( + bottom, + BranchesTree({b.name: (b, BranchesTree({})) for b in bottom.children}), + ) + } + ) + for bottom in stack.bottoms + ] + ) def get_remote_type(remote: str = "origin") -> Optional[str]: - out = run(["git", "remote", "-v"]) + out = run_always_return(CmdArgs(["git", "remote", "-v"])) for l in out.split("\n"): - match = re.match( - r"^{}\s+(?:ssh://)?([^/]*):(?!//).*\s+\(push\)$".format(remote), l - ) + match = re.match(r"^{}\s+(?:ssh://)?([^/]*):(?!//).*\s+\(push\)$".format(remote), l) if match: sshish_host = match.group(1) return sshish_host + return None + def gen_ssh_mux_cmd() -> List[str]: args = [ @@ -1121,15 +1226,16 @@ def start_muxed_ssh(remote: str = "origin"): while p.poll() is None: time.sleep(1) if p.returncode != 0: - error = p.stderr.read() - die( - f"Failed to start ssh muxed connection, error was: {error.decode('utf-8').strip()}" - ) + if p.stderr is not None: + error = p.stderr.read() + else: + error = b"unknown" + die(f"Failed to start ssh muxed connection, error was: {error.decode('utf-8').strip()}") -def get_branches_to_delete(forest): +def get_branches_to_delete(forest: BranchesTreeForest) -> List[StackBranch]: deletes = [] - for b in depth_first(forest): + for b in forest_depth_first(forest): if not b.parent or b.open_pr_info: continue for pr_info in b.pr_info.values(): @@ -1152,7 +1258,7 @@ def get_branches_to_delete(forest): return deletes -def delete_branches(stack, deletes): +def delete_branches(stack: StackBranchSet, deletes: List[StackBranch]): global CURRENT_BRANCH # Make sure we're not trying to delete the current branch for b in deletes: @@ -1164,31 +1270,33 @@ def delete_branches(stack, deletes): if b.name == CURRENT_BRANCH: new_branch = next(iter(stack.bottoms)) info("About to delete current branch, switching to {}", new_branch.name) - run(["git", "checkout", new_branch.name]) - CURRENT_BRANCH = new_branch - run(["git", "branch", "-D", b.name]) + run(CmdArgs(["git", "checkout", new_branch.name])) + CURRENT_BRANCH = new_branch.name + run(CmdArgs(["git", "branch", "-D", b.name])) -def cmd_update(stack, args): +def cmd_update(stack: StackBranchSet, args): remote = "origin" start_muxed_ssh(remote) info("Fetching from {}", remote) - run(["git", "fetch", remote]) + run(CmdArgs(["git", "fetch", remote])) # TODO(tudor): We should rebase instead of silently dropping # everything you have on local master. Oh well. global CURRENT_BRANCH for b in stack.bottoms: run( - [ - "git", - "update-ref", - "refs/heads/{}".format(b.name), - "refs/remotes/{}/{}".format(remote, b.remote_branch), - ] + CmdArgs( + [ + "git", + "update-ref", + "refs/heads/{}".format(b.name), + "refs/remotes/{}/{}".format(remote, b.remote_branch), + ] + ) ) if b.name == CURRENT_BRANCH: - run(["git", "reset", "--hard", "HEAD"]) + run(CmdArgs(["git", "reset", "--hard", "HEAD"])) # We treat origin as the source of truth for bottom branches (master), and # the local repo as the source of truth for everything else. So we can only @@ -1206,17 +1314,20 @@ def cmd_update(stack, args): stop_muxed_ssh(remote) -def cmd_import(stack, args): +def cmd_import(stack: StackBranchSet, args): # Importing has to happen based on PR info, rather than local branch # relationships, as that's the only place Graphite populates. branch = args.name branches = [] bottoms = set(b.name for b in stack.bottoms) while branch not in bottoms: - _, open_pr = get_pr_info(branch, full=True) + pr_info = get_pr_info(branch, full=True) + open_pr = pr_info.open info("Getting PR information for {}", branch) - if not open_pr: + if open_pr is None: die("Branch {} has no open PR", branch) + # Never reached because the die but makes mypy happy + assert open_pr is not None if open_pr["headRefName"] != branch: die( "Branch {} is misconfigured: PR #{} head is {}", @@ -1227,7 +1338,7 @@ def cmd_import(stack, args): if not open_pr["commits"]: die("PR #{} has no commits", open_pr["number"]) first_commit = open_pr["commits"][0]["oid"] - parent_commit = run(["git", "rev-parse", "{}^".format(first_commit)]) + parent_commit = Commit(run_always_return(CmdArgs(["git", "rev-parse", "{}^".format(first_commit)]))) next_branch = open_pr["baseRefName"] info( "Branch {}: PR #{}, parent is {} at commit {}", @@ -1264,11 +1375,11 @@ def cmd_import(stack, args): branch = b -def get_merge_base(b1, b2): - return run(["git", "merge-base", b1, b2]) +def get_merge_base(b1: BranchName, b2: BranchName): + return run(CmdArgs(["git", "merge-base", str(b1), str(b2)])) -def cmd_adopt(stack, args): +def cmd_adopt(stack: StackBranch, args): """ Adopt a branch that is based on the current branch (which must be a valid stack bottom or the stack bottom (master or main) will be used @@ -1277,10 +1388,11 @@ def cmd_adopt(stack, args): branch = args.name global CURRENT_BRANCH if CURRENT_BRANCH not in STACK_BOTTOMS: + # TODO remove that, the initialisation code is already dealing with that in fact main_branch = get_real_stack_bottom() if CONFIG.change_to_main and main_branch is not None: - run(["git", "checkout", main_branch]) + run(CmdArgs(["git", "checkout", main_branch])) CURRENT_BRANCH = main_branch else: die( @@ -1292,10 +1404,10 @@ def cmd_adopt(stack, args): set_parent(branch, CURRENT_BRANCH, set_origin=True) set_parent_commit(branch, parent_commit) if CONFIG.change_to_adopted: - run(["git", "checkout", branch]) + run(CmdArgs(["git", "checkout", branch])) -def cmd_land(stack, args): +def cmd_land(stack: StackBranchSet, args): forest = get_current_downstack_as_forest(stack) assert len(forest) == 1 branches = [] @@ -1326,6 +1438,7 @@ def cmd_land(stack, args): pr = b.open_pr_info if not pr: die("Branch {} does not have an open PR", b.name) + assert pr is not None if pr["mergeable"] != "MERGEABLE": die( @@ -1353,8 +1466,10 @@ def cmd_land(stack, args): if not args.force: confirm() - head_commit = run(["git", "rev-parse", b.name]) - cmd = ["gh", "pr", "merge", b.name, "--squash", "--match-head-commit", head_commit] + v = run(CmdArgs(["git", "rev-parse", b.name])) + assert v is not None + head_commit = Commit(v) + cmd = CmdArgs(["gh", "pr", "merge", b.name, "--squash", "--match-head-commit", head_commit]) if args.auto: cmd.append("--auto") run(cmd, out=True) @@ -1387,20 +1502,14 @@ def main(): subparsers = parser.add_subparsers(required=True, dest="command") # continue - continue_parser = subparsers.add_parser( - "continue", help="Continue previously interrupted command" - ) + continue_parser = subparsers.add_parser("continue", help="Continue previously interrupted command") continue_parser.set_defaults(func=None) # down - down_parser = subparsers.add_parser( - "down", help="Go down in the current stack (towards master/main)" - ) + down_parser = subparsers.add_parser("down", help="Go down in the current stack (towards master/main)") down_parser.set_defaults(func=cmd_branch_down) # up - up_parser = subparsers.add_parser( - "up", help="Go up in the current stack (away master/main)" - ) + up_parser = subparsers.add_parser("up", help="Go up in the current stack (away master/main)") up_parser.set_defaults(func=cmd_branch_up) # info info_parser = subparsers.add_parser("info", help="Stack info") @@ -1410,73 +1519,43 @@ def main(): # commit commit_parser = subparsers.add_parser("commit", help="Commit") commit_parser.add_argument("-m", help="Commit message", dest="message") - commit_parser.add_argument( - "--amend", action="store_true", help="Amend last commit" - ) - commit_parser.add_argument( - "--allow-empty", action="store_true", help="Allow empty commit" - ) + commit_parser.add_argument("--amend", action="store_true", help="Amend last commit") + commit_parser.add_argument("--allow-empty", action="store_true", help="Allow empty commit") commit_parser.add_argument("--no-edit", action="store_true", help="Skip editor") commit_parser.set_defaults(func=cmd_commit) # amend - amend_parser = subparsers.add_parser( - "amend", help="Shortcut for amending last commit" - ) + amend_parser = subparsers.add_parser("amend", help="Shortcut for amending last commit") amend_parser.set_defaults(func=cmd_amend) # branch - branch_parser = subparsers.add_parser( - "branch", aliases=["b"], help="Operations on branches" - ) - branch_subparsers = branch_parser.add_subparsers( - required=True, dest="branch_command" - ) - branch_up_parser = branch_subparsers.add_parser( - "up", aliases=["u"], help="Move upstack" - ) + branch_parser = subparsers.add_parser("branch", aliases=["b"], help="Operations on branches") + branch_subparsers = branch_parser.add_subparsers(required=True, dest="branch_command") + branch_up_parser = branch_subparsers.add_parser("up", aliases=["u"], help="Move upstack") branch_up_parser.set_defaults(func=cmd_branch_up) - branch_down_parser = branch_subparsers.add_parser( - "down", aliases=["d"], help="Move downstack" - ) + branch_down_parser = branch_subparsers.add_parser("down", aliases=["d"], help="Move downstack") branch_down_parser.set_defaults(func=cmd_branch_down) - branch_new_parser = branch_subparsers.add_parser( - "new", aliases=["create"], help="Create a new branch" - ) + branch_new_parser = branch_subparsers.add_parser("new", aliases=["create"], help="Create a new branch") branch_new_parser.add_argument("name", help="Branch name") branch_new_parser.set_defaults(func=cmd_branch_new) - branch_checkout_parser = branch_subparsers.add_parser( - "checkout", aliases=["co"], help="Checkout a branch" - ) + branch_checkout_parser = branch_subparsers.add_parser("checkout", aliases=["co"], help="Checkout a branch") branch_checkout_parser.add_argument("name", help="Branch name", nargs="?") branch_checkout_parser.set_defaults(func=cmd_branch_checkout) # stack - stack_parser = subparsers.add_parser( - "stack", aliases=["s"], help="Operations on the full current stack" - ) - stack_subparsers = stack_parser.add_subparsers( - required=True, dest="stack_command" - ) + stack_parser = subparsers.add_parser("stack", aliases=["s"], help="Operations on the full current stack") + stack_subparsers = stack_parser.add_subparsers(required=True, dest="stack_command") - stack_info_parser = stack_subparsers.add_parser( - "info", aliases=["i"], help="Info for current stack" - ) - stack_info_parser.add_argument( - "--pr", action="store_true", help="Get PR info (slow)" - ) + stack_info_parser = stack_subparsers.add_parser("info", aliases=["i"], help="Info for current stack") + stack_info_parser.add_argument("--pr", action="store_true", help="Get PR info (slow)") stack_info_parser.set_defaults(func=cmd_stack_info) stack_push_parser = stack_subparsers.add_parser("push", help="Push") - stack_push_parser.add_argument( - "--force", "-f", action="store_true", help="Bypass confirmation" - ) - stack_push_parser.add_argument( - "--no-pr", dest="pr", action="store_false", help="Skip Create PRs" - ) + stack_push_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") + stack_push_parser.add_argument("--no-pr", dest="pr", action="store_false", help="Skip Create PRs") stack_push_parser.set_defaults(func=cmd_stack_push) stack_sync_parser = stack_subparsers.add_parser("sync", help="Sync") @@ -1488,36 +1567,22 @@ def main(): stack_checkout_parser.set_defaults(func=cmd_stack_checkout) # upstack - upstack_parser = subparsers.add_parser( - "upstack", aliases=["us"], help="Operations on the current upstack" - ) - upstack_subparsers = upstack_parser.add_subparsers( - required=True, dest="upstack_command" - ) + upstack_parser = subparsers.add_parser("upstack", aliases=["us"], help="Operations on the current upstack") + upstack_subparsers = upstack_parser.add_subparsers(required=True, dest="upstack_command") - upstack_info_parser = upstack_subparsers.add_parser( - "info", aliases=["i"], help="Info for current upstack" - ) - upstack_info_parser.add_argument( - "--pr", action="store_true", help="Get PR info (slow)" - ) + upstack_info_parser = upstack_subparsers.add_parser("info", aliases=["i"], help="Info for current upstack") + upstack_info_parser.add_argument("--pr", action="store_true", help="Get PR info (slow)") upstack_info_parser.set_defaults(func=cmd_upstack_info) upstack_push_parser = upstack_subparsers.add_parser("push", help="Push") - upstack_push_parser.add_argument( - "--force", "-f", action="store_true", help="Bypass confirmation" - ) - upstack_push_parser.add_argument( - "--no-pr", dest="pr", action="store_false", help="Skip Create PRs" - ) + upstack_push_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") + upstack_push_parser.add_argument("--no-pr", dest="pr", action="store_false", help="Skip Create PRs") upstack_push_parser.set_defaults(func=cmd_upstack_push) upstack_sync_parser = upstack_subparsers.add_parser("sync", help="Sync") upstack_sync_parser.set_defaults(func=cmd_upstack_sync) - upstack_onto_parser = upstack_subparsers.add_parser( - "onto", aliases=["restack"], help="Restack" - ) + upstack_onto_parser = upstack_subparsers.add_parser("onto", aliases=["restack"], help="Restack") upstack_onto_parser.add_argument("target", help="New parent") upstack_onto_parser.set_defaults(func=cmd_upstack_onto) @@ -1525,25 +1590,17 @@ def main(): downstack_parser = subparsers.add_parser( "downstack", aliases=["ds"], help="Operations on the current downstack" ) - downstack_subparsers = downstack_parser.add_subparsers( - required=True, dest="downstack_command" - ) + downstack_subparsers = downstack_parser.add_subparsers(required=True, dest="downstack_command") downstack_info_parser = downstack_subparsers.add_parser( "info", aliases=["i"], help="Info for current downstack" ) - downstack_info_parser.add_argument( - "--pr", action="store_true", help="Get PR info (slow)" - ) + downstack_info_parser.add_argument("--pr", action="store_true", help="Get PR info (slow)") downstack_info_parser.set_defaults(func=cmd_downstack_info) downstack_push_parser = downstack_subparsers.add_parser("push", help="Push") - downstack_push_parser.add_argument( - "--force", "-f", action="store_true", help="Bypass confirmation" - ) - downstack_push_parser.add_argument( - "--no-pr", dest="pr", action="store_false", help="Skip Create PRs" - ) + downstack_push_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") + downstack_push_parser.add_argument("--no-pr", dest="pr", action="store_false", help="Skip Create PRs") downstack_push_parser.set_defaults(func=cmd_downstack_push) downstack_sync_parser = downstack_subparsers.add_parser("sync", help="Sync") @@ -1551,16 +1608,12 @@ def main(): # update update_parser = subparsers.add_parser("update", help="Update repo") - update_parser.add_argument( - "--force", "-f", action="store_true", help="Bypass confirmation" - ) + update_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") update_parser.set_defaults(func=cmd_update) # import import_parser = subparsers.add_parser("import", help="Import Graphite stack") - import_parser.add_argument( - "--force", "-f", action="store_true", help="Bypass confirmation" - ) + import_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") import_parser.add_argument("name", help="Foreign stack top") import_parser.set_defaults(func=cmd_import) @@ -1570,12 +1623,8 @@ def main(): adopt_parser.set_defaults(func=cmd_adopt) # land - land_parser = subparsers.add_parser( - "land", help="Land bottom-most PR on current stack" - ) - land_parser.add_argument( - "--force", "-f", action="store_true", help="Bypass confirmation" - ) + land_parser = subparsers.add_parser("land", help="Land bottom-most PR on current stack") + land_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") land_parser.add_argument( "--auto", "-a", @@ -1586,35 +1635,25 @@ def main(): # shortcuts push_parser = subparsers.add_parser("push", help="Alias for downstack push") - push_parser.add_argument( - "--force", "-f", action="store_true", help="Bypass confirmation" - ) - push_parser.add_argument( - "--no-pr", dest="pr", action="store_false", help="Skip Create PRs" - ) + push_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") + push_parser.add_argument("--no-pr", dest="pr", action="store_false", help="Skip Create PRs") push_parser.set_defaults(func=cmd_downstack_push) sync_parser = subparsers.add_parser("sync", help="Alias for stack sync") sync_parser.set_defaults(func=cmd_stack_sync) - checkout_parser = subparsers.add_parser( - "checkout", aliases=["co"], help="Checkout a branch" - ) + checkout_parser = subparsers.add_parser("checkout", aliases=["co"], help="Checkout a branch") checkout_parser.add_argument("name", help="Branch name", nargs="?") checkout_parser.set_defaults(func=cmd_branch_checkout) - checkout_parser = subparsers.add_parser( - "sco", help="Checkout a branch in this stack" - ) + checkout_parser = subparsers.add_parser("sco", help="Checkout a branch in this stack") checkout_parser.set_defaults(func=cmd_stack_checkout) global CONFIG CONFIG = read_config() args = parser.parse_args() - logging.basicConfig( - format=_LOGGING_FORMAT, level=LOGLEVELS[args.log_level], force=True - ) + logging.basicConfig(format=_LOGGING_FORMAT, level=LOGLEVELS[args.log_level], force=True) global COLOR_STDERR global COLOR_STDOUT @@ -1648,6 +1687,8 @@ def main(): inner_do_sync(syncs, sync_names) else: + # TODO restore the current branch after changing the branch on some commands for + # instance `info` if CURRENT_BRANCH not in stack.stack: main_branch = get_real_stack_bottom() @@ -1657,6 +1698,7 @@ def main(): else: die("Current branch {} is not in a stack", CURRENT_BRANCH) + get_current_stack_as_forest(stack) args.func(stack, args) # Success, delete the state file diff --git a/src/stacky/stacky_test.py b/src/stacky/stacky_test.py new file mode 100755 index 0000000..0471373 --- /dev/null +++ b/src/stacky/stacky_test.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +import unittest +from unittest import mock +from unittest.mock import MagicMock +from stacky import PRInfos, read_config, get_top_level_dir + + +class TestStringMethods(unittest.TestCase): + def test_upper(self): + self.assertEqual("foo".upper(), "FOO") + + def test_isupper(self): + self.assertTrue("FOO".isupper()) + self.assertFalse("Foo".isupper()) + + def test_split(self): + s = "hello world" + self.assertEqual(s.split(), ["hello", "world"]) + # check that s.split fails when the separator is not a string + with self.assertRaises(TypeError): + s.split(2) + + @mock.patch("get_top_level_dir") + def test_read_config(self, mock_get_tld): + patcher = mock.patch("os.path.exists") + mock_thing = patcher.start() + mock_thing.return_value = False + read_config() + + +if __name__ == "__main__": + unittest.main()