Skip to content

Commit

Permalink
Auto merge of #123572 - Mark-Simulacrum:vtable-methods, r=oli-obk
Browse files Browse the repository at this point in the history
Increase vtable layout size

This improves LLVM's codegen by allowing vtable loads to be hoisted out of loops (as just one example). The calculation here is an under-approximation but works for simple trait hierarchies (e.g., FnMut will be improved). We have a runtime assert that the approximation is accurate, so there's no risk of UB as a result of getting this wrong.

```rust
#[no_mangle]
pub fn foo(elements: &[u32], callback: &mut dyn Callback) {
    for element in elements.iter() {
        if *element != 0 {
            callback.call(*element);
        }
    }
}

pub trait Callback {
    fn call(&mut self, _: u32);
}
```

Simplifying a bit (e.g., numbering ends up different):

```diff
 ; Function Attrs: nonlazybind uwtable
-define void `@foo(ptr` noalias noundef nonnull readonly align 4 %elements.0, i64 noundef %elements.1, ptr noundef nonnull align 1 %callback.0, ptr noalias nocapture noundef readonly align 8 dereferenceable(24) %callback.1) unnamed_addr #0 {
+define void `@foo(ptr` noalias noundef nonnull readonly align 4 %elements.0, i64 noundef %elements.1, ptr noundef nonnull align 1 %callback.0, ptr noalias nocapture noundef readonly align 8 dereferenceable(32) %callback.1) unnamed_addr #0 {
 start:
   %_15 = getelementptr inbounds i32, ptr %elements.0, i64 %elements.1
`@@` -13,4 +13,5 `@@`
 bb4.lr.ph:                                        ; preds = %start
   %1 = getelementptr inbounds i8, ptr %callback.1, i64 24
+  %2 = load ptr, ptr %1, align 8, !nonnull !3
   br label %bb4

 bb6:                                              ; preds = %bb4
-  %4 = load ptr, ptr %1, align 8, !invariant.load !3, !nonnull !3
-  tail call void %4(ptr noundef nonnull align 1 %callback.0, i32 noundef %_9)
+  tail call void %2(ptr noundef nonnull align 1 %callback.0, i32 noundef %_9)
   br label %bb7
 }
```
  • Loading branch information
bors committed Jun 1, 2024
2 parents acaf0ae + 95e0732 commit f2208b3
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 127 deletions.
3 changes: 1 addition & 2 deletions compiler/rustc_hir_analysis/src/coherence/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use rustc_middle::query::Providers;
use rustc_middle::ty::{self, TyCtxt, TypeVisitableExt};
use rustc_session::parse::feature_err;
use rustc_span::{sym, ErrorGuaranteed};
use rustc_trait_selection::traits;

mod builtin;
mod inherent_impls;
Expand Down Expand Up @@ -199,7 +198,7 @@ fn check_object_overlap<'tcx>(
// With the feature enabled, the trait is not implemented automatically,
// so this is valid.
} else {
let mut supertrait_def_ids = traits::supertrait_def_ids(tcx, component_def_id);
let mut supertrait_def_ids = tcx.supertrait_def_ids(component_def_id);
if supertrait_def_ids.any(|d| d == trait_def_id) {
let span = tcx.def_span(impl_def_id);
return Err(struct_span_code_err!(
Expand Down
25 changes: 7 additions & 18 deletions compiler/rustc_middle/src/ty/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -827,25 +827,14 @@ where
});
}

let mk_dyn_vtable = || {
let mk_dyn_vtable = |principal: Option<ty::PolyExistentialTraitRef<'tcx>>| {
let min_count = ty::vtable_min_entries(tcx, principal);
Ty::new_imm_ref(
tcx,
tcx.lifetimes.re_static,
Ty::new_array(tcx, tcx.types.usize, 3),
// FIXME: properly type (e.g. usize and fn pointers) the fields.
Ty::new_array(tcx, tcx.types.usize, min_count.try_into().unwrap()),
)
/* FIXME: use actual fn pointers
Warning: naively computing the number of entries in the
vtable by counting the methods on the trait + methods on
all parent traits does not work, because some methods can
be not object safe and thus excluded from the vtable.
Increase this counter if you tried to implement this but
failed to do it without duplicating a lot of code from
other places in the compiler: 2
Ty::new_tup(tcx,&[
Ty::new_array(tcx,tcx.types.usize, 3),
Ty::new_array(tcx,Option<fn()>),
])
*/
};

let metadata = if let Some(metadata_def_id) = tcx.lang_items().metadata_type()
Expand All @@ -864,16 +853,16 @@ where
// `std::mem::uninitialized::<&dyn Trait>()`, for example.
if let ty::Adt(def, args) = metadata.kind()
&& Some(def.did()) == tcx.lang_items().dyn_metadata()
&& args.type_at(0).is_trait()
&& let ty::Dynamic(data, _, ty::Dyn) = args.type_at(0).kind()
{
mk_dyn_vtable()
mk_dyn_vtable(data.principal())
} else {
metadata
}
} else {
match tcx.struct_tail_erasing_lifetimes(pointee, cx.param_env()).kind() {
ty::Slice(_) | ty::Str => tcx.types.usize,
ty::Dynamic(_, _, ty::Dyn) => mk_dyn_vtable(),
ty::Dynamic(data, _, ty::Dyn) => mk_dyn_vtable(data.principal()),
_ => bug!("TyAndLayout::field({:?}): not applicable", this),
}
};
Expand Down
62 changes: 62 additions & 0 deletions compiler/rustc_middle/src/ty/vtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::fmt;
use crate::mir::interpret::{alloc_range, AllocId, Allocation, Pointer, Scalar};
use crate::ty::{self, Instance, PolyTraitRef, Ty, TyCtxt};
use rustc_ast::Mutability;
use rustc_data_structures::fx::FxHashSet;
use rustc_hir::def_id::DefId;
use rustc_macros::HashStable;

#[derive(Clone, Copy, PartialEq, HashStable)]
Expand Down Expand Up @@ -40,12 +42,69 @@ impl<'tcx> fmt::Debug for VtblEntry<'tcx> {
impl<'tcx> TyCtxt<'tcx> {
pub const COMMON_VTABLE_ENTRIES: &'tcx [VtblEntry<'tcx>] =
&[VtblEntry::MetadataDropInPlace, VtblEntry::MetadataSize, VtblEntry::MetadataAlign];

pub fn supertrait_def_ids(self, trait_def_id: DefId) -> SupertraitDefIds<'tcx> {
SupertraitDefIds {
tcx: self,
stack: vec![trait_def_id],
visited: Some(trait_def_id).into_iter().collect(),
}
}
}

