diff --git a/workers/go/kitchensink/kitchen_sink.go b/workers/go/kitchensink/kitchen_sink.go index 5d7bb95..95f5351 100644 --- a/workers/go/kitchensink/kitchen_sink.go +++ b/workers/go/kitchensink/kitchen_sink.go @@ -128,9 +128,16 @@ func (ws *KSWorkflowState) handleAction( if child.WorkflowType != "" { childType = child.WorkflowType } - err := withAwaitableChoice(ctx, func(ctx workflow.Context) workflow.Future { + err := withAwaitableChoiceCustom(ctx, func(ctx workflow.Context) workflow.ChildWorkflowFuture { return workflow.ExecuteChildWorkflow(ctx, childType, child.GetInput()[0]) - }, child.AwaitableChoice) + }, child.AwaitableChoice, + func(ctx workflow.Context, fut workflow.ChildWorkflowFuture) error { + return fut.GetChildWorkflowExecution().Get(ctx, nil) + }, + func(ctx workflow.Context, fut workflow.ChildWorkflowFuture) error { + return fut.Get(ctx, nil) + }, + ) return nil, err } else if patch := action.GetSetPatchMarker(); patch != nil { if workflow.GetVersion(ctx, patch.GetPatchId(), workflow.DefaultVersion, 1) == 1 { @@ -209,10 +216,27 @@ func launchActivity(ctx workflow.Context, act *kitchensink.ExecuteActivityAction } } -func withAwaitableChoice( +func withAwaitableChoice[F workflow.Future]( ctx workflow.Context, - starter func(workflow.Context) workflow.Future, + starter func(workflow.Context) F, awaitChoice *kitchensink.AwaitableChoice, +) error { + return withAwaitableChoiceCustom(ctx, starter, awaitChoice, + func(ctx workflow.Context, fut F) error { + _ = workflow.Sleep(ctx, 1) + return nil + }, + func(ctx workflow.Context, fut F) error { + return fut.Get(ctx, nil) + }) +} + +func withAwaitableChoiceCustom[F workflow.Future]( + ctx workflow.Context, + starter func(workflow.Context) F, + awaitChoice *kitchensink.AwaitableChoice, + afterStartedWaiter func(workflow.Context, F) error, + afterCompletedWaiter func(workflow.Context, F) error, ) error { cancelCtx, cancel := workflow.WithCancel(ctx) fut := starter(cancelCtx) @@ -225,12 +249,15 @@ func withAwaitableChoice( didCancel = true err = fut.Get(ctx, nil) } else if awaitChoice.GetCancelAfterStarted() != nil { - _ = workflow.Sleep(ctx, 1) + err = afterStartedWaiter(ctx, fut) + if err != nil { + return err + } cancel() didCancel = true err = fut.Get(ctx, nil) } else if awaitChoice.GetCancelAfterCompleted() != nil { - res := fut.Get(ctx, nil) + res := afterCompletedWaiter(ctx, fut) cancel() err = res } else { diff --git a/workers/python/kitchen_sink.py b/workers/python/kitchen_sink.py index a1cfbd9..711f908 100644 --- a/workers/python/kitchen_sink.py +++ b/workers/python/kitchen_sink.py @@ -2,7 +2,7 @@ import asyncio from datetime import timedelta -from typing import Any, Coroutine, Optional, Union +from typing import Any, Awaitable, Callable, Coroutine, Optional, TypeVar, Union import temporalio.workflow from temporalio import exceptions, workflow @@ -13,7 +13,7 @@ SearchAttributeKey, SearchAttributeUpdate, ) -from temporalio.workflow import ActivityHandle +from temporalio.workflow import ActivityHandle, ChildWorkflowHandle from protos.kitchen_sink_pb2 import ( Action, @@ -129,6 +129,8 @@ async def handle_action(self, action: Action) -> Optional[Payload]: child, id=child_action.workflow_id, args=args ), child_action.awaitable_choice, + after_started_fn=wait_task_complete, + after_completed_fn=wait_child_wf_complete, ) elif action.HasField("set_patch_marker"): if action.set_patch_marker.deprecated: @@ -220,8 +222,24 @@ def launch_activity(execute_activity: ExecuteActivityAction) -> ActivityHandle: return activity_task +async def brief_wait(_: asyncio.Task): + await asyncio.sleep(0.001) + + +async def wait_task_complete(task: asyncio.Task): + await task + + +async def wait_child_wf_complete(task: asyncio.Task[ChildWorkflowHandle]): + res = await task + await res + + async def handle_awaitable_choice( - awaitable: Union[Coroutine, asyncio.Task], choice: AwaitableChoice + awaitable: Union[Coroutine, asyncio.Task], + choice: AwaitableChoice, + after_started_fn: Callable[[asyncio.Task], Awaitable] = brief_wait, + after_completed_fn: Callable[[asyncio.Task], Awaitable] = wait_task_complete, ): if isinstance(awaitable, asyncio.Task): task = awaitable @@ -239,12 +257,12 @@ async def handle_awaitable_choice( did_cancel = True await task elif choice.HasField("cancel_after_started"): - await asyncio.sleep(0.001) + await after_started_fn(task) task.cancel() did_cancel = True await task elif choice.HasField("cancel_after_completed"): - await task + await after_completed_fn(task) task.cancel() did_cancel = True else: