Skip to content

Commit

Permalink
AdjacencyMap::reverse_topological (+ fixes) (#5527)
Browse files Browse the repository at this point in the history
### Description

This adds `AdjacencyMap::reverse_topological`, which is similar to
`AdjacencyMap::into_reverse_topological` but doesn't consume the graph.

This also:
* Makes `AdjacencyMap` storable in `turbo_tasks::value`s;
* Fixes ValueDebugFormat and TraceRawVcs derive macros so they support
generic argument and bounds properly.

### Testing Instructions

N/A
  • Loading branch information
alexkirsz committed Jul 17, 2023
1 parent 8433a32 commit b069545
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 18 deletions.
13 changes: 9 additions & 4 deletions crates/turbo-tasks-macros/src/derive/trace_raw_vcs_macro.rs
Expand Up @@ -11,14 +11,19 @@ fn filter_field(field: &Field) -> bool {
}

pub fn derive_trace_raw_vcs(input: TokenStream) -> TokenStream {
let derive_input = parse_macro_input!(input as DeriveInput);
let mut derive_input = parse_macro_input!(input as DeriveInput);
let ident = &derive_input.ident;
let generics = &derive_input.generics;

for type_param in derive_input.generics.type_params_mut() {
type_param
.bounds
.push(syn::parse_quote!(turbo_tasks::trace::TraceRawVcs));
}
let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl();

let trace_items = match_expansion(&derive_input, &trace_named, &trace_unnamed, &trace_unit);
let generics_params = &generics.params.iter().collect::<Vec<_>>();
quote! {
impl #generics turbo_tasks::trace::TraceRawVcs for #ident #generics #(where #generics_params: turbo_tasks::trace::TraceRawVcs)* {
impl #impl_generics turbo_tasks::trace::TraceRawVcs for #ident #ty_generics #where_clause {
fn trace_raw_vcs(&self, __context__: &mut turbo_tasks::trace::TraceRawVcsContext) {
#trace_items
}
Expand Down
17 changes: 14 additions & 3 deletions crates/turbo-tasks-macros/src/derive/value_debug_format_macro.rs
Expand Up @@ -16,17 +16,28 @@ fn filter_field(field: &Field) -> bool {
/// Fields annotated with `#[debug_ignore]` will not appear in the
/// `ValueDebugFormat` representation of the type.
pub fn derive_value_debug_format(input: TokenStream) -> TokenStream {
let derive_input = parse_macro_input!(input as DeriveInput);
let mut derive_input = parse_macro_input!(input as DeriveInput);

let ident = &derive_input.ident;

for type_param in derive_input.generics.type_params_mut() {
type_param
.bounds
.push(syn::parse_quote!(turbo_tasks::debug::ValueDebugFormat));
type_param.bounds.push(syn::parse_quote!(std::fmt::Debug));
type_param.bounds.push(syn::parse_quote!(std::marker::Send));
type_param.bounds.push(syn::parse_quote!(std::marker::Sync));
}
let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl();

let formatting_logic =
match_expansion(&derive_input, &format_named, &format_unnamed, &format_unit);

let value_debug_format_ident = get_value_debug_format_ident(ident);

quote! {
#[doc(hidden)]
impl #ident {
impl #impl_generics #ident #ty_generics #where_clause {
#[doc(hidden)]
#[allow(non_snake_case)]
async fn #value_debug_format_ident(&self, depth: usize) -> anyhow::Result<turbo_tasks::Vc<turbo_tasks::debug::ValueDebugString>> {
Expand All @@ -40,7 +51,7 @@ pub fn derive_value_debug_format(input: TokenStream) -> TokenStream {
}
}

impl turbo_tasks::debug::ValueDebugFormat for #ident {
impl #impl_generics turbo_tasks::debug::ValueDebugFormat for #ident #ty_generics #where_clause {
fn value_debug_format<'a>(&'a self, depth: usize) -> turbo_tasks::debug::ValueDebugFormatString<'a> {
turbo_tasks::debug::ValueDebugFormatString::Async(
Box::pin(async move {
Expand Down
41 changes: 30 additions & 11 deletions crates/turbo-tasks/src/graph/adjacency_map.rs
@@ -1,8 +1,13 @@
use std::collections::{HashMap, HashSet};

use serde::{Deserialize, Serialize};
use turbo_tasks_macros::{TraceRawVcs, ValueDebugFormat};

use super::graph_store::{GraphNode, GraphStore};
use crate as turbo_tasks;

/// A graph traversal that builds an adjacency map
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, TraceRawVcs, ValueDebugFormat)]
pub struct AdjacencyMap<T>
where
T: Eq + std::hash::Hash + Clone,
Expand Down Expand Up @@ -68,10 +73,10 @@ impl<T> AdjacencyMap<T>
where
T: Eq + std::hash::Hash + Clone,
{
/// Returns an iterator over the nodes in reverse topological order,
/// Returns an owned iterator over the nodes in reverse topological order,
/// starting from the roots.
pub fn into_reverse_topological(self) -> ReverseTopologicalIter<T> {
ReverseTopologicalIter {
pub fn into_reverse_topological(self) -> IntoReverseTopologicalIter<T> {
IntoReverseTopologicalIter {
adjacency_map: self.adjacency_map,
stack: self
.roots
Expand All @@ -82,13 +87,27 @@ where
}
}

/// Returns an iterator over the nodes in reverse topological order,
/// starting from the roots.
pub fn reverse_topological(&self) -> ReverseTopologicalIter<T> {
ReverseTopologicalIter {
adjacency_map: &self.adjacency_map,
stack: self
.roots
.iter()
.map(|root| (ReverseTopologicalPass::Pre, root))
.collect(),
visited: HashSet::new(),
}
}

/// Returns an iterator over the nodes in reverse topological order,
/// starting from the given node.
pub fn reverse_topological_from_node<'graph>(
&'graph self,
node: &'graph T,
) -> ReverseTopologicalFromNodeIter<'graph, T> {
ReverseTopologicalFromNodeIter {
) -> ReverseTopologicalIter<'graph, T> {
ReverseTopologicalIter {
adjacency_map: &self.adjacency_map,
stack: vec![(ReverseTopologicalPass::Pre, node)],
visited: HashSet::new(),
Expand All @@ -104,7 +123,7 @@ enum ReverseTopologicalPass {

/// An iterator over the nodes of a graph in reverse topological order, starting
/// from the roots.
pub struct ReverseTopologicalIter<T>
pub struct IntoReverseTopologicalIter<T>
where
T: Eq + std::hash::Hash + Clone,
{
Expand All @@ -113,7 +132,7 @@ where
visited: HashSet<T>,
}

impl<T> Iterator for ReverseTopologicalIter<T>
impl<T> Iterator for IntoReverseTopologicalIter<T>
where
T: Eq + std::hash::Hash + Clone,
{
Expand Down Expand Up @@ -153,8 +172,8 @@ where
}

/// An iterator over the nodes of a graph in reverse topological order, starting
/// from a given node.
pub struct ReverseTopologicalFromNodeIter<'graph, T>
/// from the roots.
pub struct ReverseTopologicalIter<'graph, T>
where
T: Eq + std::hash::Hash + Clone,
{
Expand All @@ -163,7 +182,7 @@ where
visited: HashSet<&'graph T>,
}

impl<'graph, T> Iterator for ReverseTopologicalFromNodeIter<'graph, T>
impl<'graph, T> Iterator for ReverseTopologicalIter<'graph, T>
where
T: Eq + std::hash::Hash + Clone,
{
Expand All @@ -178,7 +197,7 @@ where
break current;
}
ReverseTopologicalPass::Pre => {
if self.visited.contains(&current) {
if self.visited.contains(current) {
continue;
}

Expand Down

0 comments on commit b069545

Please sign in to comment.