Skip to content

Commit

Permalink
Auto merge of rust-lang#73180 - matthewjasper:predicate-cache, r=niko…
Browse files Browse the repository at this point in the history
…matsakis

Cache flags and escaping vars for predicates

With predicates becoming interned (rust-lang/compiler-team#285) this is now possible and could be a perf win. It would become an even larger win once we have recursive predicates.

cc @lcnr @nikomatsakis

r? @ghost
  • Loading branch information
bors committed Jun 21, 2020
2 parents a8cf399 + 6e12272 commit 1a4e2b6
Show file tree
Hide file tree
Showing 11 changed files with 321 additions and 100 deletions.
25 changes: 23 additions & 2 deletions src/librustc_metadata/rmeta/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,36 @@ impl<'a, 'tcx> TyDecoder<'tcx> for DecodeContext<'a, 'tcx> {

let key = ty::CReaderCacheKey { cnum: self.cdata().cnum, pos: shorthand };

if let Some(&ty) = tcx.rcache.borrow().get(&key) {
if let Some(&ty) = tcx.ty_rcache.borrow().get(&key) {
return Ok(ty);
}

let ty = or_insert_with(self)?;
tcx.rcache.borrow_mut().insert(key, ty);
tcx.ty_rcache.borrow_mut().insert(key, ty);
Ok(ty)
}

fn cached_predicate_for_shorthand<F>(
&mut self,
shorthand: usize,
or_insert_with: F,
) -> Result<ty::Predicate<'tcx>, Self::Error>
where
F: FnOnce(&mut Self) -> Result<ty::Predicate<'tcx>, Self::Error>,
{
let tcx = self.tcx();

let key = ty::CReaderCacheKey { cnum: self.cdata().cnum, pos: shorthand };

if let Some(&pred) = tcx.pred_rcache.borrow().get(&key) {
return Ok(pred);
}

let pred = or_insert_with(self)?;
tcx.pred_rcache.borrow_mut().insert(key, pred);
Ok(pred)
}

fn with_position<F, R>(&mut self, pos: usize, f: F) -> R
where
F: FnOnce(&mut Self) -> R,
Expand Down
27 changes: 11 additions & 16 deletions src/librustc_metadata/rmeta/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,17 @@ where
}
}

impl<'b, 'tcx> SpecializedEncoder<ty::Predicate<'b>> for EncodeContext<'tcx> {
fn specialized_encode(&mut self, predicate: &ty::Predicate<'b>) -> Result<(), Self::Error> {
debug_assert!(self.tcx.lift(predicate).is_some());
let predicate =
unsafe { std::mem::transmute::<&ty::Predicate<'b>, &ty::Predicate<'tcx>>(predicate) };
ty_codec::encode_with_shorthand(self, predicate, |encoder| {
&mut encoder.predicate_shorthands
})
}
}

impl<'tcx> SpecializedEncoder<interpret::AllocId> for EncodeContext<'tcx> {
fn specialized_encode(&mut self, alloc_id: &interpret::AllocId) -> Result<(), Self::Error> {
use std::collections::hash_map::Entry;
Expand All @@ -256,22 +267,6 @@ impl<'tcx> SpecializedEncoder<interpret::AllocId> for EncodeContext<'tcx> {
}
}

impl<'a, 'b, 'tcx> SpecializedEncoder<&'a [(ty::Predicate<'b>, Span)]> for EncodeContext<'tcx> {
fn specialized_encode(
&mut self,
predicates: &&'a [(ty::Predicate<'b>, Span)],
) -> Result<(), Self::Error> {
debug_assert!(self.tcx.lift(*predicates).is_some());
let predicates = unsafe {
std::mem::transmute::<
&&'a [(ty::Predicate<'b>, Span)],
&&'tcx [(ty::Predicate<'tcx>, Span)],
>(predicates)
};
ty_codec::encode_spanned_predicates(self, &predicates, |ecx| &mut ecx.predicate_shorthands)
}
}

