Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions library/core/src/iter/adapters/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use core::ops::ControlFlow;

use crate::fmt;
use crate::iter::adapters::SourceIter;
use crate::iter::{FusedIterator, InPlaceIterable, TrustedFused};
use crate::iter::{FusedIterator, InPlaceIterable, TrustedFused, TrustedLen};
use crate::num::NonZero;
use crate::ops::Try;

Expand Down Expand Up @@ -138,7 +138,13 @@ where
move |x| predicate(&x) as usize
}

self.iter.map(to_usize(self.predicate)).sum()
let before = self.iter.size_hint().1.unwrap_or(usize::MAX);
let total = self.iter.map(to_usize(self.predicate)).sum();
// SAFETY: `total` and `before` came from the same iterator of type `I`
unsafe {
<I as SpecAssumeCount>::assume_count_le_upper_bound(total, before);
}
total
}

#[inline]
Expand Down Expand Up @@ -214,3 +220,34 @@ unsafe impl<I: InPlaceIterable, P> InPlaceIterable for Filter<I, P> {
const EXPAND_BY: Option<NonZero<usize>> = I::EXPAND_BY;
const MERGE_BY: Option<NonZero<usize>> = I::MERGE_BY;
}

trait SpecAssumeCount {
/// # Safety
///
/// `count` must be an number of items actually read from the iterator.
///
/// `upper` must either:
/// - have come from `size_hint().1` on the iterator, or
/// - be `usize::MAX` which will vacuously do nothing.
unsafe fn assume_count_le_upper_bound(count: usize, upper: usize);
}

impl<I: Iterator> SpecAssumeCount for I {
#[inline]
#[rustc_inherit_overflow_checks]
default unsafe fn assume_count_le_upper_bound(count: usize, upper: usize) {
// In the default we can't trust the `upper` for soundness
// because it came from an untrusted `size_hint`.

// In debug mode we might as well check that the size_hint wasn't too small
let _ = upper - count;
}
}

impl<I: TrustedLen> SpecAssumeCount for I {
#[inline]
unsafe fn assume_count_le_upper_bound(count: usize, upper: usize) {
// SAFETY: The `upper` is trusted because it came from a `TrustedLen` iterator.
unsafe { crate::hint::assert_unchecked(count <= upper) }
}
}
34 changes: 34 additions & 0 deletions tests/codegen-llvm/lib-optimizations/iter-filter-count-assume.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//@ compile-flags: -Copt-level=3
//@ edition: 2024

#![crate_type = "lib"]

// Similar to how we `assume` that `slice::Iter::position` is within the length,
// check that `count` also does that for `TrustedLen` iterators.
// See https://rust-lang.zulipchat.com/#narrow/channel/122651-general/topic/Overflow-chk.20removed.20for.20array.20of.2059.2C.20but.20not.2060.2C.20elems/with/561070780

// CHECK-LABEL: @filter_count_untrusted
#[unsafe(no_mangle)]
pub fn filter_count_untrusted(bar: &[u8; 1234]) -> u16 {
// CHECK-NOT: llvm.assume
// CHECK: call void @{{.+}}unwrap_failed
// CHECK-NOT: llvm.assume
let mut iter = bar.iter();
let iter = std::iter::from_fn(|| iter.next()); // Make it not TrustedLen
u16::try_from(iter.filter(|v| **v == 0).count()).unwrap()
}

// CHECK-LABEL: @filter_count_trusted
#[unsafe(no_mangle)]
pub fn filter_count_trusted(bar: &[u8; 1234]) -> u16 {
// CHECK-NOT: unwrap_failed
// CHECK: %[[ASSUME:.+]] = icmp ult {{i64|i32|i16}} %{{.+}}, 1235
// CHECK-NEXT: tail call void @llvm.assume(i1 %[[ASSUME]])
// CHECK-NOT: unwrap_failed
let iter = bar.iter();
u16::try_from(iter.filter(|v| **v == 0).count()).unwrap()
}

// CHECK: ; core::result::unwrap_failed
// CHECK-NEXT: Function Attrs
// CHECK-NEXT: declare{{.+}}void @{{.+}}unwrap_failed
34 changes: 34 additions & 0 deletions tests/ui/iterators/iter-filter-count-debug-check.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//@ run-pass
//@ needs-unwind
//@ ignore-backends: gcc
//@ compile-flags: -C overflow-checks

use std::panic;

struct Lies(usize);

impl Iterator for Lies {
type Item = usize;

fn next(&mut self) -> Option<usize> {
if self.0 == 0 {
None
} else {
self.0 -= 1;
Some(self.0)
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(2))
}
}

fn main() {
let r = panic::catch_unwind(|| {
// This returns more items than its `size_hint` said was possible,
// which `Filter::count` detects via `overflow-checks`.
let _ = Lies(10).filter(|&x| x > 3).count();
});
assert!(r.is_err());
}
Loading