Skip to content

Commit

Permalink
Auto merge of rust-lang#117749 - aliemjay:perf-canon-cache, r=<try>
Browse files Browse the repository at this point in the history
[perf] canonicalize param env only once

Canonicalize ParamEnv only once and store it. Then whenever we try to canonicalize `ParamEnvAnd<'tcx, T>` we only have to canonicalize `T` and then merge the results.

Prelimiary results show ~3-4% savings in diesel and serde benchmarks.

r? `@ghost`
  • Loading branch information
bors committed Dec 4, 2023
2 parents 0e2dac8 + bcbfbe6 commit 3d74566
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 13 deletions.
110 changes: 104 additions & 6 deletions compiler/rustc_infer/src/infer/canonical/canonicalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,35 @@ impl<'tcx> InferCtxt<'tcx> {
/// [c]: https://rust-lang.github.io/chalk/book/canonical_queries/canonicalization.html#canonicalizing-the-query
pub fn canonicalize_query<V>(
&self,
value: V,
value: ty::ParamEnvAnd<'tcx, V>,
query_state: &mut OriginalQueryValues<'tcx>,
) -> Canonical<'tcx, V>
) -> Canonical<'tcx, ty::ParamEnvAnd<'tcx, V>>
where
V: TypeFoldable<TyCtxt<'tcx>>,
{
Canonicalizer::canonicalize(value, self, self.tcx, &CanonicalizeAllFreeRegions, query_state)
let (param_env, value) = value.into_parts();
let base = self.tcx.canonical_param_env_cache.get_or_insert(
param_env,
query_state,
|query_state| {
Canonicalizer::canonicalize(
param_env,
self,
self.tcx,
&CanonicalizeFreeRegionsOtherThanStatic,
query_state,
)
},
);
Canonicalizer::canonicalize_continue(
base,
value,
self,
self.tcx,
&CanonicalizeAllFreeRegions,
query_state,
)
.unchecked_map(|(param_env, value)| param_env.and(value))
}

/// Like [Self::canonicalize_query], but preserves distinct universes. For
Expand Down Expand Up @@ -126,19 +148,35 @@ impl<'tcx> InferCtxt<'tcx> {
/// handling of `'static` regions (e.g. trait evaluation).
pub fn canonicalize_query_keep_static<V>(
&self,
value: V,
value: ty::ParamEnvAnd<'tcx, V>,
query_state: &mut OriginalQueryValues<'tcx>,
) -> Canonical<'tcx, V>
) -> Canonical<'tcx, ty::ParamEnvAnd<'tcx, V>>
where
V: TypeFoldable<TyCtxt<'tcx>>,
{
Canonicalizer::canonicalize(
let (param_env, value) = value.into_parts();
let base = self.tcx.canonical_param_env_cache.get_or_insert(
param_env,
query_state,
|query_state| {
Canonicalizer::canonicalize(
param_env,
self,
self.tcx,
&CanonicalizeFreeRegionsOtherThanStatic,
query_state,
)
},
);
Canonicalizer::canonicalize_continue(
base,
value,
self,
self.tcx,
&CanonicalizeFreeRegionsOtherThanStatic,
query_state,
)
.unchecked_map(|(param_env, value)| param_env.and(value))
}
}

Expand Down Expand Up @@ -615,6 +653,66 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {
Canonical { max_universe, variables: canonical_variables, value: out_value }
}

fn canonicalize_continue<U, V>(
base: Canonical<'tcx, U>,
value: V,
infcx: &InferCtxt<'tcx>,
tcx: TyCtxt<'tcx>,
canonicalize_region_mode: &dyn CanonicalizeMode,
query_state: &mut OriginalQueryValues<'tcx>,
) -> Canonical<'tcx, (U, V)>
where
V: TypeFoldable<TyCtxt<'tcx>>,
{
let needs_canonical_flags = if canonicalize_region_mode.any() {
TypeFlags::HAS_INFER | TypeFlags::HAS_PLACEHOLDER | TypeFlags::HAS_FREE_REGIONS
} else {
TypeFlags::HAS_INFER | TypeFlags::HAS_PLACEHOLDER
};

// Fast path: nothing that needs to be canonicalized.
if !value.has_type_flags(needs_canonical_flags) {
return base.unchecked_map(|b| (b, value));
}

let mut canonicalizer = Canonicalizer {
infcx,
tcx,
canonicalize_mode: canonicalize_region_mode,
needs_canonical_flags,
variables: SmallVec::from_slice(base.variables),
query_state,
indices: FxHashMap::default(),
binder_index: ty::INNERMOST,
};
if canonicalizer.query_state.var_values.spilled() {
canonicalizer.indices = canonicalizer
.query_state
.var_values
.iter()
.enumerate()
.map(|(i, &kind)| (kind, BoundVar::new(i)))
.collect();
}
let out_value = value.fold_with(&mut canonicalizer);

// Once we have canonicalized `out_value`, it should not
// contain anything that ties it to this inference context
// anymore.
debug_assert!(!out_value.has_infer() && !out_value.has_placeholders());

let canonical_variables =
tcx.mk_canonical_var_infos(&canonicalizer.universe_canonicalized_variables());

let max_universe = canonical_variables
.iter()
.map(|cvar| cvar.universe())
.max()
.unwrap_or(ty::UniverseIndex::ROOT);

Canonical { max_universe, variables: canonical_variables, value: (base.value, out_value) }
}

