Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nll: correctly deal with bivariance #104411

Merged
merged 4 commits into from
Nov 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions compiler/rustc_const_eval/src/interpret/eval_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use super::{
MemPlaceMeta, Memory, MemoryKind, Operand, Place, PlaceTy, PointerArithmetic, Provenance,
Scalar, StackPopJump,
};
use crate::transform::validate::equal_up_to_regions;
use crate::util;

pub struct InterpCx<'mir, 'tcx, M: Machine<'mir, 'tcx>> {
/// Stores the `Machine` instance.
Expand Down Expand Up @@ -354,8 +354,8 @@ pub(super) fn mir_assign_valid_types<'tcx>(
// Type-changing assignments can happen when subtyping is used. While
// all normal lifetimes are erased, higher-ranked types with their
// late-bound lifetimes are still around and can lead to type
// differences. So we compare ignoring lifetimes.
if equal_up_to_regions(tcx, param_env, src.ty, dest.ty) {
// differences.
if util::is_subtype(tcx, param_env, src.ty, dest.ty) {
// Make sure the layout is equal, too -- just to be safe. Miri really
// needs layout equality. For performance reason we skip this check when
// the types are equal. Equal types *can* have different layouts when
Expand Down
59 changes: 2 additions & 57 deletions compiler/rustc_const_eval/src/transform/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

use rustc_data_structures::fx::FxHashSet;
use rustc_index::bit_set::BitSet;
use rustc_infer::infer::TyCtxtInferExt;
use rustc_middle::mir::interpret::Scalar;
use rustc_middle::mir::visit::NonUseContext::VarDebugInfo;
use rustc_middle::mir::visit::{PlaceContext, Visitor};
Expand All @@ -12,8 +11,7 @@ use rustc_middle::mir::{
ProjectionElem, RuntimePhase, Rvalue, SourceScope, Statement, StatementKind, Terminator,
TerminatorKind, UnOp, START_BLOCK,
};
use rustc_middle::ty::fold::BottomUpFolder;
use rustc_middle::ty::{self, InstanceDef, ParamEnv, Ty, TyCtxt, TypeFoldable, TypeVisitable};
use rustc_middle::ty::{self, InstanceDef, ParamEnv, Ty, TyCtxt, TypeVisitable};
use rustc_mir_dataflow::impls::MaybeStorageLive;
use rustc_mir_dataflow::storage::always_storage_live_locals;
use rustc_mir_dataflow::{Analysis, ResultsCursor};
Expand Down Expand Up @@ -70,44 +68,6 @@ impl<'tcx> MirPass<'tcx> for Validator {
}
}

/// Returns whether the two types are equal up to lifetimes.
/// All lifetimes, including higher-ranked ones, get ignored for this comparison.
/// (This is unlike the `erasing_regions` methods, which keep higher-ranked lifetimes for soundness reasons.)
///
/// The point of this function is to approximate "equal up to subtyping". However,
/// the approximation is incorrect as variance is ignored.
pub fn equal_up_to_regions<'tcx>(
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
src: Ty<'tcx>,
dest: Ty<'tcx>,
) -> bool {
// Fast path.
if src == dest {
return true;
}

// Normalize lifetimes away on both sides, then compare.
let normalize = |ty: Ty<'tcx>| {
tcx.try_normalize_erasing_regions(param_env, ty).unwrap_or(ty).fold_with(
&mut BottomUpFolder {
tcx,
// FIXME: We erase all late-bound lifetimes, but this is not fully correct.
// If you have a type like `<for<'a> fn(&'a u32) as SomeTrait>::Assoc`,
// this is not necessarily equivalent to `<fn(&'static u32) as SomeTrait>::Assoc`,
// since one may have an `impl SomeTrait for fn(&32)` and
// `impl SomeTrait for fn(&'static u32)` at the same time which
// specify distinct values for Assoc. (See also #56105)
lt_op: |_| tcx.lifetimes.re_erased,
// Leave consts and types unchanged.
ct_op: |ct| ct,
ty_op: |ty| ty,
},
)
};
tcx.infer_ctxt().build().can_eq(param_env, normalize(src), normalize(dest)).is_ok()
}

struct TypeChecker<'a, 'tcx> {
when: &'a str,
body: &'a Body<'tcx>,
Expand Down Expand Up @@ -183,22 +143,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
return true;
}

// Normalize projections and things like that.
// Type-changing assignments can happen when subtyping is used. While
// all normal lifetimes are erased, higher-ranked types with their
// late-bound lifetimes are still around and can lead to type
// differences. So we compare ignoring lifetimes.

// First, try with reveal_all. This might not work in some cases, as the predicates
// can be cleared in reveal_all mode. We try the reveal first anyways as it is used
// by some other passes like inlining as well.
let param_env = self.param_env.with_reveal_all_normalized(self.tcx);
if equal_up_to_regions(self.tcx, param_env, src, dest) {
return true;
}

// If this fails, we can try it without the reveal.
equal_up_to_regions(self.tcx, self.param_env, src, dest)
crate::util::is_subtype(self.tcx, self.param_env, src, dest)
}
}

Expand Down
63 changes: 63 additions & 0 deletions compiler/rustc_const_eval/src/util/compare_types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
//! Routines to check for relations between fully inferred types.
//!
//! FIXME: Move this to a more general place. The utility of this extends to
//! other areas of the compiler as well.

use rustc_infer::infer::{DefiningAnchor, TyCtxtInferExt};
use rustc_infer::traits::ObligationCause;
use rustc_middle::ty::{ParamEnv, Ty, TyCtxt};
use rustc_trait_selection::traits::ObligationCtxt;

/// Returns whether the two types are equal up to subtyping.
///
/// This is used in case we don't know the expected subtyping direction
/// and still want to check whether anything is broken.
pub fn is_equal_up_to_subtyping<'tcx>(
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
src: Ty<'tcx>,
dest: Ty<'tcx>,
) -> bool {
// Fast path.
if src == dest {
return true;
}

// Check for subtyping in either direction.
is_subtype(tcx, param_env, src, dest) || is_subtype(tcx, param_env, dest, src)
}

/// Returns whether `src` is a subtype of `dest`, i.e. `src <: dest`.
///
/// This mostly ignores opaque types as it can be used in constraining contexts
/// while still computing the final underlying type.
pub fn is_subtype<'tcx>(
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
src: Ty<'tcx>,
dest: Ty<'tcx>,
) -> bool {
if src == dest {
return true;
}

let mut builder =
tcx.infer_ctxt().ignoring_regions().with_opaque_type_inference(DefiningAnchor::Bubble);
let infcx = builder.build();
let ocx = ObligationCtxt::new(&infcx);
let cause = ObligationCause::dummy();
let src = ocx.normalize(cause.clone(), param_env, src);
let dest = ocx.normalize(cause.clone(), param_env, dest);
match ocx.sub(&cause, param_env, src, dest) {
Ok(()) => {}
Err(_) => return false,
};
let errors = ocx.select_all_or_error();
// With `Reveal::All`, opaque types get normalized away, with `Reveal::UserFacing`
// we would get unification errors because we're unable to look into opaque types,
// even if they're constrained in our current function.
//
// It seems very unlikely that this hides any bugs.
let _ = infcx.inner.borrow_mut().opaque_type_storage.take_opaque_types();
errors.is_empty()
}
2 changes: 2 additions & 0 deletions compiler/rustc_const_eval/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ pub mod aggregate;
mod alignment;
mod call_kind;
pub mod collect_writes;
mod compare_types;
mod find_self_call;
mod might_permit_raw_init;
mod type_name;

pub use self::aggregate::expand_aggregate;
pub use self::alignment::is_disaligned;
pub use self::call_kind::{call_kind, CallDesugaringKind, CallKind};
pub use self::compare_types::{is_equal_up_to_subtyping, is_subtype};
pub use self::find_self_call::find_self_call;
pub use self::might_permit_raw_init::might_permit_raw_init;
pub use self::type_name::type_name;
4 changes: 2 additions & 2 deletions compiler/rustc_hir_analysis/src/check/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,8 @@ fn check_opaque_meets_bounds<'tcx>(

let misc_cause = traits::ObligationCause::misc(span, hir_id);

match infcx.at(&misc_cause, param_env).eq(opaque_ty, hidden_ty) {
Ok(infer_ok) => ocx.register_infer_ok_obligations(infer_ok),
match ocx.eq(&misc_cause, param_env, opaque_ty, hidden_ty) {
Ok(()) => {}
Err(ty_err) => {
tcx.sess.delay_span_bug(
span,
Expand Down
12 changes: 4 additions & 8 deletions compiler/rustc_hir_analysis/src/check/compare_method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,8 @@ pub fn collect_trait_impl_trait_tys<'tcx>(
unnormalized_trait_sig.inputs_and_output.iter().chain(trait_sig.inputs_and_output.iter()),
);

match infcx.at(&cause, param_env).eq(trait_return_ty, impl_return_ty) {
Ok(infer::InferOk { value: (), obligations }) => {
ocx.register_obligations(obligations);
}
match ocx.eq(&cause, param_env, trait_return_ty, impl_return_ty) {
Ok(()) => {}
Err(terr) => {
let mut diag = struct_span_err!(
tcx.sess,
Expand Down Expand Up @@ -442,10 +440,8 @@ pub fn collect_trait_impl_trait_tys<'tcx>(
// the lifetimes of the return type, but do this after unifying just the
// return types, since we want to avoid duplicating errors from
// `compare_predicate_entailment`.
match infcx.at(&cause, param_env).eq(trait_fty, impl_fty) {
Ok(infer::InferOk { value: (), obligations }) => {
ocx.register_obligations(obligations);
}
match ocx.eq(&cause, param_env, trait_fty, impl_fty) {
Ok(()) => {}
Err(terr) => {
// This function gets called during `compare_predicate_entailment` when normalizing a
// signature that contains RPITIT. When the method signatures don't match, we have to
Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_infer/src/infer/nll_relate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,9 @@ where
self.ambient_variance_info = self.ambient_variance_info.xform(info);

debug!(?self.ambient_variance);

let r = self.relate(a, b)?;
// In a bivariant context this always succeeds.
let r =
if self.ambient_variance == ty::Variance::Bivariant { a } else { self.relate(a, b)? };

self.ambient_variance = old_ambient_variance;

Expand Down
12 changes: 6 additions & 6 deletions compiler/rustc_mir_transform/src/inline.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
//! Inlining pass for MIR functions
use crate::deref_separator::deref_finder;
use rustc_attr::InlineAttr;
use rustc_const_eval::transform::validate::equal_up_to_regions;
use rustc_index::bit_set::BitSet;
use rustc_index::vec::Idx;
use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs};
Expand All @@ -14,7 +13,8 @@ use rustc_span::{hygiene::ExpnKind, ExpnData, LocalExpnId, Span};
use rustc_target::abi::VariantIdx;
use rustc_target::spec::abi::Abi;

use super::simplify::{remove_dead_blocks, CfgSimplifier};
use crate::simplify::{remove_dead_blocks, CfgSimplifier};
use crate::util;
use crate::MirPass;
use std::iter;
use std::ops::{Range, RangeFrom};
Expand Down Expand Up @@ -180,7 +180,7 @@ impl<'tcx> Inliner<'tcx> {
let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() };
let destination_ty = destination.ty(&caller_body.local_decls, self.tcx).ty;
let output_type = callee_body.return_ty();
if !equal_up_to_regions(self.tcx, self.param_env, output_type, destination_ty) {
if !util::is_subtype(self.tcx, self.param_env, output_type, destination_ty) {
trace!(?output_type, ?destination_ty);
return Err("failed to normalize return type");
}
Expand All @@ -200,7 +200,7 @@ impl<'tcx> Inliner<'tcx> {
arg_tuple_tys.iter().zip(callee_body.args_iter().skip(skipped_args))
{
let input_type = callee_body.local_decls[input].ty;
if !equal_up_to_regions(self.tcx, self.param_env, arg_ty, input_type) {
if !util::is_subtype(self.tcx, self.param_env, input_type, arg_ty) {
trace!(?arg_ty, ?input_type);
return Err("failed to normalize tuple argument type");
}
Expand All @@ -209,7 +209,7 @@ impl<'tcx> Inliner<'tcx> {
for (arg, input) in args.iter().zip(callee_body.args_iter()) {
let input_type = callee_body.local_decls[input].ty;
let arg_ty = arg.ty(&caller_body.local_decls, self.tcx);
if !equal_up_to_regions(self.tcx, self.param_env, arg_ty, input_type) {
if !util::is_subtype(self.tcx, self.param_env, input_type, arg_ty) {
trace!(?arg_ty, ?input_type);
return Err("failed to normalize argument type");
}
Expand Down Expand Up @@ -847,7 +847,7 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
let parent = Place { local, projection: self.tcx.intern_place_elems(proj_base) };
let parent_ty = parent.ty(&self.callee_body.local_decls, self.tcx);
let check_equal = |this: &mut Self, f_ty| {
if !equal_up_to_regions(this.tcx, this.param_env, ty, f_ty) {
if !util::is_equal_up_to_subtyping(this.tcx, this.param_env, ty, f_ty) {
trace!(?ty, ?f_ty);
this.validation = Err("failed to normalize projection type");
return;
Expand Down
26 changes: 19 additions & 7 deletions compiler/rustc_trait_selection/src/traits/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,20 +125,32 @@ impl<'a, 'tcx> ObligationCtxt<'a, 'tcx> {
.map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
}

/// Checks whether `expected` is a subtype of `actual`: `expected <: actual`.
pub fn sub<T: ToTrace<'tcx>>(
&self,
cause: &ObligationCause<'tcx>,
param_env: ty::ParamEnv<'tcx>,
expected: T,
actual: T,
) -> Result<(), TypeError<'tcx>> {
self.infcx
.at(cause, param_env)
.sup(expected, actual)
.map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
}

/// Checks whether `expected` is a supertype of `actual`: `expected :> actual`.
pub fn sup<T: ToTrace<'tcx>>(
&self,
cause: &ObligationCause<'tcx>,
param_env: ty::ParamEnv<'tcx>,
expected: T,
actual: T,
) -> Result<(), TypeError<'tcx>> {
match self.infcx.at(cause, param_env).sup(expected, actual) {
Ok(InferOk { obligations, value: () }) => {
self.register_obligations(obligations);
Ok(())
}
Err(e) => Err(e),
}
self.infcx
.at(cause, param_env)
.sup(expected, actual)
.map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
}

pub fn select_where_possible(&self) -> Vec<FulfillmentError<'tcx>> {
Expand Down
26 changes: 26 additions & 0 deletions src/test/ui/mir/important-higher-ranked-regions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// check-pass
// compile-flags: -Zvalidate-mir

// This test checks that bivariant parameters are handled correctly
// in the mir.
#![allow(coherence_leak_check)]
trait Trait {
type Assoc;
}

struct Foo<T, U>(T)
where
T: Trait<Assoc = U>;

impl Trait for for<'a> fn(&'a ()) {
type Assoc = u32;
}
impl Trait for fn(&'static ()) {
type Assoc = String;
}

fn foo(x: Foo<for<'a> fn(&'a ()), u32>) -> Foo<fn(&'static ()), String> {
x
}

fn main() {}