From b68a60a41332f8a22840f8768c3c4eb2c27cb10e Mon Sep 17 00:00:00 2001 From: Matthieu Patou Date: Mon, 15 Jan 2024 09:59:11 -0800 Subject: [PATCH] Add typing to Stacky Summary As indicated in the title, the idea is to have things more strongly typed so that we can run mypy on diff to ensure that we have less risk of regressing. Testing Ran: ``` mypy src ``` --- pyproject.toml | 8 + setup.py | 25 +- src/stacky/stacky.py | 792 ++++++++++++++++++++------------------ src/stacky/stacky_test.py | 32 ++ 4 files changed, 470 insertions(+), 387 deletions(-) create mode 100644 pyproject.toml create mode 100755 src/stacky/stacky_test.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c8b8d08 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,8 @@ +[tool.black] +line-length = 120 +target-version = ['py310'] +include = '\.pyi?$' + +[tool.isort] +profile = "black" +line_length = 120 diff --git a/setup.py b/setup.py index 535bf12..62de83d 100644 --- a/setup.py +++ b/setup.py @@ -1,28 +1,29 @@ -from setuptools import setup, find_packages import pathlib +from setuptools import find_packages, setup + here = pathlib.Path(__file__).parent.resolve() # Get the long description from the README file long_description = (here / "README.md").read_text(encoding="utf-8") setup( - name="rockset-stacky", - version="1.0.10", + name="rockset-stacky", + version="1.0.11", description=""" stacky is a tool to manage stacks of PRs. This allows developers to easily manage many smaller, more targeted PRs that depend on each other. - """, - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/rockset/stacky", - author="Rockset", - author_email="tudor@rockset.com", + """, + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/rockset/stacky", + author="Rockset", + author_email="tudor@rockset.com", keywords="github, stack, pr, pull request", - package_dir={"": "src"}, + package_dir={"": "src"}, packages=find_packages(where="src"), python_requires=">=3.8, <4", - install_requires=["asciitree", "ansicolors", "simple-term-menu"], + install_requires=["asciitree", "ansicolors", "simple-term-menu"], entry_points={ "console_scripts": [ "stacky=stacky:main", @@ -32,4 +33,4 @@ "Bug Reports": "https://github.com/rockset/stacky/issues", "Source": "https://github.com/rockset/stacky", }, -) \ No newline at end of file +) diff --git a/src/stacky/stacky.py b/src/stacky/stacky.py index 33704a0..4502e64 100755 --- a/src/stacky/stacky.py +++ b/src/stacky/stacky.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 + # GitHub helper for stacked diffs. # # Git maintains all metadata locally. Does everything by forking "git" and "gh" @@ -31,21 +32,57 @@ import sys import time from argparse import ArgumentParser -from typing import List, Optional +from typing import Dict, FrozenSet, Generator, List, NewType, Optional, Tuple, TypedDict, Union + +import asciitree # type: ignore +import colors # type: ignore +from simple_term_menu import TerminalMenu # type: ignore + +BranchName = NewType("BranchName", str) +PathName = NewType("PathName", str) +Commit = NewType("Commit", str) +CmdArgs = NewType("CmdArgs", List[str]) +StackSubTree = Tuple["StackBranch", "BranchesTree"] +TreeNode = Tuple[BranchName, StackSubTree] +BranchesTree = NewType("BranchesTree", Dict[BranchName, StackSubTree]) +BranchesTreeForest = NewType("BranchesTreeForest", List[BranchesTree]) + +JSON = Union[Dict[str, "JSON"], List["JSON"], str, int, float, bool, None] + + +class PRInfo(TypedDict): + id: str + number: int + state: str + mergeable: str + url: str + title: str + baseRefName: str + headRefName: str + commits: List[Dict[str, str]] + + +@dataclasses.dataclass +class PRInfos: + all: Dict[str, PRInfo] + open: Optional[PRInfo] + + +@dataclasses.dataclass +class BranchNCommit: + branch: BranchName + parent_commit: Optional[str] -import asciitree -import colors -from simple_term_menu import TerminalMenu _LOGGING_FORMAT = "%(asctime)s %(module)s %(levelname)s: %(message)s" # 2 minutes ought to be enough for anybody ;-) MAX_SSH_MUX_LIFETIME = 120 -COLOR_STDOUT = os.isatty(1) -COLOR_STDERR = os.isatty(2) -IS_TERMINAL = os.isatty(1) and os.isatty(2) -CURRENT_BRANCH = None -STACK_BOTTOMS = frozenset(["master", "main"]) +COLOR_STDOUT: bool = os.isatty(1) +COLOR_STDERR: bool = os.isatty(2) +IS_TERMINAL: bool = os.isatty(1) and os.isatty(2) +CURRENT_BRANCH: BranchName +STACK_BOTTOMS: FrozenSet[BranchName] = frozenset([BranchName("master"), BranchName("main")]) STATE_FILE = os.path.expanduser("~/.stacky.state") TMP_STATE_FILE = STATE_FILE + ".tmp" @@ -70,18 +107,13 @@ def read_one_config(self, config_path: str): rawconfig = configparser.ConfigParser() rawconfig.read(config_path) if rawconfig.has_section("UI"): - self.skip_confirm = rawconfig.get( - "UI", "skip_confirm", fallback=self.skip_confirm - ) - self.change_to_main = rawconfig.get( - "UI", "change_to_main", fallback=self.change_to_main - ) - self.change_to_adopted = rawconfig.get( - "UI", "change_to_adopted", fallback=self.change_to_adopted - ) - self.share_ssh_session = rawconfig.get( - "UI", "share_ssh_session", fallback=self.share_ssh_session - ) + self.skip_confirm = bool(rawconfig.get("UI", "skip_confirm", fallback=self.skip_confirm)) + self.change_to_main = bool(rawconfig.get("UI", "change_to_main", fallback=self.change_to_main)) + self.change_to_adopted = bool(rawconfig.get("UI", "change_to_adopted", fallback=self.change_to_adopted)) + self.share_ssh_session = bool(rawconfig.get("UI", "share_ssh_session", fallback=self.share_ssh_session)) + + +CONFIG: StackyConfig def read_config() -> StackyConfig: @@ -96,7 +128,7 @@ def read_config() -> StackyConfig: return config -def fmt(s, *args, color=False, fg=None, bg=None, style=None, **kwargs): +def fmt(s: str, *args, color: bool = False, fg=None, bg=None, style=None, **kwargs) -> str: s = colors.color(s, fg=fg, bg=bg, style=style) if color else s return s.format(*args, **kwargs) @@ -148,7 +180,7 @@ def die(*args, **kwargs): raise ExitException(*args, **kwargs) -def _check_returncode(sp, cmd): +def _check_returncode(sp: subprocess.CompletedProcess, cmd: CmdArgs): rc = sp.returncode if rc == 0: return @@ -159,7 +191,7 @@ def _check_returncode(sp, cmd): die("Exited with status {}: {}. Stderr was:\n{}", rc, shlex.join(cmd), stderr) -def run_multiline(cmd, *, check=True, null=True, out=False): +def run_multiline(cmd: CmdArgs, *, check: bool = True, null: bool = True, out: bool = False) -> Optional[str]: debug("Running: {}", shlex.join(cmd)) sys.stdout.flush() sys.stderr.flush() @@ -178,29 +210,37 @@ def run_multiline(cmd, *, check=True, null=True, out=False): return sp.stdout.decode("UTF-8") -def run(cmd, **kwargs): +def run_always_return(cmd: CmdArgs, **kwargs) -> str: + out = run(cmd, **kwargs) + assert out is not None + return out + + +def run(cmd: CmdArgs, **kwargs) -> Optional[str]: out = run_multiline(cmd, **kwargs) return None if out is None else out.strip() -def remove_prefix(s, prefix): +def remove_prefix(s: str, prefix: str) -> str: if not s.startswith(prefix): die('Invalid string "{}": expected prefix "{}"', s, prefix) return s[len(prefix) :] # noqa: E203 -def get_current_branch(): - return remove_prefix(run(["git", "symbolic-ref", "-q", "HEAD"]), "refs/heads/") +def get_current_branch() -> Optional[BranchName]: + s = run(CmdArgs(["git", "symbolic-ref", "-q", "HEAD"])) + if s is not None: + return BranchName(remove_prefix(s, "refs/heads/")) + return None -def get_all_branches(): - branches = run_multiline( - ["git", "for-each-ref", "--format", "%(refname:short)", "refs/heads"] - ) - return [b for b in branches.split("\n") if b] +def get_all_branches() -> List[BranchName]: + branches = run_multiline(CmdArgs(["git", "for-each-ref", "--format", "%(refname:short)", "refs/heads"])) + assert branches is not None + return [BranchName(b) for b in branches.split("\n") if b] -def get_real_stack_bottom() -> Optional[str]: +def get_real_stack_bottom() -> Optional[BranchName]: # type: ignore [return] """ return the actual stack bottom for this current repo """ @@ -214,28 +254,36 @@ def get_real_stack_bottom() -> Optional[str]: return candiates.pop() -def get_stack_parent_branch(branch): +def get_stack_parent_branch(branch: BranchName) -> Optional[BranchName]: # type: ignore [return] if branch in STACK_BOTTOMS: return None - p = run(["git", "config", "branch.{}.merge".format(branch)], check=False) + p = run(CmdArgs(["git", "config", "branch.{}.merge".format(branch)]), check=False) if p is not None: p = remove_prefix(p, "refs/heads/") - return p + return BranchName(p) + +def get_top_level_dir() -> PathName: + p = run_always_return(CmdArgs(["git", "rev-parse", "--show-toplevel"])) + return PathName(p) -def get_top_level_dir() -> str: - return run(["git", "rev-parse", "--show-toplevel"]) +def get_stack_parent_commit(branch: BranchName) -> Optional[Commit]: # type: ignore [return] + c = run( + CmdArgs(["git", "rev-parse", "refs/stack-parent/{}".format(branch)]), + check=False, + ) -def get_stack_parent_commit(branch): - return run(["git", "rev-parse", "refs/stack-parent/{}".format(branch)], check=False) + if c is not None: + return Commit(c) -def get_commit(branch): - return run(["git", "rev-parse", "refs/heads/{}".format(branch)], check=False) +def get_commit(branch: BranchName) -> Commit: # type: ignore [return] + c = run_always_return(CmdArgs(["git", "rev-parse", "refs/heads/{}".format(branch)]), check=False) + return Commit(c) -def get_pr_info(branch, *, full=False): +def get_pr_info(branch: BranchName, *, full: bool = False) -> PRInfos: fields = [ "id", "number", @@ -248,37 +296,40 @@ def get_pr_info(branch, *, full=False): ] if full: fields += ["commits"] - fields = ",".join(fields) - infos = json.loads( - run( - [ - "gh", - "pr", - "list", - "--json", - fields, - "--state", - "all", - "--head", - branch, - ] + data = json.loads( + run_always_return( + CmdArgs( + [ + "gh", + "pr", + "list", + "--json", + ",".join(fields), + "--state", + "all", + "--head", + branch, + ] + ) ) ) - infos = {info["id"]: info for info in infos} - open_prs = [info for info in infos.values() if info["state"] == "OPEN"] + raw_infos: List[PRInfo] = data + + infos: Dict[str, PRInfo] = {info["id"]: info for info in raw_infos} + open_prs: List[PRInfo] = [info for info in infos.values() if info["state"] == "OPEN"] if len(open_prs) > 1: die( "Branch {} has more than one open PR: {}", branch, ", ".join([str(pr) for pr in open_prs]), - ) - return infos, open_prs[0] if open_prs else None + ) # type: ignore[arg-type] + return PRInfos(infos, open_prs[0] if open_prs else None) # (remote, remote_branch, remote_branch_commit) -def get_remote_info(branch): +def get_remote_info(branch: BranchName) -> Tuple[str, BranchName, Optional[Commit]]: if branch not in STACK_BOTTOMS: - remote = run(["git", "config", "branch.{}.remote".format(branch)], check=False) + remote = run(CmdArgs(["git", "config", "branch.{}.remote".format(branch)]), check=False) if remote != ".": die("Misconfigured branch {}: remote {}", branch, remote) @@ -287,28 +338,33 @@ def get_remote_info(branch): remote_branch = branch remote_commit = run( - ["git", "rev-parse", "refs/remotes/{}/{}".format(remote, remote_branch)], + CmdArgs(["git", "rev-parse", "refs/remotes/{}/{}".format(remote, remote_branch)]), check=False, ) - return (remote, remote_branch, remote_commit) + # TODO(mpatou): do something when remote_commit is none + commit = None + if remote_commit is not None: + commit = Commit(remote_commit) + + return (remote, BranchName(remote_branch), commit) class StackBranch: def __init__( self, - name, - parent, - parent_commit, + name: BranchName, + parent: "StackBranch", + parent_commit: Commit, ): self.name = name self.parent = parent self.parent_commit = parent_commit - self.children = set() + self.children: set["StackBranch"] = set() self.commit = get_commit(name) self.remote, self.remote_branch, self.remote_commit = get_remote_info(name) - self.pr_info = [] - self.open_pr_info = None + self.pr_info: Dict[str, PRInfo] = {} + self.open_pr_info: Optional[PRInfo] = None self._pr_info_loaded = False def is_synced_with_parent(self): @@ -318,21 +374,26 @@ def is_synced_with_remote(self): return self.commit == self.remote_commit def __repr__(self): - return f"StackBranch: {self.name} {len(self.children)}" + return f"StackBranch: {self.name} {len(self.children)} {self.commit}" def load_pr_info(self): if not self._pr_info_loaded: self._pr_info_loaded = True - self.pr_info, self.open_pr_info = get_pr_info(self.name) + pr_infos = get_pr_info(self.name) + # FIXME maybe store the whole object and use it elsewhere + self.pr_info, self.open_pr_info = ( + pr_infos.all, + pr_infos.open, + ) class StackBranchSet: - def __init__(self): - self.stack = {} - self.tops = set() - self.bottoms = set() + def __init__(self: "StackBranchSet"): + self.stack: Dict[BranchName, StackBranch] = {} + self.tops: set[StackBranch] = set() + self.bottoms: set[StackBranch] = set() - def add(self, name, **kwargs) -> StackBranch: + def add(self, name: BranchName, **kwargs) -> StackBranch: if name in self.stack: s = self.stack[name] assert s.name == name @@ -353,44 +414,56 @@ def add(self, name, **kwargs) -> StackBranch: self.tops.add(s) return s - def add_child(self, s, child): + def __repr__(self) -> str: + out = f"StackBranchSet: {self.stack}" + return out + + def add_child(self, s: StackBranch, child: StackBranch): s.children.add(child) self.tops.discard(s) -def load_current_stack(stack, branch, *, check=True): - branches = [] +def load_stack_for_given_branch( + stack: StackBranchSet, branch: BranchName, *, check: bool = True +) -> Tuple[Optional[StackBranch], List[BranchName]]: + """Given a stack of branch and a branch name, + update the stack with all the parents of the specified branch + if the branch is part of an existing stack. + Return also a list of BranchName of all the branch bellow the specified one + """ + branches: List[BranchNCommit] = [] while branch not in STACK_BOTTOMS: parent = get_stack_parent_branch(branch) parent_commit = get_stack_parent_commit(branch) - branches.append((branch, parent_commit)) + branches.append(BranchNCommit(branch, parent_commit)) if not parent or not parent_commit: if check: die("Branch is not in a stack: {}", branch) - return None, [b for b, _ in branches] + return None, [b.branch for b in branches] branch = parent - branches.append((branch, None)) + branches.append(BranchNCommit(branch, None)) top = None - for name, parent_commit in reversed(branches): + for b in reversed(branches): n = stack.add( - name, + b.branch, parent=top, - parent_commit=parent_commit, + parent_commit=b.parent_commit, ) if top: stack.add_child(top, n) top = n - return top, [b for b, _ in branches] + return top, [b.branch for b in branches] -def load_all_stacks(stack): +def load_all_stacks(stack: StackBranchSet) -> Optional[StackBranch]: + """Given a stack return the top of it, aka the bottom of the tree""" all_branches = set(get_all_branches()) current_branch_top = None while all_branches: b = all_branches.pop() - top, branches = load_current_stack(stack, b, check=False) + top, branches = load_stack_for_given_branch(stack, b, check=False) all_branches -= set(branches) if top is None: if len(branches) > 1: @@ -402,28 +475,29 @@ def load_all_stacks(stack): return current_branch_top -def make_tree_node(b): +def make_tree_node(b: StackBranch) -> TreeNode: return (b.name, (b, make_subtree(b))) -def make_subtree(b): - return dict(make_tree_node(c) for c in sorted(b.children, key=lambda x: x.name)) +def make_subtree(b) -> BranchesTree: + return BranchesTree(dict(make_tree_node(c) for c in sorted(b.children, key=lambda x: x.name))) -def make_tree(b): - return dict([make_tree_node(b)]) +def make_tree(b: StackBranch) -> BranchesTree: + return BranchesTree(dict([make_tree_node(b)])) -def format_name(b, *, color=None): +def format_name(b: StackBranch, *, colorize: bool) -> str: prefix = "" severity = 0 + # TODO: Align things so that we have the same prefix length ? if not b.is_synced_with_parent(): - prefix += fmt("!", color=color, fg="yellow") + prefix += fmt("!", color=colorize, fg="yellow") severity = max(severity, 2) if not b.is_synced_with_remote(): - prefix += fmt("~", color=color, fg="yellow") + prefix += fmt("~", color=colorize, fg="yellow") if b.name == CURRENT_BRANCH: - prefix += fmt("*", color=color, fg="cyan") + prefix += fmt("*", color=colorize, fg="cyan") else: severity = max(severity, 1) if prefix: @@ -432,15 +506,15 @@ def format_name(b, *, color=None): suffix = "" if b.open_pr_info: suffix += " " - suffix += fmt("(#{})", b.open_pr_info["number"], color=color, fg="blue") + suffix += fmt("(#{})", b.open_pr_info["number"], color=colorize, fg="blue") suffix += " " - suffix += fmt("{}", b.open_pr_info["title"], color=color, fg="blue") - return prefix + fmt("{}", b.name, color=color, fg=fg) + suffix + suffix += fmt("{}", b.open_pr_info["title"], color=colorize, fg="blue") + return prefix + fmt("{}", b.name, color=colorize, fg=fg) + suffix -def format_tree(tree, *, color=None): +def format_tree(tree: BranchesTree, *, colorize: bool = False): return { - format_name(branch, color=color): format_tree(children, color=color) + format_name(branch, colorize=colorize): format_tree(children, colorize=colorize) for branch, children in tree.values() } @@ -456,46 +530,46 @@ def format_tree(tree, *, color=None): ASCII_TREE = asciitree.LeftAligned(draw=_ASCII_TREE_STYLE) -def print_tree(tree): +def print_tree(tree: BranchesTree): global ASCII_TREE - s = ASCII_TREE(format_tree(tree, color=COLOR_STDOUT)) + s = ASCII_TREE(format_tree(tree, colorize=COLOR_STDOUT)) lines = s.split("\n") print("\n".join(reversed(lines))) -def print_forest(trees): +def print_forest(trees: List[BranchesTree]): for i, t in enumerate(trees): if i != 0: print() print_tree(t) -def get_all_stacks_as_forest(stack): - return [make_tree(b) for b in stack.bottoms] +def get_all_stacks_as_forest(stack: StackBranchSet) -> BranchesTreeForest: + return BranchesTreeForest([make_tree(b) for b in stack.bottoms]) -def get_current_stack_as_forest(stack): +def get_current_stack_as_forest(stack: StackBranchSet): b = stack.stack[CURRENT_BRANCH] - d = make_tree(b) + d: BranchesTree = make_tree(b) b = b.parent while b: - d = {b.name: (b, d)} + d = BranchesTree({b.name: (b, d)}) b = b.parent return [d] -def get_current_upstack_as_forest(stack): +def get_current_upstack_as_forest(stack: StackBranchSet) -> BranchesTreeForest: b = stack.stack[CURRENT_BRANCH] - return [make_tree(b)] + return BranchesTreeForest([make_tree(b)]) -def get_current_downstack_as_forest(stack): +def get_current_downstack_as_forest(stack: StackBranchSet) -> BranchesTreeForest: b = stack.stack[CURRENT_BRANCH] - d = {} + d: BranchesTree = BranchesTree({}) while b: - d = {b.name: (b, d)} + d = BranchesTree({b.name: (b, d)}) b = b.parent - return [d] + return BranchesTreeForest([d]) def init_git(): @@ -509,19 +583,23 @@ def init_git(): CURRENT_BRANCH = get_current_branch() -def depth_first(forest): - if type(forest) == list: - for tree in forest: - for b in depth_first(tree): - yield b - else: - for _, (branch, children) in forest.items(): - yield branch - for b in depth_first(children): - yield b +def forest_depth_first( + forest: BranchesTreeForest, +) -> Generator[StackBranch, None, None]: + for tree in forest: + for b in depth_first(tree): + yield b -def menu_choose_branch(forest): +def depth_first(tree: BranchesTree) -> Generator[StackBranch, None, None]: + # This is for the regular forest + for _, (branch, children) in tree.items(): + yield branch + for b in depth_first(children): + yield b + + +def menu_choose_branch(forest: BranchesTreeForest): if not IS_TERMINAL: die("May only choose from menu when using a terminal") @@ -544,17 +622,17 @@ def menu_choose_branch(forest): if idx is None: die("Aborted") - branches = list(depth_first(forest)) + branches = list(forest_depth_first(forest)) branches.reverse() return branches[idx] -def load_pr_info_for_forest(forest): - for b in depth_first(forest): +def load_pr_info_for_forest(forest: BranchesTreeForest): + for b in forest_depth_first(forest): b.load_pr_info() -def cmd_info(stack, args): +def cmd_info(stack: StackBranchSet, args): forest = get_all_stacks_as_forest(stack) if args.pr: load_pr_info_for_forest(forest) @@ -566,7 +644,7 @@ def checkout(branch): run(["git", "checkout", branch], out=True) -def cmd_branch_up(stack, args): +def cmd_branch_up(stack: StackBranchSet, args): b = stack.stack[CURRENT_BRANCH] if not b.children: info("Branch {} is already at the top of the stack", CURRENT_BRANCH) @@ -584,14 +662,14 @@ def cmd_branch_up(stack, args): len(b.children), fg="green", ) - forest = [{c.name: (c, {})} for c in b.children] + forest = BranchesTreeForest([BranchesTree({BranchName(c.name): (c, BranchesTree({}))}) for c in b.children]) child = menu_choose_branch(forest).name else: child = next(iter(b.children)).name checkout(child) -def cmd_branch_down(stack, args): +def cmd_branch_down(stack: StackBranchSet, args): b = stack.stack[CURRENT_BRANCH] if not b.parent: info("Branch {} is already at the bottom of the stack", CURRENT_BRANCH) @@ -603,15 +681,15 @@ def create_branch(branch): run(["git", "checkout", "-b", branch, "--track"], out=True) -def cmd_branch_new(stack, args): +def cmd_branch_new(stack: StackBranchSet, args): b = stack.stack[CURRENT_BRANCH] assert b.commit name = args.name create_branch(name) - run(["git", "update-ref", "refs/stack-parent/{}".format(name), b.commit, ""]) + run(CmdArgs(["git", "update-ref", "refs/stack-parent/{}".format(name), b.commit, ""])) -def cmd_branch_checkout(stack, args): +def cmd_branch_checkout(stack: StackBranchSet, args): branch_name = args.name if branch_name is None: forest = get_all_stacks_as_forest(stack) @@ -619,14 +697,14 @@ def cmd_branch_checkout(stack, args): checkout(branch_name) -def cmd_stack_info(stack, args): +def cmd_stack_info(stack: StackBranchSet, args): forest = get_current_stack_as_forest(stack) if args.pr: load_pr_info_for_forest(forest) print_forest(forest) -def cmd_stack_checkout(stack, args): +def cmd_stack_checkout(stack: StackBranchSet, args): forest = get_current_stack_as_forest(stack) branch_name = menu_choose_branch(forest).name checkout(branch_name) @@ -647,7 +725,7 @@ def prompt(message: str, default_value: Optional[str]) -> str: return default_value -def confirm(msg="Proceed?"): +def confirm(msg: str = "Proceed?"): if CONFIG.skip_confirm: return if not os.isatty(0): @@ -664,26 +742,29 @@ def confirm(msg="Proceed?"): cout("Please answer yes or no\n", fg="red") -def find_reviewers(b) -> Optional[List[str]]: +def find_reviewers(b: StackBranch) -> Optional[List[str]]: out = run_multiline( - [ - "git", - "log", - "--pretty=format:%b", - "-1", - f"{b.name}", - ], + CmdArgs( + [ + "git", + "log", + "--pretty=format:%b", + "-1", + f"{b.name}", + ] + ), ) + assert out is not None for l in out.split("\n"): reviewer_match = re.match(r"^reviewers?\s*:\s*(.*)", l, re.I) if reviewer_match: reviewers = reviewer_match.group(1).split(",") logging.debug(f"Found the following reviewers: {', '.join(reviewers)}") return reviewers - return + return None -def create_gh_pr(b, prefix): +def create_gh_pr(b: StackBranch, prefix: str): cout("Creating PR for {}\n", b.name, fg="green") parent_prefix = "" if b.parent.name not in STACK_BOTTOMS: @@ -701,34 +782,34 @@ def create_gh_pr(b, prefix): reviewers = find_reviewers(b) if match: out = run_multiline( - ["git", "log", "--pretty=oneline", f"{b.parent.name}..{b.name}"], + CmdArgs(["git", "log", "--pretty=oneline", f"{b.parent.name}..{b.name}"]), ) title = f"[{match.group(1)}] " # Just one line (hence 2 elements with the last one being an empty string when we # split on "\"n ? # Then use the title of the commit as the title of the PR - - if len(out.split("\n")) == 2: + if out is not None and len(out.split("\n")) == 2: out = run( - [ - "git", - "log", - "--pretty=format:%s", - "-1", - f"{b.name}", - ], + CmdArgs( + [ + "git", + "log", + "--pretty=format:%s", + "-1", + f"{b.name}", + ] + ), out=False, ) + if out is None: + out = "" if b.name not in out: title += out else: title = out title = prompt( - ( - fmt("? ", color=COLOR_STDOUT, fg="green") - + fmt("Title ", color=COLOR_STDOUT, style="bold", fg="white") - ), + (fmt("? ", color=COLOR_STDOUT, fg="green") + fmt("Title ", color=COLOR_STDOUT, style="bold", fg="white")), title, ) cmd.extend(["--title", title.strip()]) @@ -741,17 +822,23 @@ def create_gh_pr(b, prefix): cmd.extend(["--reviewer", r]) run( - cmd, + CmdArgs(cmd), out=True, ) -def do_push(forest, *, force=False, pr=False, remote_name="origin"): +def do_push( + forest: BranchesTreeForest, + *, + force: bool = False, + pr: bool = False, + remote_name: str = "origin", +): start_muxed_ssh(remote_name) if pr: load_pr_info_for_forest(forest) print_forest(forest) - for b in depth_first(forest): + for b in forest_depth_first(forest): if not b.is_synced_with_parent(): die( "Branch {} is not synced with parent {}, sync first", @@ -764,7 +851,7 @@ def do_push(forest, *, force=False, pr=False, remote_name="origin"): PR_FIX_BASE = 1 PR_CREATE = 2 actions = [] - for b in depth_first(forest): + for b in forest_depth_first(forest): if not b.parent: cout("✓ Not pushing base branch {}\n", b.name, fg="green") continue @@ -817,12 +904,12 @@ def do_push(forest, *, force=False, pr=False, remote_name="origin"): # Figure out if we need to add a prefix to the branch # ie. user:foo # We should call gh repo set-default before doing that - val = run(["git", "config", f"remote.{remote_name}.gh-resolved"], check=False) + val = run(CmdArgs(["git", "config", f"remote.{remote_name}.gh-resolved"]), check=False) if val is not None and "/" in val: # If there is a "/" in the gh-resolved it means that the repo where # the should be created is not the same as the one where the push will # be made, we need to add a prefix to the branch in the gh pr command - val = run(["git", "config", f"remote.{remote_name}.url"]) + val = run_always_return(CmdArgs(["git", "config", f"remote.{remote_name}.url"])) prefix = f'{val.split(":")[1].split("/")[0]}:' else: prefix = "" @@ -830,26 +917,31 @@ def do_push(forest, *, force=False, pr=False, remote_name="origin"): if push: cout("Pushing {}\n", b.name, fg="green") run( - [ - "git", - "push", - "-f", - b.remote, - "{}:{}".format(b.name, b.remote_branch), - ], + CmdArgs( + [ + "git", + "push", + "-f", + b.remote, + "{}:{}".format(b.name, b.remote_branch), + ] + ), out=True, ) if pr_action == PR_FIX_BASE: cout("Fixing PR base for {}\n", b.name, fg="green") + assert b.open_pr_info is not None run( - [ - "gh", - "pr", - "edit", - str(b.open_pr_info["number"]), - "--base", - b.parent.name, - ], + CmdArgs( + [ + "gh", + "pr", + "edit", + str(b.open_pr_info["number"]), + "--base", + b.parent.name, + ] + ), out=True, ) elif pr_action == PR_CREATE: @@ -858,7 +950,7 @@ def do_push(forest, *, force=False, pr=False, remote_name="origin"): stop_muxed_ssh(remote_name) -def cmd_stack_push(stack, args): +def cmd_stack_push(stack: StackBranchSet, args): do_push( get_current_stack_as_forest(stack), force=args.force, @@ -867,13 +959,13 @@ def cmd_stack_push(stack, args): ) -def do_sync(forest): +def do_sync(forest: BranchesTreeForest): print_forest(forest) - syncs = [] - sync_names = [] - syncs_set = set() - for b in depth_first(forest): + syncs: List[StackBranch] = [] + sync_names: List[BranchName] = [] + syncs_set: set[StackBranch] = set() + for b in forest_depth_first(forest): if not b.parent: cout("✓ Not syncing base branch {}\n", b.name, fg="green") continue @@ -895,10 +987,11 @@ def do_sync(forest): syncs.reverse() sync_names.reverse() + # TODO: use list(syncs_set).reverse() ? inner_do_sync(syncs, sync_names) -def set_parent_commit(branch, new_commit, prev_commit=None): +def set_parent_commit(branch: BranchName, new_commit: Commit, prev_commit: Optional[str] = None): cmd = [ "git", "update-ref", @@ -907,15 +1000,16 @@ def set_parent_commit(branch, new_commit, prev_commit=None): ] if prev_commit is not None: cmd.append(prev_commit) - run(cmd) + run(CmdArgs(cmd)) -def get_commits_between(a, b): - lines = run_multiline(["git", "rev-list", "{}..{}".format(a, b)]) +def get_commits_between(a: Commit, b: Commit): + lines = run_multiline(CmdArgs(["git", "rev-list", "{}..{}".format(a, b)])) + assert lines is not None return [x.strip() for x in lines.split("\n")] -def inner_do_sync(syncs, sync_names): +def inner_do_sync(syncs: List[StackBranch], sync_names: List[BranchName]): print() while syncs: with open(TMP_STATE_FILE, "w") as f: @@ -937,7 +1031,7 @@ def inner_do_sync(syncs, sync_names): else: cout("Rebasing {} on top of {}\n", b.name, b.parent.name, fg="green") r = run( - ["git", "rebase", "--onto", b.parent.name, b.parent_commit, b.name], + CmdArgs(["git", "rebase", "--onto", b.parent.name, b.parent_commit, b.name]), out=True, check=False, ) @@ -949,14 +1043,14 @@ def inner_do_sync(syncs, sync_names): b.commit = get_commit(b.name) set_parent_commit(b.name, b.parent.commit, b.parent_commit) b.parent_commit = b.parent.commit - run(["git", "checkout", CURRENT_BRANCH]) + run(CmdArgs(["git", "checkout", str(CURRENT_BRANCH)])) -def cmd_stack_sync(stack, args): +def cmd_stack_sync(stack: StackBranchSet, args): do_sync(get_current_stack_as_forest(stack)) -def do_commit(stack, *, message=None, amend=False, allow_empty=False, edit=True): +def do_commit(stack: StackBranchSet, *, message=None, amend=False, allow_empty=False, edit=True): b = stack.stack[CURRENT_BRANCH] if not b.parent: die("Do not commit directly on {}", b.name) @@ -980,14 +1074,14 @@ def do_commit(stack, *, message=None, amend=False, allow_empty=False, edit=True) die("--no-edit is only supported with --amend") if message: cmd += ["-m", message] - run(cmd, out=True) + run(CmdArgs(cmd), out=True) # Sync everything upstack b.commit = get_commit(b.name) do_sync(get_current_upstack_as_forest(stack)) -def cmd_commit(stack, args): +def cmd_commit(stack: StackBranchSet, args): do_commit( stack, message=args.message, @@ -997,18 +1091,18 @@ def cmd_commit(stack, args): ) -def cmd_amend(stack, args): +def cmd_amend(stack: StackBranchSet, args): do_commit(stack, amend=True, edit=False) -def cmd_upstack_info(stack, args): +def cmd_upstack_info(stack: StackBranchSet, args): forest = get_current_upstack_as_forest(stack) if args.pr: load_pr_info_for_forest(forest) print_forest(forest) -def cmd_upstack_push(stack, args): +def cmd_upstack_push(stack: StackBranchSet, args): do_push( get_current_upstack_as_forest(stack), force=args.force, @@ -1017,31 +1111,33 @@ def cmd_upstack_push(stack, args): ) -def cmd_upstack_sync(stack, args): +def cmd_upstack_sync(stack: StackBranchSet, args): do_sync(get_current_upstack_as_forest(stack)) -def set_parent(branch, target, *, set_origin=False): +def set_parent(branch: BranchName, target: BranchName, *, set_origin: bool = False): if set_origin: - run(["git", "config", "branch.{}.remote".format(branch), "."]) + run(CmdArgs(["git", "config", "branch.{}.remote".format(branch), "."])) run( - [ - "git", - "config", - "branch.{}.merge".format(branch), - "refs/heads/{}".format(target), - ] + CmdArgs( + [ + "git", + "config", + "branch.{}.merge".format(branch), + "refs/heads/{}".format(target), + ] + ) ) -def cmd_upstack_onto(stack, args): +def cmd_upstack_onto(stack: StackBranchSet, args): b = stack.stack[CURRENT_BRANCH] if not b.parent: die("May not restack {}", b.name) target = stack.stack[args.target] upstack = get_current_upstack_as_forest(stack) - for ub in depth_first(upstack): + for ub in forest_depth_first(upstack): if ub == target: die("Target branch {} is upstack of {}", target.name, b.name) b.parent = target @@ -1057,7 +1153,7 @@ def cmd_downstack_info(stack, args): print_forest(forest) -def cmd_downstack_push(stack, args): +def cmd_downstack_push(stack: StackBranchSet, args): do_push( get_current_downstack_as_forest(stack), force=args.force, @@ -1066,27 +1162,36 @@ def cmd_downstack_push(stack, args): ) -def cmd_downstack_sync(stack, args): +def cmd_downstack_sync(stack: StackBranchSet, args): do_sync(get_current_downstack_as_forest(stack)) -def get_bottom_level_branches_as_forest(stack): - return [ - {bottom.name: (bottom, {b.name: (b, {}) for b in bottom.children})} - for bottom in stack.bottoms - ] +def get_bottom_level_branches_as_forest(stack: StackBranchSet) -> BranchesTreeForest: + return BranchesTreeForest( + [ + BranchesTree( + { + bottom.name: ( + bottom, + BranchesTree({b.name: (b, BranchesTree({})) for b in bottom.children}), + ) + } + ) + for bottom in stack.bottoms + ] + ) def get_remote_type(remote: str = "origin") -> Optional[str]: - out = run(["git", "remote", "-v"]) + out = run_always_return(CmdArgs(["git", "remote", "-v"])) for l in out.split("\n"): - match = re.match( - r"^{}\s+(?:ssh://)?([^/]*):(?!//).*\s+\(push\)$".format(remote), l - ) + match = re.match(r"^{}\s+(?:ssh://)?([^/]*):(?!//).*\s+\(push\)$".format(remote), l) if match: sshish_host = match.group(1) return sshish_host + return None + def gen_ssh_mux_cmd() -> List[str]: args = [ @@ -1121,15 +1226,16 @@ def start_muxed_ssh(remote: str = "origin"): while p.poll() is None: time.sleep(1) if p.returncode != 0: - error = p.stderr.read() - die( - f"Failed to start ssh muxed connection, error was: {error.decode('utf-8').strip()}" - ) + if p.stderr is not None: + error = p.stderr.read() + else: + error = b"unknown" + die(f"Failed to start ssh muxed connection, error was: {error.decode('utf-8').strip()}") -def get_branches_to_delete(forest): +def get_branches_to_delete(forest: BranchesTreeForest) -> List[StackBranch]: deletes = [] - for b in depth_first(forest): + for b in forest_depth_first(forest): if not b.parent or b.open_pr_info: continue for pr_info in b.pr_info.values(): @@ -1152,7 +1258,7 @@ def get_branches_to_delete(forest): return deletes -def delete_branches(stack, deletes): +def delete_branches(stack: StackBranchSet, deletes: List[StackBranch]): global CURRENT_BRANCH # Make sure we're not trying to delete the current branch for b in deletes: @@ -1164,31 +1270,33 @@ def delete_branches(stack, deletes): if b.name == CURRENT_BRANCH: new_branch = next(iter(stack.bottoms)) info("About to delete current branch, switching to {}", new_branch.name) - run(["git", "checkout", new_branch.name]) - CURRENT_BRANCH = new_branch - run(["git", "branch", "-D", b.name]) + run(CmdArgs(["git", "checkout", new_branch.name])) + CURRENT_BRANCH = new_branch.name + run(CmdArgs(["git", "branch", "-D", b.name])) -def cmd_update(stack, args): +def cmd_update(stack: StackBranchSet, args): remote = "origin" start_muxed_ssh(remote) info("Fetching from {}", remote) - run(["git", "fetch", remote]) + run(CmdArgs(["git", "fetch", remote])) # TODO(tudor): We should rebase instead of silently dropping # everything you have on local master. Oh well. global CURRENT_BRANCH for b in stack.bottoms: run( - [ - "git", - "update-ref", - "refs/heads/{}".format(b.name), - "refs/remotes/{}/{}".format(remote, b.remote_branch), - ] + CmdArgs( + [ + "git", + "update-ref", + "refs/heads/{}".format(b.name), + "refs/remotes/{}/{}".format(remote, b.remote_branch), + ] + ) ) if b.name == CURRENT_BRANCH: - run(["git", "reset", "--hard", "HEAD"]) + run(CmdArgs(["git", "reset", "--hard", "HEAD"])) # We treat origin as the source of truth for bottom branches (master), and # the local repo as the source of truth for everything else. So we can only @@ -1206,17 +1314,20 @@ def cmd_update(stack, args): stop_muxed_ssh(remote) -def cmd_import(stack, args): +def cmd_import(stack: StackBranchSet, args): # Importing has to happen based on PR info, rather than local branch # relationships, as that's the only place Graphite populates. branch = args.name branches = [] bottoms = set(b.name for b in stack.bottoms) while branch not in bottoms: - _, open_pr = get_pr_info(branch, full=True) + pr_info = get_pr_info(branch, full=True) + open_pr = pr_info.open info("Getting PR information for {}", branch) - if not open_pr: + if open_pr is None: die("Branch {} has no open PR", branch) + # Never reached because the die but makes mypy happy + assert open_pr is not None if open_pr["headRefName"] != branch: die( "Branch {} is misconfigured: PR #{} head is {}", @@ -1227,7 +1338,7 @@ def cmd_import(stack, args): if not open_pr["commits"]: die("PR #{} has no commits", open_pr["number"]) first_commit = open_pr["commits"][0]["oid"] - parent_commit = run(["git", "rev-parse", "{}^".format(first_commit)]) + parent_commit = Commit(run_always_return(CmdArgs(["git", "rev-parse", "{}^".format(first_commit)]))) next_branch = open_pr["baseRefName"] info( "Branch {}: PR #{}, parent is {} at commit {}", @@ -1264,11 +1375,11 @@ def cmd_import(stack, args): branch = b -def get_merge_base(b1, b2): - return run(["git", "merge-base", b1, b2]) +def get_merge_base(b1: BranchName, b2: BranchName): + return run(CmdArgs(["git", "merge-base", str(b1), str(b2)])) -def cmd_adopt(stack, args): +def cmd_adopt(stack: StackBranch, args): """ Adopt a branch that is based on the current branch (which must be a valid stack bottom or the stack bottom (master or main) will be used @@ -1277,10 +1388,11 @@ def cmd_adopt(stack, args): branch = args.name global CURRENT_BRANCH if CURRENT_BRANCH not in STACK_BOTTOMS: + # TODO remove that, the initialisation code is already dealing with that in fact main_branch = get_real_stack_bottom() if CONFIG.change_to_main and main_branch is not None: - run(["git", "checkout", main_branch]) + run(CmdArgs(["git", "checkout", main_branch])) CURRENT_BRANCH = main_branch else: die( @@ -1292,10 +1404,10 @@ def cmd_adopt(stack, args): set_parent(branch, CURRENT_BRANCH, set_origin=True) set_parent_commit(branch, parent_commit) if CONFIG.change_to_adopted: - run(["git", "checkout", branch]) + run(CmdArgs(["git", "checkout", branch])) -def cmd_land(stack, args): +def cmd_land(stack: StackBranchSet, args): forest = get_current_downstack_as_forest(stack) assert len(forest) == 1 branches = [] @@ -1326,6 +1438,7 @@ def cmd_land(stack, args): pr = b.open_pr_info if not pr: die("Branch {} does not have an open PR", b.name) + assert pr is not None if pr["mergeable"] != "MERGEABLE": die( @@ -1353,8 +1466,10 @@ def cmd_land(stack, args): if not args.force: confirm() - head_commit = run(["git", "rev-parse", b.name]) - cmd = ["gh", "pr", "merge", b.name, "--squash", "--match-head-commit", head_commit] + v = run(CmdArgs(["git", "rev-parse", b.name])) + assert v is not None + head_commit = Commit(v) + cmd = CmdArgs(["gh", "pr", "merge", b.name, "--squash", "--match-head-commit", head_commit]) if args.auto: cmd.append("--auto") run(cmd, out=True) @@ -1387,20 +1502,14 @@ def main(): subparsers = parser.add_subparsers(required=True, dest="command") # continue - continue_parser = subparsers.add_parser( - "continue", help="Continue previously interrupted command" - ) + continue_parser = subparsers.add_parser("continue", help="Continue previously interrupted command") continue_parser.set_defaults(func=None) # down - down_parser = subparsers.add_parser( - "down", help="Go down in the current stack (towards master/main)" - ) + down_parser = subparsers.add_parser("down", help="Go down in the current stack (towards master/main)") down_parser.set_defaults(func=cmd_branch_down) # up - up_parser = subparsers.add_parser( - "up", help="Go up in the current stack (away master/main)" - ) + up_parser = subparsers.add_parser("up", help="Go up in the current stack (away master/main)") up_parser.set_defaults(func=cmd_branch_up) # info info_parser = subparsers.add_parser("info", help="Stack info") @@ -1410,73 +1519,43 @@ def main(): # commit commit_parser = subparsers.add_parser("commit", help="Commit") commit_parser.add_argument("-m", help="Commit message", dest="message") - commit_parser.add_argument( - "--amend", action="store_true", help="Amend last commit" - ) - commit_parser.add_argument( - "--allow-empty", action="store_true", help="Allow empty commit" - ) + commit_parser.add_argument("--amend", action="store_true", help="Amend last commit") + commit_parser.add_argument("--allow-empty", action="store_true", help="Allow empty commit") commit_parser.add_argument("--no-edit", action="store_true", help="Skip editor") commit_parser.set_defaults(func=cmd_commit) # amend - amend_parser = subparsers.add_parser( - "amend", help="Shortcut for amending last commit" - ) + amend_parser = subparsers.add_parser("amend", help="Shortcut for amending last commit") amend_parser.set_defaults(func=cmd_amend) # branch - branch_parser = subparsers.add_parser( - "branch", aliases=["b"], help="Operations on branches" - ) - branch_subparsers = branch_parser.add_subparsers( - required=True, dest="branch_command" - ) - branch_up_parser = branch_subparsers.add_parser( - "up", aliases=["u"], help="Move upstack" - ) + branch_parser = subparsers.add_parser("branch", aliases=["b"], help="Operations on branches") + branch_subparsers = branch_parser.add_subparsers(required=True, dest="branch_command") + branch_up_parser = branch_subparsers.add_parser("up", aliases=["u"], help="Move upstack") branch_up_parser.set_defaults(func=cmd_branch_up) - branch_down_parser = branch_subparsers.add_parser( - "down", aliases=["d"], help="Move downstack" - ) + branch_down_parser = branch_subparsers.add_parser("down", aliases=["d"], help="Move downstack") branch_down_parser.set_defaults(func=cmd_branch_down) - branch_new_parser = branch_subparsers.add_parser( - "new", aliases=["create"], help="Create a new branch" - ) + branch_new_parser = branch_subparsers.add_parser("new", aliases=["create"], help="Create a new branch") branch_new_parser.add_argument("name", help="Branch name") branch_new_parser.set_defaults(func=cmd_branch_new) - branch_checkout_parser = branch_subparsers.add_parser( - "checkout", aliases=["co"], help="Checkout a branch" - ) + branch_checkout_parser = branch_subparsers.add_parser("checkout", aliases=["co"], help="Checkout a branch") branch_checkout_parser.add_argument("name", help="Branch name", nargs="?") branch_checkout_parser.set_defaults(func=cmd_branch_checkout) # stack - stack_parser = subparsers.add_parser( - "stack", aliases=["s"], help="Operations on the full current stack" - ) - stack_subparsers = stack_parser.add_subparsers( - required=True, dest="stack_command" - ) + stack_parser = subparsers.add_parser("stack", aliases=["s"], help="Operations on the full current stack") + stack_subparsers = stack_parser.add_subparsers(required=True, dest="stack_command") - stack_info_parser = stack_subparsers.add_parser( - "info", aliases=["i"], help="Info for current stack" - ) - stack_info_parser.add_argument( - "--pr", action="store_true", help="Get PR info (slow)" - ) + stack_info_parser = stack_subparsers.add_parser("info", aliases=["i"], help="Info for current stack") + stack_info_parser.add_argument("--pr", action="store_true", help="Get PR info (slow)") stack_info_parser.set_defaults(func=cmd_stack_info) stack_push_parser = stack_subparsers.add_parser("push", help="Push") - stack_push_parser.add_argument( - "--force", "-f", action="store_true", help="Bypass confirmation" - ) - stack_push_parser.add_argument( - "--no-pr", dest="pr", action="store_false", help="Skip Create PRs" - ) + stack_push_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") + stack_push_parser.add_argument("--no-pr", dest="pr", action="store_false", help="Skip Create PRs") stack_push_parser.set_defaults(func=cmd_stack_push) stack_sync_parser = stack_subparsers.add_parser("sync", help="Sync") @@ -1488,36 +1567,22 @@ def main(): stack_checkout_parser.set_defaults(func=cmd_stack_checkout) # upstack - upstack_parser = subparsers.add_parser( - "upstack", aliases=["us"], help="Operations on the current upstack" - ) - upstack_subparsers = upstack_parser.add_subparsers( - required=True, dest="upstack_command" - ) + upstack_parser = subparsers.add_parser("upstack", aliases=["us"], help="Operations on the current upstack") + upstack_subparsers = upstack_parser.add_subparsers(required=True, dest="upstack_command") - upstack_info_parser = upstack_subparsers.add_parser( - "info", aliases=["i"], help="Info for current upstack" - ) - upstack_info_parser.add_argument( - "--pr", action="store_true", help="Get PR info (slow)" - ) + upstack_info_parser = upstack_subparsers.add_parser("info", aliases=["i"], help="Info for current upstack") + upstack_info_parser.add_argument("--pr", action="store_true", help="Get PR info (slow)") upstack_info_parser.set_defaults(func=cmd_upstack_info) upstack_push_parser = upstack_subparsers.add_parser("push", help="Push") - upstack_push_parser.add_argument( - "--force", "-f", action="store_true", help="Bypass confirmation" - ) - upstack_push_parser.add_argument( - "--no-pr", dest="pr", action="store_false", help="Skip Create PRs" - ) + upstack_push_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") + upstack_push_parser.add_argument("--no-pr", dest="pr", action="store_false", help="Skip Create PRs") upstack_push_parser.set_defaults(func=cmd_upstack_push) upstack_sync_parser = upstack_subparsers.add_parser("sync", help="Sync") upstack_sync_parser.set_defaults(func=cmd_upstack_sync) - upstack_onto_parser = upstack_subparsers.add_parser( - "onto", aliases=["restack"], help="Restack" - ) + upstack_onto_parser = upstack_subparsers.add_parser("onto", aliases=["restack"], help="Restack") upstack_onto_parser.add_argument("target", help="New parent") upstack_onto_parser.set_defaults(func=cmd_upstack_onto) @@ -1525,25 +1590,17 @@ def main(): downstack_parser = subparsers.add_parser( "downstack", aliases=["ds"], help="Operations on the current downstack" ) - downstack_subparsers = downstack_parser.add_subparsers( - required=True, dest="downstack_command" - ) + downstack_subparsers = downstack_parser.add_subparsers(required=True, dest="downstack_command") downstack_info_parser = downstack_subparsers.add_parser( "info", aliases=["i"], help="Info for current downstack" ) - downstack_info_parser.add_argument( - "--pr", action="store_true", help="Get PR info (slow)" - ) + downstack_info_parser.add_argument("--pr", action="store_true", help="Get PR info (slow)") downstack_info_parser.set_defaults(func=cmd_downstack_info) downstack_push_parser = downstack_subparsers.add_parser("push", help="Push") - downstack_push_parser.add_argument( - "--force", "-f", action="store_true", help="Bypass confirmation" - ) - downstack_push_parser.add_argument( - "--no-pr", dest="pr", action="store_false", help="Skip Create PRs" - ) + downstack_push_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") + downstack_push_parser.add_argument("--no-pr", dest="pr", action="store_false", help="Skip Create PRs") downstack_push_parser.set_defaults(func=cmd_downstack_push) downstack_sync_parser = downstack_subparsers.add_parser("sync", help="Sync") @@ -1551,16 +1608,12 @@ def main(): # update update_parser = subparsers.add_parser("update", help="Update repo") - update_parser.add_argument( - "--force", "-f", action="store_true", help="Bypass confirmation" - ) + update_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") update_parser.set_defaults(func=cmd_update) # import import_parser = subparsers.add_parser("import", help="Import Graphite stack") - import_parser.add_argument( - "--force", "-f", action="store_true", help="Bypass confirmation" - ) + import_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") import_parser.add_argument("name", help="Foreign stack top") import_parser.set_defaults(func=cmd_import) @@ -1570,12 +1623,8 @@ def main(): adopt_parser.set_defaults(func=cmd_adopt) # land - land_parser = subparsers.add_parser( - "land", help="Land bottom-most PR on current stack" - ) - land_parser.add_argument( - "--force", "-f", action="store_true", help="Bypass confirmation" - ) + land_parser = subparsers.add_parser("land", help="Land bottom-most PR on current stack") + land_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") land_parser.add_argument( "--auto", "-a", @@ -1586,35 +1635,25 @@ def main(): # shortcuts push_parser = subparsers.add_parser("push", help="Alias for downstack push") - push_parser.add_argument( - "--force", "-f", action="store_true", help="Bypass confirmation" - ) - push_parser.add_argument( - "--no-pr", dest="pr", action="store_false", help="Skip Create PRs" - ) + push_parser.add_argument("--force", "-f", action="store_true", help="Bypass confirmation") + push_parser.add_argument("--no-pr", dest="pr", action="store_false", help="Skip Create PRs") push_parser.set_defaults(func=cmd_downstack_push) sync_parser = subparsers.add_parser("sync", help="Alias for stack sync") sync_parser.set_defaults(func=cmd_stack_sync) - checkout_parser = subparsers.add_parser( - "checkout", aliases=["co"], help="Checkout a branch" - ) + checkout_parser = subparsers.add_parser("checkout", aliases=["co"], help="Checkout a branch") checkout_parser.add_argument("name", help="Branch name", nargs="?") checkout_parser.set_defaults(func=cmd_branch_checkout) - checkout_parser = subparsers.add_parser( - "sco", help="Checkout a branch in this stack" - ) + checkout_parser = subparsers.add_parser("sco", help="Checkout a branch in this stack") checkout_parser.set_defaults(func=cmd_stack_checkout) global CONFIG CONFIG = read_config() args = parser.parse_args() - logging.basicConfig( - format=_LOGGING_FORMAT, level=LOGLEVELS[args.log_level], force=True - ) + logging.basicConfig(format=_LOGGING_FORMAT, level=LOGLEVELS[args.log_level], force=True) global COLOR_STDERR global COLOR_STDOUT @@ -1648,6 +1687,8 @@ def main(): inner_do_sync(syncs, sync_names) else: + # TODO restore the current branch after changing the branch on some commands for + # instance `info` if CURRENT_BRANCH not in stack.stack: main_branch = get_real_stack_bottom() @@ -1657,6 +1698,7 @@ def main(): else: die("Current branch {} is not in a stack", CURRENT_BRANCH) + get_current_stack_as_forest(stack) args.func(stack, args) # Success, delete the state file diff --git a/src/stacky/stacky_test.py b/src/stacky/stacky_test.py new file mode 100755 index 0000000..0471373 --- /dev/null +++ b/src/stacky/stacky_test.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +import unittest +from unittest import mock +from unittest.mock import MagicMock +from stacky import PRInfos, read_config, get_top_level_dir + + +class TestStringMethods(unittest.TestCase): + def test_upper(self): + self.assertEqual("foo".upper(), "FOO") + + def test_isupper(self): + self.assertTrue("FOO".isupper()) + self.assertFalse("Foo".isupper()) + + def test_split(self): + s = "hello world" + self.assertEqual(s.split(), ["hello", "world"]) + # check that s.split fails when the separator is not a string + with self.assertRaises(TypeError): + s.split(2) + + @mock.patch("get_top_level_dir") + def test_read_config(self, mock_get_tld): + patcher = mock.patch("os.path.exists") + mock_thing = patcher.start() + mock_thing.return_value = False + read_config() + + +if __name__ == "__main__": + unittest.main()