-
Notifications
You must be signed in to change notification settings - Fork 7
/
flatten_flow.py
58 lines (43 loc) · 1.76 KB
/
flatten_flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from prefect import flow, get_run_logger, task
@task
def setup() -> str:
logger = get_run_logger()
logger.info("setup")
return "done"
@task
def fetch_batches() -> list[str]:
return [f"batch {i}" for i in range(10)]
@task
def repartition_batches(batch: str) -> list[list[str]]:
logger = get_run_logger()
batch_size = 4
materialised = [f"{batch}-{i}" for i in range(8)]
repartitioned = [materialised[i : i + batch_size] for i in range(0, len(materialised), batch_size)]
logger.info(f"Breaking result batch of size {len(materialised)} into {len(repartitioned)} batches")
return repartitioned
@task
def count_rows(batch: list[str]) -> int:
logger = get_run_logger()
logger.info(batch)
return len(batch)
@task
def summary(count: list[int]) -> tuple[int, int]:
logger = get_run_logger()
logger.info(f"{count} Num batches: {len(count)} Num rows: {sum(count)}")
return len(count), sum(count)
@flow
def flatten() -> None:
logger = get_run_logger()
# 10 batches
batches = fetch_batches.submit()
# split each batch into 2 sub batches each, with 4 row in each sub batch
balanced_batches = repartition_batches.map(batches) # type: ignore see https://github.com/PrefectHQ/prefect/issues/6922
# extract setup.submit() from the list expression so it only gets called once for all batches
setup_future = setup.submit()
count = [count_rows.map(batch, wait_for=[setup_future]) for batch in balanced_batches] # type: ignore
flatten_count = [batch for group in count for batch in group]
# summary: [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4] Num batches: 20 Num rows: 80
b, r = summary.submit(flatten_count).result()
logger.info(f"{b},{r}")
if __name__ == "__main__":
flatten()