diff --git a/core/src/worker/activities/local_activities.rs b/core/src/worker/activities/local_activities.rs index 31357575f..106b7d49a 100644 --- a/core/src/worker/activities/local_activities.rs +++ b/core/src/worker/activities/local_activities.rs @@ -490,7 +490,7 @@ impl LocalActivityManager { dispatch_time: Instant::now(), attempt, _permit: permit.into_used(LocalActivitySlotInfo { - activity_type: new_la.workflow_type.clone(), + activity_type: sa.activity_type.clone(), }), }, ); diff --git a/tests/integ_tests/worker_tests.rs b/tests/integ_tests/worker_tests.rs index 728ba3222..38c7223da 100644 --- a/tests/integ_tests/worker_tests.rs +++ b/tests/integ_tests/worker_tests.rs @@ -7,15 +7,17 @@ use futures_util::FutureExt; use std::{ cell::Cell, sync::{ - Arc, + Arc, Mutex, atomic::{AtomicBool, Ordering::Relaxed}, }, time::Duration, }; use temporal_client::WorkflowOptions; -use temporal_sdk::{ActivityOptions, WfContext, interceptors::WorkerInterceptor}; +use temporal_sdk::{ + ActivityOptions, LocalActivityOptions, WfContext, interceptors::WorkerInterceptor, +}; use temporal_sdk_core::{ - CoreRuntime, ResourceBasedTuner, ResourceSlotOptions, init_worker, + CoreRuntime, ResourceBasedTuner, ResourceSlotOptions, TunerBuilder, init_worker, test_help::{ FakeWfResponses, MockPollCfg, ResponseType, TEST_Q, build_mock_pollers, drain_pollers_and_shutdown, hist_to_poll_resp, mock_worker, mock_worker_client, @@ -24,7 +26,11 @@ use temporal_sdk_core::{ use temporal_sdk_core_api::{ Worker, errors::WorkerValidationError, - worker::{PollerBehavior, WorkerConfigBuilder, WorkerVersioningStrategy}, + worker::{ + ActivitySlotKind, LocalActivitySlotKind, PollerBehavior, SlotInfo, SlotInfoTrait, + SlotMarkUsedContext, SlotReleaseContext, SlotReservationContext, SlotSupplier, + SlotSupplierPermit, WorkerConfigBuilder, WorkerVersioningStrategy, WorkflowSlotKind, + }, }; use temporal_sdk_core_protos::{ DEFAULT_WORKFLOW_TYPE, TestHistoryBuilder, canned_histories, @@ -571,3 +577,282 @@ async fn sets_build_id_from_wft_complete() { .unwrap(); worker.run_until_done().await.unwrap(); } + +#[derive(Debug, Clone)] +enum SlotEvent { + ReserveSlot { + slot_type: &'static str, + }, + TryReserveSlot { + slot_type: &'static str, + }, + MarkSlotUsed { + slot_type: &'static str, + is_sticky: bool, + workflow_type: Option, + activity_type: Option, + }, + ReleaseSlot { + slot_type: &'static str, + }, +} + +struct TrackingSlotSupplier { + events: Arc>>, + slot_type: &'static str, + _phantom: std::marker::PhantomData, +} + +impl TrackingSlotSupplier { + fn new(slot_type: &'static str) -> Self { + Self { + events: Arc::new(Mutex::new(Vec::new())), + slot_type, + _phantom: std::marker::PhantomData, + } + } + + fn get_events(&self) -> Vec { + self.events.lock().unwrap().clone() + } + + fn add_event(&self, event: SlotEvent) { + self.events.lock().unwrap().push(event); + } + + fn extract_slot_info(info: &dyn SlotInfoTrait) -> (bool, Option, Option) { + match info.downcast() { + SlotInfo::Workflow(w) => (w.is_sticky, Some(w.workflow_type.clone()), None), + SlotInfo::Activity(a) => (false, None, Some(a.activity_type.clone())), + SlotInfo::LocalActivity(a) => (false, None, Some(a.activity_type.clone())), + SlotInfo::Nexus(_) => (false, None, None), + } + } +} + +#[async_trait::async_trait] +impl SlotSupplier for TrackingSlotSupplier +where + SK: temporal_sdk_core_api::worker::SlotKind + Send + Sync, + SK::Info: SlotInfoTrait, +{ + type SlotKind = SK; + + async fn reserve_slot(&self, _ctx: &dyn SlotReservationContext) -> SlotSupplierPermit { + self.add_event(SlotEvent::ReserveSlot { + slot_type: self.slot_type, + }); + SlotSupplierPermit::with_user_data(()) + } + + fn try_reserve_slot(&self, _ctx: &dyn SlotReservationContext) -> Option { + self.add_event(SlotEvent::TryReserveSlot { + slot_type: self.slot_type, + }); + Some(SlotSupplierPermit::with_user_data(())) + } + + fn mark_slot_used(&self, ctx: &dyn SlotMarkUsedContext) { + let (is_sticky, workflow_type, activity_type) = Self::extract_slot_info(ctx.info()); + self.add_event(SlotEvent::MarkSlotUsed { + slot_type: self.slot_type, + is_sticky, + workflow_type, + activity_type, + }); + } + + fn release_slot(&self, _ctx: &dyn SlotReleaseContext) { + self.add_event(SlotEvent::ReleaseSlot { + slot_type: self.slot_type, + }); + } +} + +#[tokio::test] +async fn test_custom_slot_supplier_simple() { + let wf_supplier = Arc::new(TrackingSlotSupplier::::new("workflow")); + let activity_supplier = Arc::new(TrackingSlotSupplier::::new("activity")); + let local_activity_supplier = Arc::new(TrackingSlotSupplier::::new( + "local_activity", + )); + + let mut starter = CoreWfStarter::new("test_custom_slot_supplier_simple"); + starter.worker_config.clear_max_outstanding_opts(); + + let mut tb = TunerBuilder::default(); + tb.workflow_slot_supplier(wf_supplier.clone()); + tb.activity_slot_supplier(activity_supplier.clone()); + tb.local_activity_slot_supplier(local_activity_supplier.clone()); + starter.worker_config.tuner(Arc::new(tb.build())); + + let mut worker = starter.worker().await; + + worker.register_activity( + "SlotSupplierActivity", + |_: temporal_sdk::ActContext, _: ()| async move { Ok(()) }, + ); + worker.register_wf( + "SlotSupplierWorkflow".to_owned(), + |ctx: WfContext| async move { + let _result = ctx + .activity(ActivityOptions { + activity_type: "SlotSupplierActivity".to_string(), + start_to_close_timeout: Some(Duration::from_secs(10)), + ..Default::default() + }) + .await; + let _result = ctx + .local_activity(LocalActivityOptions { + activity_type: "SlotSupplierActivity".to_string(), + start_to_close_timeout: Some(Duration::from_secs(10)), + ..Default::default() + }) + .await; + Ok(().into()) + }, + ); + + worker + .submit_wf( + "test-wf".to_owned(), + "SlotSupplierWorkflow".to_owned(), + vec![], + Default::default(), + ) + .await + .unwrap(); + + worker.run_until_done().await.unwrap(); + + // Collect all events + let wf_events = wf_supplier.get_events(); + let activity_events = activity_supplier.get_events(); + let local_activity_events = local_activity_supplier.get_events(); + + // Verify workflow slot events - should have reserve, mark used, and release events + assert!(wf_events.iter().any( + |e| matches!(e, SlotEvent::ReserveSlot { slot_type, .. } if *slot_type == "workflow") + )); + assert!(wf_events.iter().any( + |e| matches!(e, SlotEvent::MarkSlotUsed { slot_type, .. } if *slot_type == "workflow") + )); + assert!( + wf_events + .iter() + .any(|e| matches!(e, SlotEvent::ReleaseSlot { slot_type } if *slot_type == "workflow")) + ); + + // Verify activity slot events - should have reserve, try_reserve (for eager execution), mark + // used, and release + assert!(activity_events.iter().any( + |e| matches!(e, SlotEvent::ReserveSlot { slot_type, .. } if *slot_type == "activity") + )); + assert!( + activity_events.iter().any( + |e| matches!(e, SlotEvent::TryReserveSlot { slot_type } if *slot_type == "activity") + ) + ); + assert!(activity_events.iter().any( + |e| matches!(e, SlotEvent::MarkSlotUsed { slot_type, .. } if *slot_type == "activity") + )); + assert!( + activity_events + .iter() + .any(|e| matches!(e, SlotEvent::ReleaseSlot { slot_type } if *slot_type == "activity")) + ); + + // Verify local activity slot events + assert!(local_activity_events.iter().any( + |e| matches!(e, SlotEvent::ReserveSlot { slot_type, .. } if *slot_type == "local_activity") + )); + assert!(local_activity_events.iter().any( + |e| matches!(e, SlotEvent::MarkSlotUsed { slot_type, .. } if *slot_type == "local_activity") + )); + assert!(local_activity_events.iter().any( + |e| matches!(e, SlotEvent::ReleaseSlot { slot_type } if *slot_type == "local_activity") + )); + + assert!( + wf_events + .iter() + .any(|e| matches!(e, SlotEvent::MarkSlotUsed { + slot_type: "workflow", + workflow_type: Some(wf_type), + .. + } if wf_type == "SlotSupplierWorkflow")) + ); + assert!( + activity_events + .iter() + .any(|e| matches!(e, SlotEvent::MarkSlotUsed { + slot_type: "activity", + activity_type: Some(act_type), + .. + } if act_type == "SlotSupplierActivity")) + ); + assert!( + local_activity_events + .iter() + .any(|e| matches!(e, SlotEvent::MarkSlotUsed { + slot_type: "local_activity", + activity_type: Some(act_type), + .. + } if act_type == "SlotSupplierActivity")) + ); + assert!(wf_events.iter().any(|e| matches!( + e, + SlotEvent::MarkSlotUsed { + slot_type: "workflow", + is_sticky: false, + .. + } + ))); + + // Verify that the number of reserve/try_reserve events matches the number of release events + let total_reserves = wf_events + .iter() + .filter(|e| { + matches!( + e, + SlotEvent::ReserveSlot { .. } | SlotEvent::TryReserveSlot { .. } + ) + }) + .count() + + activity_events + .iter() + .filter(|e| { + matches!( + e, + SlotEvent::ReserveSlot { .. } | SlotEvent::TryReserveSlot { .. } + ) + }) + .count() + + local_activity_events + .iter() + .filter(|e| { + matches!( + e, + SlotEvent::ReserveSlot { .. } | SlotEvent::TryReserveSlot { .. } + ) + }) + .count(); + + let total_releases = wf_events + .iter() + .filter(|e| matches!(e, SlotEvent::ReleaseSlot { .. })) + .count() + + activity_events + .iter() + .filter(|e| matches!(e, SlotEvent::ReleaseSlot { .. })) + .count() + + local_activity_events + .iter() + .filter(|e| matches!(e, SlotEvent::ReleaseSlot { .. })) + .count(); + + assert_eq!( + total_reserves, total_releases, + "Number of reserves should equal number of releases" + ); +}