Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ dependencies:
- zlib=1.2.11=h166bdaf_1014
- zstd=1.5.2=ha95c52a_0
- pip:
- acres
- bids-validator==1.9.3
- docopt==0.6.2
- formulaic==0.3.4
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ classifiers = [
license = {file = "LICENSE"}
requires-python = ">=3.9"
dependencies = [
"acres >= 0.5.0",
"pybids >= 0.15.2",
"importlib_resources >= 5.7; python_version < '3.11'",
"requests",
"tqdm",
]
Expand Down
193 changes: 0 additions & 193 deletions templateflow/_loader.py

This file was deleted.

22 changes: 22 additions & 0 deletions templateflow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def get(template, raise_empty=False, **kwargs):
if raise_empty and not out_file:
raise Exception('No results found')

# Truncate possible S3 error files from previous attempts
_truncate_s3_errors(out_file)

# Try DataLad first
dl_missing = [p for p in out_file if not p.is_file()]
if TF_USE_DATALAD and dl_missing:
Expand Down Expand Up @@ -320,6 +323,8 @@ def _s3_get(filepath):
print(f'Downloading {url}', file=stderr)
# Streaming, so we can iterate over the response.
r = requests.get(url, stream=True, timeout=TF_GET_TIMEOUT)
if r.status_code != 200:
raise RuntimeError(f'Failed to download {url} with status code {r.status_code}')

# Total size in bytes.
total_size = int(r.headers.get('content-length', 0))
Expand Down Expand Up @@ -393,3 +398,20 @@ def _normalize_ext(value):
if isinstance(value, str):
return f'{"" if value.startswith(".") else "."}{value}'
return [_normalize_ext(v) for v in value]


def _truncate_s3_errors(filepaths):
"""
Truncate XML error bodies saved by previous versions of TemplateFlow.

Parameters
----------
filepaths : list of Path
List of file paths to check and truncate if necessary.
"""
for filepath in filepaths:
if filepath.is_file(follow_symlinks=False) and 0 < filepath.stat().st_size < 1024:
with open(filepath, 'rb') as f:
content = f.read(100)
if content.startswith(b'<?xml') and b'<Error><Code>' in content:
filepath.write_bytes(b'') # Truncate file to zero bytes
4 changes: 2 additions & 2 deletions templateflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from click.decorators import FC, Option, _param_memo

from templateflow import __package__, api
from templateflow._loader import Loader as _Loader
from acres import Loader as _Loader
from templateflow.conf import TF_AUTOUPDATE, TF_HOME, TF_USE_DATALAD

load_data = _Loader(__package__)
Expand All @@ -58,7 +58,7 @@ def _nulls(s):
def entity_opts():
"""Attaches all entities as options to the command."""

entities = json.loads(Path(load_data('conf/config.json')).read_text())['entities']
entities = json.loads(load_data('conf/config.json').read_text())['entities']

args = [
(f'--{e["name"]}', *ENTITY_SHORTHANDS.get(e['name'], ()))
Expand Down
4 changes: 2 additions & 2 deletions templateflow/conf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from pathlib import Path
from warnings import warn

from .._loader import Loader
from acres import Loader

load_data = Loader(__package__)
load_data = Loader(__spec__.name)


def _env_to_bool(envvar: str, default: bool) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions templateflow/tests/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from acres import Loader

load_data = Loader(__spec__.name)
2 changes: 2 additions & 0 deletions templateflow/tests/data/error_response.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
<?xml version="1.0" encoding="UTF-8"?>
<Error><Code>InvalidArgument</Code><Message>Invalid version id specified</Message><ArgumentName>versionId</ArgumentName><ArgumentValue>test</ArgumentValue><RequestId>BKT12MP069SFQGH3</RequestId><HostId>DIljS3MUsLCEa27wSyqAxsZZE3MhqEWYf3lRbH2Rl19VV0pGe/61Hh3MzSBeS45VltnZDzliHaTMxjnGPvKOOk+SY/it3Ond</HostId></Error>
62 changes: 62 additions & 0 deletions templateflow/tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@
from importlib import reload
from pathlib import Path

import pytest
import requests

from templateflow import api as tf
from templateflow import conf as tfc

from .data import load_data


def test_get_skel_file(tmp_path, monkeypatch):
"""Exercise the skeleton file generation."""
Expand Down Expand Up @@ -87,3 +91,61 @@ def test_update_s3(tmp_path, monkeypatch):
for p in (newhome / 'tpl-MNI152NLin6Sym').glob('*.nii.gz'):
p.unlink()
assert tfc._s3.update(newhome, local=False, overwrite=False)


def mock_get(*args, **kwargs):
class MockResponse:
status_code = 400

return MockResponse()


def test_s3_400_error(monkeypatch):
"""Simulate a 400 error when fetching the skeleton file."""

reload(tfc)
reload(tf)

monkeypatch.setattr(requests, 'get', mock_get)
with pytest.raises(RuntimeError, match=r'Failed to download .* code 400'):
tf._s3_get(
Path(tfc.TF_LAYOUT.root)
/ 'tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-02_T1w.nii.gz'
)


def test_bad_skeleton(tmp_path, monkeypatch):
newhome = (tmp_path / 's3-update').resolve()
monkeypatch.setattr(tfc, 'TF_USE_DATALAD', False)
monkeypatch.setattr(tfc, 'TF_HOME', newhome)
monkeypatch.setattr(tfc, 'TF_LAYOUT', None)

tfc._init_cache()
tfc.init_layout()

assert tfc.TF_LAYOUT is not None
assert tfc.TF_LAYOUT.root == str(newhome)

# Instead of reloading
monkeypatch.setattr(tf, 'TF_LAYOUT', tfc.TF_LAYOUT)

paths = tf.ls('MNI152NLin2009cAsym', resolution='02', suffix='T1w', desc=None)
assert paths
path = Path(paths[0])
assert path.read_bytes() == b''

error_file = load_data.readable('error_response.xml')
path.write_bytes(error_file.read_bytes())

# Test directly before testing through API paths
tf._truncate_s3_errors(paths)
assert path.read_bytes() == b''

path.write_bytes(error_file.read_bytes())

monkeypatch.setattr(requests, 'get', mock_get)
with pytest.raises(RuntimeError):
tf.get('MNI152NLin2009cAsym', resolution='02', suffix='T1w', desc=None)

# Running get clears bad files before attempting to download
assert path.read_bytes() == b''
Loading