From 6bd9d7625d65b14a9c09a9dafa0cf7b9b6263e4f Mon Sep 17 00:00:00 2001 From: Scott McMurray Date: Sun, 30 Nov 2025 23:32:07 -0800 Subject: [PATCH] =?UTF-8?q?Assume=20the=20returned=20value=20in=20`.filter?= =?UTF-8?q?(=E2=80=A6).count()`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Similar to how this helps in `slice::Iter::position`, LLVM sometimes loses track of how high this can get, so for `TrustedLen` iterators tell it what the upper bound is. --- library/core/src/iter/adapters/filter.rs | 41 ++++++++++++++++++- .../iter-filter-count-assume.rs | 34 +++++++++++++++ .../iter-filter-count-debug-check.rs | 34 +++++++++++++++ 3 files changed, 107 insertions(+), 2 deletions(-) create mode 100644 tests/codegen-llvm/lib-optimizations/iter-filter-count-assume.rs create mode 100644 tests/ui/iterators/iter-filter-count-debug-check.rs diff --git a/library/core/src/iter/adapters/filter.rs b/library/core/src/iter/adapters/filter.rs index dd08cd6f61c4c..b22419ccf080a 100644 --- a/library/core/src/iter/adapters/filter.rs +++ b/library/core/src/iter/adapters/filter.rs @@ -4,7 +4,7 @@ use core::ops::ControlFlow; use crate::fmt; use crate::iter::adapters::SourceIter; -use crate::iter::{FusedIterator, InPlaceIterable, TrustedFused}; +use crate::iter::{FusedIterator, InPlaceIterable, TrustedFused, TrustedLen}; use crate::num::NonZero; use crate::ops::Try; @@ -138,7 +138,13 @@ where move |x| predicate(&x) as usize } - self.iter.map(to_usize(self.predicate)).sum() + let before = self.iter.size_hint().1.unwrap_or(usize::MAX); + let total = self.iter.map(to_usize(self.predicate)).sum(); + // SAFETY: `total` and `before` came from the same iterator of type `I` + unsafe { + ::assume_count_le_upper_bound(total, before); + } + total } #[inline] @@ -214,3 +220,34 @@ unsafe impl InPlaceIterable for Filter { const EXPAND_BY: Option> = I::EXPAND_BY; const MERGE_BY: Option> = I::MERGE_BY; } + +trait SpecAssumeCount { + /// # Safety + /// + /// `count` must be an number of items actually read from the iterator. + /// + /// `upper` must either: + /// - have come from `size_hint().1` on the iterator, or + /// - be `usize::MAX` which will vacuously do nothing. + unsafe fn assume_count_le_upper_bound(count: usize, upper: usize); +} + +impl SpecAssumeCount for I { + #[inline] + #[rustc_inherit_overflow_checks] + default unsafe fn assume_count_le_upper_bound(count: usize, upper: usize) { + // In the default we can't trust the `upper` for soundness + // because it came from an untrusted `size_hint`. + + // In debug mode we might as well check that the size_hint wasn't too small + let _ = upper - count; + } +} + +impl SpecAssumeCount for I { + #[inline] + unsafe fn assume_count_le_upper_bound(count: usize, upper: usize) { + // SAFETY: The `upper` is trusted because it came from a `TrustedLen` iterator. + unsafe { crate::hint::assert_unchecked(count <= upper) } + } +} diff --git a/tests/codegen-llvm/lib-optimizations/iter-filter-count-assume.rs b/tests/codegen-llvm/lib-optimizations/iter-filter-count-assume.rs new file mode 100644 index 0000000000000..7588b23efc3af --- /dev/null +++ b/tests/codegen-llvm/lib-optimizations/iter-filter-count-assume.rs @@ -0,0 +1,34 @@ +//@ compile-flags: -Copt-level=3 +//@ edition: 2024 + +#![crate_type = "lib"] + +// Similar to how we `assume` that `slice::Iter::position` is within the length, +// check that `count` also does that for `TrustedLen` iterators. +// See https://rust-lang.zulipchat.com/#narrow/channel/122651-general/topic/Overflow-chk.20removed.20for.20array.20of.2059.2C.20but.20not.2060.2C.20elems/with/561070780 + +// CHECK-LABEL: @filter_count_untrusted +#[unsafe(no_mangle)] +pub fn filter_count_untrusted(bar: &[u8; 1234]) -> u16 { + // CHECK-NOT: llvm.assume + // CHECK: call void @{{.+}}unwrap_failed + // CHECK-NOT: llvm.assume + let mut iter = bar.iter(); + let iter = std::iter::from_fn(|| iter.next()); // Make it not TrustedLen + u16::try_from(iter.filter(|v| **v == 0).count()).unwrap() +} + +// CHECK-LABEL: @filter_count_trusted +#[unsafe(no_mangle)] +pub fn filter_count_trusted(bar: &[u8; 1234]) -> u16 { + // CHECK-NOT: unwrap_failed + // CHECK: %[[ASSUME:.+]] = icmp ult {{i64|i32|i16}} %{{.+}}, 1235 + // CHECK-NEXT: tail call void @llvm.assume(i1 %[[ASSUME]]) + // CHECK-NOT: unwrap_failed + let iter = bar.iter(); + u16::try_from(iter.filter(|v| **v == 0).count()).unwrap() +} + +// CHECK: ; core::result::unwrap_failed +// CHECK-NEXT: Function Attrs +// CHECK-NEXT: declare{{.+}}void @{{.+}}unwrap_failed diff --git a/tests/ui/iterators/iter-filter-count-debug-check.rs b/tests/ui/iterators/iter-filter-count-debug-check.rs new file mode 100644 index 0000000000000..6e3a3f73920e4 --- /dev/null +++ b/tests/ui/iterators/iter-filter-count-debug-check.rs @@ -0,0 +1,34 @@ +//@ run-pass +//@ needs-unwind +//@ ignore-backends: gcc +//@ compile-flags: -C overflow-checks + +use std::panic; + +struct Lies(usize); + +impl Iterator for Lies { + type Item = usize; + + fn next(&mut self) -> Option { + if self.0 == 0 { + None + } else { + self.0 -= 1; + Some(self.0) + } + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(2)) + } +} + +fn main() { + let r = panic::catch_unwind(|| { + // This returns more items than its `size_hint` said was possible, + // which `Filter::count` detects via `overflow-checks`. + let _ = Lies(10).filter(|&x| x > 3).count(); + }); + assert!(r.is_err()); +}