/
base.py
211 lines (174 loc) · 6.66 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
"""Foundational tools to set up a VCS manager in libvcs.sync."""
import logging
import pathlib
from collections.abc import Sequence
from typing import Any, NamedTuple, Optional
from urllib import parse as urlparse
from libvcs._internal.run import _CMD, CmdLoggingAdapter, ProgressCallbackProtocol, run
from libvcs._internal.types import StrPath
logger = logging.getLogger(__name__)
class VCSLocation(NamedTuple):
"""Generic VCS Location (URL and optional revision)."""
url: str
rev: Optional[str]
def convert_pip_url(pip_url: str) -> VCSLocation:
"""Parse pip URL via `libvcs.sync.base.BaseSync.url`."""
error_message = (
"Sorry, '%s' is a malformed VCS url. "
"The format is <vcs>+<protocol>://<url>, "
"e.g. svn+http://myrepo/svn/MyApp#egg=MyApp"
)
assert "+" in pip_url, error_message % pip_url
url = pip_url.split("+", 1)[1]
scheme, netloc, path, query, _frag = urlparse.urlsplit(url)
rev = None
if "@" in path:
path, rev = path.rsplit("@", 1)
url = urlparse.urlunsplit((scheme, netloc, path, query, ""))
return VCSLocation(url=url, rev=rev)
class BaseSync:
"""Base class for repositories."""
log_in_real_time = None
"""Log command output to buffer"""
bin_name: str = ""
"""VCS app name, e.g. 'git'"""
schemes: tuple[str, ...] = ()
"""List of supported schemes to register in ``urlparse.uses_netloc``"""
def __init__(
self,
*,
url: str,
path: StrPath,
progress_callback: Optional[ProgressCallbackProtocol] = None,
**kwargs: Any,
) -> None:
r"""Initialize a tool to manage a local VCS Checkout, Clone, Copy, or Work tree.
Parameters
----------
progress_callback : func
Retrieve live progress from ``sys.stderr`` (useful for certain vcs commands
like ``git pull``. Use ``progress_callback``:
>>> import os
>>> import sys
>>> def progress_cb(output, timestamp):
... sys.stdout.write(output)
... sys.stdout.flush()
>>> class Project(BaseSync):
... bin_name = 'git'
... def obtain(self, *args, **kwargs):
... self.ensure_dir()
... self.run(
... ['clone', '--progress', self.url, self.path],
... log_in_real_time=True
... )
>>> r = Project(
... url=f'file://{create_git_remote_repo()}',
... path=str(tmp_path),
... progress_callback=progress_cb
... )
>>> r.obtain()
Cloning into '...'...
remote: Enumerating objects: ...
remote: Counting objects: ...% (...)...
...
remote: Total ... (delta 0), reused 0 (delta 0), pack-reused 0
...
Receiving objects: ...% (...)...
...
>>> assert r.path.exists()
>>> assert pathlib.Path(r.path / '.git').exists()
"""
self.url = url
#: Callback for run updates
self.progress_callback = progress_callback
#: Directory to check out
self.path: pathlib.Path
if isinstance(path, pathlib.Path):
self.path = path
else:
self.path = pathlib.Path(path)
if "rev" in kwargs:
self.rev = kwargs["rev"]
# Register more schemes with urlparse for various version control
# systems
if hasattr(self, "schemes"):
urlparse.uses_netloc.extend(self.schemes)
# Python >= 2.7.4, 3.3 doesn't have uses_fragment
if getattr(urlparse, "uses_fragment", None):
urlparse.uses_fragment.extend(self.schemes)
#: Logging attribute
self.log: CmdLoggingAdapter = CmdLoggingAdapter(
bin_name=self.bin_name,
keyword=self.repo_name,
logger=logger,
extra={},
)
@property
def repo_name(self) -> str:
"""Return the short name of a repo checkout."""
return self.path.stem
@classmethod
def from_pip_url(cls, pip_url: str, **kwargs: Any) -> "BaseSync":
"""Create synchronization object from pip-style URL."""
url, rev = convert_pip_url(pip_url)
return cls(url=url, rev=rev, **kwargs)
def run(
self,
cmd: _CMD,
cwd: None = None,
check_returncode: bool = True,
log_in_real_time: Optional[bool] = None,
*args: Any,
**kwargs: Any,
) -> str:
"""Return combined stderr/stdout from a command.
This method will also prefix the VCS command bin_name. By default runs
using the cwd `libvcs.sync.base.BaseSync.path` of the repo.
Parameters
----------
cwd : str
dir command is run from, defaults to `libvcs.sync.base.BaseSync.path`.
check_returncode : bool
Indicate whether a :exc:`~exc.CommandError` should be raised if return code
is different from 0.
Returns
-------
str
combined stdout/stderr in a big string, newlines retained
"""
if cwd is None:
cwd = getattr(self, "path", None)
if isinstance(cmd, Sequence):
cmd = [self.bin_name, *cmd]
else:
cmd = [self.bin_name, cmd]
return run(
cmd,
callback=(
self.progress_callback if callable(self.progress_callback) else None
),
check_returncode=check_returncode,
log_in_real_time=log_in_real_time or self.log_in_real_time or False,
cwd=cwd,
)
def ensure_dir(self, *args: Any, **kwargs: Any) -> bool:
"""Assure destination path exists. If not, create directories."""
if self.path.exists():
return True
if not self.path.parent.exists():
self.path.parent.mkdir(parents=True)
if not self.path.exists():
self.log.debug(
f"Project directory for {self.repo_name} does not exist @ {self.path}",
)
self.path.mkdir(parents=True)
return True
def update_repo(self, *args: Any, **kwargs: Any) -> None:
"""Pull latest changes to here from remote repository."""
raise NotImplementedError
def obtain(self, *args: Any, **kwargs: Any) -> None:
"""Checkout initial VCS repository or working copy from remote repository."""
raise NotImplementedError
def __repr__(self) -> str:
"""Representation of a VCS management object."""
return f"<{self.__class__.__name__} {self.repo_name}>"