Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions dvc/remote/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import logging
import os
import re
from datetime import datetime
from datetime import timedelta
from datetime import datetime, timedelta
import threading

from funcy import cached_property
from funcy import cached_property, wrap_prop

from dvc.config import Config
from dvc.path_info import CloudURLInfo
Expand Down Expand Up @@ -64,6 +64,7 @@ def __init__(self, repo, config):
else self.path_cls.from_parts(scheme=self.scheme, netloc=bucket)
)

@wrap_prop(threading.Lock())
@cached_property
def blob_service(self):
from azure.storage.blob import BlockBlobService
Expand Down
4 changes: 3 additions & 1 deletion dvc/remote/gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from functools import wraps
import io
import os.path
import threading

from funcy import cached_property
from funcy import cached_property, wrap_prop

from dvc.config import Config
from dvc.exceptions import DvcException
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(self, repo, config):
self.projectname = config.get(Config.SECTION_GCP_PROJECTNAME, None)
self.credentialpath = config.get(Config.SECTION_GCP_CREDENTIALPATH)

@wrap_prop(threading.Lock())
@cached_property
def gs(self):
from google.cloud.storage import Client
Expand Down
4 changes: 3 additions & 1 deletion dvc/remote/http.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import unicode_literals

import logging
import threading

from funcy import cached_property
from funcy import cached_property, wrap_prop

from dvc.config import Config
from dvc.config import ConfigError
Expand Down Expand Up @@ -81,6 +82,7 @@ def get_file_checksum(self, path_info):

return etag

@wrap_prop(threading.Lock())
@cached_property
def _session(self):
import requests
Expand Down
42 changes: 22 additions & 20 deletions dvc/remote/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import logging
import os
import threading

from funcy import cached_property, wrap_prop

from dvc.config import Config
from dvc.path_info import CloudURLInfo
Expand Down Expand Up @@ -61,30 +64,29 @@ def __init__(self, repo, config):
or "defaultSecret"
)

self._bucket = None

@property
@wrap_prop(threading.Lock())
@cached_property
def oss_service(self):
import oss2

if self._bucket is None:
logger.debug("URL {}".format(self.path_info))
logger.debug("key id {}".format(self.key_id))
logger.debug("key secret {}".format(self.key_secret))
auth = oss2.Auth(self.key_id, self.key_secret)
self._bucket = oss2.Bucket(
auth, self.endpoint, self.path_info.bucket
logger.debug("URL {}".format(self.path_info))
logger.debug("key id {}".format(self.key_id))
logger.debug("key secret {}".format(self.key_secret))

auth = oss2.Auth(self.key_id, self.key_secret)
bucket = oss2.Bucket(auth, self.endpoint, self.path_info.bucket)

# Ensure bucket exists
try:
bucket.get_bucket_info()
except oss2.exceptions.NoSuchBucket:
bucket.create_bucket(
oss2.BUCKET_ACL_PUBLIC_READ,
oss2.models.BucketCreateConfig(
oss2.BUCKET_STORAGE_CLASS_STANDARD
),
)
try: # verify that bucket exists
self._bucket.get_bucket_info()
except oss2.exceptions.NoSuchBucket:
self._bucket.create_bucket(
oss2.BUCKET_ACL_PUBLIC_READ,
oss2.models.BucketCreateConfig(
oss2.BUCKET_STORAGE_CLASS_STANDARD
),
)
return self._bucket
return bucket

def remove(self, path_info):
if path_info.scheme != self.scheme:
Expand Down
4 changes: 3 additions & 1 deletion dvc/remote/pool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections import deque
from contextlib import contextmanager
import threading

from funcy import memoize
from funcy import memoize, wrap_with


@contextmanager
Expand All @@ -17,6 +18,7 @@ def get_connection(conn_func, *args, **kwargs):
pool.release(conn)


@wrap_with(threading.Lock())
@memoize
def get_pool(conn_func, *args, **kwargs):
return Pool(conn_func, *args, **kwargs)
Expand Down
4 changes: 3 additions & 1 deletion dvc/remote/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import logging
import os
import threading

from funcy import cached_property
from funcy import cached_property, wrap_prop

from dvc.config import Config
from dvc.exceptions import DvcException
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(self, repo, config):
if shared_creds:
os.environ.setdefault("AWS_SHARED_CREDENTIALS_FILE", shared_creds)

@wrap_prop(threading.Lock())
@cached_property
def s3(self):
import boto3
Expand Down
32 changes: 15 additions & 17 deletions dvc/remote/ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import os
import threading
from concurrent.futures import ThreadPoolExecutor
from contextlib import closing
from contextlib import contextmanager
from contextlib import closing, contextmanager

from funcy import memoize, wrap_with

import dvc.prompt as prompt
from dvc.config import Config
Expand All @@ -24,8 +25,15 @@
logger = logging.getLogger(__name__)


saved_passwords = {}
saved_passwords_lock = threading.Lock()
@wrap_with(threading.Lock())
@memoize
def ask_password(host, user, port):
return prompt.password(
"Enter a private key passphrase or a password for "
"host '{host}' port '{port}' user '{user}'".format(
host=host, port=port, user=user
)
)


class RemoteSSH(RemoteBASE):
Expand Down Expand Up @@ -120,21 +128,11 @@ def _try_get_ssh_config_keyfile(user_ssh_config):
def ensure_credentials(self, path_info=None):
if path_info is None:
path_info = self.path_info
host, user, port = path_info.host, path_info.user, path_info.port

# NOTE: we use the same password regardless of the server :(
if self.ask_password and self.password is None:
with saved_passwords_lock:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

server_key = (host, user, port)
password = saved_passwords.get(server_key)

if password is None:
saved_passwords[server_key] = password = prompt.password(
"Enter a private key passphrase or a password for "
"host '{host}' port '{port}' user '{user}'".format(
host=host, port=port, user=user
)
)
self.password = password
host, user, port = path_info.host, path_info.user, path_info.port
self.password = ask_password(host, user, port)

def ssh(self, path_info):
self.ensure_credentials(path_info)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run(self):
"humanize>=0.5.1",
"PyYAML>=5.1.2",
"ruamel.yaml>=0.16.1",
"funcy>=1.12",
"funcy>=1.14",
"pathspec>=0.6.0",
"shortuuid>=0.5.0",
"tqdm>=4.38.0,<5",
Expand Down