Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Very long reading or hangs time for lazy slice collection in worker process #14358

Open
2 tasks done
tchaton opened this issue Feb 8, 2024 · 4 comments
Open
2 tasks done
Labels
bug Something isn't working needs triage Awaiting prioritization by a maintainer python Related to Python Polars

Comments

@tchaton
Copy link

tchaton commented Feb 8, 2024

Hey @ritchie46, @stinodego, @alexander-beedie

Checks

  • I have checked that this issue has not already been reported.
  • I have confirmed this bug exists on the latest version of Polars.

Reproducible example

I am trying to distribute reading parquet files across workers and it seems polars loading time either increases or hangs.

It seems this a common issue as Img2Dataset got around lazy loading but actually re-generating the shards: https://github.com/rom1504/img2dataset/blob/main/img2dataset/reader.py#L189

Here is a reproducible script:

import os
import polars as pol
from multiprocessing import Process
from dataclasses import dataclass
import numpy as np
from time import time

input_dir = "the-eye.eu/public/AI/cah/laion400m-met-release/laion400m-meta" # 36 parquet files of 1.7GB
parquet_files = [os.path.join(root, f) for root, _, filenames in os.walk(input_dir) for f in filenames if f.endswith(".parquet")]


@dataclass
class ParquetSlice:
    """Keep track of a parquet file slice with its filepath, start and end."""
    filepath: str
    start: int
    end: int

def get_num_rows(path):
    import pyarrow.dataset as ds
    df = ds.dataset(path).scanner()
    return df.count_rows()

    # FIXED: There is a bug in polars. This leads to read_parquet to hang.
    import polars as pol
    df = pol.scan_parquet(path)
    num_rows = df.select(pol.len()).collect().item()
    return num_rows


num_rows_per_parquets = [get_num_rows(parquet_file) for parquet_file in parquet_files]
num_rows = sum(num_rows_per_parquets)
num_workers = os.cpu_count()
batch_size = 2048
num_rows_per_worker = num_rows // num_workers

workers_data = {}
worker_idx = 0
worker_rows = 0

for parquet_file, num_rows in zip(parquet_files, num_rows_per_parquets):
    for row_start in range(0, num_rows, batch_size):
        row_end = min(row_start + batch_size, num_rows) 
        if worker_rows >= num_rows_per_worker:
            worker_idx += 1
            worker_rows = 0
        worker_rows += row_end -  row_start
        if worker_idx not in workers_data:
            workers_data[worker_idx] = []
        workers_data[worker_idx].append(ParquetSlice(parquet_file, start=row_start, end=row_end))


assert len(workers_data) == num_workers

def target(parquet_slices):
    for parquet_slice in parquet_slices:
        t0 = time()
        df = pol.scan_parquet(parquet_slice.filepath).slice(parquet_slice.start, parquet_slice.end).collect()
        print(time() - t0)

workers = [Process(target=target, args=(workers_data[worker_idx], )) for worker_idx in range(num_workers)]
for worker in workers:
    worker.start()


for worker in workers:
    worker.join()

Log output

Here are the logs. As you can observe, the time just increases insanely.

~ python read_from_slice.py
0.23226261138916016
0.3072049617767334
0.3433570861816406
0.3518714904785156
0.3415353298187256
0.442699670791626
0.5451216697692871
0.6135847568511963
1.864415168762207
68.32183408737183
68.40467309951782
69.0478003025055
68.78134942054749
69.16423535346985
69.20430397987366
69.25753045082092
69.44716906547546
69.5466377735138
69.36016774177551
69.53426384925842
69.54183912277222
69.09911513328552
69.74612975120544
69.74536037445068
69.80545735359192
69.90860056877136
69.9910020828247
69.64902091026306
71.0997965335846
71.39023017883301
71.73733067512512
71.76467275619507
71.77034497261047
71.53651094436646
71.90763735771179
71.75954151153564
72.58842492103577
72.73675751686096
71.13086915016174
72.96966505050659
74.83749222755432
57.01331090927124
58.48121666908264
58.511826038360596
58.79772067070007
58.977129220962524
60.27041792869568
59.17792892456055
59.18590831756592
58.8665874004364
57.32769322395325
59.9551956653595
59.113983392715454
59.499932050704956
59.80842208862305
60.18164610862732
57.377845287323
58.38949942588806
60.073336601257324
57.579628705978394
59.76506447792053
59.55999445915222
57.7710235118866
57.84200859069824
57.831613540649414
54.57873201370239
58.27272343635559
60.447746992111206
60.56101703643799
56.32220506668091
56.74964261054993
56.98266530036926
57.15247344970703
59.272478103637695
60.10847592353821
59.91492676734924
60.23894929885864
59.58669590950012
59.57875299453735
60.92053747177124
60.933486223220825
61.29428195953369
61.034302949905396
61.01777148246765
61.063395261764526
61.285802364349365
61.302069664001465
61.77613568305969
61.602858781814575
61.47451376914978
62.16115379333496
62.12472677230835
62.90268111228943
63.374130964279175
62.88069176673889
62.86286687850952
62.864922761917114
62.95443248748779
64.20906710624695
63.75752806663513
64.0263454914093
64.09385251998901
65.00613903999329
65.11042881011963
65.4911162853241
53.45730209350586
56.89541268348694
55.55260396003723
55.398075103759766
57.848676681518555
57.6532621383667
58.47229194641113
58.632524728775024
58.474345684051514
57.96194338798523
56.618701457977295
58.633341550827026
58.78440833091736
58.292826414108276
57.83442425727844
58.43972849845886
56.97259163856506
57.188791036605835
57.28998398780823
59.49136400222778
57.94972801208496
58.27231240272522
60.66723847389221
59.365389823913574
57.62423133850098
58.830878019332886
59.62562656402588
59.9276008605957
59.81395125389099
62.17665123939514
64.58833956718445

