Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eagerly instantiate closure/coroutine-like bounds with placeholders to deal with binders correctly #122267

Merged
merged 2 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
120 changes: 56 additions & 64 deletions compiler/rustc_trait_selection/src/traits/select/confirmation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::traits::{
BuiltinDerivedObligation, ImplDerivedObligation, ImplDerivedObligationCause, ImplSource,
ImplSourceUserDefinedData, Normalized, Obligation, ObligationCause, PolyTraitObligation,
PredicateObligation, Selection, SelectionError, SignatureMismatch, TraitNotObjectSafe,
Unimplemented,
TraitObligation, Unimplemented,
};

use super::BuiltinImplConditions;
Expand Down Expand Up @@ -678,17 +678,10 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
fn_host_effect: ty::Const<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
debug!(?obligation, "confirm_fn_pointer_candidate");
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());

let tcx = self.tcx();

let Some(self_ty) = self.infcx.shallow_resolve(obligation.self_ty().no_bound_vars()) else {
// FIXME: Ideally we'd support `for<'a> fn(&'a ()): Fn(&'a ())`,
// but we do not currently. Luckily, such a bound is not
// particularly useful, so we don't expect users to write
// them often.
return Err(SelectionError::Unimplemented);
};

let sig = self_ty.fn_sig(tcx);
let trait_ref = closure_trait_ref_and_return_type(
tcx,
Expand All @@ -700,7 +693,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
)
.map_bound(|(trait_ref, _)| trait_ref);

let mut nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;
let mut nested =
self.equate_trait_refs(obligation.with(tcx, placeholder_predicate), trait_ref)?;
let cause = obligation.derived_cause(BuiltinDerivedObligation);

// Confirm the `type Output: Sized;` bound that is present on `FnOnce`
Expand Down Expand Up @@ -748,10 +742,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on coroutine types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
bug!("closure candidate for non-closure {:?}", obligation);
};
Expand All @@ -760,23 +752,17 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {

let coroutine_sig = args.as_coroutine().sig();

// NOTE: The self-type is a coroutine type and hence is
// in fact unparameterized (or at least does not reference any
// regions bound in the obligation).
let self_ty = obligation
.predicate
.self_ty()
.no_bound_vars()
.expect("unboxed closure type should not capture bound vars from the predicate");

let (trait_ref, _, _) = super::util::coroutine_trait_ref_and_outputs(
self.tcx(),
obligation.predicate.def_id(),
self_ty,
coroutine_sig,
);

let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
let nested = self.equate_trait_refs(
obligation.with(self.tcx(), placeholder_predicate),
ty::Binder::dummy(trait_ref),
)?;
debug!(?trait_ref, ?nested, "coroutine candidate obligations");

Ok(nested)
Expand All @@ -786,10 +772,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on coroutine types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
bug!("closure candidate for non-closure {:?}", obligation);
};
Expand All @@ -801,11 +785,14 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let (trait_ref, _) = super::util::future_trait_ref_and_outputs(
self.tcx(),
obligation.predicate.def_id(),
obligation.predicate.no_bound_vars().expect("future has no bound vars").self_ty(),
self_ty,
coroutine_sig,
);

let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
let nested = self.equate_trait_refs(
obligation.with(self.tcx(), placeholder_predicate),
ty::Binder::dummy(trait_ref),
)?;
debug!(?trait_ref, ?nested, "future candidate obligations");

Ok(nested)
Expand All @@ -815,10 +802,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on coroutine types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
bug!("closure candidate for non-closure {:?}", obligation);
};
Expand All @@ -830,11 +815,14 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let (trait_ref, _) = super::util::iterator_trait_ref_and_outputs(
self.tcx(),
obligation.predicate.def_id(),
obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
self_ty,
gen_sig,
);

let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
let nested = self.equate_trait_refs(
obligation.with(self.tcx(), placeholder_predicate),
ty::Binder::dummy(trait_ref),
)?;
debug!(?trait_ref, ?nested, "iterator candidate obligations");

Ok(nested)
Expand All @@ -844,10 +832,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on coroutine types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
bug!("closure candidate for non-closure {:?}", obligation);
};
Expand All @@ -859,11 +845,14 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let (trait_ref, _) = super::util::async_iterator_trait_ref_and_outputs(
self.tcx(),
obligation.predicate.def_id(),
obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
self_ty,
gen_sig,
);

let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
let nested = self.equate_trait_refs(
obligation.with(self.tcx(), placeholder_predicate),
ty::Binder::dummy(trait_ref),
)?;
debug!(?trait_ref, ?nested, "iterator candidate obligations");

Ok(nested)
Expand All @@ -874,14 +863,15 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on closure types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty: Ty<'_> = self.infcx.shallow_resolve(placeholder_predicate.self_ty());

let trait_ref = match *self_ty.kind() {
ty::Closure(_, args) => {
self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_)
}
ty::Closure(..) => self.closure_trait_ref_unnormalized(
self_ty,
obligation.predicate.def_id(),
self.tcx().consts.true_,
),
ty::CoroutineClosure(_, args) => {
args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
ty::TraitRef::new(
Expand All @@ -896,16 +886,18 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
}
};

self.confirm_poly_trait_refs(obligation, trait_ref)
self.equate_trait_refs(obligation.with(self.tcx(), placeholder_predicate), trait_ref)
}

#[instrument(skip(self), level = "debug")]
fn confirm_async_closure_candidate(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());

let tcx = self.tcx();
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());

