diff --git a/tracing-attributes/src/attr.rs b/tracing-attributes/src/attr.rs index b74e88afd2..ff875e1797 100644 --- a/tracing-attributes/src/attr.rs +++ b/tracing-attributes/src/attr.rs @@ -12,6 +12,7 @@ pub(crate) struct InstrumentArgs { pub(crate) name: Option, target: Option, pub(crate) parent: Option, + pub(crate) follows_from: Option, pub(crate) skips: HashSet, pub(crate) skip_all: bool, pub(crate) fields: Option, @@ -130,6 +131,12 @@ impl Parse for InstrumentArgs { } let parent = input.parse::>()?; args.parent = Some(parent.value); + } else if lookahead.peek(kw::follows_from) { + if args.target.is_some() { + return Err(input.error("expected only a single `follows_from` argument")); + } + let follows_from = input.parse::>()?; + args.follows_from = Some(follows_from.value); } else if lookahead.peek(kw::level) { if args.level.is_some() { return Err(input.error("expected only a single `level` argument")); @@ -399,6 +406,7 @@ mod kw { syn::custom_keyword!(level); syn::custom_keyword!(target); syn::custom_keyword!(parent); + syn::custom_keyword!(follows_from); syn::custom_keyword!(name); syn::custom_keyword!(err); syn::custom_keyword!(ret); diff --git a/tracing-attributes/src/expand.rs b/tracing-attributes/src/expand.rs index a629af18bc..b563d4bbfe 100644 --- a/tracing-attributes/src/expand.rs +++ b/tracing-attributes/src/expand.rs @@ -88,6 +88,13 @@ fn gen_block( let level = args.level(); + let follows_from = args.follows_from.iter(); + let follows_from = quote! { + #(for cause in #follows_from { + __tracing_attr_span.follows_from(cause); + })* + }; + // generate this inside a closure, so we can return early on errors. let span = (|| { // Pull out the arguments-to-be-skipped first, so we can filter results @@ -261,6 +268,7 @@ fn gen_block( let __tracing_attr_span = #span; let __tracing_instrument_future = #mk_fut; if !__tracing_attr_span.is_disabled() { + #follows_from tracing::Instrument::instrument( __tracing_instrument_future, __tracing_attr_span @@ -287,6 +295,7 @@ fn gen_block( let __tracing_attr_guard; if tracing::level_enabled!(#level) { __tracing_attr_span = #span; + #follows_from __tracing_attr_guard = __tracing_attr_span.enter(); } ); diff --git a/tracing-attributes/src/lib.rs b/tracing-attributes/src/lib.rs index 749affb876..115da88460 100644 --- a/tracing-attributes/src/lib.rs +++ b/tracing-attributes/src/lib.rs @@ -368,6 +368,24 @@ mod expand; /// fn my_method(&self) {} /// } /// ``` +/// Specifying [`follows_from`] relationships: +/// ``` +/// # use tracing_attributes::instrument; +/// #[instrument(follows_from = causes)] +/// pub fn my_function(causes: &[tracing::Id]) { +/// // ... +/// } +/// ``` +/// Any expression of type `impl IntoIterator>>` +/// may be provided to `follows_from`; e.g.: +/// ``` +/// # use tracing_attributes::instrument; +/// #[instrument(follows_from = [cause])] +/// pub fn my_function(cause: &tracing::span::EnteredSpan) { +/// // ... +/// } +/// ``` +/// /// /// To skip recording an argument, pass the argument's name to the `skip`: /// @@ -524,6 +542,8 @@ mod expand; /// [`INFO`]: https://docs.rs/tracing/latest/tracing/struct.Level.html#associatedconstant.INFO /// [empty field]: https://docs.rs/tracing/latest/tracing/field/struct.Empty.html /// [field syntax]: https://docs.rs/tracing/latest/tracing/#recording-fields +/// [`follows_from`]: https://docs.rs/tracing/latest/tracing/struct.Span.html#method.follows_from +/// [`tracing`]: https://github.com/tokio-rs/tracing /// [`fmt::Debug`]: std::fmt::Debug #[proc_macro_attribute] pub fn instrument( diff --git a/tracing-attributes/tests/follows_from.rs b/tracing-attributes/tests/follows_from.rs new file mode 100644 index 0000000000..da0eec6357 --- /dev/null +++ b/tracing-attributes/tests/follows_from.rs @@ -0,0 +1,99 @@ +use tracing::{subscriber::with_default, Id, Level, Span}; +use tracing_attributes::instrument; +use tracing_mock::*; + +#[instrument(follows_from = causes, skip(causes))] +fn with_follows_from_sync(causes: impl IntoIterator>>) {} + +#[instrument(follows_from = causes, skip(causes))] +async fn with_follows_from_async(causes: impl IntoIterator>>) {} + +#[instrument(follows_from = [&Span::current()])] +fn follows_from_current() {} + +#[test] +fn follows_from_sync_test() { + let cause_a = span::mock().named("cause_a"); + let cause_b = span::mock().named("cause_b"); + let cause_c = span::mock().named("cause_c"); + let consequence = span::mock().named("with_follows_from_sync"); + + let (subscriber, handle) = subscriber::mock() + .new_span(cause_a.clone()) + .new_span(cause_b.clone()) + .new_span(cause_c.clone()) + .new_span(consequence.clone()) + .follows_from(consequence.clone(), cause_a) + .follows_from(consequence.clone(), cause_b) + .follows_from(consequence.clone(), cause_c) + .enter(consequence.clone()) + .exit(consequence) + .done() + .run_with_handle(); + + with_default(subscriber, || { + let cause_a = tracing::span!(Level::TRACE, "cause_a"); + let cause_b = tracing::span!(Level::TRACE, "cause_b"); + let cause_c = tracing::span!(Level::TRACE, "cause_c"); + + with_follows_from_sync(&[cause_a, cause_b, cause_c]) + }); + + handle.assert_finished(); +} + +#[test] +fn follows_from_async_test() { + let cause_a = span::mock().named("cause_a"); + let cause_b = span::mock().named("cause_b"); + let cause_c = span::mock().named("cause_c"); + let consequence = span::mock().named("with_follows_from_async"); + + let (subscriber, handle) = subscriber::mock() + .new_span(cause_a.clone()) + .new_span(cause_b.clone()) + .new_span(cause_c.clone()) + .new_span(consequence.clone()) + .follows_from(consequence.clone(), cause_a) + .follows_from(consequence.clone(), cause_b) + .follows_from(consequence.clone(), cause_c) + .enter(consequence.clone()) + .exit(consequence) + .done() + .run_with_handle(); + + with_default(subscriber, || { + block_on_future(async { + let cause_a = tracing::span!(Level::TRACE, "cause_a"); + let cause_b = tracing::span!(Level::TRACE, "cause_b"); + let cause_c = tracing::span!(Level::TRACE, "cause_c"); + + with_follows_from_async(&[cause_a, cause_b, cause_c]).await + }) + }); + + handle.assert_finished(); +} + +#[test] +fn follows_from_current_test() { + let cause = span::mock().named("cause"); + let consequence = span::mock().named("follows_from_current"); + + let (subscriber, handle) = subscriber::mock() + .new_span(cause.clone()) + .enter(cause.clone()) + .new_span(consequence.clone()) + .follows_from(consequence.clone(), cause.clone()) + .enter(consequence.clone()) + .exit(consequence) + .exit(cause) + .done() + .run_with_handle(); + + with_default(subscriber, || { + tracing::span!(Level::TRACE, "cause").in_scope(follows_from_current) + }); + + handle.assert_finished(); +} diff --git a/tracing-mock/src/subscriber.rs b/tracing-mock/src/subscriber.rs index 17e1a7ed73..32f27ada00 100644 --- a/tracing-mock/src/subscriber.rs +++ b/tracing-mock/src/subscriber.rs @@ -23,6 +23,10 @@ use tracing::{ #[derive(Debug, Eq, PartialEq)] pub enum Expect { Event(MockEvent), + FollowsFrom { + consequence: MockSpan, + cause: MockSpan, + }, Enter(MockSpan), Exit(MockSpan), CloneSpan(MockSpan), @@ -35,6 +39,7 @@ pub enum Expect { struct SpanState { name: &'static str, refs: usize, + meta: &'static Metadata<'static>, } struct Running) -> bool> { @@ -97,6 +102,12 @@ where self } + pub fn follows_from(mut self, consequence: MockSpan, cause: MockSpan) -> Self { + self.expected + .push_back(Expect::FollowsFrom { consequence, cause }); + self + } + pub fn event(mut self, event: MockEvent) -> Self { self.expected.push_back(Expect::Event(event)); self @@ -250,8 +261,37 @@ where } } - fn record_follows_from(&self, _span: &Id, _follows: &Id) { - // TODO: it should be possible to expect spans to follow from other spans + fn record_follows_from(&self, consequence_id: &Id, cause_id: &Id) { + let spans = self.spans.lock().unwrap(); + if let Some(consequence_span) = spans.get(consequence_id) { + if let Some(cause_span) = spans.get(cause_id) { + println!( + "[{}] record_follows_from: {} (id={:?}) follows {} (id={:?})", + self.name, consequence_span.name, consequence_id, cause_span.name, cause_id, + ); + match self.expected.lock().unwrap().pop_front() { + None => {} + Some(Expect::FollowsFrom { + consequence: ref expected_consequence, + cause: ref expected_cause, + }) => { + if let Some(name) = expected_consequence.name() { + assert_eq!(name, consequence_span.name); + } + if let Some(name) = expected_cause.name() { + assert_eq!(name, cause_span.name); + } + } + Some(ex) => ex.bad( + &self.name, + format_args!( + "consequence {:?} followed cause {:?}", + consequence_span.name, cause_span.name + ), + ), + } + } + }; } fn new_span(&self, span: &Attributes<'_>) -> Id { @@ -284,6 +324,7 @@ where id.clone(), SpanState { name: meta.name(), + meta, refs: 1, }, ); @@ -415,6 +456,18 @@ where } } } + + fn current_span(&self) -> tracing_core::span::Current { + let stack = self.current.lock().unwrap(); + match stack.last() { + Some(id) => { + let spans = self.spans.lock().unwrap(); + let state = spans.get(id).expect("state for current span"); + tracing_core::span::Current::new(id.clone(), state.meta) + } + None => tracing_core::span::Current::none(), + } + } } impl MockHandle { @@ -442,6 +495,10 @@ impl Expect { "\n[{}] expected event {}\n[{}] but instead {}", name, e, name, what, ), + Expect::FollowsFrom { consequence, cause } => panic!( + "\n[{}] expected consequence {} to follow cause {} but instead {}", + name, consequence, cause, what, + ), Expect::Enter(e) => panic!( "\n[{}] expected to enter {}\n[{}] but instead {}", name, e, name, what,