Skip to content

Commit

Permalink
Make SearchGraph fully generic
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Jun 18, 2024
1 parent af3d100 commit dba4147
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 95 deletions.
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/traits/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::ty::{

mod cache;

pub use cache::{CacheData, EvaluationCache};
pub use cache::EvaluationCache;

pub type Goal<'tcx, P> = ir::solve::Goal<TyCtxt<'tcx>, P>;
pub type QueryInput<'tcx, P> = ir::solve::QueryInput<TyCtxt<'tcx>, P>;
Expand Down
28 changes: 11 additions & 17 deletions compiler/rustc_middle/src/traits/solve/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use rustc_data_structures::sync::Lock;
use rustc_query_system::cache::WithDepNode;
use rustc_query_system::dep_graph::DepNodeIndex;
use rustc_session::Limit;
use rustc_type_ir::solve::CacheData;

/// The trait solver cache used by `-Znext-solver`.
///
/// FIXME(@lcnr): link to some official documentation of how
Expand All @@ -14,17 +16,9 @@ pub struct EvaluationCache<'tcx> {
map: Lock<FxHashMap<CanonicalInput<'tcx>, CacheEntry<'tcx>>>,
}

#[derive(Debug, PartialEq, Eq)]
pub struct CacheData<'tcx> {
pub result: QueryResult<'tcx>,
pub proof_tree: Option<&'tcx inspect::CanonicalGoalEvaluationStep<TyCtxt<'tcx>>>,
pub additional_depth: usize,
pub encountered_overflow: bool,
}

