diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py index 0fe81edcbb735b..e5a280e1949de9 100644 --- a/Lib/collections/__init__.py +++ b/Lib/collections/__init__.py @@ -1486,6 +1486,8 @@ def format_map(self, mapping): return self.data.format_map(mapping) def index(self, sub, start=0, end=_sys.maxsize): + if isinstance(sub, UserString): + sub = sub.data return self.data.index(sub, start, end) def isalpha(self): @@ -1554,6 +1556,8 @@ def rfind(self, sub, start=0, end=_sys.maxsize): return self.data.rfind(sub, start, end) def rindex(self, sub, start=0, end=_sys.maxsize): + if isinstance(sub, UserString): + sub = sub.data return self.data.rindex(sub, start, end) def rjust(self, width, *args): diff --git a/Lib/test/string_tests.py b/Lib/test/string_tests.py index 8dcfa98f85b075..ed405472574a93 100644 --- a/Lib/test/string_tests.py +++ b/Lib/test/string_tests.py @@ -90,6 +90,18 @@ def checkcall(self, obj, methodname, *args): args = self.fixtype(args) getattr(obj, methodname)(*args) + def _get_teststrings(self, charset, digits): + base = len(charset) + teststrings = set() + for i in range(base ** digits): + entry = [] + for j in range(digits): + i, m = divmod(i, base) + entry.append(charset[m]) + teststrings.add(''.join(entry)) + teststrings = [self.fixtype(ts) for ts in teststrings] + return teststrings + def test_count(self): self.checkequal(3, 'aaa', 'count', 'a') self.checkequal(0, 'aaa', 'count', 'b') @@ -130,17 +142,7 @@ def test_count(self): # For a variety of combinations, # verify that str.count() matches an equivalent function # replacing all occurrences and then differencing the string lengths - charset = ['', 'a', 'b'] - digits = 7 - base = len(charset) - teststrings = set() - for i in range(base ** digits): - entry = [] - for j in range(digits): - i, m = divmod(i, base) - entry.append(charset[m]) - teststrings.add(''.join(entry)) - teststrings = [self.fixtype(ts) for ts in teststrings] + teststrings = self._get_teststrings(['', 'a', 'b'], 7) for i in teststrings: n = len(i) for j in teststrings: @@ -197,17 +199,7 @@ def test_find(self): # For a variety of combinations, # verify that str.find() matches __contains__ # and that the found substring is really at that location - charset = ['', 'a', 'b', 'c'] - digits = 5 - base = len(charset) - teststrings = set() - for i in range(base ** digits): - entry = [] - for j in range(digits): - i, m = divmod(i, base) - entry.append(charset[m]) - teststrings.add(''.join(entry)) - teststrings = [self.fixtype(ts) for ts in teststrings] + teststrings = self._get_teststrings(['', 'a', 'b', 'c'], 5) for i in teststrings: for j in teststrings: loc = i.find(j) @@ -244,17 +236,7 @@ def test_rfind(self): # For a variety of combinations, # verify that str.rfind() matches __contains__ # and that the found substring is really at that location - charset = ['', 'a', 'b', 'c'] - digits = 5 - base = len(charset) - teststrings = set() - for i in range(base ** digits): - entry = [] - for j in range(digits): - i, m = divmod(i, base) - entry.append(charset[m]) - teststrings.add(''.join(entry)) - teststrings = [self.fixtype(ts) for ts in teststrings] + teststrings = self._get_teststrings(['', 'a', 'b', 'c'], 5) for i in teststrings: for j in teststrings: loc = i.rfind(j) @@ -295,6 +277,19 @@ def test_index(self): else: self.checkraises(TypeError, 'hello', 'index', 42) + # For a variety of combinations, + # verify that str.index() matches __contains__ + # and that the found substring is really at that location + teststrings = self._get_teststrings(['', 'a', 'b', 'c'], 5) + for i in teststrings: + for j in teststrings: + if j in i: + loc = i.index(j) + self.assertGreaterEqual(loc, 0) + self.assertEqual(i[loc:loc+len(j)], j) + else: + self.assertRaises(ValueError, i.index, j) + def test_rindex(self): self.checkequal(12, 'abcdefghiabc', 'rindex', '') self.checkequal(3, 'abcdefghiabc', 'rindex', 'def') @@ -321,6 +316,19 @@ def test_rindex(self): else: self.checkraises(TypeError, 'hello', 'rindex', 42) + # For a variety of combinations, + # verify that str.rindex() matches __contains__ + # and that the found substring is really at that location + teststrings = self._get_teststrings(['', 'a', 'b', 'c'], 5) + for i in teststrings: + for j in teststrings: + if j in i: + loc = i.rindex(j) + self.assertGreaterEqual(loc, 0) + self.assertEqual(i[loc:loc+len(j)], j) + else: + self.assertRaises(ValueError, i.rindex, j) + def test_find_periodic_pattern(self): """Cover the special path for periodic patterns.""" def reference_find(p, s): diff --git a/Misc/NEWS.d/next/Library/2025-11-03-17-13-00.gh-issue-140911.7KFvSQ.rst b/Misc/NEWS.d/next/Library/2025-11-03-17-13-00.gh-issue-140911.7KFvSQ.rst new file mode 100644 index 00000000000000..b0b6e4611924c2 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-11-03-17-13-00.gh-issue-140911.7KFvSQ.rst @@ -0,0 +1,3 @@ +:mod:`collections`: Ensure that the methods ``UserString.rindex()`` and +``UserString.index()`` accept :class:`collections.UserString` instances as the +sub argument.