diff --git a/dvc/path_info.py b/dvc/path_info.py index 6b4c300b5e..613a76c014 100644 --- a/dvc/path_info.py +++ b/dvc/path_info.py @@ -241,3 +241,79 @@ class CloudURLInfo(URLInfo): @property def path(self): return self._spath.lstrip("/") + + +class HTTPURLInfo(URLInfo): + def __init__(self, url): + p = urlparse(url) + stripped = p._replace(params=None, query=None, fragment=None) + super().__init__(stripped.geturl()) + self.params = p.params + self.query = p.query + self.fragment = p.fragment + + @classmethod + def from_parts( + cls, + scheme=None, + host=None, + user=None, + port=None, + path="", + netloc=None, + params=None, + query=None, + fragment=None, + ): + assert bool(host) ^ bool(netloc) + + if netloc is not None: + return cls( + "{}://{}{}{}{}{}".format( + scheme, + netloc, + path, + (";" + params) if params else "", + ("?" + query) if query else "", + ("#" + fragment) if fragment else "", + ) + ) + + obj = cls.__new__(cls) + obj.fill_parts(scheme, host, user, port, path) + obj.params = params + obj.query = query + obj.fragment = fragment + return obj + + @property + def _extra_parts(self): + return (self.params, self.query, self.fragment) + + @property + def parts(self): + return self._base_parts + self._path.parts + self._extra_parts + + @cached_property + def url(self): + return "{}://{}{}{}{}{}".format( + self.scheme, + self.netloc, + self._spath, + (";" + self.params) if self.params else "", + ("?" + self.query) if self.query else "", + ("#" + self.fragment) if self.fragment else "", + ) + + def __eq__(self, other): + if isinstance(other, (str, bytes)): + other = self.__class__(other) + return ( + self.__class__ == other.__class__ + and self._base_parts == other._base_parts + and self._path == other._path + and self._extra_parts == other._extra_parts + ) + + def __hash__(self): + return hash(self.parts) diff --git a/dvc/remote/http.py b/dvc/remote/http.py index d0f35029bb..fa4966685e 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -4,6 +4,7 @@ from funcy import cached_property, memoize, wrap_prop, wrap_with +from dvc.path_info import HTTPURLInfo import dvc.prompt as prompt from dvc.config import ConfigError from dvc.exceptions import DvcException, HTTPError @@ -25,6 +26,7 @@ def ask_password(host, user): class RemoteHTTP(RemoteBASE): scheme = Schemes.HTTP + path_cls = HTTPURLInfo SESSION_RETRIES = 5 SESSION_BACKOFF_FACTOR = 0.1 REQUEST_TIMEOUT = 10 diff --git a/tests/unit/test_path_info.py b/tests/unit/test_path_info.py index 99476ff951..0b202fa124 100644 --- a/tests/unit/test_path_info.py +++ b/tests/unit/test_path_info.py @@ -4,6 +4,7 @@ import pytest from dvc.path_info import CloudURLInfo +from dvc.path_info import HTTPURLInfo from dvc.path_info import PathInfo from dvc.path_info import URLInfo @@ -44,13 +45,23 @@ def test_url_info_parents(cls): ] -@pytest.mark.parametrize("cls", [URLInfo, CloudURLInfo]) +@pytest.mark.parametrize("cls", [URLInfo, CloudURLInfo, HTTPURLInfo]) def test_url_info_deepcopy(cls): u1 = cls("ssh://user@test.com:/test1/test2/test3") u2 = copy.deepcopy(u1) assert u1 == u2 +def test_https_url_info_str(): + url = "https://user@test.com/test1;p=par?q=quer#frag" + u = HTTPURLInfo(url) + assert u.url == url + assert str(u) == u.url + assert u.params == "p=par" + assert u.query == "q=quer" + assert u.fragment == "frag" + + @pytest.mark.parametrize( "path, as_posix, osname", [