impl<'tcx> EvaluationCache<'tcx> {
impl<'tcx> rustc_type_ir::inherent::EvaluationCache<TyCtxt<'tcx>> for &'tcx EvaluationCache<'tcx> {
/// Insert a final result into the global cache.
pub fn insert(
fn insert(
&self,
tcx: TyCtxt<'tcx>,
key: CanonicalInput<'tcx>,
Expand All @@ -48,7 +42,7 @@ impl<'tcx> EvaluationCache<'tcx> {
if cfg!(debug_assertions) {
drop(map);
let expected = CacheData { result, proof_tree, additional_depth, encountered_overflow };
let actual = self.get(tcx, key, [], Limit(additional_depth));
let actual = self.get(tcx, key, [], additional_depth);
if !actual.as_ref().is_some_and(|actual| expected == *actual) {
bug!("failed to lookup inserted element for {key:?}: {expected:?} != {actual:?}");
}
Expand All @@ -59,13 +53,13 @@ impl<'tcx> EvaluationCache<'tcx> {
/// and handling root goals of coinductive cycles.
///
/// If this returns `Some` the cache result can be used.
pub fn get(
fn get(
&self,
tcx: TyCtxt<'tcx>,
key: CanonicalInput<'tcx>,
stack_entries: impl IntoIterator<Item = CanonicalInput<'tcx>>,
available_depth: Limit,
) -> Option<CacheData<'tcx>> {
available_depth: usize,
) -> Option<CacheData<TyCtxt<'tcx>>> {
let map = self.map.borrow();
let entry = map.get(&key)?;

Expand All @@ -76,7 +70,7 @@ impl<'tcx> EvaluationCache<'tcx> {
}

if let Some(ref success) = entry.success {
if available_depth.value_within_limit(success.additional_depth) {
if Limit(available_depth).value_within_limit(success.additional_depth) {
let QueryData { result, proof_tree } = success.data.get(tcx);
return Some(CacheData {
result,
Expand All @@ -87,12 +81,12 @@ impl<'tcx> EvaluationCache<'tcx> {
}
}

entry.with_overflow.get(&available_depth.0).map(|e| {
entry.with_overflow.get(&available_depth).map(|e| {
let QueryData { result, proof_tree } = e.get(tcx);
CacheData {
result,
proof_tree,
additional_depth: available_depth.0,
additional_depth: available_depth,
encountered_overflow: true,
}
})
Expand Down
21 changes: 21 additions & 0 deletions compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ use rustc_target::abi::{FieldIdx, Layout, LayoutS, TargetDataLayout, VariantIdx}
use rustc_target::spec::abi;
use rustc_type_ir::fold::TypeFoldable;
use rustc_type_ir::lang_items::TraitSolverLangItem;
use rustc_type_ir::solve::SolverMode;
use rustc_type_ir::TyKind::*;
use rustc_type_ir::{CollectAndApply, Interner, TypeFlags, WithCachedTypeInfo};
use tracing::{debug, instrument};
Expand Down Expand Up @@ -139,10 +140,30 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
type Clause = Clause<'tcx>;
type Clauses = ty::Clauses<'tcx>;

type DepNodeIndex = DepNodeIndex;
fn with_cached_task<T>(self, task: impl FnOnce() -> T) -> (T, DepNodeIndex) {
self.dep_graph.with_anon_task(self, crate::dep_graph::dep_kinds::TraitSelect, task)
}

type EvaluationCache = &'tcx solve::EvaluationCache<'tcx>;
fn evaluation_cache(self, mode: SolverMode) -> &'tcx solve::EvaluationCache<'tcx> {
match mode {
SolverMode::Normal => &self.new_solver_evaluation_cache,
SolverMode::Coherence => &self.new_solver_coherence_evaluation_cache,
}
}

fn expand_abstract_consts<T: TypeFoldable<TyCtxt<'tcx>>>(self, t: T) -> T {
self.expand_abstract_consts(t)
}

fn mk_external_constraints(
self,
data: ExternalConstraintsData<Self>,
) -> ExternalConstraints<'tcx> {
self.mk_external_constraints(data)
}

fn mk_canonical_var_infos(self, infos: &[ty::CanonicalVarInfo<Self>]) -> Self::CanonicalVars {
self.mk_canonical_var_infos(infos)
}
Expand Down
30 changes: 9 additions & 21 deletions compiler/rustc_trait_selection/src/solve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@
//! FIXME(@lcnr): Write that section. If you read this before then ask me
//! about it on zulip.
use rustc_hir::def_id::DefId;
use rustc_infer::infer::canonical::{Canonical, CanonicalVarValues};
use rustc_infer::infer::canonical::Canonical;
use rustc_infer::infer::InferCtxt;
use rustc_infer::traits::query::NoSolution;
use rustc_macros::extension;
use rustc_middle::bug;
use rustc_middle::infer::canonical::CanonicalVarInfos;
use rustc_middle::traits::solve::{
CanonicalResponse, Certainty, ExternalConstraintsData, Goal, GoalSource, QueryResult, Response,
};
use rustc_middle::ty::{
self, AliasRelationDirection, CoercePredicate, RegionOutlivesPredicate, SubtypePredicate, Ty,
TyCtxt, TypeOutlivesPredicate, UniverseIndex,
};
use rustc_type_ir::solve::SolverMode;
use rustc_type_ir::{self as ir, Interner};

mod alias_relate;
mod assembly;
Expand Down Expand Up @@ -57,19 +58,6 @@ pub use select::InferCtxtSelectExt;
/// recursion limit again. However, this feels very unlikely.
const FIXPOINT_STEP_LIMIT: usize = 8;

#[derive(Debug, Clone, Copy)]
enum SolverMode {
/// Ordinary trait solving, using everywhere except for coherence.
Normal,
/// Trait solving during coherence. There are a few notable differences
/// between coherence and ordinary trait solving.
///
/// Most importantly, trait solving during coherence must not be incomplete,
/// i.e. return `Err(NoSolution)` for goals for which a solution exists.
/// This means that we must not make any guesses or arbitrary choices.
Coherence,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum GoalEvaluationKind {
Root,
Expand Down Expand Up @@ -314,17 +302,17 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> {
}
}

fn response_no_constraints_raw<'tcx>(
tcx: TyCtxt<'tcx>,
fn response_no_constraints_raw<I: Interner>(
tcx: I,
max_universe: UniverseIndex,
variables: CanonicalVarInfos<'tcx>,
variables: I::CanonicalVars,
certainty: Certainty,
) -> CanonicalResponse<'tcx> {
Canonical {
) -> ir::solve::CanonicalResponse<I> {
ir::Canonical {
max_universe,
variables,
value: Response {
var_values: CanonicalVarValues::make_identity(tcx, variables),
var_values: ir::CanonicalVarValues::make_identity(tcx, variables),
// FIXME: maybe we should store the "no response" version in tcx, like
// we do for tcx.types and stuff.
external_constraints: tcx.mk_external_constraints(ExternalConstraintsData::default()),
Expand Down
92 changes: 40 additions & 52 deletions compiler/rustc_trait_selection/src/solve/search_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@ use std::mem;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_index::Idx;
use rustc_index::IndexVec;
use rustc_infer::infer::InferCtxt;
use rustc_middle::dep_graph::dep_kinds;
use rustc_middle::traits::solve::CacheData;
use rustc_middle::traits::solve::EvaluationCache;
use rustc_middle::ty::TyCtxt;
use rustc_next_trait_solver::solve::CacheData;
use rustc_next_trait_solver::solve::{CanonicalInput, Certainty, QueryResult};
use rustc_session::Limit;
use rustc_type_ir::inherent::*;
use rustc_type_ir::InferCtxtLike;
use rustc_type_ir::Interner;

use super::inspect;
Expand Down Expand Up @@ -240,34 +237,26 @@ impl<I: Interner> SearchGraph<I> {
!entry.is_empty()
});
}
}

impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
/// The trait solver behavior is different for coherence
/// so we use a separate cache. Alternatively we could use
/// a single cache and share it between coherence and ordinary
/// trait solving.
pub(super) fn global_cache(&self, tcx: TyCtxt<'tcx>) -> &'tcx EvaluationCache<'tcx> {
match self.mode {
SolverMode::Normal => &tcx.new_solver_evaluation_cache,
SolverMode::Coherence => &tcx.new_solver_coherence_evaluation_cache,
}
pub(super) fn global_cache(&self, tcx: I) -> I::EvaluationCache {
tcx.evaluation_cache(self.mode)
}

/// Probably the most involved method of the whole solver.
///
/// Given some goal which is proven via the `prove_goal` closure, this
/// handles caching, overflow, and coinductive cycles.
pub(super) fn with_new_goal(
pub(super) fn with_new_goal<Infcx: InferCtxtLike<Interner = I>>(
&mut self,
tcx: TyCtxt<'tcx>,
input: CanonicalInput<TyCtxt<'tcx>>,
inspect: &mut ProofTreeBuilder<InferCtxt<'tcx>>,
mut prove_goal: impl FnMut(
&mut Self,
&mut ProofTreeBuilder<InferCtxt<'tcx>>,
) -> QueryResult<TyCtxt<'tcx>>,
) -> QueryResult<TyCtxt<'tcx>> {
tcx: I,
input: CanonicalInput<I>,
inspect: &mut ProofTreeBuilder<Infcx>,
mut prove_goal: impl FnMut(&mut Self, &mut ProofTreeBuilder<Infcx>) -> QueryResult<I>,
) -> QueryResult<I> {
self.check_invariants();
// Check for overflow.
let Some(available_depth) = Self::allowed_depth_for_nested(tcx, &self.stack) else {
Expand Down Expand Up @@ -361,21 +350,20 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
// not tracked by the cache key and from outside of this anon task, it
// must not be added to the global cache. Notably, this is the case for
// trait solver cycles participants.
let ((final_entry, result), dep_node) =
tcx.dep_graph.with_anon_task(tcx, dep_kinds::TraitSelect, || {
for _ in 0..FIXPOINT_STEP_LIMIT {
match self.fixpoint_step_in_task(tcx, input, inspect, &mut prove_goal) {
StepResult::Done(final_entry, result) => return (final_entry, result),
StepResult::HasChanged => debug!("fixpoint changed provisional results"),
}
let ((final_entry, result), dep_node) = tcx.with_cached_task(|| {
for _ in 0..FIXPOINT_STEP_LIMIT {
match self.fixpoint_step_in_task(tcx, input, inspect, &mut prove_goal) {
StepResult::Done(final_entry, result) => return (final_entry, result),
StepResult::HasChanged => debug!("fixpoint changed provisional results"),
}
}

debug!("canonical cycle overflow");
let current_entry = self.pop_stack();
debug_assert!(current_entry.has_been_used.is_empty());
let result = Self::response_no_constraints(tcx, input, Certainty::overflow(false));
(current_entry, result)
});
debug!("canonical cycle overflow");
let current_entry = self.pop_stack();
debug_assert!(current_entry.has_been_used.is_empty());
let result = Self::response_no_constraints(tcx, input, Certainty::overflow(false));
(current_entry, result)
});

let proof_tree = inspect.finalize_canonical_goal_evaluation(tcx);

Expand Down Expand Up @@ -423,16 +411,17 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
/// Try to fetch a previously computed result from the global cache,
/// making sure to only do so if it would match the result of reevaluating
/// this goal.
fn lookup_global_cache(
fn lookup_global_cache<Infcx: InferCtxtLike<Interner = I>>(
&mut self,
tcx: TyCtxt<'tcx>,
input: CanonicalInput<TyCtxt<'tcx>>,
tcx: I,
input: CanonicalInput<I>,
available_depth: Limit,
inspect: &mut ProofTreeBuilder<InferCtxt<'tcx>>,
) -> Option<QueryResult<TyCtxt<'tcx>>> {
inspect: &mut ProofTreeBuilder<Infcx>,
) -> Option<QueryResult<I>> {
let CacheData { result, proof_tree, additional_depth, encountered_overflow } = self
.global_cache(tcx)
.get(tcx, input, self.stack.iter().map(|e| e.input), available_depth)?;
// TODO: Awkward `Limit -> usize -> Limit`.
.get(tcx, input, self.stack.iter().map(|e| e.input), available_depth.0)?;

// If we're building a proof tree and the current cache entry does not
// contain a proof tree, we do not use the entry but instead recompute
Expand Down Expand Up @@ -465,21 +454,22 @@ enum StepResult<I: Interner> {
HasChanged,
}

impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
impl<I: Interner> SearchGraph<I> {
/// When we encounter a coinductive cycle, we have to fetch the
/// result of that cycle while we are still computing it. Because
/// of this we continuously recompute the cycle until the result
/// of the previous iteration is equal to the final result, at which
/// point we are done.
fn fixpoint_step_in_task<F>(
fn fixpoint_step_in_task<Infcx, F>(
&mut self,
tcx: TyCtxt<'tcx>,
input: CanonicalInput<TyCtxt<'tcx>>,
inspect: &mut ProofTreeBuilder<InferCtxt<'tcx>>,
tcx: I,
input: CanonicalInput<I>,
inspect: &mut ProofTreeBuilder<Infcx>,
prove_goal: &mut F,
) -> StepResult<TyCtxt<'tcx>>
) -> StepResult<I>
where
F: FnMut(&mut Self, &mut ProofTreeBuilder<InferCtxt<'tcx>>) -> QueryResult<TyCtxt<'tcx>>,
Infcx: InferCtxtLike<Interner = I>,
F: FnMut(&mut Self, &mut ProofTreeBuilder<Infcx>) -> QueryResult<I>,
{
let result = prove_goal(self, inspect);
let stack_entry = self.pop_stack();
Expand Down Expand Up @@ -533,15 +523,13 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
}

fn response_no_constraints(
tcx: TyCtxt<'tcx>,
goal: CanonicalInput<TyCtxt<'tcx>>,
tcx: I,
goal: CanonicalInput<I>,
certainty: Certainty,
) -> QueryResult<TyCtxt<'tcx>> {
) -> QueryResult<I> {
Ok(super::response_no_constraints_raw(tcx, goal.max_universe, goal.variables, certainty))
}
}

impl<I: Interner> SearchGraph<I> {
#[allow(rustc::potential_query_instability)]
fn check_invariants(&self) {
if !cfg!(debug_assertions) {
Expand Down
Loading

0 comments on commit dba4147

Please sign in to comment.