Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[utils] traverse_obj: Allow iterables in traversal #6902

Merged
merged 2 commits into from Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/test_utils.py
Expand Up @@ -2016,13 +2016,17 @@ def test_traverse_obj(self):
msg='nested `...` queries should work')
self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4),
msg='`...` query result should be flattened')
self.assertEqual(traverse_obj(range(4), ...), list(range(4)),
msg='`...` should accept iterables')

# Test function as key
self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)),
[_TEST_DATA['urls']],
msg='function as query key should perform a filter based on (key, value)')
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'},
msg='exceptions in the query function should be catched')
self.assertEqual(traverse_obj(range(4), lambda _, x: x % 2 == 0), [0, 2],
msg='function key should accept iterables')
if __debug__:
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
traverse_obj(_TEST_DATA, lambda a: ...)
Expand Down
7 changes: 3 additions & 4 deletions yt_dlp/utils.py
Expand Up @@ -5528,7 +5528,6 @@ def traverse_obj(
If no `default` is given and the last path branches, a `list` of results
is always returned. If a path ends on a `dict` that result will always be a `dict`.
"""
is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes))
casefold = lambda k: k.casefold() if isinstance(k, str) else k

if isinstance(expected_type, type):
Expand Down Expand Up @@ -5564,7 +5563,7 @@ def apply_key(key, obj, is_last):
branching = True
if isinstance(obj, collections.abc.Mapping):
result = obj.values()
elif is_sequence(obj):
elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)):
result = obj
elif isinstance(obj, re.Match):
result = obj.groups()
Expand All @@ -5578,7 +5577,7 @@ def apply_key(key, obj, is_last):
branching = True
if isinstance(obj, collections.abc.Mapping):
iter_obj = obj.items()
elif is_sequence(obj):
elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)):
iter_obj = enumerate(obj)
elif isinstance(obj, re.Match):
iter_obj = itertools.chain(
Expand Down Expand Up @@ -5614,7 +5613,7 @@ def apply_key(key, obj, is_last):
result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)

elif isinstance(key, (int, slice)):
if is_sequence(obj):
if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, (str, bytes)):
branching = isinstance(key, slice)
with contextlib.suppress(IndexError):
result = obj[key]
Expand Down