In [1]:
import os
import polars as pl
import urllib.request

from codetiming import Timer
from polars import col
from polars_tdigest import estimate_quantile, tdigest, tdigest_cast
from tqdm import tqdm

In [2]:
class DownloadProgressBar(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)


def download_url(url, output_path):
    with DownloadProgressBar(
        unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]
    ) as t:
        urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)

In [7]:
base_url = "https://d37ci6vzurychx.cloudfront.net/trip-data/"
dataset_files = ["yellow_tripdata_2024-02.parquet", "yellow_tripdata_2024-03.parquet"]
local_folder = "/tmp/"
datasets = []
tdigests = []
run_performance_test = True
numeric_col = "trip_distance"
max_size = 100

for dataset in dataset_files:
    local_file = f"{local_folder}{dataset}"
    if not os.path.exists(local_file):
        download_url(f"{base_url}{dataset}", f"{local_folder}{dataset}")
    df = pl.scan_parquet(local_file)
    datasets.append(df)

    print(f"Dataset {dataset} has {df.select(pl.len()).collect().item()} rows")

    query = df.select(tdigest(numeric_col, max_size))
    query_cast = df.select(tdigest_cast(numeric_col, max_size))
    if run_performance_test:
        for _ in range(5):
            with Timer(text="TDigest took: {milliseconds:.0f} ms"):
                query.collect()

        for _ in range(5):
            with Timer(text="TDigest with cast took: {milliseconds:.0f} ms"):
                query_cast.collect()

    tdigests.append(query.collect())

df = pl.concat(tdigests)
if run_performance_test:
    for _ in range(5):
        with Timer(text="Estimate median took: {milliseconds:.0f} ms"):
            df.select(estimate_quantile(numeric_col, 0.5))

print("Estimated median =", df.select(estimate_quantile(numeric_col, 0.5)).item())

Dataset yellow_tripdata_2024-02.parquet has 3007526 rows
TDigest took: 119 ms
TDigest took: 108 ms
TDigest took: 108 ms
TDigest took: 109 ms
TDigest took: 108 ms
TDigest with cast took: 109 ms
TDigest with cast took: 108 ms
TDigest with cast took: 108 ms
TDigest with cast took: 109 ms
TDigest with cast took: 107 ms
Dataset yellow_tripdata_2024-03.parquet has 3582628 rows
TDigest took: 136 ms
TDigest took: 116 ms
TDigest took: 113 ms
TDigest took: 112 ms
TDigest took: 112 ms
TDigest with cast took: 111 ms
TDigest with cast took: 111 ms
TDigest with cast took: 111 ms
TDigest with cast took: 112 ms
TDigest with cast took: 112 ms
Estimate median took: 0 ms
Estimate median took: 0 ms
Estimate median took: 0 ms
Estimate median took: 0 ms
Estimate median took: 0 ms
Estimated median = 1.7201926614756424


In [13]:
median_query = pl.concat(datasets).select(col(numeric_col).median())

for _ in range(5):
    with Timer(text="Median took: {milliseconds:.0f} ms"):
        median_query.collect()

print(
    "Median =",
    median_query.collect().item(),
)

Median took: 36 ms
Median took: 28 ms
Median took: 28 ms
Median took: 30 ms
Median took: 29 ms
Median = 1.71


In [7]:
for partition in datasets:
    for _ in range(5):
        with Timer(text="Median on partition took: {milliseconds:.0f} ms"):
            partition.select(col(numeric_col).median()).collect()

Median on partition took: 20 ms
Median on partition took: 20 ms
Median on partition took: 19 ms
Median on partition took: 19 ms
Median on partition took: 19 ms
Median on partition took: 21 ms
Median on partition took: 21 ms
Median on partition took: 20 ms
Median on partition took: 20 ms
Median on partition took: 20 ms