/// Creates a canonical variable replacing `kind` from the input,
/// or returns an existing variable if `kind` has already been
/// seen. `kind` is expected to be an unbound variable (or
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_infer/src/infer/combine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ impl<'tcx> InferCtxt<'tcx> {
// two const param's types are able to be equal has to go through a canonical query with the actual logic
// in `rustc_trait_selection`.
let canonical = self.canonicalize_query(
(relation.param_env(), a.ty(), b.ty()),
relation.param_env().and((a.ty(), b.ty())),
&mut OriginalQueryValues::default(),
);
self.tcx.check_tys_might_be_eq(canonical).map_err(|_| {
Expand Down
36 changes: 36 additions & 0 deletions compiler/rustc_middle/src/infer/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,39 @@ impl<'tcx> Index<BoundVar> for CanonicalVarValues<'tcx> {
&self.var_values[value.as_usize()]
}
}

use rustc_data_structures::fx::FxHashMap;
use rustc_data_structures::sync::Lock;

#[derive(Default)]
pub struct CanonicalParamEnvCache<'tcx> {
map: Lock<
FxHashMap<
ty::ParamEnv<'tcx>,
(Canonical<'tcx, ty::ParamEnv<'tcx>>, OriginalQueryValues<'tcx>),
>,
>,
}

impl<'tcx> CanonicalParamEnvCache<'tcx> {
pub fn get_or_insert(
&self,
key: ty::ParamEnv<'tcx>,
vals: &mut OriginalQueryValues<'tcx>,
f: impl FnOnce(&mut OriginalQueryValues<'tcx>) -> Canonical<'tcx, ty::ParamEnv<'tcx>>,
) -> Canonical<'tcx, ty::ParamEnv<'tcx>> {
use std::collections::hash_map::Entry;
match self.map.borrow().entry(key) {
Entry::Occupied(e) => {
let (canonical, vals2) = e.get();
vals.clone_from(vals2);
canonical.clone()
}
Entry::Vacant(e) => {
let canonical = f(vals);
e.insert((canonical.clone(), vals.clone()));
canonical
}
}
}
}
4 changes: 3 additions & 1 deletion compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2177,7 +2177,9 @@ rustc_queries! {
/// Used in `super_combine_consts` to ICE if the type of the two consts are definitely not going to end up being
/// equal to eachother. This might return `Ok` even if the types are not equal, but will never return `Err` if
/// the types might be equal.
query check_tys_might_be_eq(arg: Canonical<'tcx, (ty::ParamEnv<'tcx>, Ty<'tcx>, Ty<'tcx>)>) -> Result<(), NoSolution> {
query check_tys_might_be_eq(
arg: Canonical<'tcx, ty::ParamEnvAnd<'tcx, (Ty<'tcx>, Ty<'tcx>)>>
) -> Result<(), NoSolution> {
desc { "check whether two const param are definitely not equal to eachother"}
}

Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub mod tls;

use crate::arena::Arena;
use crate::dep_graph::{DepGraph, DepKindStruct};
use crate::infer::canonical::{CanonicalVarInfo, CanonicalVarInfos};
use crate::infer::canonical::{CanonicalParamEnvCache, CanonicalVarInfo, CanonicalVarInfos};
use crate::lint::struct_lint_level;
use crate::metadata::ModChild;
use crate::middle::codegen_fn_attrs::CodegenFnAttrs;
Expand Down Expand Up @@ -615,6 +615,8 @@ pub struct GlobalCtxt<'tcx> {
pub new_solver_evaluation_cache: solve::EvaluationCache<'tcx>,
pub new_solver_coherence_evaluation_cache: solve::EvaluationCache<'tcx>,

pub canonical_param_env_cache: CanonicalParamEnvCache<'tcx>,

/// Data layout specification for the current target.
pub data_layout: TargetDataLayout,

Expand Down Expand Up @@ -779,6 +781,7 @@ impl<'tcx> TyCtxt<'tcx> {
evaluation_cache: Default::default(),
new_solver_evaluation_cache: Default::default(),
new_solver_coherence_evaluation_cache: Default::default(),
canonical_param_env_cache: Default::default(),
data_layout,
alloc_map: Lock::new(interpret::AllocMap::new()),
}
Expand Down
8 changes: 4 additions & 4 deletions compiler/rustc_trait_selection/src/traits/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rustc_infer::infer::canonical::Canonical;
use rustc_infer::infer::{RegionResolutionError, TyCtxtInferExt};
use rustc_infer::traits::query::NoSolution;
use rustc_infer::{infer::outlives::env::OutlivesEnvironment, traits::FulfillmentError};
use rustc_middle::ty::{self, AdtDef, GenericArg, List, ParamEnv, Ty, TyCtxt, TypeVisitableExt};
use rustc_middle::ty::{self, AdtDef, GenericArg, List, Ty, TyCtxt, TypeVisitableExt};
use rustc_span::DUMMY_SP;

use super::outlives_bounds::InferCtxtExt;
Expand Down Expand Up @@ -209,10 +209,10 @@ pub fn all_fields_implement_trait<'tcx>(

pub fn check_tys_might_be_eq<'tcx>(
tcx: TyCtxt<'tcx>,
canonical: Canonical<'tcx, (ParamEnv<'tcx>, Ty<'tcx>, Ty<'tcx>)>,
canonical: Canonical<'tcx, ty::ParamEnvAnd<'tcx, (Ty<'tcx>, Ty<'tcx>)>>,
) -> Result<(), NoSolution> {
let (infcx, (param_env, ty_a, ty_b), _) =
tcx.infer_ctxt().build_with_canonical(DUMMY_SP, &canonical);
let (infcx, key, _) = tcx.infer_ctxt().build_with_canonical(DUMMY_SP, &canonical);
let (param_env, (ty_a, ty_b)) = key.into_parts();
let ocx = ObligationCtxt::new(&infcx);

let result = ocx.eq(&ObligationCause::dummy(), param_env, ty_a, ty_b);
Expand Down

0 comments on commit 3d74566

Please sign in to comment.