In [7]:
from burr.core import State, action, ApplicationContext, ApplicationBuilder
from burr.core.parallelism import MapActions, RunnableGraph
from burr.core.graph import GraphBuilder

from typing import Generator, Dict, Any, List
import random
import time

def sleep_random(max_sleep: float=1):
    # return
    return time.sleep(random.random()*max_sleep)
    

In [8]:
@action(reads=[], writes=["num"])
def process_input(state: State, num: int) -> State:
    return state.update(num=num)
    
@action(reads=["num"], writes=["num"])
def increment(state: State, increment_by: int=1) -> State:
    sleep_random()
    return state.update(output=state["num"] + increment_by)

@action(reads=["num"], writes=["num"])
def multiply(state: State, multiply_by: int=2) -> State:
    sleep_random()
    return state.update(output=state["num"] * multiply_by)


@action(reads=["num"], writes=["nums"])
def explode(state: State, chunks: int) -> State:
    return state.update(nums=[random.randint(0, state["num"]) for _ in range(chunks)])


@action(reads=["nums"], writes=["num"])
def sum_nums(state: State) -> State:
    return state.update(final_result=sum(state["nums"]))

class Recursive(MapActions):
    def __init__(self, recursion_depth: int, max_recursion_depth: int):
        super(Recursive, self).__init__()
        self.recursion_depth = recursion_depth
        self.max_recursion_depth = max_recursion_depth

    def actions(self, state: State, context: ApplicationContext, inputs: Dict[str, Any]) -> Generator[RunnableGraph, None, None]:
        print("generating actions", state)
        for input_num in state["nums"]:
            print(self.recursion_depth, self.max_recursion_depth, input_num)
            yield create_graph(
                recursion_depth=self.recursion_depth,
                max_recursion_depth=self.max_recursion_depth,
                input_num=input_num
            )    

    def state(self, state: State, inputs: Dict[str, Any]) -> RunnableGraph:
        return State() # empty, we don't need it as it gets bound...

    @property
    def reads(self) -> List[str]:
        return ["nums"]

    @property
    def writes(self) -> List[str]:
        return ["nums"]

    def reduce(self, state: State, states: Generator[State, None, None]) -> State:
        nums = []
        for state_ in states:
            nums.append(state_["num"])
        return state.update(nums=nums)

In [9]:
def create_graph(recursion_depth: int, max_recursion_depth: int, input_num: int=None) -> RunnableGraph:
    should_recur = recursion_depth < max_recursion_depth
    graph = None
    if should_recur:
        g = (
            GraphBuilder()
            .with_actions(
                process_input.bind(num=input_num) if input_num is not None else process_input, 
                increment, 
                multiply, 
                explode.bind(chunks=3), 
                sum_nums,
                recur=Recursive(
                    recursion_depth=recursion_depth+1, 
                    max_recursion_depth=max_recursion_depth)
            )
            .with_transitions(
                ("process_input", "increment"),
                ("increment", "multiply"),
                ("multiply", "explode"),
                ("explode", "recur"),
                ("recur", "sum_nums")
            )
            .build()
        )
    else:
        print("not_recurring")
        g = (
            GraphBuilder()
            .with_actions(
                process_input.bind(num=40), 
                increment, 
                multiply, 
            )
            .with_transitions(
                ("process_input", "increment"),
                ("increment", "multiply")       
            )
            .build()
        )
    return RunnableGraph(
        graph=g,
        entrypoint="process_input",
        halt_after=["sum_nums"] if should_recur else ["multiply"]
    )

In [12]:
app = (
    ApplicationBuilder()
    .with_graph(
        create_graph(
            recursion_depth=0, 
            max_recursion_depth=6, 
            input_num=None).graph
    )
    .with_tracker(project="test_recursion")
    .with_entrypoint("process_input")
    .build()
)

In [13]:
action_, _, state = app.run(
    halt_after=["sum_nums"],
    inputs={"num" : 10, "increment_by" : 1}
)

generating actions {'num': 10, 'output': 20, 'nums': [3, 0, 6]}
1 6 3
1 6 0
1 6 6
generating actions {'num': 3, 'output': 6, 'nums': [3, 2, 2]}
2 6 3
2 6 2
2 6 2
generating actions {'num': 0, 'output': 0, 'nums': [0, 0, 0]}
2 6 0
2 6 0
2 6 0
generating actions {'num': 6, 'output': 12, 'nums': [6, 4, 5]}
2 6 6
2 6 4
2 6 5
generating actions {'num': 0, 'output': 0, 'nums': [0, 0, 0]}
3 6 0
3 6 0
3 6 0
generating actions {'num': 2, 'output': 4, 'nums': [0, 1, 2]}
3 6 0
3 6 1
3 6 2
generating actions {'num': 0, 'output': 0, 'nums': [0, 0, 0]}
4 6 0
4 6 0
4 6 0
generating actions {'num': 4, 'output': 8, 'nums': [3, 2, 3]}
3 6 3
3 6 2
3 6 3
generating actions {'num': 5, 'output': 10, 'nums': [0, 1, 1]}
3 6 0
3 6 1
3 6 1
generating actions {'num': 0, 'output': 0, 'nums': [0, 0, 0]}
3 6 0
3 6 0
3 6 0
generating actions {'num': 3, 'output': 6, 'nums': [1, 2, 3]}
3 6 1
3 6 2
3 6 3
generating actions {'num': 2, 'output': 4, 'nums': [1, 0, 2]}
3 6 1
3 6 0
3 6 2
generating actions {'num': 6, 'outpu

In [6]:
state

{'__SEQUENCE_ID': 5, 'num': 10, '__PRIOR_STEP': 'sum_nums', 'output': 20, 'nums': [9, 8, 9], 'final_result': 26}