In [None]:
import asyncio
import logging
import time
from typing import AsyncGenerator, AsyncIterable, Callable, Coroutine, Generator, List

# Configure logging
logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger(__name__)

TASK = int


async def produce() -> AsyncIterable[TASK]:
    """Generates tasks to be processed."""
    for i in range(10):  # Replace with the actual number of tasks
        # await asyncio.sleep(i)  # Simulate I/O or computation delay
        yield i
    log.info("Produced tasks")


async def worker_spread(task: TASK) -> AsyncGenerator[TASK, None]:
    """Processes a task at a given worker stage."""
    for i in range(task):
        await asyncio.sleep(0.2 * i)  # Simulate processing time
        yield task + i
    log.info(f"Spread tasks: {task}")


async def worker_apply(task: TASK):
    """Processes a task at a given worker stage."""
    await asyncio.sleep(2)  # Simulate processing time
    yield task**2
    log.info(f"Apply task: {task}")


async def consume(task: TASK) -> None:
    """Consumes the processed task."""
    log.info(f"Consumed task: {task}")


ProduceStage = Callable[[], AsyncIterable[TASK]]
WorkStage = Callable[[TASK], AsyncGenerator[TASK, None]]


class Pipeline:
    queus: asyncio.Queue[TASK]

    def __init__(
        self, produce_stage: ProduceStage, work_stages: list[WorkStage], consume_stage
    ):
        self.queues = [asyncio.Queue() for _ in range(len(work_stages) + 1)]
        self.produce_stage = produce_stage
        self.work_stages = work_stages
        self.consume_stage = consume_stage
        self.stop = object()

    async def run(self):
        await asyncio.gather(self.produce(), self.workers(), self.consume())

    async def produce(self):
        log.debug("Producing tasks")
        t_start = time.time()
        async for task in self.produce_stage():
            await self.queues[0].put(task)
        await self.queues[0].put(self.stop)
        log.debug(f"Producing tasks took {time.time() - t_start}s")

    async def workers(self):
        log.info("Doing work")
        t_start = time.time()
        coros = []
        for i, work_stage in enumerate(self.work_stages):
            coros.append(self.worker(i, work_stage))
        await asyncio.gather(*coros)
        log.info(f"Doing workers took {time.time() - t_start}s")

    async def worker(self, i: int, work_stage: WorkStage):
        coros = []

        # All workers in this stage
        async def worker_coro(task: TASK, work_stage: WorkStage):
            # 1 tasks to many next tasks
            async for next_task in work_stage(task):
                await self.queues[i + 1].put(next_task)

        while True:
            prev_task = await self.queues[i].get()  # from queue or wait until new item
            if prev_task is self.stop:
                break
            coros.append(asyncio.create_task(worker_coro(prev_task, work_stage)))

        await asyncio.gather(*coros)
        await self.queues[i + 1].put(self.stop)

    async def consume(self):
        coros = []
        while True:
            task = await self.queues[-1].get()
            if task is self.stop:
                break
            coros.append(asyncio.create_task(self.consume_stage(task)))

        return await asyncio.gather(*coros)


# Create the pipeline
pipeline = Pipeline(
    produce_stage=produce,
    work_stages=[worker_spread, worker_apply],
    consume_stage=consume,
)

# Run the pipeline
await pipeline.run()