In comparison with Pyarrow. Still not great but much better.

def target(parquet_slices):
    for parquet_slice in parquet_slices:
        t0 = time()
        df = ds.dataset(parquet_slice.filepath).scanner()

        df = df.take([parquet_slice.start, parquet_slice.end])
        print(time() - t0)
~ python read_from_slice.py
2.900271415710449
0.3804919719696045
3.3455705642700195
3.346801996231079
0.4026670455932617
0.44400691986083984
4.9205238819122314
1.588649034500122
5.469159364700317
5.494945526123047
0.5918009281158447
1.909053087234497
5.712141275405884
6.022623062133789
6.006503105163574
6.318320035934448
6.310782432556152
6.941025495529175
3.9690725803375244
7.5848777294158936
7.867754697799683
8.124830484390259
8.758509635925293
3.3831429481506348
3.7338757514953613
9.204360008239746
9.57539176940918
9.908674478530884
10.119199514389038
4.726012229919434
10.352258443832397
10.37306261062622
4.263919115066528
10.683980226516724
10.841665744781494
10.799356698989868
10.965590476989746
11.203482627868652
4.302994966506958
11.754274606704712
5.534180164337158
6.171769618988037
12.366551160812378
12.43632435798645
12.517409086227417
6.545653343200684
12.77980661392212
7.1889026165008545
7.476249694824219
7.609501361846924
7.133400917053223
6.986341714859009
6.956153869628906
8.623024463653564
10.989533185958862
6.136003494262695
9.670828342437744
10.059322118759155
10.663979053497314
13.220695972442627
10.103495597839355
12.12840485572815
10.318628549575806
10.898629188537598
11.721484661102295
11.351822853088379
11.219998836517334
9.451461553573608
12.116554260253906
10.976173877716064
10.738194704055786
13.604996681213379
12.722578763961792
27.70877242088318
26.76956033706665

Issue description

Doing a partial lazy loading of parquet slice is a key component to distributed data processing across workers and machines.

Additionally, if I get the length using polars instead of pyarrow, it seems to hang. This might be a second bug.

Expected behavior

This is fast and reading time is constant.

Installed versions

polars                    0.20.6
@tchaton tchaton added bug Something isn't working needs triage Awaiting prioritization by a maintainer python Related to Python Polars labels Feb 8, 2024
@tchaton tchaton changed the title Very long reading or hangs time for slice in processes Very long reading or hangs time for lazy slice collection in worker process Feb 8, 2024
@tchaton
Copy link
Author

tchaton commented Feb 9, 2024

Hey Team,

I would appreciate to get an answer :) This is blocking me. I am strongly thinking of dropping polars as a possible backend to PyTorch Lightning parquet backend.

Best,
T.C

@cmdlineluser
Copy link
Contributor

I think at the very least you need to use multiprocessing.get_context("spawn").Process instead of Process as mentioned in the guide from the previous issue.

(I don't have much knowledge on the topic, so I'm not sure if this will fix things.)

@ritchie46
Copy link
Member

Read this: https://docs.pola.rs/user-guide/misc/multiprocessing/

It is not something we can fix. Python multiprocessing is badly designed and assumes current processes don't have any state in mutexes which is very, very unsafe assumption.

@ritchie46
Copy link
Member

ritchie46 commented Feb 10, 2024

To quote the python mutliprocessing library:

The parent process uses [os.fork()](https://docs.python.org/3/library/os.html#os.fork) to fork the Python interpreter. 
The child process, when it begins, is effectively identical to the parent process. 
All resources of the parent are inherited by the child process. 
Note that safely forking a multithreaded process is problematic.

In other words, use spawn.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Awaiting prioritization by a maintainer python Related to Python Polars
Projects
None yet
Development

No branches or pull requests

3 participants