diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index 217ce91fc4..b991195dfe 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -4,10 +4,10 @@ import logging import os import re -from datetime import datetime -from datetime import timedelta +from datetime import datetime, timedelta +import threading -from funcy import cached_property +from funcy import cached_property, wrap_prop from dvc.config import Config from dvc.path_info import CloudURLInfo @@ -64,6 +64,7 @@ def __init__(self, repo, config): else self.path_cls.from_parts(scheme=self.scheme, netloc=bucket) ) + @wrap_prop(threading.Lock()) @cached_property def blob_service(self): from azure.storage.blob import BlockBlobService diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index 8ba47c2923..6f389ec4f6 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -5,8 +5,9 @@ from functools import wraps import io import os.path +import threading -from funcy import cached_property +from funcy import cached_property, wrap_prop from dvc.config import Config from dvc.exceptions import DvcException @@ -91,6 +92,7 @@ def __init__(self, repo, config): self.projectname = config.get(Config.SECTION_GCP_PROJECTNAME, None) self.credentialpath = config.get(Config.SECTION_GCP_CREDENTIALPATH) + @wrap_prop(threading.Lock()) @cached_property def gs(self): from google.cloud.storage import Client diff --git a/dvc/remote/http.py b/dvc/remote/http.py index 7f96e0b8ef..c224166254 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -1,8 +1,9 @@ from __future__ import unicode_literals import logging +import threading -from funcy import cached_property +from funcy import cached_property, wrap_prop from dvc.config import Config from dvc.config import ConfigError @@ -81,6 +82,7 @@ def get_file_checksum(self, path_info): return etag + @wrap_prop(threading.Lock()) @cached_property def _session(self): import requests diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index ca41e68db1..9445fd0b2b 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -3,6 +3,9 @@ import logging import os +import threading + +from funcy import cached_property, wrap_prop from dvc.config import Config from dvc.path_info import CloudURLInfo @@ -61,30 +64,29 @@ def __init__(self, repo, config): or "defaultSecret" ) - self._bucket = None - - @property + @wrap_prop(threading.Lock()) + @cached_property def oss_service(self): import oss2 - if self._bucket is None: - logger.debug("URL {}".format(self.path_info)) - logger.debug("key id {}".format(self.key_id)) - logger.debug("key secret {}".format(self.key_secret)) - auth = oss2.Auth(self.key_id, self.key_secret) - self._bucket = oss2.Bucket( - auth, self.endpoint, self.path_info.bucket + logger.debug("URL {}".format(self.path_info)) + logger.debug("key id {}".format(self.key_id)) + logger.debug("key secret {}".format(self.key_secret)) + + auth = oss2.Auth(self.key_id, self.key_secret) + bucket = oss2.Bucket(auth, self.endpoint, self.path_info.bucket) + + # Ensure bucket exists + try: + bucket.get_bucket_info() + except oss2.exceptions.NoSuchBucket: + bucket.create_bucket( + oss2.BUCKET_ACL_PUBLIC_READ, + oss2.models.BucketCreateConfig( + oss2.BUCKET_STORAGE_CLASS_STANDARD + ), ) - try: # verify that bucket exists - self._bucket.get_bucket_info() - except oss2.exceptions.NoSuchBucket: - self._bucket.create_bucket( - oss2.BUCKET_ACL_PUBLIC_READ, - oss2.models.BucketCreateConfig( - oss2.BUCKET_STORAGE_CLASS_STANDARD - ), - ) - return self._bucket + return bucket def remove(self, path_info): if path_info.scheme != self.scheme: diff --git a/dvc/remote/pool.py b/dvc/remote/pool.py index bec579cc47..59f16b01b8 100644 --- a/dvc/remote/pool.py +++ b/dvc/remote/pool.py @@ -1,7 +1,8 @@ from collections import deque from contextlib import contextmanager +import threading -from funcy import memoize +from funcy import memoize, wrap_with @contextmanager @@ -17,6 +18,7 @@ def get_connection(conn_func, *args, **kwargs): pool.release(conn) +@wrap_with(threading.Lock()) @memoize def get_pool(conn_func, *args, **kwargs): return Pool(conn_func, *args, **kwargs) diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 952db8fa4f..5594806f03 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -3,8 +3,9 @@ import logging import os +import threading -from funcy import cached_property +from funcy import cached_property, wrap_prop from dvc.config import Config from dvc.exceptions import DvcException @@ -56,6 +57,7 @@ def __init__(self, repo, config): if shared_creds: os.environ.setdefault("AWS_SHARED_CREDENTIALS_FILE", shared_creds) + @wrap_prop(threading.Lock()) @cached_property def s3(self): import boto3 diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 589bea8d1e..7509d75040 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -8,8 +8,9 @@ import os import threading from concurrent.futures import ThreadPoolExecutor -from contextlib import closing -from contextlib import contextmanager +from contextlib import closing, contextmanager + +from funcy import memoize, wrap_with import dvc.prompt as prompt from dvc.config import Config @@ -24,8 +25,15 @@ logger = logging.getLogger(__name__) -saved_passwords = {} -saved_passwords_lock = threading.Lock() +@wrap_with(threading.Lock()) +@memoize +def ask_password(host, user, port): + return prompt.password( + "Enter a private key passphrase or a password for " + "host '{host}' port '{port}' user '{user}'".format( + host=host, port=port, user=user + ) + ) class RemoteSSH(RemoteBASE): @@ -120,21 +128,11 @@ def _try_get_ssh_config_keyfile(user_ssh_config): def ensure_credentials(self, path_info=None): if path_info is None: path_info = self.path_info - host, user, port = path_info.host, path_info.user, path_info.port + # NOTE: we use the same password regardless of the server :( if self.ask_password and self.password is None: - with saved_passwords_lock: - server_key = (host, user, port) - password = saved_passwords.get(server_key) - - if password is None: - saved_passwords[server_key] = password = prompt.password( - "Enter a private key passphrase or a password for " - "host '{host}' port '{port}' user '{user}'".format( - host=host, port=port, user=user - ) - ) - self.password = password + host, user, port = path_info.host, path_info.user, path_info.port + self.password = ask_password(host, user, port) def ssh(self, path_info): self.ensure_credentials(path_info) diff --git a/setup.py b/setup.py index ae460aa3c8..d029ee4d5d 100644 --- a/setup.py +++ b/setup.py @@ -74,7 +74,7 @@ def run(self): "humanize>=0.5.1", "PyYAML>=5.1.2", "ruamel.yaml>=0.16.1", - "funcy>=1.12", + "funcy>=1.14", "pathspec>=0.6.0", "shortuuid>=0.5.0", "tqdm>=4.38.0,<5",