diff --git a/src/scmrepo/git/backend/base.py b/src/scmrepo/git/backend/base.py index 794e0584..71b06f5d 100644 --- a/src/scmrepo/git/backend/base.py +++ b/src/scmrepo/git/backend/base.py @@ -64,7 +64,12 @@ def dir(self) -> str: pass @abstractmethod - def add(self, paths: Union[str, Iterable[str]], update=False): + def add( + self, + paths: Union[str, Iterable[str]], + update: bool = False, + force: bool = False, + ): pass @abstractmethod diff --git a/src/scmrepo/git/backend/dulwich/__init__.py b/src/scmrepo/git/backend/dulwich/__init__.py index 8fc09f59..99646370 100644 --- a/src/scmrepo/git/backend/dulwich/__init__.py +++ b/src/scmrepo/git/backend/dulwich/__init__.py @@ -12,6 +12,7 @@ Callable, Dict, Iterable, + Iterator, List, Mapping, Optional, @@ -248,17 +249,37 @@ def init(path: str, bare: bool = False) -> None: def dir(self) -> str: return self.repo.commondir() - def add(self, paths: Union[str, Iterable[str]], update=False): + def add( + self, + paths: Union[str, Iterable[str]], + update: bool = False, + force: bool = False, + ): assert paths or update - if isinstance(paths, str): - paths = [paths] + paths = [paths] if isinstance(paths, str) else list(paths) if update and not paths: self.repo.stage(list(self.repo.open_index())) return - files: List[bytes] = [] + files: List[bytes] = [ + os.fsencode(fpath) for fpath in self._expand_paths(paths, force=force) + ] + if update: + index = self.repo.open_index() + if os.name == "nt": + # NOTE: we need git/unix separator to compare against index + # paths but repo.stage() expects to be called with OS paths + self.repo.stage( + [fname for fname in files if fname.replace(b"\\", b"/") in index] + ) + else: + self.repo.stage([fname for fname in files if fname in index]) + else: + self.repo.stage(files) + + def _expand_paths(self, paths: List[str], force: bool = False) -> Iterator[str]: for path in paths: if not os.path.isabs(path) and self._submodules: # NOTE: If path is inside a submodule, Dulwich expects the @@ -275,27 +296,15 @@ def add(self, paths: Union[str, Iterable[str]], update=False): ) break if os.path.isdir(path): - files.extend( - os.fsencode(relpath(os.path.join(root, fpath), self.root_dir)) - for root, _, fs in os.walk(path) - for fpath in fs - ) + for root, _, fs in os.walk(path): + for fpath in fs: + rel = relpath(os.path.join(root, fpath), self.root_dir) + if force or not self.ignore_manager.is_ignored(rel): + yield rel else: - files.append(os.fsencode(relpath(path, self.root_dir))) - - # NOTE: this doesn't check gitignore, same as GitPythonBackend.add - if update: - index = self.repo.open_index() - if os.name == "nt": - # NOTE: we need git/unix separator to compare against index - # paths but repo.stage() expects to be called with OS paths - self.repo.stage( - [fname for fname in files if fname.replace(b"\\", b"/") in index] - ) - else: - self.repo.stage([fname for fname in files if fname in index]) - else: - self.repo.stage(files) + rel = relpath(path, self.root_dir) + if force or not self.ignore_manager.is_ignored(rel): + yield rel def commit(self, msg: str, no_verify: bool = False): from dulwich.errors import CommitError diff --git a/src/scmrepo/git/backend/gitpython.py b/src/scmrepo/git/backend/gitpython.py index c8ccdaec..9b9cc5ed 100644 --- a/src/scmrepo/git/backend/gitpython.py +++ b/src/scmrepo/git/backend/gitpython.py @@ -6,6 +6,7 @@ from functools import partial from typing import ( TYPE_CHECKING, + Any, Callable, Dict, Iterable, @@ -211,17 +212,29 @@ def is_sha(rev): def dir(self) -> str: return self.repo.git_dir - def add(self, paths: Union[str, Iterable[str]], update=False): - # NOTE: GitPython is not currently able to handle index version >= 3. - # See https://github.com/iterative/dvc/issues/610 for more details. + def add( + self, + paths: Union[str, Iterable[str]], + update: bool = False, + force: bool = False, + ): try: - if update: + if update or not force: + # NOTE: git-python index.add() defines force parameter but + # ignores it (index.add() behavior is always force=True) + kwargs: Dict[str, Any] = {} + if update: + kwargs["update"] = True if isinstance(paths, str): paths = [paths] - self.git.add(*paths, update=True) + if not force: + paths = [path for path in paths if not self.is_ignored(path)] + self.git.add(*paths, **kwargs) else: self.repo.index.add(paths) except AssertionError as exc: + # NOTE: GitPython is not currently able to handle index version >= 3. + # See https://github.com/iterative/dvc/issues/610 for more details. raise UnsupportedIndexFormat from exc def commit(self, msg: str, no_verify: bool = False): diff --git a/src/scmrepo/git/backend/pygit2/__init__.py b/src/scmrepo/git/backend/pygit2/__init__.py index bb889560..af76256c 100644 --- a/src/scmrepo/git/backend/pygit2/__init__.py +++ b/src/scmrepo/git/backend/pygit2/__init__.py @@ -183,7 +183,12 @@ def init(path: str, bare: bool = False) -> None: def dir(self) -> str: raise NotImplementedError - def add(self, paths: Union[str, Iterable[str]], update=False): + def add( + self, + paths: Union[str, Iterable[str]], + update: bool = False, + force: bool = False, + ): raise NotImplementedError def commit(self, msg: str, no_verify: bool = False): diff --git a/tests/test_git.py b/tests/test_git.py index 5e6f4631..79389550 100644 --- a/tests/test_git.py +++ b/tests/test_git.py @@ -669,6 +669,22 @@ def test_add(tmp_dir: TmpDir, scm: Git, git: Git): assert len(untracked) == 1 +@pytest.mark.skip_git_backend("pygit2") +def test_add_force(tmp_dir: TmpDir, scm: Git, git: Git): + tmp_dir.gen({"foo": "foo", "bar": "bar", "dir": {"baz": "baz"}}) + tmp_dir.gen({".gitignore": "foo\ndir/"}) + + git.add(["foo", "bar", "dir"]) + staged, _unstaged, untracked = scm.status() + assert set(staged["add"]) == {"bar"} + assert set(untracked) == {".gitignore"} + + git.add(["foo", "bar", "dir"], force=True) + staged, _unstaged, untracked = scm.status() + assert set(staged["add"]) == {"foo", "bar", "dir/baz"} + assert set(untracked) == {".gitignore"} + + @pytest.mark.skip_git_backend("dulwich", "gitpython") def test_checkout_subdir(tmp_dir: TmpDir, scm: Git, git: Git): tmp_dir.gen("foo", "foo")