Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harden the pre-tyctxt query system against accidental recomputation #105603

Merged
merged 1 commit into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions compiler/rustc_data_structures/src/steal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ impl<T> Steal<T> {
ReadGuard::map(borrow, |opt| opt.as_ref().unwrap())
}

#[track_caller]
pub fn get_mut(&mut self) -> &mut T {
self.value.get_mut().as_mut().expect("attempt to read from stolen value")
}

#[track_caller]
pub fn steal(&self) -> T {
let value_ref = &mut *self.value.try_write().expect("stealing value which is locked");
Expand Down
11 changes: 6 additions & 5 deletions compiler/rustc_driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ fn run_compiler(

if let Some(ppm) = &sess.opts.pretty {
if ppm.needs_ast_map() {
let expanded_crate = queries.expansion()?.peek().0.clone();
queries.global_ctxt()?.peek_mut().enter(|tcx| {
let expanded_crate = queries.expansion()?.borrow().0.clone();
queries.global_ctxt()?.enter(|tcx| {
pretty::print_after_hir_lowering(
tcx,
compiler.input(),
Expand All @@ -321,7 +321,7 @@ fn run_compiler(
Ok(())
})?;
} else {
let krate = queries.parse()?.take();
let krate = queries.parse()?.steal();
pretty::print_after_parsing(
sess,
compiler.input(),
Expand All @@ -343,7 +343,8 @@ fn run_compiler(
}

{
let (_, lint_store) = &*queries.register_plugins()?.peek();
let plugins = queries.register_plugins()?;
let (_, lint_store) = &*plugins.borrow();

// Lint plugins are registered; now we can process command line flags.
if sess.opts.describe_lints {
Expand Down Expand Up @@ -371,7 +372,7 @@ fn run_compiler(
return early_exit();
}

queries.global_ctxt()?.peek_mut().enter(|tcx| {
queries.global_ctxt()?.enter(|tcx| {
let result = tcx.analysis(());
if sess.opts.unstable_opts.save_analysis {
let crate_name = tcx.crate_name(LOCAL_CRATE);
Expand Down
108 changes: 60 additions & 48 deletions compiler/rustc_interface/src/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::passes::{self, BoxedResolver, QueryContext};
use rustc_ast as ast;
use rustc_codegen_ssa::traits::CodegenBackend;
use rustc_codegen_ssa::CodegenResults;
use rustc_data_structures::steal::Steal;
use rustc_data_structures::svh::Svh;
use rustc_data_structures::sync::{Lrc, OnceCell, WorkerLocal};
use rustc_hir::def_id::LOCAL_CRATE;
Expand All @@ -19,43 +20,53 @@ use rustc_session::{output::find_crate_name, Session};
use rustc_span::symbol::sym;
use rustc_span::Symbol;
use std::any::Any;
use std::cell::{Ref, RefCell, RefMut};
use std::cell::{RefCell, RefMut};
use std::rc::Rc;
use std::sync::Arc;

/// Represent the result of a query.
///
/// This result can be stolen with the [`take`] method and generated with the [`compute`] method.
/// This result can be stolen once with the [`steal`] method and generated with the [`compute`] method.
///
/// [`take`]: Self::take
/// [`steal`]: Steal::steal
/// [`compute`]: Self::compute
pub struct Query<T> {
result: RefCell<Option<Result<T>>>,
/// `None` means no value has been computed yet.
result: RefCell<Option<Result<Steal<T>>>>,
}

impl<T> Query<T> {
fn compute<F: FnOnce() -> Result<T>>(&self, f: F) -> Result<&Query<T>> {
self.result.borrow_mut().get_or_insert_with(f).as_ref().map(|_| self).map_err(|&err| err)
fn compute<F: FnOnce() -> Result<T>>(&self, f: F) -> Result<QueryResult<'_, T>> {
RefMut::filter_map(
self.result.borrow_mut(),
|r: &mut Option<Result<Steal<T>>>| -> Option<&mut Steal<T>> {
r.get_or_insert_with(|| f().map(Steal::new)).as_mut().ok()
},
)
.map_err(|r| *r.as_ref().unwrap().as_ref().map(|_| ()).unwrap_err())
.map(QueryResult)
}
}

pub struct QueryResult<'a, T>(RefMut<'a, Steal<T>>);

impl<'a, T> std::ops::Deref for QueryResult<'a, T> {
type Target = RefMut<'a, Steal<T>>;

/// Takes ownership of the query result. Further attempts to take or peek the query
/// result will panic unless it is generated by calling the `compute` method.
pub fn take(&self) -> T {
self.result.borrow_mut().take().expect("missing query result").unwrap()
fn deref(&self) -> &Self::Target {
&self.0
}
}

/// Borrows the query result using the RefCell. Panics if the result is stolen.
pub fn peek(&self) -> Ref<'_, T> {
Ref::map(self.result.borrow(), |r| {
r.as_ref().unwrap().as_ref().expect("missing query result")
})
impl<'a, T> std::ops::DerefMut for QueryResult<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

/// Mutably borrows the query result using the RefCell. Panics if the result is stolen.
pub fn peek_mut(&self) -> RefMut<'_, T> {
RefMut::map(self.result.borrow_mut(), |r| {
r.as_mut().unwrap().as_mut().expect("missing query result")
})
impl<'a, 'tcx> QueryResult<'a, QueryContext<'tcx>> {
pub fn enter<T>(mut self, f: impl FnOnce(TyCtxt<'tcx>) -> T) -> T {
(*self.0).get_mut().enter(f)
}
}

Expand Down Expand Up @@ -111,24 +122,24 @@ impl<'tcx> Queries<'tcx> {
self.compiler.codegen_backend()
}

fn dep_graph_future(&self) -> Result<&Query<Option<DepGraphFuture>>> {
fn dep_graph_future(&self) -> Result<QueryResult<'_, Option<DepGraphFuture>>> {
self.dep_graph_future.compute(|| {
let sess = self.session();
Ok(sess.opts.build_dep_graph().then(|| rustc_incremental::load_dep_graph(sess)))
})
}

pub fn parse(&self) -> Result<&Query<ast::Crate>> {
pub fn parse(&self) -> Result<QueryResult<'_, ast::Crate>> {
self.parse.compute(|| {
passes::parse(self.session(), &self.compiler.input)
.map_err(|mut parse_error| parse_error.emit())
})
}

pub fn register_plugins(&self) -> Result<&Query<(ast::Crate, Lrc<LintStore>)>> {
pub fn register_plugins(&self) -> Result<QueryResult<'_, (ast::Crate, Lrc<LintStore>)>> {
self.register_plugins.compute(|| {
let crate_name = *self.crate_name()?.peek();
let krate = self.parse()?.take();
let crate_name = *self.crate_name()?.borrow();
let krate = self.parse()?.steal();

let empty: &(dyn Fn(&Session, &mut LintStore) + Sync + Send) = &|_, _| {};
let (krate, lint_store) = passes::register_plugins(
Expand All @@ -150,11 +161,11 @@ impl<'tcx> Queries<'tcx> {
})
}

pub fn crate_name(&self) -> Result<&Query<Symbol>> {
pub fn crate_name(&self) -> Result<QueryResult<'_, Symbol>> {
self.crate_name.compute(|| {
Ok({
let parse_result = self.parse()?;
let krate = parse_result.peek();
let krate = parse_result.borrow();
// parse `#[crate_name]` even if `--crate-name` was passed, to make sure it matches.
find_crate_name(self.session(), &krate.attrs, &self.compiler.input)
})
Expand All @@ -163,11 +174,12 @@ impl<'tcx> Queries<'tcx> {

pub fn expansion(
&self,
) -> Result<&Query<(Lrc<ast::Crate>, Rc<RefCell<BoxedResolver>>, Lrc<LintStore>)>> {
) -> Result<QueryResult<'_, (Lrc<ast::Crate>, Rc<RefCell<BoxedResolver>>, Lrc<LintStore>)>>
{
trace!("expansion");
self.expansion.compute(|| {
let crate_name = *self.crate_name()?.peek();
let (krate, lint_store) = self.register_plugins()?.take();
let crate_name = *self.crate_name()?.borrow();
let (krate, lint_store) = self.register_plugins()?.steal();
let _timer = self.session().timer("configure_and_expand");
let sess = self.session();
let mut resolver = passes::create_resolver(
Expand All @@ -183,10 +195,10 @@ impl<'tcx> Queries<'tcx> {
})
}

fn dep_graph(&self) -> Result<&Query<DepGraph>> {
fn dep_graph(&self) -> Result<QueryResult<'_, DepGraph>> {
self.dep_graph.compute(|| {
let sess = self.session();
let future_opt = self.dep_graph_future()?.take();
let future_opt = self.dep_graph_future()?.steal();
let dep_graph = future_opt
.and_then(|future| {
let (prev_graph, prev_work_products) =
Expand All @@ -199,10 +211,11 @@ impl<'tcx> Queries<'tcx> {
})
}

pub fn prepare_outputs(&self) -> Result<&Query<OutputFilenames>> {
pub fn prepare_outputs(&self) -> Result<QueryResult<'_, OutputFilenames>> {
self.prepare_outputs.compute(|| {
let (krate, boxed_resolver, _) = &*self.expansion()?.peek();
let crate_name = *self.crate_name()?.peek();
let expansion = self.expansion()?;
let (krate, boxed_resolver, _) = &*expansion.borrow();
let crate_name = *self.crate_name()?.borrow();
passes::prepare_outputs(
self.session(),
self.compiler,
Expand All @@ -213,12 +226,12 @@ impl<'tcx> Queries<'tcx> {
})
}

pub fn global_ctxt(&'tcx self) -> Result<&Query<QueryContext<'tcx>>> {
pub fn global_ctxt(&'tcx self) -> Result<QueryResult<'_, QueryContext<'tcx>>> {
self.global_ctxt.compute(|| {
let crate_name = *self.crate_name()?.peek();
let outputs = self.prepare_outputs()?.take();
let dep_graph = self.dep_graph()?.peek().clone();
let (krate, resolver, lint_store) = self.expansion()?.take();
let crate_name = *self.crate_name()?.borrow();
let outputs = self.prepare_outputs()?.steal();
let dep_graph = self.dep_graph()?.borrow().clone();
let (krate, resolver, lint_store) = self.expansion()?.steal();
Ok(passes::create_global_ctxt(
self.compiler,
lint_store,
Expand All @@ -235,9 +248,9 @@ impl<'tcx> Queries<'tcx> {
})
}

pub fn ongoing_codegen(&'tcx self) -> Result<&Query<Box<dyn Any>>> {
pub fn ongoing_codegen(&'tcx self) -> Result<QueryResult<'_, Box<dyn Any>>> {
self.ongoing_codegen.compute(|| {
self.global_ctxt()?.peek_mut().enter(|tcx| {
self.global_ctxt()?.enter(|tcx| {
tcx.analysis(()).ok();

// Don't do code generation if there were any errors
Expand Down Expand Up @@ -293,12 +306,10 @@ impl<'tcx> Queries<'tcx> {
let sess = self.session().clone();
let codegen_backend = self.codegen_backend().clone();

let dep_graph = self.dep_graph()?.peek().clone();
let (crate_hash, prepare_outputs) = self
.global_ctxt()?
.peek_mut()
.enter(|tcx| (tcx.crate_hash(LOCAL_CRATE), tcx.output_filenames(()).clone()));
let ongoing_codegen = self.ongoing_codegen()?.take();
let (crate_hash, prepare_outputs, dep_graph) = self.global_ctxt()?.enter(|tcx| {
(tcx.crate_hash(LOCAL_CRATE), tcx.output_filenames(()).clone(), tcx.dep_graph.clone())
});
let ongoing_codegen = self.ongoing_codegen()?.steal();

Ok(Linker {
sess,
Expand Down Expand Up @@ -382,6 +393,7 @@ impl Compiler {
// NOTE: intentionally does not compute the global context if it hasn't been built yet,
// since that likely means there was a parse error.
if let Some(Ok(gcx)) = &mut *queries.global_ctxt.result.borrow_mut() {
let gcx = gcx.get_mut();
// We assume that no queries are run past here. If there are new queries
// after this point, they'll show up as "<unknown>" in self-profiling data.
{
Expand Down
8 changes: 2 additions & 6 deletions src/librustdoc/doctest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ pub(crate) fn run(options: RustdocOptions) -> Result<(), ErrorGuaranteed> {
let (tests, unused_extern_reports, compiling_test_count) =
interface::run_compiler(config, |compiler| {
compiler.enter(|queries| {
let mut global_ctxt = queries.global_ctxt()?.take();

let collector = global_ctxt.enter(|tcx| {
let collector = queries.global_ctxt()?.enter(|tcx| {
let crate_attrs = tcx.hir().attrs(CRATE_HIR_ID);

let opts = scrape_test_config(crate_attrs);
Expand Down Expand Up @@ -156,9 +154,7 @@ pub(crate) fn run(options: RustdocOptions) -> Result<(), ErrorGuaranteed> {

let unused_extern_reports = collector.unused_extern_reports.clone();
let compiling_test_count = collector.compiling_test_count.load(Ordering::SeqCst);
let ret: Result<_, ErrorGuaranteed> =
Ok((collector.tests, unused_extern_reports, compiling_test_count));
ret
Ok((collector.tests, unused_extern_reports, compiling_test_count))
})
})?;

Expand Down
5 changes: 3 additions & 2 deletions src/librustdoc/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,8 @@ fn main_args(at_args: &[String]) -> MainResult {
// FIXME(#83761): Resolver cloning can lead to inconsistencies between data in the
// two copies because one of the copies can be modified after `TyCtxt` construction.
let (resolver, resolver_caches) = {
let (krate, resolver, _) = &*abort_on_err(queries.expansion(), sess).peek();
let expansion = abort_on_err(queries.expansion(), sess);
let (krate, resolver, _) = &*expansion.borrow();
let resolver_caches = resolver.borrow_mut().access(|resolver| {
collect_intra_doc_links::early_resolve_intra_doc_links(
resolver,
Expand All @@ -817,7 +818,7 @@ fn main_args(at_args: &[String]) -> MainResult {
sess.fatal("Compilation failed, aborting rustdoc");
}

let mut global_ctxt = abort_on_err(queries.global_ctxt(), sess).peek_mut();
let global_ctxt = abort_on_err(queries.global_ctxt(), sess);

global_ctxt.enter(|tcx| {
let (krate, render_opts, mut cache) = sess.time("run_global_ctxt", || {
Expand Down
2 changes: 1 addition & 1 deletion src/tools/miri/src/bin/miri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl rustc_driver::Callbacks for MiriCompilerCalls {
) -> Compilation {
compiler.session().abort_if_errors();

queries.global_ctxt().unwrap().peek_mut().enter(|tcx| {
queries.global_ctxt().unwrap().enter(|tcx| {
init_late_loggers(tcx);
if !tcx.sess.crate_types().contains(&CrateType::Executable) {
tcx.sess.fatal("miri only makes sense on bin crates");
Expand Down
2 changes: 1 addition & 1 deletion tests/run-make-fulldeps/obtain-borrowck/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl rustc_driver::Callbacks for CompilerCalls {
queries: &'tcx Queries<'tcx>,
) -> Compilation {
compiler.session().abort_if_errors();
queries.global_ctxt().unwrap().peek_mut().enter(|tcx| {
queries.global_ctxt().unwrap().enter(|tcx| {
// Collect definition ids of MIR bodies.
let hir = tcx.hir();
let mut bodies = Vec::new();
Expand Down