diff --git a/dvc_ssh/__init__.py b/dvc_ssh/__init__.py index 466534f..5574430 100644 --- a/dvc_ssh/__init__.py +++ b/dvc_ssh/__init__.py @@ -41,9 +41,14 @@ def _strip_protocol(cls, path: str) -> str: def unstrip_protocol(self, path: str) -> str: host = self.fs_args["host"] - port = self.fs_args["port"] + port = self.fs_args.get("port") path = path.lstrip("/") - return f"ssh://{host}:{port}/{path}" + + url = f"ssh://{host}" + if port: + url += f":{port}" + url += f"/{path}" + return url def _prepare_credentials(self, **config): from .client import InteractiveSSHClient diff --git a/dvc_ssh/tests/test_fs.py b/dvc_ssh/tests/test_fs.py index 3cbf02f..19c6bf5 100644 --- a/dvc_ssh/tests/test_fs.py +++ b/dvc_ssh/tests/test_fs.py @@ -103,3 +103,17 @@ def test_ssh_keyfile(config, expected_keyfile): def test_ssh_gss_auth(config, expected_gss_auth): fs = SSHFileSystem(**config) assert fs.fs_args["gss_auth"] == expected_gss_auth + + +@pytest.mark.parametrize( + "config,path,expected_path", + [ + ({"host": "example.com"}, "path", "ssh://example.com/path"), + ({"host": "example.com"}, "/path", "ssh://example.com/path"), + ({"host": "example.com", "port": 1234}, "path", "ssh://example.com:1234/path"), + ({"host": "example.com", "port": 1234}, "/path", "ssh://example.com:1234/path"), + ], +) +def test_unstrip_protocol(mocker, config, path, expected_path): + fs = SSHFileSystem(**config, fs=mocker.MagicMock()) + assert fs.unstrip_protocol(path) == expected_path