Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions servicex/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@
ignore_cache_opt = typer.Option(
None, "--ignore-cache", help="Ignore local cache and always submit to ServiceX"
)
hide_results_opt = typer.Option(
False,
"--hide-results",
help="Exclude printing results to the console",
)


def show_version(show: bool):
Expand All @@ -75,19 +80,20 @@ def deliver(
config_path: Optional[str] = config_file_option,
spec_file: str = spec_file_arg,
ignore_cache: Optional[bool] = ignore_cache_opt,
hide_results: bool = hide_results_opt,
):
"""
Deliver a file to the ServiceX cache.
"""

print(f"Delivering {spec_file} to ServiceX cache")
results = servicex_client.deliver(
servicex_client.deliver(
spec_file,
servicex_name=backend,
config_path=config_path,
ignore_local_cache=ignore_cache,
display_results=not hide_results,
)
rich.print(results)


if __name__ == "__main__":
Expand Down
84 changes: 73 additions & 11 deletions servicex/servicex_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from collections.abc import Sequence, Coroutine
from enum import Enum
import traceback
from rich.table import Table

T = TypeVar("T")
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -233,6 +234,58 @@ def _output_handler(
return out_dict


def _get_progress_options(progress_bar: ProgressBarFormat) -> dict:
"""Get progress options based on progress bar format."""
if progress_bar == ProgressBarFormat.expanded:
return {}
elif progress_bar == ProgressBarFormat.compact:
return {"overall_progress": True}
elif progress_bar == ProgressBarFormat.none:
return {"display_progress": False}
else:
raise ValueError(f"Invalid value {progress_bar} for progress_bar provided")


def _display_results(out_dict):
"""Display the delivery results using rich styling."""
from rich import get_console

console = get_console()

console.print("\n[bold green]✓ ServiceX Delivery Complete![/bold green]\n")

table = Table(
title="Delivered Files", show_header=True, header_style="bold magenta"
)
table.add_column("Sample", style="cyan", no_wrap=True)
table.add_column("File Count", justify="right", style="green")
table.add_column("Files", style="dim")

total_files = 0
for sample_name, files in out_dict.items():
if isinstance(files, GuardList) and files.valid():
file_list = list(files)
file_count = len(file_list)
total_files += file_count

# Show first few files with ellipsis if many
if file_count <= 3:
files_display = "\n".join(str(f) for f in file_list)
else:
files_display = "\n".join(str(f) for f in file_list[:2])
files_display += f"\n... and {file_count - 2} more files"

table.add_row(sample_name, str(file_count), files_display)
else:
# Handle error case
table.add_row(
sample_name, "[red]Error[/red]", "[red]Failed to retrieve files[/red]"
)

console.print(table)
console.print(f"\n[bold blue]Total files delivered: {total_files}[/bold blue]\n")


async def deliver_async(
spec: Union[ServiceXSpec, Mapping[str, Any], str, Path],
config_path: Optional[str] = None,
Expand All @@ -241,6 +294,7 @@ async def deliver_async(
fail_if_incomplete: bool = True,
ignore_local_cache: bool = False,
progress_bar: ProgressBarFormat = ProgressBarFormat.default,
display_results: bool = True,
concurrency: int = 10,
):
r"""
Expand All @@ -263,6 +317,8 @@ async def deliver_async(
will have its own progress bars; :py:const:`ProgressBarFormat.compact` gives one
summary progress bar for all transformations; :py:const:`ProgressBarFormat.none`
switches off progress bars completely.
:param display_results: Specifies whether the results should be displayed to the console.
Defaults to True.
:param concurrency: specify how many downloads to run in parallel (default is 8).
:return: A dictionary mapping the name of each :py:class:`Sample` to a :py:class:`.GuardList`
with the file names or URLs for the outputs.
Expand All @@ -282,26 +338,32 @@ async def deliver_async(

group = DatasetGroup(datasets)

if progress_bar == ProgressBarFormat.expanded:
progress_options = {}
elif progress_bar == ProgressBarFormat.compact:
progress_options = {"overall_progress": True}
elif progress_bar == ProgressBarFormat.none:
progress_options = {"display_progress": False}
else:
raise ValueError(f"Invalid value {progress_bar} for progress_bar provided")
progress_options = _get_progress_options(progress_bar)

if config.General.Delivery not in [
General.DeliveryEnum.URLs,
General.DeliveryEnum.LocalCache,
]:
raise ValueError(
f"unexpected value for config.general.Delivery: {config.General.Delivery}"
)

if config.General.Delivery == General.DeliveryEnum.URLs:
results = await group.as_signed_urls_async(
return_exceptions=return_exceptions, **progress_options
)
return _output_handler(config, datasets, results)

elif config.General.Delivery == General.DeliveryEnum.LocalCache:
else:
results = await group.as_files_async(
return_exceptions=return_exceptions, **progress_options
)
return _output_handler(config, datasets, results)

output_dict = _output_handler(config, datasets, results)

if display_results:
_display_results(output_dict)

return output_dict


deliver = make_sync(deliver_async)
Expand Down
30 changes: 28 additions & 2 deletions tests/app/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from unittest.mock import Mock, patch


Expand All @@ -45,6 +46,31 @@ def test_deliver(script_runner):
assert result.returncode == 0
result_rows = result.stdout.split("\n")
assert result_rows[0] == "Delivering foo.yaml to ServiceX cache"
assert (
result_rows[1] == "{'UprootRaw_YAML': ['/tmp/foo.root', '/tmp/bar.root']}"
mock_servicex_client.deliver.assert_called_once_with(
"foo.yaml",
servicex_name=None,
config_path=None,
ignore_local_cache=None,
display_results=True,
)


def test_deliver_hide_results(script_runner):
with patch("servicex.app.main.servicex_client") as mock_servicex_client:
mock_servicex_client.deliver = Mock(
return_value={"UprootRaw_YAML": ["/tmp/foo.root", "/tmp/bar.root"]}
)
result = script_runner.run(
["servicex", "deliver", "foo.yaml", "--hide-results"]
)
assert result.returncode == 0
result_rows = result.stdout.split("\n")
assert result_rows[0] == "Delivering foo.yaml to ServiceX cache"
# Verify that servicex_client.deliver was called with display_results=False
mock_servicex_client.deliver.assert_called_once_with(
"foo.yaml",
servicex_name=None,
config_path=None,
ignore_local_cache=None,
display_results=False,
)
92 changes: 92 additions & 0 deletions tests/test_display_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2022, IRIS-HEP
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from unittest.mock import Mock, patch
import pytest

from servicex.servicex_client import (
GuardList,
_display_results,
_get_progress_options,
ProgressBarFormat,
ServiceXClient,
)


def test_display_results_basic():
"""Test _display_results basic functionality - covers most code paths."""
# Test with valid files (covers main path)
valid_files = GuardList(["/tmp/file1.root", "/tmp/file2.root"])
# Test with error case (covers error path)
error_files = GuardList(ValueError("Test error"))

out_dict = {"ValidSample": valid_files, "ErrorSample": error_files}

with patch("rich.get_console") as mock_get_console:
mock_console = Mock()
mock_get_console.return_value = mock_console

with patch("servicex.servicex_client.Table"):
_display_results(out_dict)

# Just verify it was called - don't over-test internal details
mock_get_console.assert_called_once()
assert mock_console.print.call_count >= 2 # At least completion + total


def test_essential_valueerrors():
"""Test the most important ValueError cases in one simple test."""
# Test progress options
assert _get_progress_options(ProgressBarFormat.expanded) == {}
with pytest.raises(ValueError, match="Invalid value"):
_get_progress_options("invalid")

# Test ServiceX client errors - simplest possible
with pytest.raises(ValueError, match="Only specify backend or url"):
with patch("servicex.servicex_client.Configuration") as mock_config_class:
mock_config = Mock()
mock_config.endpoint_dict.return_value = {}
mock_config.default_endpoint = None
mock_config_class.read.return_value = mock_config
ServiceXClient(backend="test", url="http://test.com")


def test_guardlist_basics():
"""Test GuardList basic functionality."""
# Valid case
valid_list = GuardList([1, 2, 3])
assert len(valid_list) == 3
assert valid_list[0] == 1
assert valid_list.valid()

# Error case
from servicex.servicex_client import ReturnValueException

error_list = GuardList(ValueError("error"))
assert not error_list.valid()
with pytest.raises(ReturnValueException):
_ = error_list[0]
64 changes: 64 additions & 0 deletions tests/test_servicex_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,67 @@ def test_invalid_backend_raises_error_with_filename():
ServiceXClient(backend="badname", config_path=config_file)

assert f"Backend badname not defined in {expected} file" in str(err.value)


def test_display_results_with_many_files():
from servicex.servicex_client import _display_results, GuardList
from unittest.mock import patch, MagicMock

# Mock GuardList with more than 3 files to trigger lines 275-276
mock_guard_list = MagicMock(spec=GuardList)
mock_guard_list.valid.return_value = True
mock_guard_list.__iter__.return_value = iter(
[
"file1.parquet",
"file2.parquet",
"file3.parquet",
"file4.parquet",
"file5.parquet",
]
)

out_dict = {"sample1": mock_guard_list}

with patch("rich.get_console") as mock_get_console:
mock_console = MagicMock()
mock_get_console.return_value = mock_console

with patch("servicex.servicex_client.Table") as mock_table:
mock_table_instance = MagicMock()
mock_table.return_value = mock_table_instance

_display_results(out_dict)

# Verify that add_row was called with the truncated file list
mock_table_instance.add_row.assert_called_once()
call_args = mock_table_instance.add_row.call_args[0]
assert call_args[0] == "sample1"
assert call_args[1] == "5"
assert "... and 3 more files" in call_args[2]


@pytest.mark.asyncio
async def test_deliver_async_invalid_delivery_config():
from servicex.servicex_client import deliver_async
from unittest.mock import patch, MagicMock

# Mock the config loading to return invalid delivery type
with patch("servicex.servicex_client._load_ServiceXSpec") as mock_load_spec:
with patch("servicex.servicex_client._build_datasets") as mock_build_datasets:
with patch("servicex.minio_adapter.init_s3_config"):
mock_config = MagicMock()
mock_config.General.Delivery = (
"INVALID_DELIVERY" # Invalid delivery type
)
mock_config.General.IgnoreLocalCache = False
mock_config.Sample = []
mock_load_spec.return_value = mock_config
mock_build_datasets.return_value = []

with pytest.raises(ValueError) as exc_info:
await deliver_async("test_spec.yaml")

assert (
"unexpected value for config.general.Delivery: INVALID_DELIVERY"
in str(exc_info.value)
)
Loading