Skip to content

Commit

Permalink
Auto merge of #87375 - fee1-dead:move-constness-to-traitpred, r=oli-obk
Browse files Browse the repository at this point in the history
Try filtering out non-const impls when we expect const impls

**TL;DR**: Associated types on const impls are now bounded; we now disallow calling a const function with bounds when the specified type param only has a non-const impl.

r? `@oli-obk`
  • Loading branch information
bors committed Aug 14, 2021
2 parents fa26929 + 74627c1 commit 136eaa1
Show file tree
Hide file tree
Showing 68 changed files with 517 additions and 211 deletions.
23 changes: 21 additions & 2 deletions compiler/rustc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2751,6 +2751,15 @@ pub enum Constness {
NotConst,
}

impl fmt::Display for Constness {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match *self {
Self::Const => "const",
Self::NotConst => "non-const",
})
}
}

#[derive(Copy, Clone, Encodable, Debug, HashStable_Generic)]
pub struct FnHeader {
pub unsafety: Unsafety,
Expand Down Expand Up @@ -3252,8 +3261,13 @@ impl<'hir> Node<'hir> {
}
}

/// Returns `Constness::Const` when this node is a const fn/impl.
pub fn constness(&self) -> Constness {
/// Returns `Constness::Const` when this node is a const fn/impl/item,
///
/// HACK(fee1-dead): or an associated type in a trait. This works because
/// only typeck cares about const trait predicates, so although the predicates
/// query would return const predicates when it does not need to be const,
/// it wouldn't have any effect.
pub fn constness_for_typeck(&self) -> Constness {
match self {
Node::Item(Item {
kind: ItemKind::Fn(FnSig { header: FnHeader { constness, .. }, .. }, ..),
Expand All @@ -3269,6 +3283,11 @@ impl<'hir> Node<'hir> {
})
| Node::Item(Item { kind: ItemKind::Impl(Impl { constness, .. }), .. }) => *constness,

Node::Item(Item { kind: ItemKind::Const(..), .. })
| Node::TraitItem(TraitItem { kind: TraitItemKind::Const(..), .. })
| Node::TraitItem(TraitItem { kind: TraitItemKind::Type(..), .. })
| Node::ImplItem(ImplItem { kind: ImplItemKind::Const(..), .. }) => Constness::Const,

_ => Constness::NotConst,
}
}
Expand Down
18 changes: 18 additions & 0 deletions compiler/rustc_infer/src/traits/engine.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::infer::InferCtxt;
use crate::traits::Obligation;
use rustc_hir as hir;
use rustc_hir::def_id::DefId;
use rustc_middle::ty::{self, ToPredicate, Ty, WithConstness};

Expand Down Expand Up @@ -49,11 +50,28 @@ pub trait TraitEngine<'tcx>: 'tcx {
infcx: &InferCtxt<'_, 'tcx>,
) -> Result<(), Vec<FulfillmentError<'tcx>>>;

fn select_all_with_constness_or_error(
&mut self,
infcx: &InferCtxt<'_, 'tcx>,
_constness: hir::Constness,
) -> Result<(), Vec<FulfillmentError<'tcx>>> {
self.select_all_or_error(infcx)
}

fn select_where_possible(
&mut self,
infcx: &InferCtxt<'_, 'tcx>,
) -> Result<(), Vec<FulfillmentError<'tcx>>>;

// FIXME this should not provide a default body for chalk as chalk should be updated
fn select_with_constness_where_possible(
&mut self,
infcx: &InferCtxt<'_, 'tcx>,
_constness: hir::Constness,
) -> Result<(), Vec<FulfillmentError<'tcx>>> {
self.select_where_possible(infcx)
}

fn pending_obligations(&self) -> Vec<PredicateObligation<'tcx>>;
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_infer/src/traits/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl Elaborator<'tcx> {