impl<'tcx> SpecializedEncoder<Fingerprint> for EncodeContext<'tcx> {
fn specialized_encode(&mut self, f: &Fingerprint) -> Result<(), Self::Error> {
f.encode_opaque(&mut self.opaque)
Expand Down
1 change: 1 addition & 0 deletions src/librustc_middle/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ macro_rules! arena_types {

// Interned types
[] tys: rustc_middle::ty::TyS<$tcx>, rustc_middle::ty::TyS<'_x>;
[] predicates: rustc_middle::ty::PredicateInner<$tcx>, rustc_middle::ty::PredicateInner<'_x>;

// HIR query types
[few] indexed_hir: rustc_middle::hir::map::IndexedHir<$tcx>, rustc_middle::hir::map::IndexedHir<'_x>;
Expand Down
82 changes: 48 additions & 34 deletions src/librustc_middle/ty/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::arena::ArenaAllocatable;
use crate::infer::canonical::{CanonicalVarInfo, CanonicalVarInfos};
use crate::mir::{self, interpret::Allocation};
use crate::ty::subst::SubstsRef;
use crate::ty::{self, List, ToPredicate, Ty, TyCtxt};
use crate::ty::{self, List, Ty, TyCtxt};
use rustc_data_structures::fx::FxHashMap;
use rustc_hir::def_id::{CrateNum, DefId};
use rustc_serialize::{opaque, Decodable, Decoder, Encodable, Encoder};
Expand Down Expand Up @@ -95,23 +95,6 @@ where
Ok(())
}

pub fn encode_spanned_predicates<'tcx, E, C>(
encoder: &mut E,
predicates: &[(ty::Predicate<'tcx>, Span)],
cache: C,
) -> Result<(), E::Error>
where
E: TyEncoder,
C: for<'b> Fn(&'b mut E) -> &'b mut FxHashMap<ty::Predicate<'tcx>, usize>,
{
predicates.len().encode(encoder)?;
for (predicate, span) in predicates {
encode_with_shorthand(encoder, predicate, &cache)?;
span.encode(encoder)?;
}
Ok(())
}

pub trait TyDecoder<'tcx>: Decoder {
fn tcx(&self) -> TyCtxt<'tcx>;

Expand All @@ -127,6 +110,14 @@ pub trait TyDecoder<'tcx>: Decoder {
where
F: FnOnce(&mut Self) -> Result<Ty<'tcx>, Self::Error>;

fn cached_predicate_for_shorthand<F>(
&mut self,
shorthand: usize,
or_insert_with: F,
) -> Result<ty::Predicate<'tcx>, Self::Error>
where
F: FnOnce(&mut Self) -> Result<ty::Predicate<'tcx>, Self::Error>;

fn with_position<F, R>(&mut self, pos: usize, f: F) -> R
where
F: FnOnce(&mut Self) -> R;
Expand Down Expand Up @@ -188,6 +179,26 @@ where
}
}

#[inline]
pub fn decode_predicate<D>(decoder: &mut D) -> Result<ty::Predicate<'tcx>, D::Error>
where
D: TyDecoder<'tcx>,
{
// Handle shorthands first, if we have an usize > 0x80.
if decoder.positioned_at_shorthand() {
let pos = decoder.read_usize()?;
assert!(pos >= SHORTHAND_OFFSET);
let shorthand = pos - SHORTHAND_OFFSET;

decoder.cached_predicate_for_shorthand(shorthand, |decoder| {
decoder.with_position(shorthand, ty::Predicate::decode)
})
} else {
let tcx = decoder.tcx();
Ok(tcx.mk_predicate(ty::PredicateKind::decode(decoder)?))
}
}