pub const COMMON_VTABLE_ENTRIES_DROPINPLACE: usize = 0;
pub const COMMON_VTABLE_ENTRIES_SIZE: usize = 1;
pub const COMMON_VTABLE_ENTRIES_ALIGN: usize = 2;

pub struct SupertraitDefIds<'tcx> {
tcx: TyCtxt<'tcx>,
stack: Vec<DefId>,
visited: FxHashSet<DefId>,
}

impl Iterator for SupertraitDefIds<'_> {
type Item = DefId;

fn next(&mut self) -> Option<DefId> {
let def_id = self.stack.pop()?;
let predicates = self.tcx.super_predicates_of(def_id);
let visited = &mut self.visited;
self.stack.extend(
predicates
.predicates
.iter()
.filter_map(|(pred, _)| pred.as_trait_clause())
.map(|trait_ref| trait_ref.def_id())
.filter(|&super_def_id| visited.insert(super_def_id)),
);
Some(def_id)
}
}

// Note that we don't have access to a self type here, this has to be purely based on the trait (and
// supertrait) definitions. That means we can't call into the same vtable_entries code since that
// returns a specific instantiation (e.g., with Vacant slots when bounds aren't satisfied). The goal
// here is to do a best-effort approximation without duplicating a lot of code.
//
// This function is used in layout computation for e.g. &dyn Trait, so it's critical that this
// function is an accurate approximation. We verify this when actually computing the vtable below.
pub(crate) fn vtable_min_entries<'tcx>(
tcx: TyCtxt<'tcx>,
trait_ref: Option<ty::PolyExistentialTraitRef<'tcx>>,
) -> usize {
let mut count = TyCtxt::COMMON_VTABLE_ENTRIES.len();
let Some(trait_ref) = trait_ref else {
return count;
};

// This includes self in supertraits.
for def_id in tcx.supertrait_def_ids(trait_ref.def_id()) {
count += tcx.own_existential_vtable_entries(def_id).len();
}

count
}

/// Retrieves an allocation that represents the contents of a vtable.
/// Since this is a query, allocations are cached and not duplicated.
pub(super) fn vtable_allocation_provider<'tcx>(
Expand All @@ -63,6 +122,9 @@ pub(super) fn vtable_allocation_provider<'tcx>(
TyCtxt::COMMON_VTABLE_ENTRIES
};

// This confirms that the layout computation for &dyn Trait has an accurate sizing.
assert!(vtable_entries.len() >= vtable_min_entries(tcx, poly_trait_ref));