let mut nested = vec![];
let (trait_ref, kind_ty) = match *self_ty.kind() {
Expand Down Expand Up @@ -972,7 +964,9 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
_ => bug!("expected callable type for AsyncFn candidate"),
};

nested.extend(self.confirm_poly_trait_refs(obligation, trait_ref)?);
nested.extend(
self.equate_trait_refs(obligation.with(tcx, placeholder_predicate), trait_ref)?,
);

let goal_kind =
self.tcx().async_fn_trait_kind_from_def_id(obligation.predicate.def_id()).unwrap();
Expand Down Expand Up @@ -1025,42 +1019,40 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
/// selection of the impl. Therefore, if there is a mismatch, we
/// report an error to the user.
#[instrument(skip(self), level = "trace")]
fn confirm_poly_trait_refs(
fn equate_trait_refs(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
self_ty_trait_ref: ty::PolyTraitRef<'tcx>,
obligation: TraitObligation<'tcx>,
found_trait_ref: ty::PolyTraitRef<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
let obligation_trait_ref =
self.infcx.enter_forall_and_leak_universe(obligation.predicate.to_poly_trait_ref());
let self_ty_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
let found_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
self_ty_trait_ref,
found_trait_ref,
);
// Normalize the obligation and expected trait refs together, because why not
let Normalized { obligations: nested, value: (obligation_trait_ref, expected_trait_ref) } =
let Normalized { obligations: nested, value: (obligation_trait_ref, found_trait_ref) } =
ensure_sufficient_stack(|| {
normalize_with_depth(
self,
obligation.param_env,
obligation.cause.clone(),
obligation.recursion_depth + 1,
(obligation_trait_ref, self_ty_trait_ref),
(obligation.predicate.trait_ref, found_trait_ref),
)
});

// needed to define opaque types for tests/ui/type-alias-impl-trait/assoc-projection-ice.rs
self.infcx
.at(&obligation.cause, obligation.param_env)
.eq(DefineOpaqueTypes::Yes, obligation_trait_ref, expected_trait_ref)
.eq(DefineOpaqueTypes::Yes, obligation_trait_ref, found_trait_ref)
.map(|InferOk { mut obligations, .. }| {
obligations.extend(nested);
obligations
})
.map_err(|terr| {
SignatureMismatch(Box::new(SignatureMismatchData {
expected_trait_ref: ty::Binder::dummy(obligation_trait_ref),
found_trait_ref: ty::Binder::dummy(expected_trait_ref),
found_trait_ref: ty::Binder::dummy(found_trait_ref),
terr,
}))
})
Expand Down
20 changes: 6 additions & 14 deletions compiler/rustc_trait_selection/src/traits/select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2679,26 +2679,18 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
#[instrument(skip(self), level = "debug")]
fn closure_trait_ref_unnormalized(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
args: GenericArgsRef<'tcx>,
self_ty: Ty<'tcx>,
fn_trait_def_id: DefId,
fn_host_effect: ty::Const<'tcx>,
) -> ty::PolyTraitRef<'tcx> {
let ty::Closure(_, args) = *self_ty.kind() else {
bug!("expected closure, found {self_ty}");
};
let closure_sig = args.as_closure().sig();

debug!(?closure_sig);

// NOTE: The self-type is an unboxed closure type and hence is
// in fact unparameterized (or at least does not reference any
// regions bound in the obligation).
let self_ty = obligation
.predicate
.self_ty()
.no_bound_vars()
.expect("unboxed closure type should not capture bound vars from the predicate");

closure_trait_ref_and_return_type(
self.tcx(),
obligation.predicate.def_id(),
fn_trait_def_id,
self_ty,
closure_sig,
util::TupleArgumentsFlag::No,
Expand Down
58 changes: 58 additions & 0 deletions tests/ui/higher-ranked/builtin-closure-like-bounds.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
//@ edition:2024
//@ compile-flags: -Zunstable-options
//@ revisions: current next
//@[next] compile-flags: -Znext-solver
//@ check-pass

compiler-errors marked this conversation as resolved.
Show resolved Hide resolved
// Makes sure that we support closure/coroutine goals where the signature of
// the item references higher-ranked lifetimes from the *predicate* binder,
// not its own internal signature binder.
//
// This was fixed in <https://github.com/rust-lang/rust/pull/122267>.

#![feature(unboxed_closures, gen_blocks)]

trait Dispatch {
fn dispatch(self);
}

struct Fut<T>(T);
impl<T: for<'a> Fn<(&'a (),)>> Dispatch for Fut<T>
where
for<'a> <T as FnOnce<(&'a (),)>>::Output: Future,
{
fn dispatch(self) {
(self.0)(&());
}
}

struct Gen<T>(T);
impl<T: for<'a> Fn<(&'a (),)>> Dispatch for Gen<T>
where
for<'a> <T as FnOnce<(&'a (),)>>::Output: Iterator,
{
fn dispatch(self) {
(self.0)(&());
}
}

struct Closure<T>(T);
impl<T: for<'a> Fn<(&'a (),)>> Dispatch for Closure<T>
where
for<'a> <T as FnOnce<(&'a (),)>>::Output: Fn<(&'a (),)>,
{
fn dispatch(self) {
(self.0)(&())(&());
}
}

fn main() {
async fn foo(_: &()) {}
Fut(foo).dispatch();

gen fn bar(_: &()) {}
Gen(bar).dispatch();

fn uwu<'a>(x: &'a ()) -> impl Fn(&'a ()) { |_| {} }
Closure(uwu).dispatch();
}