let bound_predicate = obligation.predicate.kind();
match bound_predicate.skip_binder() {
ty::PredicateKind::Trait(data, _) => {
ty::PredicateKind::Trait(data) => {
// Get predicates declared on the trait.
let predicates = tcx.super_predicates_of(data.def_id());

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_lint/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl<'tcx> LateLintPass<'tcx> for DropTraitConstraints {
let predicates = cx.tcx.explicit_predicates_of(item.def_id);
for &(predicate, span) in predicates.predicates {
let trait_predicate = match predicate.kind().skip_binder() {
Trait(trait_predicate, _constness) => trait_predicate,
Trait(trait_predicate) => trait_predicate,
_ => continue,
};
let def_id = trait_predicate.trait_ref.def_id;
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_lint/src/unused.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ impl<'tcx> LateLintPass<'tcx> for UnusedResults {
let mut has_emitted = false;
for &(predicate, _) in cx.tcx.explicit_item_bounds(def) {
// We only look at the `DefId`, so it is safe to skip the binder here.
if let ty::PredicateKind::Trait(ref poly_trait_predicate, _) =
if let ty::PredicateKind::Trait(ref poly_trait_predicate) =
predicate.kind().skip_binder()
{
let def_id = poly_trait_predicate.trait_ref.def_id;
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_middle/src/traits/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ use rustc_hir::def_id::DefId;
use rustc_query_system::cache::Cache;

pub type SelectionCache<'tcx> = Cache<
ty::ParamEnvAnd<'tcx, ty::TraitRef<'tcx>>,
ty::ConstnessAnd<ty::ParamEnvAnd<'tcx, ty::TraitRef<'tcx>>>,
SelectionResult<'tcx, SelectionCandidate<'tcx>>,
>;

pub type EvaluationCache<'tcx> =
Cache<ty::ParamEnvAnd<'tcx, ty::PolyTraitRef<'tcx>>, EvaluationResult>;
Cache<ty::ParamEnvAnd<'tcx, ty::ConstnessAnd<ty::PolyTraitRef<'tcx>>>, EvaluationResult>;

/// The selection process begins by considering all impls, where
/// clauses, and so forth that might resolve an obligation. Sometimes
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_middle/src/ty/assoc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ pub enum AssocItemContainer {
}

impl AssocItemContainer {
pub fn impl_def_id(&self) -> Option<DefId> {
match *self {
ImplContainer(id) => Some(id),
_ => None,
}
}

/// Asserts that this is the `DefId` of an associated item declared
/// in a trait, and returns the trait `DefId`.
pub fn assert_trait(&self) -> DefId {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2169,7 +2169,7 @@ impl<'tcx> TyCtxt<'tcx> {
let generic_predicates = self.super_predicates_of(trait_did);

for (predicate, _) in generic_predicates.predicates {
if let ty::PredicateKind::Trait(data, _) = predicate.kind().skip_binder() {
if let ty::PredicateKind::Trait(data) = predicate.kind().skip_binder() {
if set.insert(data.def_id()) {
stack.push(data.def_id());
}
Expand Down
12 changes: 9 additions & 3 deletions compiler/rustc_middle/src/ty/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ impl<T> ExpectedFound<T> {
#[derive(Clone, Debug, TypeFoldable)]
pub enum TypeError<'tcx> {
Mismatch,
ConstnessMismatch(ExpectedFound<hir::Constness>),
UnsafetyMismatch(ExpectedFound<hir::Unsafety>),
AbiMismatch(ExpectedFound<abi::Abi>),
Mutability,
Expand Down Expand Up @@ -106,6 +107,9 @@ impl<'tcx> fmt::Display for TypeError<'tcx> {
CyclicTy(_) => write!(f, "cyclic type of infinite size"),
CyclicConst(_) => write!(f, "encountered a self-referencing constant"),
Mismatch => write!(f, "types differ"),
ConstnessMismatch(values) => {
write!(f, "expected {} fn, found {} fn", values.expected, values.found)
}
UnsafetyMismatch(values) => {
write!(f, "expected {} fn, found {} fn", values.expected, values.found)
}
Expand Down Expand Up @@ -213,9 +217,11 @@ impl<'tcx> TypeError<'tcx> {
pub fn must_include_note(&self) -> bool {
use self::TypeError::*;
match self {
CyclicTy(_) | CyclicConst(_) | UnsafetyMismatch(_) | Mismatch | AbiMismatch(_)
| FixedArraySize(_) | ArgumentSorts(..) | Sorts(_) | IntMismatch(_)
| FloatMismatch(_) | VariadicMismatch(_) | TargetFeatureCast(_) => false,
CyclicTy(_) | CyclicConst(_) | UnsafetyMismatch(_) | ConstnessMismatch(_)
| Mismatch | AbiMismatch(_) | FixedArraySize(_) | ArgumentSorts(..) | Sorts(_)
| IntMismatch(_) | FloatMismatch(_) | VariadicMismatch(_) | TargetFeatureCast(_) => {
false
}

Mutability
| ArgumentMutability(_)
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ impl FlagComputation {

fn add_predicate_atom(&mut self, atom: ty::PredicateKind<'_>) {
match atom {
ty::PredicateKind::Trait(trait_pred, _constness) => {
ty::PredicateKind::Trait(trait_pred) => {
self.add_substs(trait_pred.trait_ref.substs);
}
ty::PredicateKind::RegionOutlives(ty::OutlivesPredicate(a, b)) => {
Expand Down
24 changes: 16 additions & 8 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ pub enum PredicateKind<'tcx> {
/// A trait predicate will have `Constness::Const` if it originates
/// from a bound on a `const fn` without the `?const` opt-out (e.g.,
/// `const fn foobar<Foo: Bar>() {}`).
Trait(TraitPredicate<'tcx>, Constness),
Trait(TraitPredicate<'tcx>),

/// `where 'a: 'b`
RegionOutlives(RegionOutlivesPredicate<'tcx>),
Expand Down Expand Up @@ -617,6 +617,11 @@ impl<'tcx> Predicate<'tcx> {
#[derive(HashStable, TypeFoldable)]
pub struct TraitPredicate<'tcx> {
pub trait_ref: TraitRef<'tcx>,

/// A trait predicate will have `Constness::Const` if it originates
/// from a bound on a `const fn` without the `?const` opt-out (e.g.,
/// `const fn foobar<Foo: Bar>() {}`).
pub constness: hir::Constness,
}

pub type PolyTraitPredicate<'tcx> = ty::Binder<'tcx, TraitPredicate<'tcx>>;
Expand Down Expand Up @@ -750,24 +755,27 @@ impl ToPredicate<'tcx> for PredicateKind<'tcx> {

impl<'tcx> ToPredicate<'tcx> for ConstnessAnd<TraitRef<'tcx>> {
fn to_predicate(self, tcx: TyCtxt<'tcx>) -> Predicate<'tcx> {
PredicateKind::Trait(ty::TraitPredicate { trait_ref: self.value }, self.constness)
.to_predicate(tcx)
PredicateKind::Trait(ty::TraitPredicate {
trait_ref: self.value,
constness: self.constness,
})
.to_predicate(tcx)
}
}

impl<'tcx> ToPredicate<'tcx> for ConstnessAnd<PolyTraitRef<'tcx>> {
fn to_predicate(self, tcx: TyCtxt<'tcx>) -> Predicate<'tcx> {
self.value
.map_bound(|trait_ref| {
PredicateKind::Trait(ty::TraitPredicate { trait_ref }, self.constness)
PredicateKind::Trait(ty::TraitPredicate { trait_ref, constness: self.constness })
})
.to_predicate(tcx)
}
}

impl<'tcx> ToPredicate<'tcx> for ConstnessAnd<PolyTraitPredicate<'tcx>> {
impl<'tcx> ToPredicate<'tcx> for PolyTraitPredicate<'tcx> {
fn to_predicate(self, tcx: TyCtxt<'tcx>) -> Predicate<'tcx> {
self.value.map_bound(|value| PredicateKind::Trait(value, self.constness)).to_predicate(tcx)
self.map_bound(PredicateKind::Trait).to_predicate(tcx)
}
}

Expand All @@ -793,8 +801,8 @@ impl<'tcx> Predicate<'tcx> {
pub fn to_opt_poly_trait_ref(self) -> Option<ConstnessAnd<PolyTraitRef<'tcx>>> {
let predicate = self.kind();
match predicate.skip_binder() {
PredicateKind::Trait(t, constness) => {
Some(ConstnessAnd { constness, value: predicate.rebind(t.trait_ref) })
PredicateKind::Trait(t) => {
Some(ConstnessAnd { constness: t.constness, value: predicate.rebind(t.trait_ref) })
}
PredicateKind::Projection(..)
| PredicateKind::Subtype(..)
Expand Down
7 changes: 2 additions & 5 deletions compiler/rustc_middle/src/ty/print/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ pub trait PrettyPrinter<'tcx>:
for (predicate, _) in bounds {
let predicate = predicate.subst(self.tcx(), substs);
let bound_predicate = predicate.kind();
if let ty::PredicateKind::Trait(pred, _) = bound_predicate.skip_binder() {
if let ty::PredicateKind::Trait(pred) = bound_predicate.skip_binder() {
let trait_ref = bound_predicate.rebind(pred.trait_ref);
// Don't print +Sized, but rather +?Sized if absent.
if Some(trait_ref.def_id()) == self.tcx().lang_items().sized_trait() {
Expand Down Expand Up @@ -2264,10 +2264,7 @@ define_print_and_forward_display! {

ty::PredicateKind<'tcx> {
match *self {
ty::PredicateKind::Trait(ref data, constness) => {
if let hir::Constness::Const = constness {
p!("const ");
}
ty::PredicateKind::Trait(ref data) => {
p!(print(data))
}
ty::PredicateKind::Subtype(predicate) => p!(print(predicate)),
Expand Down
32 changes: 31 additions & 1 deletion compiler/rustc_middle/src/ty/relate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,33 @@ impl<'tcx> Relate<'tcx> for ty::FnSig<'tcx> {
}
}

impl<'tcx> Relate<'tcx> for ast::Constness {
fn relate<R: TypeRelation<'tcx>>(
relation: &mut R,
a: ast::Constness,
b: ast::Constness,
) -> RelateResult<'tcx, ast::Constness> {
if a != b {
Err(TypeError::ConstnessMismatch(expected_found(relation, a, b)))
} else {
Ok(a)
}
}
}

impl<'tcx, T: Relate<'tcx>> Relate<'tcx> for ty::ConstnessAnd<T> {
fn relate<R: TypeRelation<'tcx>>(
relation: &mut R,
a: ty::ConstnessAnd<T>,
b: ty::ConstnessAnd<T>,
) -> RelateResult<'tcx, ty::ConstnessAnd<T>> {
Ok(ty::ConstnessAnd {
constness: relation.relate(a.constness, b.constness)?,
value: relation.relate(a.value, b.value)?,
})
}
}

impl<'tcx> Relate<'tcx> for ast::Unsafety {
fn relate<R: TypeRelation<'tcx>>(
relation: &mut R,
Expand Down Expand Up @@ -767,7 +794,10 @@ impl<'tcx> Relate<'tcx> for ty::TraitPredicate<'tcx> {
a: ty::TraitPredicate<'tcx>,
b: ty::TraitPredicate<'tcx>,
) -> RelateResult<'tcx, ty::TraitPredicate<'tcx>> {
Ok(ty::TraitPredicate { trait_ref: relation.relate(a.trait_ref, b.trait_ref)? })
Ok(ty::TraitPredicate {
trait_ref: relation.relate(a.trait_ref, b.trait_ref)?,
constness: relation.relate(a.constness, b.constness)?,
})
}
}

Expand Down
18 changes: 8 additions & 10 deletions compiler/rustc_middle/src/ty/structural_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ impl fmt::Debug for ty::ParamConst {

impl fmt::Debug for ty::TraitPredicate<'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let hir::Constness::Const = self.constness {
write!(f, "const ")?;
}
write!(f, "TraitPredicate({:?})", self.trait_ref)
}
}
Expand All @@ -174,12 +177,7 @@ impl fmt::Debug for ty::Predicate<'tcx> {
impl fmt::Debug for ty::PredicateKind<'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
ty::PredicateKind::Trait(ref a, constness) => {
if let hir::Constness::Const = constness {
write!(f, "const ")?;
}
a.fmt(f)
}
ty::PredicateKind::Trait(ref a) => a.fmt(f),
ty::PredicateKind::Subtype(ref pair) => pair.fmt(f),
ty::PredicateKind::RegionOutlives(ref pair) => pair.fmt(f),
ty::PredicateKind::TypeOutlives(ref pair) => pair.fmt(f),
Expand Down Expand Up @@ -366,7 +364,8 @@ impl<'a, 'tcx> Lift<'tcx> for ty::ExistentialPredicate<'a> {
impl<'a, 'tcx> Lift<'tcx> for ty::TraitPredicate<'a> {
type Lifted = ty::TraitPredicate<'tcx>;
fn lift_to_tcx(self, tcx: TyCtxt<'tcx>) -> Option<ty::TraitPredicate<'tcx>> {
tcx.lift(self.trait_ref).map(|trait_ref| ty::TraitPredicate { trait_ref })
tcx.lift(self.trait_ref)
.map(|trait_ref| ty::TraitPredicate { trait_ref, constness: self.constness })
}
}

Expand Down Expand Up @@ -419,9 +418,7 @@ impl<'a, 'tcx> Lift<'tcx> for ty::PredicateKind<'a> {
type Lifted = ty::PredicateKind<'tcx>;
fn lift_to_tcx(self, tcx: TyCtxt<'tcx>) -> Option<Self::Lifted> {
match self {
ty::PredicateKind::Trait(data, constness) => {
tcx.lift(data).map(|data| ty::PredicateKind::Trait(data, constness))
}
ty::PredicateKind::Trait(data) => tcx.lift(data).map(ty::PredicateKind::Trait),
ty::PredicateKind::Subtype(data) => tcx.lift(data).map(ty::PredicateKind::Subtype),
ty::PredicateKind::RegionOutlives(data) => {
tcx.lift(data).map(ty::PredicateKind::RegionOutlives)
Expand Down Expand Up @@ -584,6 +581,7 @@ impl<'a, 'tcx> Lift<'tcx> for ty::error::TypeError<'a> {

Some(match self {
Mismatch => Mismatch,
ConstnessMismatch(x) => ConstnessMismatch(x),
UnsafetyMismatch(x) => UnsafetyMismatch(x),
AbiMismatch(x) => AbiMismatch(x),
Mutability => Mutability,
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,10 @@ impl<'tcx> PolyTraitRef<'tcx> {
}

pub fn to_poly_trait_predicate(&self) -> ty::PolyTraitPredicate<'tcx> {
self.map_bound(|trait_ref| ty::TraitPredicate { trait_ref })
self.map_bound(|trait_ref| ty::TraitPredicate {
trait_ref,
constness: hir::Constness::NotConst,
})
}
}

Expand Down
Loading

0 comments on commit 136eaa1

Please sign in to comment.