let layout = tcx
.layout_of(ty::ParamEnv::reveal_all().and(ty))
.expect("failed to build vtable representation");
Expand Down
5 changes: 2 additions & 3 deletions compiler/rustc_trait_selection/src/solve/trait_goals.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
//! Dealing with trait goals, i.e. `T: Trait<'a, U>`.

use crate::traits::supertrait_def_ids;

use super::assembly::structural_traits::AsyncCallableRelevantTypes;
use super::assembly::{self, structural_traits, Candidate};
use super::{EvalCtxt, GoalSource, SolverMode};
Expand Down Expand Up @@ -837,7 +835,8 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> {
let a_auto_traits: FxIndexSet<DefId> = a_data
.auto_traits()
.chain(a_data.principal_def_id().into_iter().flat_map(|principal_def_id| {
supertrait_def_ids(self.interner(), principal_def_id)
self.interner()
.supertrait_def_ids(principal_def_id)
.filter(|def_id| self.interner().trait_is_auto(*def_id))
}))
.collect();
Expand Down
5 changes: 1 addition & 4 deletions compiler/rustc_trait_selection/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ pub use self::structural_normalize::StructurallyNormalizeExt;
pub use self::util::elaborate;
pub use self::util::{expand_trait_aliases, TraitAliasExpander, TraitAliasExpansionInfo};
pub use self::util::{get_vtable_index_of_object_method, impl_item_is_final, upcast_choices};
pub use self::util::{
supertrait_def_ids, supertraits, transitive_bounds, transitive_bounds_that_define_assoc_item,
SupertraitDefIds,
};
pub use self::util::{supertraits, transitive_bounds, transitive_bounds_that_define_assoc_item};
pub use self::util::{with_replaced_escaping_bound_vars, BoundVarReplacer, PlaceholderReplacer};

pub use rustc_infer::traits::*;
Expand Down
118 changes: 63 additions & 55 deletions compiler/rustc_trait_selection/src/traits/object_safety.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use rustc_middle::ty::{TypeVisitableExt, Upcast};
use rustc_session::lint::builtin::WHERE_CLAUSES_OBJECT_SAFETY;
use rustc_span::symbol::Symbol;
use rustc_span::Span;
use rustc_target::abi::Abi;
use smallvec::SmallVec;

use std::iter;
Expand All @@ -44,7 +45,8 @@ pub fn hir_ty_lowering_object_safety_violations(
trait_def_id: DefId,
) -> Vec<ObjectSafetyViolation> {
debug_assert!(tcx.generics_of(trait_def_id).has_self);
let violations = traits::supertrait_def_ids(tcx, trait_def_id)
let violations = tcx
.supertrait_def_ids(trait_def_id)
.map(|def_id| predicates_reference_self(tcx, def_id, true))
.filter(|spans| !spans.is_empty())
.map(ObjectSafetyViolation::SupertraitSelf)
Expand All @@ -58,7 +60,7 @@ fn object_safety_violations(tcx: TyCtxt<'_>, trait_def_id: DefId) -> &'_ [Object
debug!("object_safety_violations: {:?}", trait_def_id);

tcx.arena.alloc_from_iter(
traits::supertrait_def_ids(tcx, trait_def_id)
tcx.supertrait_def_ids(trait_def_id)
.flat_map(|def_id| object_safety_violations_for_trait(tcx, def_id)),
)
}
Expand Down Expand Up @@ -145,6 +147,14 @@ fn object_safety_violations_for_trait(
violations.push(ObjectSafetyViolation::SupertraitNonLifetimeBinder(spans));
}

if violations.is_empty() {
for item in tcx.associated_items(trait_def_id).in_definition_order() {
if let ty::AssocKind::Fn = item.kind {
check_receiver_correct(tcx, trait_def_id, *item);
}
}
}

debug!(
"object_safety_violations_for_trait(trait_def_id={:?}) = {:?}",
trait_def_id, violations
Expand Down Expand Up @@ -493,59 +503,8 @@ fn virtual_call_violations_for_method<'tcx>(
};
errors.push(MethodViolationCode::UndispatchableReceiver(span));
} else {
// Do sanity check to make sure the receiver actually has the layout of a pointer.

use rustc_target::abi::Abi;

let param_env = tcx.param_env(method.def_id);

let abi_of_ty = |ty: Ty<'tcx>| -> Option<Abi> {
match tcx.layout_of(param_env.and(ty)) {
Ok(layout) => Some(layout.abi),
Err(err) => {
// #78372
tcx.dcx().span_delayed_bug(
tcx.def_span(method.def_id),
format!("error: {err}\n while computing layout for type {ty:?}"),
);
None
}
}
};

// e.g., `Rc<()>`
let unit_receiver_ty =
receiver_for_self_ty(tcx, receiver_ty, tcx.types.unit, method.def_id);

match abi_of_ty(unit_receiver_ty) {
Some(Abi::Scalar(..)) => (),
abi => {
tcx.dcx().span_delayed_bug(
tcx.def_span(method.def_id),
format!(
"receiver when `Self = ()` should have a Scalar ABI; found {abi:?}"
),
);
}
}

let trait_object_ty = object_ty_for_trait(tcx, trait_def_id, tcx.lifetimes.re_static);

// e.g., `Rc<dyn Trait>`
let trait_object_receiver =
receiver_for_self_ty(tcx, receiver_ty, trait_object_ty, method.def_id);

match abi_of_ty(trait_object_receiver) {
Some(Abi::ScalarPair(..)) => (),
abi => {
tcx.dcx().span_delayed_bug(
tcx.def_span(method.def_id),
format!(
"receiver when `Self = {trait_object_ty}` should have a ScalarPair ABI; found {abi:?}"
),
);
}
}
// We confirm that the `receiver_is_dispatchable` is accurate later,
// see `check_receiver_correct`. It should be kept in sync with this code.
}
}