#[inline]
pub fn decode_spanned_predicates<D>(
decoder: &mut D,
Expand All @@ -198,20 +209,7 @@ where
let tcx = decoder.tcx();
Ok(tcx.arena.alloc_from_iter(
(0..decoder.read_usize()?)
.map(|_| {
// Handle shorthands first, if we have an usize > 0x80.
let predicate_kind = if decoder.positioned_at_shorthand() {
let pos = decoder.read_usize()?;
assert!(pos >= SHORTHAND_OFFSET);
let shorthand = pos - SHORTHAND_OFFSET;

decoder.with_position(shorthand, ty::PredicateKind::decode)
} else {
ty::PredicateKind::decode(decoder)
}?;
let predicate = predicate_kind.to_predicate(tcx);
Ok((predicate, Decodable::decode(decoder)?))
})
.map(|_| Decodable::decode(decoder))
.collect::<Result<Vec<_>, _>>()?,
))
}
Expand Down Expand Up @@ -421,7 +419,6 @@ macro_rules! implement_ty_decoder {
// FIXME(#36588): These impls are horribly unsound as they allow
// the caller to pick any lifetime for `'tcx`, including `'static`.

rustc_hir::arena_types!(impl_arena_allocatable_decoders, [$DecoderName [$($typaram),*]], 'tcx);
arena_types!(impl_arena_allocatable_decoders, [$DecoderName [$($typaram),*]], 'tcx);

impl<$($typaram),*> SpecializedDecoder<CrateNum>
Expand All @@ -436,7 +433,24 @@ macro_rules! implement_ty_decoder {
where &'_x ty::TyS<'_y>: UseSpecializedDecodable
{
fn specialized_decode(&mut self) -> Result<&'_x ty::TyS<'_y>, Self::Error> {
unsafe { transmute::<Result<ty::Ty<'tcx>, Self::Error>, Result<&'_x ty::TyS<'_y>, Self::Error>>(decode_ty(self)) }
unsafe {
transmute::<
Result<ty::Ty<'tcx>, Self::Error>,
Result<&'_x ty::TyS<'_y>, Self::Error>,
>(decode_ty(self))
}
}
}

impl<'_x, $($typaram),*> SpecializedDecoder<ty::Predicate<'_x>>
for $DecoderName<$($typaram),*> {
fn specialized_decode(&mut self) -> Result<ty::Predicate<'_x>, Self::Error> {
unsafe {
transmute::<
Result<ty::Predicate<'tcx>, Self::Error>,
Result<ty::Predicate<'_x>, Self::Error>,
>(decode_predicate(self))
}
}
}

Expand Down
65 changes: 52 additions & 13 deletions src/librustc_middle/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ use crate::ty::TyKind::*;
use crate::ty::{
self, query, AdtDef, AdtKind, BindingMode, BoundVar, CanonicalPolyFnSig, Const, ConstVid,
DefIdTree, ExistentialPredicate, FloatVar, FloatVid, GenericParamDefKind, InferConst, InferTy,
IntVar, IntVid, List, ParamConst, ParamTy, PolyFnSig, Predicate, PredicateKind, ProjectionTy,
Region, RegionKind, ReprOptions, TraitObjectVisitor, Ty, TyKind, TyS, TyVar, TyVid, TypeAndMut,
IntVar, IntVid, List, ParamConst, ParamTy, PolyFnSig, Predicate, PredicateInner, PredicateKind,
ProjectionTy, Region, RegionKind, ReprOptions, TraitObjectVisitor, Ty, TyKind, TyS, TyVar,
TyVid, TypeAndMut,
};
use rustc_ast::ast;
use rustc_ast::expand::allocator::AllocatorKind;
Expand Down Expand Up @@ -76,7 +77,7 @@ pub struct CtxtInterners<'tcx> {
canonical_var_infos: InternedSet<'tcx, List<CanonicalVarInfo>>,
region: InternedSet<'tcx, RegionKind>,
existential_predicates: InternedSet<'tcx, List<ExistentialPredicate<'tcx>>>,
predicate_kind: InternedSet<'tcx, PredicateKind<'tcx>>,
predicate: InternedSet<'tcx, PredicateInner<'tcx>>,
predicates: InternedSet<'tcx, List<Predicate<'tcx>>>,
projs: InternedSet<'tcx, List<ProjectionKind>>,
place_elems: InternedSet<'tcx, List<PlaceElem<'tcx>>>,
Expand All @@ -95,7 +96,7 @@ impl<'tcx> CtxtInterners<'tcx> {
region: Default::default(),
existential_predicates: Default::default(),
canonical_var_infos: Default::default(),
predicate_kind: Default::default(),
predicate: Default::default(),
predicates: Default::default(),
projs: Default::default(),
place_elems: Default::default(),
Expand Down Expand Up @@ -123,6 +124,23 @@ impl<'tcx> CtxtInterners<'tcx> {
})
.0
}

#[inline(never)]
fn intern_predicate(&self, kind: PredicateKind<'tcx>) -> &'tcx PredicateInner<'tcx> {
self.predicate
.intern(kind, |kind| {
let flags = super::flags::FlagComputation::for_predicate(&kind);

let predicate_struct = PredicateInner {
kind,
flags: flags.flags,
outer_exclusive_binder: flags.outer_exclusive_binder,
};

Interned(self.arena.alloc(predicate_struct))
})
.0
}
}

pub struct CommonTypes<'tcx> {
Expand Down Expand Up @@ -938,8 +956,9 @@ pub struct GlobalCtxt<'tcx> {
/// via `extern crate` item and not `--extern` option or compiler built-in.
pub extern_prelude: FxHashMap<Symbol, bool>,

// Internal cache for metadata decoding. No need to track deps on this.
pub rcache: Lock<FxHashMap<ty::CReaderCacheKey, Ty<'tcx>>>,
// Internal caches for metadata decoding. No need to track deps on this.
pub ty_rcache: Lock<FxHashMap<ty::CReaderCacheKey, Ty<'tcx>>>,
pub pred_rcache: Lock<FxHashMap<ty::CReaderCacheKey, Predicate<'tcx>>>,

/// Caches the results of trait selection. This cache is used
/// for things that do not have to do with the parameters in scope.
Expand Down Expand Up @@ -1128,7 +1147,8 @@ impl<'tcx> TyCtxt<'tcx> {
definitions,
def_path_hash_to_def_id,
queries: query::Queries::new(providers, extern_providers, on_disk_query_result_cache),
rcache: Default::default(),
ty_rcache: Default::default(),
pred_rcache: Default::default(),
selection_cache: Default::default(),
evaluation_cache: Default::default(),
crate_name: Symbol::intern(crate_name),
Expand Down Expand Up @@ -1625,7 +1645,7 @@ macro_rules! nop_list_lift {
nop_lift! {type_; Ty<'a> => Ty<'tcx>}
nop_lift! {region; Region<'a> => Region<'tcx>}
nop_lift! {const_; &'a Const<'a> => &'tcx Const<'tcx>}
nop_lift! {predicate_kind; &'a PredicateKind<'a> => &'tcx PredicateKind<'tcx>}
nop_lift! {predicate; &'a PredicateInner<'a> => &'tcx PredicateInner<'tcx>}

nop_list_lift! {type_list; Ty<'a> => Ty<'tcx>}
nop_list_lift! {existential_predicates; ExistentialPredicate<'a> => ExistentialPredicate<'tcx>}
Expand Down Expand Up @@ -1984,6 +2004,26 @@ impl<'tcx> Borrow<TyKind<'tcx>> for Interned<'tcx, TyS<'tcx>> {
&self.0.kind
}
}
// N.B., an `Interned<PredicateInner>` compares and hashes as a `PredicateKind`.
impl<'tcx> PartialEq for Interned<'tcx, PredicateInner<'tcx>> {
fn eq(&self, other: &Interned<'tcx, PredicateInner<'tcx>>) -> bool {
self.0.kind == other.0.kind
}
}

impl<'tcx> Eq for Interned<'tcx, PredicateInner<'tcx>> {}

impl<'tcx> Hash for Interned<'tcx, PredicateInner<'tcx>> {
fn hash<H: Hasher>(&self, s: &mut H) {
self.0.kind.hash(s)
}
}

impl<'tcx> Borrow<PredicateKind<'tcx>> for Interned<'tcx, PredicateInner<'tcx>> {
fn borrow<'a>(&'a self) -> &'a PredicateKind<'tcx> {
&self.0.kind
}
}

// N.B., an `Interned<List<T>>` compares and hashes as its elements.
impl<'tcx, T: PartialEq> PartialEq for Interned<'tcx, List<T>> {
Expand Down Expand Up @@ -2050,11 +2090,10 @@ macro_rules! direct_interners {
}
}

direct_interners!(
direct_interners! {
region: mk_region(RegionKind),
const_: mk_const(Const<'tcx>),
predicate_kind: intern_predicate_kind(PredicateKind<'tcx>),
);
}

macro_rules! slice_interners {
($($field:ident: $method:ident($ty:ty)),+) => (
Expand Down Expand Up @@ -2125,8 +2164,8 @@ impl<'tcx> TyCtxt<'tcx> {

#[inline]
pub fn mk_predicate(&self, kind: PredicateKind<'tcx>) -> Predicate<'tcx> {
let kind = self.intern_predicate_kind(kind);
Predicate { kind }
let inner = self.interners.intern_predicate(kind);
Predicate { inner }
}

pub fn mk_mach_int(self, tm: ast::IntTy) -> Ty<'tcx> {
Expand Down
Loading

0 comments on commit 1a4e2b6

Please sign in to comment.