From 73eac3b6cc0ac8ebd3cca488e1d8d41868591753 Mon Sep 17 00:00:00 2001 From: Alexander Schepanovski Date: Thu, 28 Nov 2019 21:05:20 +0100 Subject: [PATCH 1/2] remote: protect all remote client/session creation code with locks With @cached_property last write wins, so all threads will use the same session. However, several might be created concurrently wasting resources and even potentially asking something from the user. This change guarantees only one client instance will be created for a remote. --- dvc/remote/azure.py | 7 ++++--- dvc/remote/gs.py | 4 +++- dvc/remote/http.py | 4 +++- dvc/remote/oss.py | 42 ++++++++++++++++++++++-------------------- dvc/remote/pool.py | 4 +++- dvc/remote/s3.py | 4 +++- setup.py | 2 +- 7 files changed, 39 insertions(+), 28 deletions(-) diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index 217ce91fc4..b991195dfe 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -4,10 +4,10 @@ import logging import os import re -from datetime import datetime -from datetime import timedelta +from datetime import datetime, timedelta +import threading -from funcy import cached_property +from funcy import cached_property, wrap_prop from dvc.config import Config from dvc.path_info import CloudURLInfo @@ -64,6 +64,7 @@ def __init__(self, repo, config): else self.path_cls.from_parts(scheme=self.scheme, netloc=bucket) ) + @wrap_prop(threading.Lock()) @cached_property def blob_service(self): from azure.storage.blob import BlockBlobService diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index 8ba47c2923..6f389ec4f6 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -5,8 +5,9 @@ from functools import wraps import io import os.path +import threading -from funcy import cached_property +from funcy import cached_property, wrap_prop from dvc.config import Config from dvc.exceptions import DvcException @@ -91,6 +92,7 @@ def __init__(self, repo, config): self.projectname = config.get(Config.SECTION_GCP_PROJECTNAME, None) self.credentialpath = config.get(Config.SECTION_GCP_CREDENTIALPATH) + @wrap_prop(threading.Lock()) @cached_property def gs(self): from google.cloud.storage import Client diff --git a/dvc/remote/http.py b/dvc/remote/http.py index 7f96e0b8ef..c224166254 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -1,8 +1,9 @@ from __future__ import unicode_literals import logging +import threading -from funcy import cached_property +from funcy import cached_property, wrap_prop from dvc.config import Config from dvc.config import ConfigError @@ -81,6 +82,7 @@ def get_file_checksum(self, path_info): return etag + @wrap_prop(threading.Lock()) @cached_property def _session(self): import requests diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index ca41e68db1..9445fd0b2b 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -3,6 +3,9 @@ import logging import os +import threading + +from funcy import cached_property, wrap_prop from dvc.config import Config from dvc.path_info import CloudURLInfo @@ -61,30 +64,29 @@ def __init__(self, repo, config): or "defaultSecret" ) - self._bucket = None - - @property + @wrap_prop(threading.Lock()) + @cached_property def oss_service(self): import oss2 - if self._bucket is None: - logger.debug("URL {}".format(self.path_info)) - logger.debug("key id {}".format(self.key_id)) - logger.debug("key secret {}".format(self.key_secret)) - auth = oss2.Auth(self.key_id, self.key_secret) - self._bucket = oss2.Bucket( - auth, self.endpoint, self.path_info.bucket + logger.debug("URL {}".format(self.path_info)) + logger.debug("key id {}".format(self.key_id)) + logger.debug("key secret {}".format(self.key_secret)) + + auth = oss2.Auth(self.key_id, self.key_secret) + bucket = oss2.Bucket(auth, self.endpoint, self.path_info.bucket) + + # Ensure bucket exists + try: + bucket.get_bucket_info() + except oss2.exceptions.NoSuchBucket: + bucket.create_bucket( + oss2.BUCKET_ACL_PUBLIC_READ, + oss2.models.BucketCreateConfig( + oss2.BUCKET_STORAGE_CLASS_STANDARD + ), ) - try: # verify that bucket exists - self._bucket.get_bucket_info() - except oss2.exceptions.NoSuchBucket: - self._bucket.create_bucket( - oss2.BUCKET_ACL_PUBLIC_READ, - oss2.models.BucketCreateConfig( - oss2.BUCKET_STORAGE_CLASS_STANDARD - ), - ) - return self._bucket + return bucket def remove(self, path_info): if path_info.scheme != self.scheme: diff --git a/dvc/remote/pool.py b/dvc/remote/pool.py index bec579cc47..59f16b01b8 100644 --- a/dvc/remote/pool.py +++ b/dvc/remote/pool.py @@ -1,7 +1,8 @@ from collections import deque from contextlib import contextmanager +import threading -from funcy import memoize +from funcy import memoize, wrap_with @contextmanager @@ -17,6 +18,7 @@ def get_connection(conn_func, *args, **kwargs): pool.release(conn) +@wrap_with(threading.Lock()) @memoize def get_pool(conn_func, *args, **kwargs): return Pool(conn_func, *args, **kwargs) diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 952db8fa4f..5594806f03 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -3,8 +3,9 @@ import logging import os +import threading -from funcy import cached_property +from funcy import cached_property, wrap_prop from dvc.config import Config from dvc.exceptions import DvcException @@ -56,6 +57,7 @@ def __init__(self, repo, config): if shared_creds: os.environ.setdefault("AWS_SHARED_CREDENTIALS_FILE", shared_creds) + @wrap_prop(threading.Lock()) @cached_property def s3(self): import boto3 diff --git a/setup.py b/setup.py index ae460aa3c8..d029ee4d5d 100644 --- a/setup.py +++ b/setup.py @@ -74,7 +74,7 @@ def run(self): "humanize>=0.5.1", "PyYAML>=5.1.2", "ruamel.yaml>=0.16.1", - "funcy>=1.12", + "funcy>=1.14", "pathspec>=0.6.0", "shortuuid>=0.5.0", "tqdm>=4.38.0,<5", From 8af1aeefb4b51e956e32dfabf58fdd6a44f56dd5 Mon Sep 17 00:00:00 2001 From: Alexander Schepanovski Date: Thu, 28 Nov 2019 21:10:35 +0100 Subject: [PATCH 2/2] remote: refactor ssh ask password code --- dvc/remote/ssh/__init__.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 589bea8d1e..7509d75040 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -8,8 +8,9 @@ import os import threading from concurrent.futures import ThreadPoolExecutor -from contextlib import closing -from contextlib import contextmanager +from contextlib import closing, contextmanager + +from funcy import memoize, wrap_with import dvc.prompt as prompt from dvc.config import Config @@ -24,8 +25,15 @@ logger = logging.getLogger(__name__) -saved_passwords = {} -saved_passwords_lock = threading.Lock() +@wrap_with(threading.Lock()) +@memoize +def ask_password(host, user, port): + return prompt.password( + "Enter a private key passphrase or a password for " + "host '{host}' port '{port}' user '{user}'".format( + host=host, port=port, user=user + ) + ) class RemoteSSH(RemoteBASE): @@ -120,21 +128,11 @@ def _try_get_ssh_config_keyfile(user_ssh_config): def ensure_credentials(self, path_info=None): if path_info is None: path_info = self.path_info - host, user, port = path_info.host, path_info.user, path_info.port + # NOTE: we use the same password regardless of the server :( if self.ask_password and self.password is None: - with saved_passwords_lock: - server_key = (host, user, port) - password = saved_passwords.get(server_key) - - if password is None: - saved_passwords[server_key] = password = prompt.password( - "Enter a private key passphrase or a password for " - "host '{host}' port '{port}' user '{user}'".format( - host=host, port=port, user=user - ) - ) - self.password = password + host, user, port = path_info.host, path_info.user, path_info.port + self.password = ask_password(host, user, port) def ssh(self, path_info): self.ensure_credentials(path_info)