From 66a0338738f82acc0c59b4858240ae786e1a566a Mon Sep 17 00:00:00 2001 From: Simon Sawicki Date: Mon, 24 Apr 2023 15:57:35 +0200 Subject: [PATCH 1/2] Allow iterables in traversal --- test/test_utils.py | 4 ++++ yt_dlp/utils.py | 7 +++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index d4a301583f0..2bb8e5e650b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -2016,6 +2016,8 @@ def test_traverse_obj(self): msg='nested `...` queries should work') self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4), msg='`...` query result should be flattened') + self.assertEqual(traverse_obj(range(4), ...), range(4), + msg='`...` should accept iterables') # Test function as key self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)), @@ -2023,6 +2025,8 @@ def test_traverse_obj(self): msg='function as query key should perform a filter based on (key, value)') self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'}, msg='exceptions in the query function should be catched') + self.assertEqual(traverse_obj(range(4), lambda _, x: x % 2 == 0), [0, 2], + msg='function key should accept iterables') if __debug__: with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'): traverse_obj(_TEST_DATA, lambda a: ...) diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index 746a2885d63..f69311462da 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -5528,7 +5528,6 @@ def traverse_obj( If no `default` is given and the last path branches, a `list` of results is always returned. If a path ends on a `dict` that result will always be a `dict`. """ - is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes)) casefold = lambda k: k.casefold() if isinstance(k, str) else k if isinstance(expected_type, type): @@ -5564,7 +5563,7 @@ def apply_key(key, obj, is_last): branching = True if isinstance(obj, collections.abc.Mapping): result = obj.values() - elif is_sequence(obj): + elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)): result = obj elif isinstance(obj, re.Match): result = obj.groups() @@ -5578,7 +5577,7 @@ def apply_key(key, obj, is_last): branching = True if isinstance(obj, collections.abc.Mapping): iter_obj = obj.items() - elif is_sequence(obj): + elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)): iter_obj = enumerate(obj) elif isinstance(obj, re.Match): iter_obj = itertools.chain( @@ -5614,7 +5613,7 @@ def apply_key(key, obj, is_last): result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) elif isinstance(key, (int, slice)): - if is_sequence(obj): + if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, (str, bytes)): branching = isinstance(key, slice) with contextlib.suppress(IndexError): result = obj[key] From 119a53ec1c3b563bb19a3bb0b80081ce1613fbac Mon Sep 17 00:00:00 2001 From: Simon Sawicki Date: Mon, 24 Apr 2023 16:06:04 +0200 Subject: [PATCH 2/2] Fix test --- test/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index 2bb8e5e650b..f2f3b8170a2 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -2016,7 +2016,7 @@ def test_traverse_obj(self): msg='nested `...` queries should work') self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4), msg='`...` query result should be flattened') - self.assertEqual(traverse_obj(range(4), ...), range(4), + self.assertEqual(traverse_obj(range(4), ...), list(range(4)), msg='`...` should accept iterables') # Test function as key