Expand Down Expand Up @@ -606,6 +565,55 @@ fn virtual_call_violations_for_method<'tcx>(
errors
}

/// This code checks that `receiver_is_dispatchable` is correctly implemented.
///
/// This check is outlined from the object safety check to avoid cycles with
/// layout computation, which relies on knowing whether methods are object safe.
pub fn check_receiver_correct<'tcx>(tcx: TyCtxt<'tcx>, trait_def_id: DefId, method: ty::AssocItem) {
if !is_vtable_safe_method(tcx, trait_def_id, method) {
return;
}

let method_def_id = method.def_id;
let sig = tcx.fn_sig(method_def_id).instantiate_identity();
let param_env = tcx.param_env(method_def_id);
let receiver_ty = tcx.liberate_late_bound_regions(method_def_id, sig.input(0));

if receiver_ty == tcx.types.self_param {
// Assumed OK, may change later if unsized_locals permits `self: Self` as dispatchable.
return;
}

// e.g., `Rc<()>`
let unit_receiver_ty = receiver_for_self_ty(tcx, receiver_ty, tcx.types.unit, method_def_id);
match tcx.layout_of(param_env.and(unit_receiver_ty)).map(|l| l.abi) {
Ok(Abi::Scalar(..)) => (),
abi => {
tcx.dcx().span_delayed_bug(
tcx.def_span(method_def_id),
format!("receiver {unit_receiver_ty:?} when `Self = ()` should have a Scalar ABI; found {abi:?}"),
);
}
}

let trait_object_ty = object_ty_for_trait(tcx, trait_def_id, tcx.lifetimes.re_static);

// e.g., `Rc<dyn Trait>`
let trait_object_receiver =
receiver_for_self_ty(tcx, receiver_ty, trait_object_ty, method_def_id);
match tcx.layout_of(param_env.and(trait_object_receiver)).map(|l| l.abi) {
Ok(Abi::ScalarPair(..)) => (),
abi => {
tcx.dcx().span_delayed_bug(
tcx.def_span(method_def_id),
format!(
"receiver {trait_object_receiver:?} when `Self = {trait_object_ty}` should have a ScalarPair ABI; found {abi:?}"
),
);
}
}
}

/// Performs a type instantiation to produce the version of `receiver_ty` when `Self = self_ty`.
/// For example, for `receiver_ty = Rc<Self>` and `self_ty = Foo`, returns `Rc<Foo>`.
fn receiver_for_self_ty<'tcx>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let a_auto_traits: FxIndexSet<DefId> = a_data
.auto_traits()
.chain(principal_def_id_a.into_iter().flat_map(|principal_def_id| {
util::supertrait_def_ids(self.tcx(), principal_def_id)
self.tcx()
.supertrait_def_ids(principal_def_id)
.filter(|def_id| self.tcx().trait_is_auto(*def_id))
}))
.collect();
Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_trait_selection/src/traits/select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2591,8 +2591,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
let a_auto_traits: FxIndexSet<DefId> = a_data
.auto_traits()
.chain(a_data.principal_def_id().into_iter().flat_map(|principal_def_id| {
util::supertrait_def_ids(tcx, principal_def_id)
.filter(|def_id| tcx.trait_is_auto(*def_id))
tcx.supertrait_def_ids(principal_def_id).filter(|def_id| tcx.trait_is_auto(*def_id))
}))
.collect();

Expand Down
Loading

0 comments on commit f2208b3

Please sign in to comment.