Skip to content

Commit

Permalink
Make child workflow awaits more specific
Browse files Browse the repository at this point in the history
  • Loading branch information
Sushisource committed Dec 5, 2023
1 parent 7173e03 commit 71ec0e4
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 11 deletions.
39 changes: 33 additions & 6 deletions workers/go/kitchensink/kitchen_sink.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,16 @@ func (ws *KSWorkflowState) handleAction(
if child.WorkflowType != "" {
childType = child.WorkflowType
}
err := withAwaitableChoice(ctx, func(ctx workflow.Context) workflow.Future {
err := withAwaitableChoiceCustom(ctx, func(ctx workflow.Context) workflow.ChildWorkflowFuture {
return workflow.ExecuteChildWorkflow(ctx, childType, child.GetInput()[0])
}, child.AwaitableChoice)
}, child.AwaitableChoice,
func(ctx workflow.Context, fut workflow.ChildWorkflowFuture) error {
return fut.GetChildWorkflowExecution().Get(ctx, nil)
},
func(ctx workflow.Context, fut workflow.ChildWorkflowFuture) error {
return fut.Get(ctx, nil)
},
)
return nil, err
} else if patch := action.GetSetPatchMarker(); patch != nil {
if workflow.GetVersion(ctx, patch.GetPatchId(), workflow.DefaultVersion, 1) == 1 {
Expand Down Expand Up @@ -209,10 +216,27 @@ func launchActivity(ctx workflow.Context, act *kitchensink.ExecuteActivityAction
}
}

func withAwaitableChoice(
func withAwaitableChoice[F workflow.Future](
ctx workflow.Context,
starter func(workflow.Context) workflow.Future,
starter func(workflow.Context) F,
awaitChoice *kitchensink.AwaitableChoice,
) error {
return withAwaitableChoiceCustom(ctx, starter, awaitChoice,
func(ctx workflow.Context, fut F) error {
_ = workflow.Sleep(ctx, 1)
return nil
},
func(ctx workflow.Context, fut F) error {
return fut.Get(ctx, nil)
})
}

func withAwaitableChoiceCustom[F workflow.Future](
ctx workflow.Context,
starter func(workflow.Context) F,
awaitChoice *kitchensink.AwaitableChoice,
afterStartedWaiter func(workflow.Context, F) error,
afterCompletedWaiter func(workflow.Context, F) error,
) error {
cancelCtx, cancel := workflow.WithCancel(ctx)
fut := starter(cancelCtx)
Expand All @@ -225,12 +249,15 @@ func withAwaitableChoice(
didCancel = true
err = fut.Get(ctx, nil)
} else if awaitChoice.GetCancelAfterStarted() != nil {
_ = workflow.Sleep(ctx, 1)
err = afterStartedWaiter(ctx, fut)
if err != nil {
return err
}
cancel()
didCancel = true
err = fut.Get(ctx, nil)
} else if awaitChoice.GetCancelAfterCompleted() != nil {
res := fut.Get(ctx, nil)
res := afterCompletedWaiter(ctx, fut)
cancel()
err = res
} else {
Expand Down
28 changes: 23 additions & 5 deletions workers/python/kitchen_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
from datetime import timedelta
from typing import Any, Coroutine, Optional, Union
from typing import Any, Awaitable, Callable, Coroutine, Optional, TypeVar, Union

import temporalio.workflow
from temporalio import exceptions, workflow
Expand All @@ -13,7 +13,7 @@
SearchAttributeKey,
SearchAttributeUpdate,
)
from temporalio.workflow import ActivityHandle
from temporalio.workflow import ActivityHandle, ChildWorkflowHandle

from protos.kitchen_sink_pb2 import (
Action,
Expand Down Expand Up @@ -129,6 +129,8 @@ async def handle_action(self, action: Action) -> Optional[Payload]:
child, id=child_action.workflow_id, args=args
),
child_action.awaitable_choice,
after_started_fn=wait_task_complete,
after_completed_fn=wait_child_wf_complete,
)
elif action.HasField("set_patch_marker"):
if action.set_patch_marker.deprecated:
Expand Down Expand Up @@ -220,8 +222,24 @@ def launch_activity(execute_activity: ExecuteActivityAction) -> ActivityHandle:
return activity_task


async def brief_wait(_: asyncio.Task):
await asyncio.sleep(0.001)


async def wait_task_complete(task: asyncio.Task):
await task


async def wait_child_wf_complete(task: asyncio.Task[ChildWorkflowHandle]):
res = await task
await res


async def handle_awaitable_choice(
awaitable: Union[Coroutine, asyncio.Task], choice: AwaitableChoice
awaitable: Union[Coroutine, asyncio.Task],
choice: AwaitableChoice,
after_started_fn: Callable[[asyncio.Task], Awaitable] = brief_wait,
after_completed_fn: Callable[[asyncio.Task], Awaitable] = wait_task_complete,
):
if isinstance(awaitable, asyncio.Task):
task = awaitable
Expand All @@ -239,12 +257,12 @@ async def handle_awaitable_choice(
did_cancel = True
await task
elif choice.HasField("cancel_after_started"):
await asyncio.sleep(0.001)
await after_started_fn(task)
task.cancel()
did_cancel = True
await task
elif choice.HasField("cancel_after_completed"):
await task
await after_completed_fn(task)
task.cancel()
did_cancel = True
else:
Expand Down

0 comments on commit 71ec0e4

Please sign in to comment.