From 877afb83590be9a2f1140296ab4f5a3e4c3e28fd Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Mon, 1 Feb 2021 16:59:09 +0300 Subject: [PATCH 1/3] config: Support ~/.aws/config parsing --- dvc/config_schema.py | 1 + dvc/tree/s3.py | 68 ++++++++++++++++++++++++-- dvc/utils/conversions.py | 24 ++++++++++ tests/func/test_s3.py | 72 ++++++++++++++++++++++++++++ tests/unit/utils/test_conversions.py | 30 ++++++++++++ 5 files changed, 190 insertions(+), 5 deletions(-) create mode 100644 dvc/utils/conversions.py create mode 100644 tests/unit/utils/test_conversions.py diff --git a/dvc/config_schema.py b/dvc/config_schema.py index a94b50f921..0537eb8aa6 100644 --- a/dvc/config_schema.py +++ b/dvc/config_schema.py @@ -138,6 +138,7 @@ class RelPath(str): "region": str, "profile": str, "credentialpath": str, + "configpath": str, "endpointurl": str, "access_key_id": str, "secret_access_key": str, diff --git a/dvc/tree/s3.py b/dvc/tree/s3.py index eb62786fd4..cdd22d1734 100644 --- a/dvc/tree/s3.py +++ b/dvc/tree/s3.py @@ -11,12 +11,14 @@ from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm from dvc.scheme import Schemes -from dvc.utils import error_link +from dvc.utils import conversions, error_link from .base import BaseTree logger = logging.getLogger(__name__) +_AWS_CONFIG_PATH = os.path.expanduser("~/.aws/config") + class S3Tree(BaseTree): scheme = Schemes.S3 @@ -60,6 +62,54 @@ def __init__(self, repo, config): if shared_creds: os.environ.setdefault("AWS_SHARED_CREDENTIALS_FILE", shared_creds) + config_path = config.get("configpath") + if config_path: + os.environ.setdefault("AWS_CONFIG_FILE", config_path) + self._transfer_config = None + + # https://github.com/aws/aws-cli/blob/0376c6262d6b15dc36c82e6da6e1aad10249cc8c/awscli/customizations/s3/transferconfig.py#L107-L113 + _TRANSFER_CONFIG_ALIASES = { + "max_queue_size": "max_io_queue", + "max_concurrent_requests": "max_concurrency", + "multipart_threshold": "multipart_threshold", + "multipart_chunksize": "multipart_chunksize", + } + + def _transform_config(self, s3_config): + """Splits the general s3 config into 2 different config + objects, one for transfer.TransferConfig and other is the + general session config""" + + config, transfer_config = {}, {} + for key, value in s3_config.items(): + if key in self._TRANSFER_CONFIG_ALIASES: + if key in {"multipart_chunksize", "multipart_threshold"}: + # cast human readable sizes (like 24MiB) to integers + value = conversions.human_readable_to_bytes(value) + else: + value = int(value) + transfer_config[self._TRANSFER_CONFIG_ALIASES[key]] = value + else: + config[key] = value + + return config, transfer_config + + def _process_config(self): + from boto3.s3.transfer import TransferConfig + from botocore.configloader import load_config + + config = load_config( + os.environ.get("AWS_CONFIG_FILE", _AWS_CONFIG_PATH) + ) + profile = config["profiles"].get(self.profile or "default") + if not profile: + return None + + s3_config = profile.get("s3", {}) + s3_config, transfer_config = self._transform_config(s3_config) + self._transfer_config = TransferConfig(**transfer_config) + return s3_config + @wrap_prop(threading.Lock()) @cached_property def s3(self): @@ -78,12 +128,15 @@ def s3(self): session_opts["aws_session_token"] = self.session_token session = boto3.session.Session(**session_opts) + s3_config = self._process_config() return session.resource( "s3", endpoint_url=self.endpoint_url, use_ssl=self.use_ssl, - config=boto3.session.Config(signature_version="s3v4"), + config=boto3.session.Config( + signature_version="s3v4", s3=s3_config + ), ) @contextmanager @@ -355,7 +408,7 @@ def get_file_hash(self, path_info): def _upload_fobj(self, fobj, to_info): with self._get_obj(to_info) as obj: - obj.upload_fileobj(fobj) + obj.upload_fileobj(fobj, Config=self._transfer_config) def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs @@ -366,7 +419,10 @@ def _upload( disable=no_progress_bar, total=total, bytes=True, desc=name ) as pbar: obj.upload_file( - from_file, Callback=pbar.update, ExtraArgs=self.extra_args, + from_file, + Callback=pbar.update, + ExtraArgs=self.extra_args, + Config=self._transfer_config, ) def _download(self, from_info, to_file, name=None, no_progress_bar=False): @@ -377,4 +433,6 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False): bytes=True, desc=name, ) as pbar: - obj.download_file(to_file, Callback=pbar.update) + obj.download_file( + to_file, Callback=pbar.update, Config=self._transfer_config + ) diff --git a/dvc/utils/conversions.py b/dvc/utils/conversions.py new file mode 100644 index 0000000000..5fe2950d6d --- /dev/null +++ b/dvc/utils/conversions.py @@ -0,0 +1,24 @@ +# https://github.com/aws/aws-cli/blob/5aa599949f60b6af554fd5714d7161aa272716f7/awscli/customizations/s3/utils.py + +MULTIPLIERS = { + "kb": 1024, + "mb": 1024 ** 2, + "gb": 1024 ** 3, + "tb": 1024 ** 4, + "kib": 1024, + "mib": 1024 ** 2, + "gib": 1024 ** 3, + "tib": 1024 ** 4, +} + + +def human_readable_to_bytes(value): + value = value.lower() + suffix = None + if value.endswith(tuple(MULTIPLIERS.keys())): + size = 2 + size += value[-2] == "i" # KiB, MiB etc + value, suffix = value[:-size], value[-size:] + + multiplier = MULTIPLIERS.get(suffix, 1) + return int(value) * multiplier diff --git a/tests/func/test_s3.py b/tests/func/test_s3.py index 8f6f4c680c..0c002e3bc6 100644 --- a/tests/func/test_s3.py +++ b/tests/func/test_s3.py @@ -1,3 +1,4 @@ +import textwrap from functools import wraps import boto3 @@ -130,3 +131,74 @@ def test_s3_upload_fobj(tmp_dir, dvc, s3): tree.upload_fobj(stream, to_info, 1) assert to_info.read_text() == "foo" + + +KB = 1024 +MB = KB ** 2 +GB = KB ** 3 + + +def test_s3_aws_config(tmp_dir, dvc, s3, monkeypatch): + config_file = tmp_dir / "aws_config.ini" + config_file.write_text( + textwrap.dedent( + """\ + [default] + s3 = + max_concurrent_requests = 20000 + max_queue_size = 1000 + multipart_threshold = 1000KiB + multipart_chunksize = 64MB + use_accelerate_endpoint = true + addressing_style = path + """ + ) + ) + monkeypatch.setenv("AWS_CONFIG_FILE", config_file) + + tree = S3Tree(dvc, s3.config) + assert tree._transfer_config is None + + with tree._get_s3() as s3: + s3_config = s3.meta.client.meta.config.s3 + assert s3_config["use_accelerate_endpoint"] + assert s3_config["addressing_style"] == "path" + + transfer_config = tree._transfer_config + assert transfer_config.max_io_queue_size == 1000 + assert transfer_config.multipart_chunksize == 64 * MB + assert transfer_config.multipart_threshold == 1000 * KB + assert transfer_config.max_request_concurrency == 20000 + + +def test_s3_aws_config_different_profile(tmp_dir, dvc, s3, monkeypatch): + config_file = tmp_dir / "aws_config.ini" + config_file.write_text( + textwrap.dedent( + """\ + [default] + extra = keys + s3 = + addressing_style = auto + use_accelerate_endpoint = true + multipart_threshold = ThisIsNotGoingToBeCasted! + [profile dev] + some_extra = keys + s3 = + addresing_style = virtual + multipart_threshold = 2GiB + """ + ) + ) + monkeypatch.setenv("AWS_CONFIG_FILE", config_file) + + tree = S3Tree(dvc, {**s3.config, "profile": "dev"}) + assert tree._transfer_config is None + + with tree._get_s3() as s3: + s3_config = s3.meta.client.meta.config.s3 + assert s3_config["addresing_style"] == "virtual" + assert "use_accelerate_endpoint" not in s3_config + + transfer_config = tree._transfer_config + assert transfer_config.multipart_threshold == 2 * GB diff --git a/tests/unit/utils/test_conversions.py b/tests/unit/utils/test_conversions.py new file mode 100644 index 0000000000..063dd45565 --- /dev/null +++ b/tests/unit/utils/test_conversions.py @@ -0,0 +1,30 @@ +import pytest + +from dvc.utils.conversions import human_readable_to_bytes + +KB = 1024 +MB = KB ** 2 +GB = KB ** 3 +TB = KB ** 4 + + +@pytest.mark.parametrize( + "test_input, expected", + [ + ("10", 10), + ("10 ", 10), + ("1kb", 1 * KB), + ("2kb", 2 * KB), + ("1000mib", 1000 * MB), + ("20gB", 20 * GB), + ("10Tib", 10 * TB), + ], +) +def test_conversions_human_readable_to_bytes(test_input, expected): + assert human_readable_to_bytes(test_input) == expected + + +@pytest.mark.parametrize("invalid_input", ["foo", "10XB", "1000Pb", "fooMiB"]) +def test_conversions_human_readable_to_bytes_invalid(invalid_input): + with pytest.raises(ValueError): + human_readable_to_bytes(invalid_input) From 8b0895181da4fba1ecd302896a88bad800eb4782 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Mon, 1 Feb 2021 18:48:09 +0300 Subject: [PATCH 2/3] skip when the ~/.aws/config doesn't exist --- dvc/tree/s3.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dvc/tree/s3.py b/dvc/tree/s3.py index cdd22d1734..99094d0526 100644 --- a/dvc/tree/s3.py +++ b/dvc/tree/s3.py @@ -98,9 +98,11 @@ def _process_config(self): from boto3.s3.transfer import TransferConfig from botocore.configloader import load_config - config = load_config( - os.environ.get("AWS_CONFIG_FILE", _AWS_CONFIG_PATH) - ) + config_path = os.environ.get("AWS_CONFIG_FILE", _AWS_CONFIG_PATH) + if not os.path.exists(config_path): + return None + + config = load_config(config_path) profile = config["profiles"].get(self.profile or "default") if not profile: return None From ab57275b4fa7798a682e948ef52cbb15d30c7ad1 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 4 Feb 2021 15:32:10 +0300 Subject: [PATCH 3/3] Expand windows paths properly --- dvc/tree/s3.py | 2 +- tests/func/test_s3.py | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/dvc/tree/s3.py b/dvc/tree/s3.py index 99094d0526..fa7cc19f7e 100644 --- a/dvc/tree/s3.py +++ b/dvc/tree/s3.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -_AWS_CONFIG_PATH = os.path.expanduser("~/.aws/config") +_AWS_CONFIG_PATH = os.path.join(os.path.expanduser("~"), ".aws", "config") class S3Tree(BaseTree): diff --git a/tests/func/test_s3.py b/tests/func/test_s3.py index 0c002e3bc6..b148be84f5 100644 --- a/tests/func/test_s3.py +++ b/tests/func/test_s3.py @@ -1,3 +1,5 @@ +import importlib +import sys import textwrap from functools import wraps @@ -139,8 +141,9 @@ def test_s3_upload_fobj(tmp_dir, dvc, s3): def test_s3_aws_config(tmp_dir, dvc, s3, monkeypatch): - config_file = tmp_dir / "aws_config.ini" - config_file.write_text( + config_directory = tmp_dir / ".aws" + config_directory.mkdir() + (config_directory / "config").write_text( textwrap.dedent( """\ [default] @@ -154,9 +157,16 @@ def test_s3_aws_config(tmp_dir, dvc, s3, monkeypatch): """ ) ) - monkeypatch.setenv("AWS_CONFIG_FILE", config_file) - tree = S3Tree(dvc, s3.config) + if sys.platform == "win32": + var = "USERPROFILE" + else: + var = "HOME" + monkeypatch.setenv(var, str(tmp_dir)) + + # Fresh import to see the effects of changing HOME variable + s3_mod = importlib.reload(sys.modules[S3Tree.__module__]) + tree = s3_mod.S3Tree(dvc, s3.config) assert tree._transfer_config is None with tree._get_s3() as s3: