Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement refinement check for RPITITs #111931

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions compiler/rustc_const_eval/src/transform/validate.rs
Expand Up @@ -50,6 +50,9 @@ impl<'tcx> MirPass<'tcx> for Validator {
let param_env = match mir_phase.reveal() {
Reveal::UserFacing => tcx.param_env(def_id),
Reveal::All => tcx.param_env_reveal_all_normalized(def_id),
Reveal::HideReturnPositionImplTraitInTrait => {
unreachable!("only used during refinement checks")
}
};

let always_live_locals = always_storage_live_locals(body);
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_feature/src/active.rs
Expand Up @@ -529,6 +529,8 @@ declare_features! (
(active, proc_macro_hygiene, "1.30.0", Some(54727), None),
/// Allows `&raw const $place_expr` and `&raw mut $place_expr` expressions.
(active, raw_ref_op, "1.41.0", Some(64490), None),
/// Allows use of the `#![refine]` attribute, and checks items for accidental refinements.
(incomplete, refine, "CURRENT_RUSTC_VERSION", Some(1), None),
/// Allows using the `#[register_tool]` attribute.
(active, register_tool, "1.41.0", Some(66079), None),
/// Allows the `#[repr(i128)]` attribute for enums.
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_feature/src/builtin_attrs.rs
Expand Up @@ -497,6 +497,9 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[
experimental!(cfi_encoding)
),

// RFC 3245
gated!(refine, Normal, template!(Word), WarnFollowing, experimental!(refine)),

// ==========================================================================
// Internal attributes: Stability, deprecation, and unsafe:
// ==========================================================================
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_hir_analysis/messages.ftl
Expand Up @@ -219,6 +219,10 @@ hir_analysis_return_type_notation_on_non_rpitit =
.note = function returns `{$ty}`, which is not compatible with associated type return bounds
.label = this function must be `async` or return `impl Trait`

hir_analysis_rpitit_refined = impl method signature does not match trait method signature
.suggestion = replace the return type so that it matches the trait
.label = return type from trait method defined here
.unmatched_bound_label = this bound is stronger than that defined on the trait
hir_analysis_self_in_impl_self =
`Self` is not valid in the self type of an impl block
.note = replace `Self` with a different type
Expand Down
17 changes: 9 additions & 8 deletions compiler/rustc_hir_analysis/src/check/compare_impl_item.rs
Expand Up @@ -28,6 +28,8 @@ use rustc_trait_selection::traits::{
use std::borrow::Cow;
use std::iter;

mod refine;

/// Checks that a method from an impl conforms to the signature of
/// the same method as declared in the trait.
///
Expand All @@ -53,6 +55,12 @@ pub(super) fn compare_impl_method<'tcx>(
impl_trait_ref,
CheckImpliedWfMode::Check,
)?;
refine::compare_impl_trait_in_trait_predicate_entailment(
tcx,
impl_m,
trait_m,
impl_trait_ref,
)?;
};
}

Expand Down Expand Up @@ -823,14 +831,7 @@ pub(super) fn collect_return_position_impl_trait_in_trait_tys<'tcx>(
// of the `impl Sized`. Insert that here, so we don't ICE later.
for assoc_item in tcx.associated_types_for_impl_traits_in_associated_fn(trait_m.def_id) {
if !remapped_types.contains_key(assoc_item) {
remapped_types.insert(
*assoc_item,
ty::EarlyBinder::bind(Ty::new_error_with_message(
tcx,
return_span,
"missing synthetic item for RPITIT",
)),
);
return Err(tcx.sess.delay_span_bug(return_span, "missing synthetic item for RPITIT"));
}
}

Expand Down
240 changes: 240 additions & 0 deletions compiler/rustc_hir_analysis/src/check/compare_impl_item/refine.rs
@@ -0,0 +1,240 @@
use std::ops::ControlFlow;

use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
use rustc_hir as hir;
use rustc_hir::def_id::DefId;
use rustc_infer::infer::outlives::env::OutlivesEnvironment;
use rustc_infer::infer::TyCtxtInferExt;
use rustc_infer::traits::Obligation;
use rustc_middle::traits::{ObligationCause, Reveal};
use rustc_middle::ty::{
self, Ty, TyCtxt, TypeFolder, TypeSuperFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitor,
};
use rustc_span::ErrorGuaranteed;
use rustc_span::{sym, Span};
use rustc_trait_selection::traits::outlives_bounds::InferCtxtExt;
use rustc_trait_selection::traits::{normalize_param_env_or_error, ObligationCtxt};
use rustc_type_ir::fold::TypeFoldable;

/// Check that an implementation does not refine an RPITIT from a trait method signature.
pub(super) fn compare_impl_trait_in_trait_predicate_entailment<'tcx>(
tcx: TyCtxt<'tcx>,
impl_m: ty::AssocItem,
trait_m: ty::AssocItem,
impl_trait_ref: ty::TraitRef<'tcx>,
) -> Result<(), ErrorGuaranteed> {
if !tcx.impl_method_has_trait_impl_trait_tys(impl_m.def_id)
|| tcx.has_attr(impl_m.def_id, sym::refine)
{
return Ok(());
}

let hidden_tys = tcx.collect_return_position_impl_trait_in_trait_tys(impl_m.def_id)?;

let impl_def_id = impl_m.container_id(tcx);
let trait_def_id = trait_m.container_id(tcx);
let trait_m_to_impl_m_args = ty::GenericArgs::identity_for_item(tcx, impl_m.def_id)
.rebase_onto(tcx, impl_def_id, impl_trait_ref.args);

let infcx = tcx.infer_ctxt().build();
let ocx = ObligationCtxt::new(&infcx);

let mut hybrid_preds = tcx.predicates_of(impl_def_id).instantiate_identity(tcx).predicates;
hybrid_preds.extend(
tcx.predicates_of(trait_m.def_id)
.instantiate_own(tcx, trait_m_to_impl_m_args)
.map(|(pred, _)| pred),
);
let normalize_cause =
ObligationCause::misc(tcx.def_span(impl_m.def_id), impl_m.def_id.expect_local());
let unnormalized_param_env = ty::ParamEnv::new(
tcx.mk_clauses(&hybrid_preds),
Reveal::HideReturnPositionImplTraitInTrait,
);
let param_env = normalize_param_env_or_error(tcx, unnormalized_param_env, normalize_cause);

let bound_trait_m_sig = tcx.fn_sig(trait_m.def_id).instantiate(tcx, trait_m_to_impl_m_args);
let unnormalized_trait_m_sig =
tcx.liberate_late_bound_regions(impl_m.def_id, bound_trait_m_sig);
let trait_m_sig = ocx.normalize(&ObligationCause::dummy(), param_env, unnormalized_trait_m_sig);

let mut visitor = ImplTraitInTraitCollector { tcx, types: FxIndexMap::default() };
trait_m_sig.visit_with(&mut visitor);

let mut reverse_mapping = FxIndexMap::default();
let mut bounds_to_prove = vec![];
for (rpitit_def_id, rpitit_args) in visitor.types {
let hidden_ty =
hidden_tys.get(&rpitit_def_id).expect("expected hidden type for RPITIT").instantiate(
tcx,
rpitit_args.rebase_onto(
tcx,
trait_def_id,
ty::GenericArgs::identity_for_item(tcx, impl_def_id),
),
);
reverse_mapping.insert(hidden_ty, Ty::new_projection(tcx, rpitit_def_id, rpitit_args));

let ty::Alias(ty::Opaque, opaque_ty) = *hidden_ty.kind() else {
return Err(report_mismatched_rpitit_signature(
tcx,
trait_m_sig,
trait_m.def_id,
impl_m.def_id,
None,
));
};

// Check that this is an opaque that comes from our impl fn
if !tcx.hir().get_if_local(opaque_ty.def_id).map_or(false, |node| {
matches!(
node.expect_item().expect_opaque_ty().origin,
hir::OpaqueTyOrigin::AsyncFn(def_id) | hir::OpaqueTyOrigin::FnReturn(def_id)
if def_id == impl_m.def_id.expect_local()
)
}) {
return Err(report_mismatched_rpitit_signature(
tcx,
trait_m_sig,
trait_m.def_id,
impl_m.def_id,
None,
));
}

bounds_to_prove.extend(
tcx.explicit_item_bounds(opaque_ty.def_id)
.iter_instantiated_copied(tcx, opaque_ty.args),
);
}

ocx.register_obligations(
bounds_to_prove.fold_with(&mut ReverseMapper { tcx, reverse_mapping }).into_iter().map(
|(pred, span)| {
Obligation::new(tcx, ObligationCause::dummy_with_span(span), param_env, pred)
},
),
);

let errors = ocx.select_all_or_error();
if !errors.is_empty() {
let span = errors.first().unwrap().obligation.cause.span;
return Err(report_mismatched_rpitit_signature(
tcx,
trait_m_sig,
trait_m.def_id,
impl_m.def_id,
Some(span),
));
}

let mut wf_tys = FxIndexSet::default();
wf_tys.extend(unnormalized_trait_m_sig.inputs_and_output);
wf_tys.extend(trait_m_sig.inputs_and_output);
let outlives_env = OutlivesEnvironment::with_bounds(
param_env,
ocx.infcx.implied_bounds_tys(param_env, impl_m.def_id.expect_local(), wf_tys.clone()),
);
let errors = ocx.infcx.resolve_regions(&outlives_env);
if !errors.is_empty() {
return Err(report_mismatched_rpitit_signature(
tcx,
trait_m_sig,
trait_m.def_id,
impl_m.def_id,
None,
));
}

Ok(())
}

struct ImplTraitInTraitCollector<'tcx> {
tcx: TyCtxt<'tcx>,
types: FxIndexMap<DefId, ty::GenericArgsRef<'tcx>>,
}

impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitCollector<'tcx> {
type BreakTy = !;

fn visit_ty(&mut self, ty: Ty<'tcx>) -> std::ops::ControlFlow<Self::BreakTy> {
if let ty::Alias(ty::Projection, proj) = *ty.kind()
&& self.tcx.is_impl_trait_in_trait(proj.def_id)
{
if self.types.insert(proj.def_id, proj.args).is_none() {
for (pred, _) in self
.tcx
.explicit_item_bounds(proj.def_id)
.iter_instantiated_copied(self.tcx, proj.args)
{
pred.visit_with(self)?;
}
}
ControlFlow::Continue(())
} else {
ty.super_visit_with(self)
}
}
}

struct ReverseMapper<'tcx> {
tcx: TyCtxt<'tcx>,
reverse_mapping: FxIndexMap<Ty<'tcx>, Ty<'tcx>>,
}

impl<'tcx> TypeFolder<TyCtxt<'tcx>> for ReverseMapper<'tcx> {
fn interner(&self) -> TyCtxt<'tcx> {
self.tcx
}

fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
if let Some(ty) = self.reverse_mapping.get(&ty) { *ty } else { ty.super_fold_with(self) }
}
}

fn report_mismatched_rpitit_signature<'tcx>(
tcx: TyCtxt<'tcx>,
trait_m_sig: ty::FnSig<'tcx>,
trait_m_def_id: DefId,
impl_m_def_id: DefId,
unmatched_bound: Option<Span>,
) -> ErrorGuaranteed {
let mapping = std::iter::zip(
tcx.fn_sig(trait_m_def_id).skip_binder().bound_vars(),
tcx.fn_sig(impl_m_def_id).skip_binder().bound_vars(),
)
.filter_map(|(impl_bv, trait_bv)| {
if let ty::BoundVariableKind::Region(impl_bv) = impl_bv
&& let ty::BoundVariableKind::Region(trait_bv) = trait_bv
{
Some((impl_bv, trait_bv))
} else {
None
}
})
.collect();

let return_ty =
trait_m_sig.output().fold_with(&mut super::RemapLateBound { tcx, mapping: &mapping });

let (span, impl_return_span, sugg) =
match tcx.hir().get_by_def_id(impl_m_def_id.expect_local()).fn_decl().unwrap().output {
hir::FnRetTy::DefaultReturn(span) => {
(tcx.def_span(impl_m_def_id), span, format!("-> {return_ty} "))
}
hir::FnRetTy::Return(ty) => (ty.span, ty.span, format!("{return_ty}")),
};
let trait_return_span =
tcx.hir().get_if_local(trait_m_def_id).map(|node| match node.fn_decl().unwrap().output {
hir::FnRetTy::DefaultReturn(_) => tcx.def_span(trait_m_def_id),
hir::FnRetTy::Return(ty) => ty.span,
});

tcx.sess.emit_err(crate::errors::ReturnPositionImplTraitInTraitRefined {
span,
impl_return_span,
trait_return_span,
sugg,
unmatched_bound,
})
}
14 changes: 14 additions & 0 deletions compiler/rustc_hir_analysis/src/errors.rs
Expand Up @@ -918,3 +918,17 @@ pub struct UnusedAssociatedTypeBounds {
#[suggestion(code = "")]
pub span: Span,
}

#[derive(Diagnostic)]
#[diag(hir_analysis_rpitit_refined)]
pub(crate) struct ReturnPositionImplTraitInTraitRefined {
#[primary_span]
pub span: Span,
#[suggestion(applicability = "maybe-incorrect", code = "{sugg}")]
pub impl_return_span: Span,
#[label]
pub trait_return_span: Option<Span>,
pub sugg: String,
#[label(hir_analysis_unmatched_bound_label)]
pub unmatched_bound: Option<Span>,
}
9 changes: 9 additions & 0 deletions compiler/rustc_lint_defs/src/builtin.rs
Expand Up @@ -3436,6 +3436,7 @@ declare_lint_pass! {
TYVAR_BEHIND_RAW_POINTER,
UNCONDITIONAL_PANIC,
UNCONDITIONAL_RECURSION,
UNDECLARED_REFINE,
UNDEFINED_NAKED_FUNCTION_ABI,
UNFULFILLED_LINT_EXPECTATIONS,
UNINHABITED_STATIC,
Expand Down Expand Up @@ -4422,6 +4423,14 @@ declare_lint! {
@feature_gate = sym::type_privacy_lints;
}

declare_lint! {
/// Refine
pub UNDECLARED_REFINE,
Warn,
"yeet",
@feature_gate = sym::refine;
}

declare_lint! {
/// The `unknown_diagnostic_attributes` lint detects unrecognized diagnostic attributes.
///
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_middle/src/traits/mod.rs
Expand Up @@ -64,6 +64,11 @@ pub enum Reveal {
/// type-checking.
UserFacing,

// Same as user-facing reveal, but do not project ("reveal") return-position
// impl trait in traits. This is only used for checking that an RPITIT is not
// refined by an implementation.
HideReturnPositionImplTraitInTrait,

/// At codegen time, all monomorphic projections will succeed.
/// Also, `impl Trait` is normalized to the concrete type,
/// which has to be already collected by type-checking.
Expand Down