Skip to content

Commit

Permalink
Auto merge of #117749 - aliemjay:perf-canon-cache, r=lcnr
Browse files Browse the repository at this point in the history
cache param env canonicalization

Canonicalize ParamEnv only once and store it. Then whenever we try to canonicalize `ParamEnvAnd<'tcx, T>` we only have to canonicalize `T` and then merge the results.

Prelimiary results show ~3-4% savings in diesel and serde benchmarks.

Best to review commits individually. Some commits have a short description.

Initial implementation had a soundness bug (#117749 (comment)) due to cache invalidation:
- When canonicalizing `Ty<'?0>` we first try to resolve region variables in the current InferCtxt which may have a constraint `?0 == 'static`. This means that we register `Ty<'?0> => Canonical<Ty<'static>>` in the cache, which is obviously incorrect in another inference context.
- This is fixed by not doing region resolution when canonicalizing the query *input* (vs. response), which is the only place where ParamEnv is used, and then in a later commit we *statically* guard against any form of inference variable resolution of the cached canonical ParamEnv's.

r? `@ghost`
  • Loading branch information
bors committed Dec 14, 2023
2 parents e6d1b0e + aa36c35 commit d23e1a6
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 125 deletions.
236 changes: 122 additions & 114 deletions compiler/rustc_infer/src/infer/canonical/canonicalizer.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion compiler/rustc_infer/src/infer/combine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl<'tcx> InferCtxt<'tcx> {
// two const param's types are able to be equal has to go through a canonical query with the actual logic
// in `rustc_trait_selection`.
let canonical = self.canonicalize_query(
(relation.param_env(), a.ty(), b.ty()),
relation.param_env().and((a.ty(), b.ty())),
&mut OriginalQueryValues::default(),
);
self.tcx.check_tys_might_be_eq(canonical).map_err(|_| {
Expand Down
64 changes: 63 additions & 1 deletion compiler/rustc_middle/src/infer/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,20 @@
//!
//! [c]: https://rust-lang.github.io/chalk/book/canonical_queries/canonicalization.html

use rustc_data_structures::fx::FxHashMap;
use rustc_data_structures::sync::Lock;
use rustc_macros::HashStable;
use rustc_type_ir::Canonical as IrCanonical;
use rustc_type_ir::CanonicalVarInfo as IrCanonicalVarInfo;
pub use rustc_type_ir::{CanonicalTyVarKind, CanonicalVarKind};
use smallvec::SmallVec;
use std::collections::hash_map::Entry;
use std::ops::Index;

use crate::infer::MemberConstraint;
use crate::mir::ConstraintCategory;
use crate::ty::GenericArg;
use crate::ty::{self, BoundVar, List, Region, Ty, TyCtxt};
use crate::ty::{self, BoundVar, List, Region, Ty, TyCtxt, TypeFlags, TypeVisitableExt};

pub type Canonical<'tcx, V> = IrCanonical<TyCtxt<'tcx>, V>;

Expand Down Expand Up @@ -291,3 +294,62 @@ impl<'tcx> Index<BoundVar> for CanonicalVarValues<'tcx> {
&self.var_values[value.as_usize()]
}
}

#[derive(Default)]
pub struct CanonicalParamEnvCache<'tcx> {
map: Lock<
FxHashMap<
ty::ParamEnv<'tcx>,
(Canonical<'tcx, ty::ParamEnv<'tcx>>, &'tcx [GenericArg<'tcx>]),
>,
>,
}

impl<'tcx> CanonicalParamEnvCache<'tcx> {
/// Gets the cached canonical form of `key` or executes
/// `canonicalize_op` and caches the result if not present.
///
/// `canonicalize_op` is intentionally not allowed to be a closure to
/// statically prevent it from capturing `InferCtxt` and resolving
/// inference variables, which invalidates the cache.
pub fn get_or_insert(
&self,
tcx: TyCtxt<'tcx>,
key: ty::ParamEnv<'tcx>,
state: &mut OriginalQueryValues<'tcx>,
canonicalize_op: fn(
TyCtxt<'tcx>,
ty::ParamEnv<'tcx>,
&mut OriginalQueryValues<'tcx>,
) -> Canonical<'tcx, ty::ParamEnv<'tcx>>,
) -> Canonical<'tcx, ty::ParamEnv<'tcx>> {
if !key.has_type_flags(
TypeFlags::HAS_INFER | TypeFlags::HAS_PLACEHOLDER | TypeFlags::HAS_FREE_REGIONS,
) {
return Canonical {
max_universe: ty::UniverseIndex::ROOT,
variables: List::empty(),
value: key,
};
}

assert_eq!(state.var_values.len(), 0);
assert_eq!(state.universe_map.len(), 1);
debug_assert_eq!(&*state.universe_map, &[ty::UniverseIndex::ROOT]);

match self.map.borrow().entry(key) {
Entry::Occupied(e) => {
let (canonical, var_values) = e.get();
state.var_values.extend_from_slice(var_values);
canonical.clone()
}
Entry::Vacant(e) => {
let canonical = canonicalize_op(tcx, key, state);
let OriginalQueryValues { var_values, universe_map } = state;
assert_eq!(universe_map.len(), 1);
e.insert((canonical.clone(), tcx.arena.alloc_slice(var_values)));
canonical
}
}
}
}
4 changes: 3 additions & 1 deletion compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2177,7 +2177,9 @@ rustc_queries! {
/// Used in `super_combine_consts` to ICE if the type of the two consts are definitely not going to end up being
/// equal to eachother. This might return `Ok` even if the types are not equal, but will never return `Err` if
/// the types might be equal.
query check_tys_might_be_eq(arg: Canonical<'tcx, (ty::ParamEnv<'tcx>, Ty<'tcx>, Ty<'tcx>)>) -> Result<(), NoSolution> {
query check_tys_might_be_eq(
arg: Canonical<'tcx, ty::ParamEnvAnd<'tcx, (Ty<'tcx>, Ty<'tcx>)>>
) -> Result<(), NoSolution> {
desc { "check whether two const param are definitely not equal to eachother"}
}

Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub mod tls;

use crate::arena::Arena;
use crate::dep_graph::{DepGraph, DepKindStruct};
use crate::infer::canonical::{CanonicalVarInfo, CanonicalVarInfos};
use crate::infer::canonical::{CanonicalParamEnvCache, CanonicalVarInfo, CanonicalVarInfos};
use crate::lint::struct_lint_level;
use crate::metadata::ModChild;
use crate::middle::codegen_fn_attrs::CodegenFnAttrs;
Expand Down Expand Up @@ -653,6 +653,8 @@ pub struct GlobalCtxt<'tcx> {
pub new_solver_evaluation_cache: solve::EvaluationCache<'tcx>,
pub new_solver_coherence_evaluation_cache: solve::EvaluationCache<'tcx>,

pub canonical_param_env_cache: CanonicalParamEnvCache<'tcx>,

/// Data layout specification for the current target.
pub data_layout: TargetDataLayout,

Expand Down Expand Up @@ -817,6 +819,7 @@ impl<'tcx> TyCtxt<'tcx> {
evaluation_cache: Default::default(),
new_solver_evaluation_cache: Default::default(),
new_solver_coherence_evaluation_cache: Default::default(),
canonical_param_env_cache: Default::default(),
data_layout,
alloc_map: Lock::new(interpret::AllocMap::new()),
}
Expand Down
8 changes: 4 additions & 4 deletions compiler/rustc_trait_selection/src/traits/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rustc_infer::infer::canonical::Canonical;
use rustc_infer::infer::{RegionResolutionError, TyCtxtInferExt};
use rustc_infer::traits::query::NoSolution;
use rustc_infer::{infer::outlives::env::OutlivesEnvironment, traits::FulfillmentError};
use rustc_middle::ty::{self, AdtDef, GenericArg, List, ParamEnv, Ty, TyCtxt, TypeVisitableExt};
use rustc_middle::ty::{self, AdtDef, GenericArg, List, Ty, TyCtxt, TypeVisitableExt};
use rustc_span::DUMMY_SP;

use super::outlives_bounds::InferCtxtExt;
Expand Down Expand Up @@ -209,10 +209,10 @@ pub fn all_fields_implement_trait<'tcx>(

pub fn check_tys_might_be_eq<'tcx>(
tcx: TyCtxt<'tcx>,
canonical: Canonical<'tcx, (ParamEnv<'tcx>, Ty<'tcx>, Ty<'tcx>)>,
canonical: Canonical<'tcx, ty::ParamEnvAnd<'tcx, (Ty<'tcx>, Ty<'tcx>)>>,
) -> Result<(), NoSolution> {
let (infcx, (param_env, ty_a, ty_b), _) =
tcx.infer_ctxt().build_with_canonical(DUMMY_SP, &canonical);
let (infcx, key, _) = tcx.infer_ctxt().build_with_canonical(DUMMY_SP, &canonical);
let (param_env, (ty_a, ty_b)) = key.into_parts();
let ocx = ObligationCtxt::new(&infcx);

let result = ocx.eq(&ObligationCause::dummy(), param_env, ty_a, ty_b);
Expand Down
3 changes: 1 addition & 2 deletions src/librustdoc/clean/blanket_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ pub(crate) struct BlanketImplFinder<'a, 'tcx> {
impl<'a, 'tcx> BlanketImplFinder<'a, 'tcx> {
pub(crate) fn get_blanket_impls(&mut self, item_def_id: DefId) -> Vec<Item> {
let cx = &mut self.cx;
let param_env = cx.tcx.param_env(item_def_id);
let ty = cx.tcx.type_of(item_def_id);

trace!("get_blanket_impls({ty:?})");
Expand All @@ -40,7 +39,7 @@ impl<'a, 'tcx> BlanketImplFinder<'a, 'tcx> {
let infcx = cx.tcx.infer_ctxt().build();
let args = infcx.fresh_args_for_item(DUMMY_SP, item_def_id);
let impl_ty = ty.instantiate(infcx.tcx, args);
let param_env = EarlyBinder::bind(param_env).instantiate(infcx.tcx, args);
let param_env = ty::ParamEnv::empty();

let impl_args = infcx.fresh_args_for_item(DUMMY_SP, impl_def_id);
let impl_trait_ref = trait_ref.instantiate(infcx.tcx, impl_args);
Expand Down
2 changes: 1 addition & 1 deletion src/librustdoc/clean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use rustc_middle::middle::resolve_bound_vars as rbv;
use rustc_middle::ty::fold::TypeFolder;
use rustc_middle::ty::GenericArgsRef;
use rustc_middle::ty::TypeVisitableExt;
use rustc_middle::ty::{self, AdtKind, EarlyBinder, Ty, TyCtxt};
use rustc_middle::ty::{self, AdtKind, Ty, TyCtxt};
use rustc_middle::{bug, span_bug};
use rustc_span::hygiene::{AstPass, MacroKind};
use rustc_span::symbol::{kw, sym, Ident, Symbol};
Expand Down

0 comments on commit d23e1a6

Please sign in to comment.