Shows a 40x+ performance improvement over duckdb for our joinall op.

Example output for make_tables(5, 100000, 10):
```
Test data construction time: 0.023s

Total duck time: 3.076s

Hybrid join time: 0.120s
Hybrid zip time 0.030s
Total hybrid time: 0.156s

Pyarrow join time: 0.024s
Pyarrow zip time: 0.043s
Total pyarrow time: 0.069s

Hybrid speedup over duck: 19.74x
Pyarrow speedup over duck: 44.84x
```

For very large inputs I saw Duck and Hybrid converging to about the same time, with the pyarrow version achieving a speed up of ~150x. My guess is the duck join implementation is less memory efficient and starts to swap at some point.

In [None]:
import time

import duckdb
import numpy as np
import pyarrow as pa

In [None]:
# Generate test tables


def make_tables(n_tables, n_rows, n_cols):
    base_col = np.arange(n_rows, dtype="int64")
    tables = []
    for t in range(n_tables):
        cols = []
        for c in range(n_cols):
            if c == 0:
                cols.append(base_col)
            else:
                col_start = c * n_rows + t * n_cols * n_rows
                col_vals = base_col + col_start
                cols.append(
                    pa.StructArray.from_arrays(
                        [-col_vals, col_vals], names=["tag", "val"]
                    )
                )
        tables.append(
            pa.Table.from_arrays(cols, names=[f"c{c}" for c in range(n_cols)])
        )
    return tables

In [None]:
# joinall implemented fully using duckdb. This is the baseline.


def duck_query(tables):
    n_tables = len(tables)
    n_cols = tables[0].num_columns
    selects = []
    for col in range(n_cols):
        select_cols = ", ".join(f"t{t}.c{col}" for t in range(n_tables))
        selects.append(f"list_value({select_cols}) as c{col}")
    select = ", ".join(selects)
    joins = []
    for t in range(1, n_tables):
        joins.append(f"inner join t{t} on t0.c0 = t{t}.c0")
    join = " ".join(joins)
    # Must sort to get stable results for comparison, but it crashes the
    # notebook kernel when the tables are over a certain size!
    return f"SELECT {select} FROM t0 {join}"
    # return f'SELECT {select} FROM t0 {join} ORDER BY t0.c0 ASC'


def duck_join(tables):
    conn = duckdb.connect()
    for i, t in enumerate(tables):
        conn.register(f"t{i}", t)
    query = duck_query(tables)
    return conn.execute(query).arrow()


# tables = make_tables(2, 100, 3)
# old_res = duck_join(tables)
# old_res['c0'].to_pylist()

In [None]:
def arrow_zip(*arrs):
    n_arrs = len(arrs)
    output_len = min(len(a) for a in arrs)
    array_indexes = np.tile(np.arange(n_arrs, dtype="int64"), output_len)
    item_indexes = np.floor(np.arange(0, output_len, 1.0 / n_arrs)).astype("int64")
    indexes = item_indexes + array_indexes * output_len
    concatted = pa.concat_arrays(arrs)
    interleaved = concatted.take(indexes)
    offsets = np.arange(0, len(interleaved) + len(arrs), len(arrs), dtype="int64")
    return pa.ListArray.from_arrays(offsets, interleaved)


def new_duck_query(tables):
    n_tables = len(tables)
    n_cols = tables[0].num_columns
    selects = []
    for col in range(n_cols):
        for t in range(n_tables):
            selects.append(f"t{t}.c{col} as t{t}c{col}")
    select = ", ".join(selects)
    joins = []
    for t in range(1, n_tables):
        joins.append(f"inner join t{t} on t0.c0 = t{t}.c0")
    join = " ".join(joins)
    # Must sort to get stable results for comparison, but it crashes the
    # notebook kernel when the tables are over a certain size!
    return f"SELECT {select} FROM t0 {join}"
    # return f'SELECT {select} FROM t0 {join} ORDER BY t0.c0 ASC'


# This implements the join part using duckdb, but the zip step using
# pyarrow operations.
def hybrid_join(tables):
    n_tables = len(tables)
    n_cols = tables[0].num_columns

    # join with duckdb
    start_time = time.time()
    conn = duckdb.connect()
    for i, t in enumerate(tables):
        conn.register(f"t{i}", t)
    query = new_duck_query(tables)
    joined = conn.execute(query).arrow()
    print("Hybrid join time: %.03fs" % (time.time() - start_time))

    # zip cols
    start_time = time.time()
    zipped_cols = []
    for c in range(n_cols):
        col_cols = [joined[f"t{t}c{c}"].combine_chunks() for t in range(n_tables)]
        zipped_cols.append(arrow_zip(*col_cols))
    result = pa.Table.from_arrays(zipped_cols, names=[f"c{c}" for c in range(n_cols)])
    print("Hybrid zip time %.03fs" % (time.time() - start_time))
    return result


# This implements everything using pyarrow operations. (fastest version)
def pyarrow_join(tables):
    n_tables = len(tables)
    n_cols = tables[0].num_columns

    # join
    start_time = time.time()
    table0 = tables[0]
    joined = pa.Table.from_arrays(
        [table0["c0"], np.arange(len(table0), dtype="int64")],
        names=["join", "index_t0"],
    )
    for i, t in enumerate(tables[1:]):
        other = pa.Table.from_arrays(
            [t["c0"], np.arange(len(t), dtype="int64")], names=["join", f"index_t{i+1}"]
        )
        joined = joined.join(other, ["join"], join_type="inner", use_threads=False)
    print("Pyarrow join time: %.03fs" % (time.time() - start_time))

    # zip
    zipped_cols = []
    start_time = time.time()
    for c in range(n_cols):
        col_cols = []
        for t, table in enumerate(tables):
            t_indexes = joined[f"index_t{t}"]
            col_cols.append(table[f"c{c}"].take(t_indexes).combine_chunks())
        zipped_cols.append(arrow_zip(*col_cols))
    print("Pyarrow zip time: %.03fs" % (time.time() - start_time))

    result = pa.Table.from_arrays(zipped_cols, names=[f"c{c}" for c in range(n_cols)])
    return result


# new_res

In [None]:
start_time = time.time()
tables = make_tables(5, 100000, 10)
print("Test data construction time: %.03fs" % (time.time() - start_time))
print()

start_time = time.time()
old_res = duck_join(tables)
old_time = time.time() - start_time
print("Total duck time: %.03fs" % old_time)
print()

start_time = time.time()
hybrid_res = hybrid_join(tables)
hybrid_time = time.time() - start_time
print("Total hybrid time: %.03fs" % hybrid_time)
print()

start_time = time.time()
pyarrow_res = pyarrow_join(tables)
pyarrow_time = time.time() - start_time
print("Total pyarrow time: %.03fs" % pyarrow_time)
print()

print("Hybrid speedup over duck: %.02fx" % (old_time / hybrid_time))
print("Pyarrow speedup over duck: %.02fx" % (old_time / pyarrow_time))

In [None]:
# old_res.to_pylist() == new_res.to_pylist()

In [None]:
# old_res.to_pylist() == new2_res.to_pylist()