# Hierarchical algorithm optimisation

## Data generation

In [3]:
import rustworkx as rx
from collections import Counter

def verify_components(table) -> dict:
    """
    Fast verification of connected components using rustworkx.
    
    Args:
        table: PyArrow table with 'left', 'right' columns
    
    Returns:
        dictionary containing basic component statistics
    """
    # Create graph directly from arrays
    graph = rx.PyDiGraph()
    
    # Add all unique nodes at once
    unique_nodes = set(table['left'].to_numpy()) | set(table['right'].to_numpy())
    graph.add_nodes_from(range(len(unique_nodes)))
    
    # Create node mapping and edges in one pass
    node_to_idx = {node: idx for idx, node in enumerate(unique_nodes)}
    edges = [(node_to_idx[left], node_to_idx[right], prob) 
            for left, right, prob in zip(table['left'].to_numpy(), 
                                       table['right'].to_numpy(),
                                       table['probability'].to_numpy())]
    
    # Add all edges at once
    graph.add_edges_from(edges)
    
    # Get components and their sizes
    components = rx.weakly_connected_components(graph)
    component_sizes = Counter(len(component) for component in components)
    
    return {
        'num_components': len(components),
        'total_nodes': len(unique_nodes),
        'total_edges': len(edges),
        'component_sizes': component_sizes,
        'min_component_size': min(component_sizes.keys()),
        'max_component_size': max(component_sizes.keys())
    }

In [None]:
def calculate_max_possible_edges(n_nodes: int, num_components: int) -> int:
    """
    Calculate the maximum possible number of edges given n nodes split into k components.
    
    Args:
        n_nodes: Total number of nodes
        num_components: Number of components to split into
        
    Returns:
        Maximum possible number of edges
    """
    nodes_per_component = n_nodes // num_components
    max_edges_per_component = nodes_per_component * nodes_per_component  # Complete bipartite graph
    return max_edges_per_component * num_components


In [None]:
import numpy as np
import pyarrow as pa
import rustworkx as rx
from typing import List, Tuple
from decimal import Decimal

def split_values_into_components(values: List[int], num_components: int) -> List[np.ndarray]:
    """
    Split values into non-overlapping groups for each component.
    
    Args:
        values: List of values to split
        num_components: Number of components to create
        
    Returns:
        List of arrays, one for each component
    """
    values = np.array(values)
    np.random.shuffle(values)
    return np.array_split(values, num_components)


def generate_arrow_data(
    left_values: List[int],
    right_values: List[int],
    prob_range: Tuple[float, float],
    num_components: int,
    total_rows: int
) -> pa.Table:
    """
    Generate dummy arrow data with guaranteed isolated components.
    
    Args:
        left_values: List of integers to use for left column
        right_values: List of integers to use for right column
        prob_range: Tuple of (min_prob, max_prob) to constrain probabilities
        num_components: Number of distinct connected components to generate
        total_rows: Total number of rows to generate
    
    Returns:
        PyArrow Table with 'left', 'right', and 'probability' columns
    """
    if len(left_values) < 2 or len(right_values) < 2:
        raise ValueError("Need at least 2 possible values for both left and right")
    if num_components > min(len(left_values), len(right_values)):
        raise ValueError("Cannot have more components than minimum of left/right values")
    
    # Calculate maximum possible edges
    min_nodes = min(len(left_values), len(right_values))
    max_possible_edges = calculate_max_possible_edges(min_nodes, num_components)
    
    if total_rows > max_possible_edges:
        raise ValueError(
            f"Cannot generate {total_rows:,} edges with {num_components:,} components. "
            f"Maximum possible edges is {max_possible_edges:,} given {min_nodes:,} nodes. "
            "Either increase the number of nodes, decrease the number of components, "
            "or decrease the total edges requested."
        )
    
    # Convert probability range to integers (60-80 for 0.60-0.80)
    prob_min = int(prob_range[0] * 100)
    prob_max = int(prob_range[1] * 100)
    
    # Split values into completely separate groups for each component
    left_components = split_values_into_components(left_values, num_components)
    right_components = split_values_into_components(right_values, num_components)
    
    # Calculate base number of edges per component
    base_edges_per_component = total_rows // num_components
    remaining_edges = total_rows % num_components
    
    all_edges = []
    
    # Generate edges for each component
    for comp_idx in range(num_components):
        comp_left_values = left_components[comp_idx]
        comp_right_values = right_components[comp_idx]
        
        # Calculate edges for this component
        edges_in_component = base_edges_per_component
        if comp_idx < remaining_edges:  # Distribute remaining edges
            edges_in_component += 1
            
        # Ensure basic connectivity within the component
        base_edges = []
        
        # Create a spanning tree-like structure
        for i in range(len(comp_left_values)):
            base_edges.append((
                comp_left_values[i],
                comp_right_values[i % len(comp_right_values)],
                np.random.randint(prob_min, prob_max + 1)
            ))
        
        # Generate remaining random edges strictly within this component
        remaining_edges = edges_in_component - len(base_edges)
        if remaining_edges > 0:
            random_lefts = np.random.choice(comp_left_values, size=remaining_edges)
            random_rights = np.random.choice(comp_right_values, size=remaining_edges)
            random_probs = np.random.randint(prob_min, prob_max + 1, size=remaining_edges)
            
            component_edges = base_edges + list(zip(random_lefts, random_rights, random_probs))
        else:
            component_edges = base_edges
            
        all_edges.extend(component_edges)
    
    # Convert to arrays
    lefts, rights, probs = zip(*all_edges)
    
    # Create PyArrow arrays
    left_array = pa.array(lefts, type=pa.int64())
    right_array = pa.array(rights, type=pa.int64())
    decimal_probs = [Decimal(str(p/100)) for p in probs]
    prob_array = pa.array(decimal_probs, type=pa.decimal128(precision=3, scale=2))
    
    return pa.table([left_array, right_array, prob_array],
                   names=['left', 'right', 'probability'])



In [413]:
left_values = list(range(10_000))
right_values = list(range(10_000, 20_000))
prob_range = (0.6, 0.8)
num_components = 10
total_rows = 1_000_000

table = generate_arrow_data(
    left_values=left_values,
    right_values=right_values,
    prob_range=prob_range,
    num_components=num_components,
    total_rows=total_rows
)

results = verify_components(table)
print(f"Number of components found: {results['num_components']}")
print(f"Total nodes: {results['total_nodes']}")
print(f"Total edges: {results['total_edges']}")
print("\nComponent sizes:")
for size, count in sorted(results['component_sizes'].items()):
    print(f"Size {size}: {count} components")

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x10d274c90>>
Traceback (most recent call last):
  File "/Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


Number of components found: 10
Total nodes: 20000
Total edges: 1000009

Component sizes:
Size 2000: 10 components


In [None]:
left_values = list(range(int(2e7)))
right_values = list(range(int(2e7), int(4e7)))
prob_range = (0.7, 1.0)
num_components = 200_000
total_rows = int(1e8)

table = generate_arrow_data(
    left_values=left_values,
    right_values=right_values,
    prob_range=prob_range,
    num_components=num_components,
    total_rows=total_rows
)

In [None]:
results = verify_components(table)
print(f"Number of components found: {results['num_components']}")
print(f"Total nodes: {results['total_nodes']}")
print(f"Total edges: {results['total_edges']}")
print("\nComponent sizes:")
for size, count in sorted(results['component_sizes'].items()):
    print(f"Size {size}: {count} components")

Number of components found: 206763
Total nodes: 40000000
Total edges: 100000400

Component sizes:
Size 2: 6755 components
Size 4: 8 components
Size 194: 1 components
Size 196: 124 components
Size 198: 6520 components
Size 200: 193355 components


Number of components found: 206763
* Total nodes: 40000000
* Total edges: 100000400

Component sizes:
* Size 2: 6755 components
* Size 4: 8 components
* Size 194: 1 components
* Size 196: 124 components
* Size 198: 6520 components
* Size 200: 193355 components

In [None]:
from pathlib import Path
import pyarrow.parquet as pq

pq.write_table(table, Path.cwd() / 'hierarchical_cc200k.parquet')

In [None]:
left_values = list(range(int(2e5)))
right_values = list(range(int(2e5), int(4e5)))
prob_range = (0.7, 1.0)
num_components = 2_000
total_rows = int(1e6)

table2 = generate_arrow_data(
    left_values=left_values,
    right_values=right_values,
    prob_range=prob_range,
    num_components=num_components,
    total_rows=total_rows
)

In [None]:
results2 = verify_components(table2)
print(f"Number of components found: {results2['num_components']:,}")
print(f"Total nodes: {results2['total_nodes']:,}")
print(f"Total edges: {results2['total_edges']:,}")
print("\nComponent sizes:")
for size, count in sorted(results2['component_sizes'].items()):
    print(f"Size {size:,}: {count:,} components")

Number of components found: 2,080
Total nodes: 400,000
Total edges: 1,000,400

Component sizes:
Size 2: 80 components
Size 196: 1 components
Size 198: 78 components
Size 200: 1,921 components


Number of components found: 2,080
* Total nodes: 400,000
* Total edges: 1,000,400

Component sizes:
* Size 2: 80 components
* Size 196: 1 components
* Size 198: 78 components
* Size 200: 1,921 components

In [None]:
from pathlib import Path
import pyarrow.parquet as pq

pq.write_table(table2, Path.cwd() / 'hierarchical_cc2k.parquet')

## Hierarchical vs cluster representation

Dejan has suggested we store clusters in in a format that minimises recursion. 

Consider that bc forms at a lower threshold than ab, so we'll need to hold both.

Representation A -- proposed in the notebook, used in repo.

| parent | child |
| --- | --- |
| ab | a |
| ab | b |
| bc | b |
| bc | c |
|abc | ab |
|abc | bc |

Representation B -- Dejan's suggestion.

| parent | child |
| --- | --- |
| ab | a |
| ab | b |
| bc | b |
| bc | c |
|abc | a |
|abc | b |
|abc | c |

How does the space required to do this change in the "happy path" version of this data, where there's lots of components?

In [231]:
# Nice and easy to calc for cc2k -- it's what we've been building through the notebook

f"Representation A count of 2,000 connected components over 1m rows: {len(h_out):,}"

'Representation A count of 2,000 connected components over 1m rows: 2,533,537'

Let's adapt `to_clusters()` to get the answer for representation B.

In [235]:
import rustworkx as rx

def to_clusters(results: pa.Table) -> pa.Table:
    """
    Converts probabilities into a list of connected components formed at each threshold.

    Returns:
        Probabilities sorted by threshold descending.
    """
    G = rx.PyGraph()
    added: dict[bytes, int] = {}
    components: dict[str, list] = {"parent": [], "child": [], "threshold": []}

    # Sort probabilities descending and group by probability
    edges_df = results.select(['left', 'right', 'probability']).sort_by([("probability", "descending")])
    
    # Get unique probability thresholds, sorted
    thresholds = pa.compute.unique(edges_df.column('probability'))

    # Process edges grouped by probability threshold
    for prob in thresholds.to_pylist():
        mask = pa.compute.equal(edges_df.column('probability'), prob)
        threshold_edges = edges_df.filter(mask)
        # Get state before adding this batch of edges
        old_components = {frozenset(comp) for comp in rx.connected_components(G)}

        # Add all nodes and edges at this probability threshold
        edge_values = zip(
            threshold_edges.column('left').to_pylist(),
            threshold_edges.column('right').to_pylist()
        )

        for left, right in edge_values:
            for hash_val in (left, right):
                if hash_val not in added:
                    idx = G.add_node(hash_val)
                    added[hash_val] = idx

            G.add_edge(added[left], added[right], None)

        new_components = {frozenset(comp) for comp in rx.connected_components(G)}
        changed_components = new_components - old_components

        # For each changed component, add ALL members at current threshold
        for comp in changed_components:
            children = [G.get_node_data(n) for n in comp]
            parent = combine_integers(*children)

            components["parent"].extend([parent] * len(children))
            components["child"].extend(children)
            components["threshold"].extend([prob] * len(children))

    return pa.Table.from_pydict(components)

hout_b = to_clusters(h2)

In [237]:
f"Representation B count of 2,000 connected components over 1m rows: {len(hout_b):,}"

'Representation B count of 2,000 connected components over 1m rows: 6,385,847'

## Algorithm

In [3]:
import pyarrow.parquet as pq

h2 = pq.read_table('hierarchical_cc2k.parquet')

In [None]:
h2.schema

left: int64
right: int64
probability: decimal128(3, 2)

The plan.

* Find components and their sizes at lowest threshold (rustworkx)
* Use this to dask.groupby the data for parallel per-component processing
* Ensure we implement early stopping!

### Find components

In [4]:
import pyarrow as pa
import pyarrow.compute as pc
import rustworkx as rx
import numpy as np

def attach_independent_components(table: pa.Table) -> pa.Table:
    """
    Returns the original table with an additional 'component' column indicating
    which connected component each edge belongs to.
    """
    # Create dictionary array from sorted unique values
    unique = pc.unique(
        pa.concat_arrays([
            table['left'].combine_chunks(),
            table['right'].combine_chunks()
        ])
    )
    
    # Get indices into unique array for graph construction
    left_indices = pc.index_in(table['left'], unique)
    right_indices = pc.index_in(table['right'], unique)
    
    # Create and process graph
    n_nodes = len(unique)
    n_edges = len(table)
    
    graph = rx.PyGraph(
        node_count_hint=n_nodes,
        edge_count_hint=n_edges
    )
    graph.add_nodes_from(range(n_nodes))

    edges = tuple(zip(left_indices.to_numpy(), right_indices.to_numpy()))
    graph.add_edges_from_no_data(edges)
    
    # Get components and create mapping array
    components = rx.connected_components(graph)
    
    # Convert components to numpy arrays
    component_indices = np.concatenate([np.array(list(c)) for c in components])
    component_labels = np.repeat(np.arange(len(components)), [len(c) for c in components])
    
    # Create mapping array and fill with component labels
    node_to_component = np.zeros(len(unique), dtype=np.int64)
    node_to_component[component_indices] = component_labels
    
    # Use the indices we already have to map back to components  
    edge_components = pa.array(node_to_component[left_indices.to_numpy()])
    
    return table.append_column('component', edge_components).sort_by(
        [('component', 'ascending'), ('probability', 'descending')]
    )

cc2 = attach_independent_components(h2)
len(pc.unique(cc2.column('component'))), cc2

(2080,
 pyarrow.Table
 left: int64
 right: int64
 probability: decimal128(3, 2)
 component: int64
 ----
 left: [[197413,116407,114551,160857,6412,...,39429,156175,197197,48177,121674]]
 right: [[326505,384163,344025,248700,258884,...,311452,278755,204144,357956,378111]]
 probability: [[1.00,1.00,1.00,1.00,1.00,...,0.70,0.70,0.70,0.70,0.70]]
 component: [[0,0,0,0,0,...,2079,2079,2079,2079,2079]])

In [None]:
import cProfile
import pstats
from pstats import SortKey

# Profile the function
pr = cProfile.Profile()
pr.enable()
components = attach_independent_components(table2)
pr.disable()

# Print stats sorted by cumulative time
ps = pstats.Stats(pr).sort_stats(SortKey.CUMULATIVE)
ps.print_stats(20)  # Show top 20 lines

         34170 function calls in 1.108 seconds

   Ordered by: cumulative time
   List reduced from 49 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000    1.108    0.554 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3541(run_code)
        2    0.000    0.000    1.108    0.554 {built-in method builtins.exec}
        1    0.058    0.058    1.108    1.108 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_25768/655470289.py:1(<module>)
        1    0.397    0.397    1.050    1.050 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_25768/2783446239.py:5(attach_independent_components)
        3    0.296    0.099    0.296    0.099 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/pyarrow/compute.py:249(wrapper)
        1    0.118    0.118    0.118    0.118 {connected_components}
        1    0.088    0.088    0.088    0.088 {

<pstats.Stats at 0x17d84cdd0>

In [None]:
import cProfile
import pstats
from pstats import SortKey

# Profile the function
pr = cProfile.Profile()
pr.enable()
_ = attach_independent_components(table)
pr.disable()

# Print stats sorted by cumulative time
ps = pstats.Stats(pr).sort_stats(SortKey.CUMULATIVE)
ps.print_stats(20)  # Show top 20 lines

         3411228 function calls in 121.198 seconds

   Ordered by: cumulative time
   List reduced from 48 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000  139.163   69.581 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3541(run_code)
        2    0.001    0.000  139.162   69.581 {built-in method builtins.exec}
        1   47.532   47.532  121.195  121.195 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_25768/2801194547.py:1(find_independent_components)
        1   24.123   24.123   24.123   24.123 {method 'add_edges_from_no_data' of 'rustworkx.PyGraph' objects}
        1   18.713   18.713   18.731   18.731 {connected_components}
        2   14.375    7.187   14.375    7.187 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/pyarrow/compute.py:249(wrapper)
        1    7.661    7.661    7.661    7.661 /Users/willlangdale/DS/m

<pstats.Stats at 0x16d1a8190>

### Process a single component

In [238]:
from functools import lru_cache

@lru_cache(maxsize=None)
def combine_integers(*n: int) -> int:
    """
    Combine n integers into a single negative integer.

    Used to create a symmetric deterministic hash of two integers that populates the
    range of integers efficiently and reduces the likelihood of collisions.

    Aims to vectorise amazingly when used in Arrow.

    Does this by:

    * Using a Mersenne prime as a modulus
    * Making negative integers positive with modulo, sped up with bitwise operations
    * Combining using symmetric operations with coprime multipliers

    Args:
        *args: Variable number of integers to combine

    Returns:
        A negative integer
    """
    P = 2147483647

    total = 0
    product = 1

    for x in sorted(n):
        x_pos = (x ^ (x >> 31)) - (x >> 31)
        total = (total + x_pos) % P
        product = (product * x_pos) % P

    result = (31 * total + 37 * product) % P

    return -result

In [194]:
import numpy as np
from functools import lru_cache

@lru_cache(maxsize=None)
def combine_integers(*n: int) -> int:
    """Original function for single tuple of integers"""
    P = 2147483647

    total = 0
    product = 1

    for x in sorted(n):
        x_pos = (x ^ (x >> 31)) - (x >> 31)
        total = (total + x_pos) % P
        product = (product * x_pos) % P

    result = (31 * total + 37 * product) % P

    return -result

def vectorized_combine_integers(arrays: list[list[int]] | np.ndarray) -> np.ndarray:
    """
    Vectorized version of combine_integers that works on arrays of integer arrays.
    
    Args:
        arrays: List of integer arrays or 2D numpy array where each row is a set of integers
               to be combined
               
    Returns:
        1D numpy array of combined negative integers, one for each input array
    """
    P = 2147483647
    arrays = np.asarray(arrays)
    arrays = np.sort(arrays, axis=1)
    
    signs = arrays >> 31
    x_pos = (arrays ^ signs) - signs
    
    totals = np.sum(x_pos, axis=1) % P
    products = np.prod(x_pos, axis=1) % P
    
    results = (31 * totals + 37 * products) % P
    
    return -results

test_arrays = [
    # [1, 2, 3],
    # [-1, -2, -3],
    # [10, 20, 0],
    [10, 20, 0]
]

vectorized_results = vectorized_combine_integers(test_arrays)
original_results = np.array([combine_integers(*arr) for arr in test_arrays])

print("Vectorized results:", vectorized_results)
print("Original results:", original_results)
print("Match:", np.array_equal(vectorized_results, original_results))

Vectorized results: [-930]
Original results: [-930]
Match: True


In [39]:
import numpy as np
import time
from typing import List, Tuple
import random

def generate_test_data(num_arrays: int, array_size: int) -> List[List[int]]:
    """Generate random test arrays"""
    return [
        [random.randint(-1000000, 1000000) for _ in range(array_size)]
        for _ in range(num_arrays)
    ]

def run_benchmark(arrays: List[List[int]]) -> Tuple[float, float]:
    """
    Run speed comparison between original and vectorized versions
    
    Returns:
        Tuple of (original_time, vectorized_time) in seconds
    """
    # Time original version
    start = time.time()
    original_results = [combine_integers(*arr) for arr in arrays]
    original_time = time.time() - start
    
    # Time vectorized version
    start = time.time()
    vectorized_results = vectorized_combine_integers(arrays)
    vectorized_time = time.time() - start
    
    # Verify results match
    assert np.array_equal(original_results, vectorized_results), "Results don't match!"
    
    return original_time, vectorized_time

# Run benchmarks with different sizes
test_sizes = [
    (100, 3),     # 100 arrays of size 3
    (1000, 3),    # 1000 arrays of size 3
    (10000, 3),   # 10000 arrays of size 3
    (100000, 3),  # 100000 arrays of size 3
]

print("Benchmark Results:")
print("-" * 60)
print(f"{'Size':>12} | {'Original (s)':>12} | {'Vectorized (s)':>12} | {'Speedup':>8}")
print("-" * 60)

for num_arrays, array_size in test_sizes:
    # Generate test data
    arrays = generate_test_data(num_arrays, array_size)
    
    # Run benchmark
    orig_time, vec_time = run_benchmark(arrays)
    speedup = orig_time / vec_time
    
    print(f"{num_arrays:>7}x{array_size:<4} | {orig_time:>12.4f} | {vec_time:>12.4f} | {speedup:>8.2f}x")

Benchmark Results:
------------------------------------------------------------
        Size | Original (s) | Vectorized (s) |  Speedup
------------------------------------------------------------
    100x3    |       0.0002 |       0.0003 |     0.75x
   1000x3    |       0.0024 |       0.0016 |     1.46x
  10000x3    |       0.0218 |       0.0057 |     3.81x
 100000x3    |       0.2714 |       0.0585 |     4.64x


In [195]:
from itertools import chain
from functools import lru_cache

@lru_cache(maxsize=None)
def combine_strings(*n: str) -> str:
    """
    Combine n strings into a single string.

    Args:
        *args: Variable number of strings to combine
        
    Returns:
        A single string
    """
    letters = set(chain.from_iterable(n))
    return "".join(sorted(letters))


def vectorized_combine_strings(arrays: list[list[str]] | np.ndarray) -> np.ndarray:
    """
    Vectorized version of combine_strings that works on arrays of string arrays.
    
    Args:
        arrays: List of string arrays or 2D numpy array where each row is a set of 
               strings to be combined
               
    Returns:
        1D numpy array of combined strings, one for each input array
    """
    arrays = np.asarray(arrays, dtype=object)
    
    def process_row(row):
        return combine_strings(*row)
    
    vfunc = np.vectorize(process_row, signature='(n)->()', otypes=[object])
    return vfunc(arrays)

test_arrays = [
    ["abc", "def", "abd"],
    ["xyz", "xyy", "xzz"],
    ["hello", "world", "hello"]
]

vectorized_results = vectorized_combine_strings(test_arrays)
original_results = np.array([combine_strings(*arr) for arr in test_arrays])

print("Vectorized results:", vectorized_results)
print("Original results:", original_results)
print("Match:", np.array_equal(vectorized_results, original_results))

Vectorized results: ['abcdef' 'xyz' 'dehlorw']
Original results: ['abcdef' 'xyz' 'dehlorw']
Match: True


In [17]:
from collections import defaultdict
from typing import TypeVar, Generic, Hashable, Iterator
import pandas as pd

T = TypeVar('T', bound=Hashable)

class UnionFindWithDiff(Generic[T]):
    def __init__(self):
        self.parent: dict[T, T] = {}
        self.rank: dict[T, int] = {}
        self._shadow_parent: dict[T, T] = {}
        self._shadow_rank: dict[T, int] = {}
        self._pending_pairs: list[tuple[T, T]] = []
        
    def make_set(self, x: T) -> None:
        if x not in self.parent:
            self.parent[x] = x
            self.rank[x] = 0
    
    def find(self, x: T, parent_dict: dict[T, T] | None = None) -> T:
        if parent_dict is None:
            parent_dict = self.parent
            
        if x not in parent_dict:
            self.make_set(x)
            if parent_dict is self._shadow_parent:
                self._shadow_parent[x] = x
                self._shadow_rank[x] = 0
        
        while parent_dict[x] != x:
            parent_dict[x] = parent_dict[parent_dict[x]]
            x = parent_dict[x]
        return x
    
    def union(self, x: T, y: T) -> None:
        root_x = self.find(x)
        root_y = self.find(y)
        
        if root_x != root_y:
            self._pending_pairs.append((x, y))
            
            if self.rank[root_x] < self.rank[root_y]:
                root_x, root_y = root_y, root_x
            self.parent[root_y] = root_x
            if self.rank[root_x] == self.rank[root_y]:
                self.rank[root_x] += 1
    
    def get_component(self, x: T, parent_dict: dict[T, T] | None = None) -> set[T]:
        if parent_dict is None:
            parent_dict = self.parent
        
        root = self.find(x, parent_dict)
        return {y for y in parent_dict if self.find(y, parent_dict) == root}
    
    def get_components(self, parent_dict: dict[T, T] | None = None) -> list[set[T]]:
        if parent_dict is None:
            parent_dict = self.parent
            
        components = defaultdict(set)
        for x in parent_dict:
            root = self.find(x, parent_dict)
            components[root].add(x)
        return list(components.values())
    
    def diff(self) -> Iterator[tuple[set[T], set[T]]]:
        """
        Returns differences including all pairwise merges that occurred since last diff,
        excluding cases where old_comp == new_comp.
        """
        # Get current state before processing pairs
        current_components = self.get_components()
        reported_pairs = set()
        
        # Process pending pairs
        for x, y in self._pending_pairs:
            # Find the final component containing the pair
            final_component = next(comp for comp in current_components 
                                 if x in comp and y in comp)
            
            # Only report if the pair forms a proper subset of the final component
            pair_component = {x, y}
            if (pair_component != final_component and 
                frozenset((frozenset(pair_component), frozenset(final_component))) not in reported_pairs):
                reported_pairs.add(frozenset((frozenset(pair_component), frozenset(final_component))))
                yield (pair_component, final_component)
        
        self._pending_pairs.clear()
        
        # Handle initial state
        if not self._shadow_parent:
            self._shadow_parent = self.parent.copy()
            self._shadow_rank = self.rank.copy()
            return
        
        # Get old components
        old_components = self.get_components(self._shadow_parent)
        
        # Report changes between old and new states
        for old_comp in old_components:
            if len(old_comp) > 1:  # Only consider non-singleton old components
                sample_elem = next(iter(old_comp))
                new_comp = next(comp for comp in current_components if sample_elem in comp)
                
                # Only yield if the components are different and this pair hasn't been reported
                if (old_comp != new_comp and 
                    frozenset((frozenset(old_comp), frozenset(new_comp))) not in reported_pairs):
                    reported_pairs.add(frozenset((frozenset(old_comp), frozenset(new_comp))))
                    yield (old_comp, new_comp)
        
        # Update shadow copy
        self._shadow_parent = self.parent.copy()
        self._shadow_rank = self.rank.copy()


def component_to_hierarchy(key: str | int | tuple, df: pd.DataFrame) -> pd.DataFrame:
    """
    Convert pairwise probabilities into a hierarchical representation.
    Assumes data is pre-sorted by probability descending.
    
    Args:
        key: Group key (ignored in this implementation)
        df: DataFrame with columns ['left', 'right', 'probability']
    
    Returns:
        DataFrame with columns ['parent', 'child', 'probability'] representing hierarchical merges
    """
    hierarchy: list[tuple[int, int, float]] = []
    uf = UnionFindWithDiff[int]()

    for threshold in df["probability"].unique():
        current_probs = df[df["probability"] == threshold]

        for _, row in current_probs.iterrows():
            uf.union(row["left"], row["right"])
            parent = combine_integers(row["left"], row["right"])
            hierarchy.extend([
                (parent, row["left"], threshold),
                (parent, row["right"], threshold)
            ])

        for old_comp, new_comp in uf.diff():
            if len(old_comp) > 1:
                parent = combine_integers(*new_comp)
                child = combine_integers(*old_comp)
                hierarchy.extend([
                    (parent, child, threshold)
                ])
            else:
                parent = combine_integers(*new_comp)
                hierarchy.extend([
                    (parent, old_comp.pop(), threshold)
                ])

    return pd.DataFrame(hierarchy, columns=['parent', 'child', 'probability'])

In [41]:
df = cc2.to_pandas()
cc_15_pd = df[df['component'] == 15]

In [46]:
import cProfile
import pstats
from pstats import SortKey

# Profile the function
pr = cProfile.Profile()
pr.enable()
_ = component_to_hierarchy(15, cc_15_pd)
pr.disable()

# Print stats sorted by cumulative time
ps = pstats.Stats(pr).sort_stats(SortKey.CUMULATIVE)
ps.print_stats(20)  # Show top 20 lines

         142304 function calls (139675 primitive calls) in 0.136 seconds

   Ordered by: cumulative time
   List reduced from 276 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000    0.136    0.068 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3541(run_code)
        2    0.000    0.000    0.136    0.068 {built-in method builtins.exec}
        1    0.000    0.000    0.136    0.136 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_68435/3996032216.py:1(<module>)
        1    0.007    0.007    0.136    0.136 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_68435/807932433.py:115(component_to_hierarchy)
      532    0.002    0.000    0.059    0.000 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/pandas/core/frame.py:1505(iterrows)
      532    0.007    0.000    0.054    0.000 /Users/willlangdale/DS/matchbox/.venv/lib/py

<pstats.Stats at 0x12219f390>

In [None]:
import numpy as np
from typing import TypeVar, Generic, Hashable, Iterator, Optional, Dict
from collections import defaultdict

T = TypeVar('T', bound=Hashable)

class NumpyUnionFindWithDiff(Generic[T]):
    def __init__(self, max_size: int = 1000000):
        # Core numpy arrays for fast operations
        self.parent = np.arange(max_size, dtype=np.int32)
        self.rank = np.zeros(max_size, dtype=np.int32)
        self.size = 0
        
        # Mapping between external IDs and array indices
        self.id_to_idx: Dict[T, int] = {}
        self.idx_to_id: Dict[int, T] = {}
        
        # Shadow state
        self._shadow_parent = np.arange(max_size, dtype=np.int32)
        self._shadow_rank = np.zeros(max_size, dtype=np.int32)
        self._pending_x = np.array([], dtype=np.int32)
        self._pending_y = np.array([], dtype=np.int32)
        
    def _get_idx(self, x: T) -> int:
        if x not in self.id_to_idx:
            idx = self.size
            self.id_to_idx[x] = idx
            self.idx_to_id[idx] = x
            self.size += 1
        return self.id_to_idx[x]
    
    def find_vec(self, indices: np.ndarray, parent_array: Optional[np.ndarray] = None) -> np.ndarray:
        if parent_array is None:
            parent_array = self.parent
            
        paths = indices.reshape(-1, 1)
        positions = np.zeros(len(indices), dtype=np.int32)
        max_length = 32
        full_paths = np.zeros((len(indices), max_length), dtype=np.int32)
        full_paths[:, 0] = indices
        
        while True:
            next_parent = parent_array[paths[:, -1]]
            if np.all(next_parent == paths[:, -1]):
                break
            positions += 1
            full_paths[:, positions] = next_parent
            paths = np.hstack([paths, next_parent.reshape(-1, 1)])
            
        roots = paths[:, -1]
        mask = full_paths != 0
        parent_array[full_paths[mask]] = np.repeat(roots, mask.sum(axis=1))
        
        return roots
    
    def union_vectorized(self, x_vec: np.ndarray, y_vec: np.ndarray) -> None:
        x_idx = np.array([self._get_idx(x) for x in x_vec])
        y_idx = np.array([self._get_idx(y) for y in y_vec])
        
        for i in range(len(x_idx)):
            root_x = self.find_vec(np.array([x_idx[i]]))[0]
            root_y = self.find_vec(np.array([y_idx[i]]))[0]
            
            if root_x != root_y:
                self._pending_x = np.append(self._pending_x, x_idx[i])
                self._pending_y = np.append(self._pending_y, y_idx[i])
                
                if self.rank[root_x] < self.rank[root_y]:
                    root_x, root_y = root_y, root_x
                    
                self.parent[root_y] = root_x
                
                if self.rank[root_x] == self.rank[root_y]:
                    self.rank[root_x] += 1

    def get_components(self, parent_array: Optional[np.ndarray] = None) -> list[set[int]]:
        if parent_array is None:
            parent_array = self.parent
            
        all_indices = np.arange(self.size)
        roots = self.find_vec(all_indices, parent_array)
        unique_roots = np.unique(roots)
        
        components = []
        for root in unique_roots:
            comp_indices = all_indices[roots == root]
            if len(comp_indices) > 1:  # Only include non-singleton components
                components.append({self.idx_to_id[i] for i in comp_indices})
        return components

    def diff(self) -> tuple[np.ndarray, np.ndarray]:
        if len(self._pending_x) == 0:
            if not hasattr(self, '_initialized_shadow'):
                self._shadow_parent[:] = self.parent
                self._shadow_rank[:] = self.rank
                self._initialized_shadow = True
            return np.array([], dtype=object), np.array([], dtype=object)
        
        # Get current final state
        final_components = self.get_components()
        changes = set()
        
        # Process pending pairs
        all_indices = np.arange(self.size)
        current_roots = self.find_vec(all_indices)
        
        for x_idx, y_idx in zip(self._pending_x, self._pending_y):
            # Get the final component containing this pair
            final_root = self.find_vec(np.array([x_idx]))[0]
            final_comp = next(comp for comp in final_components 
                            if self.idx_to_id[x_idx] in comp)
            
            # Add the direct pair change
            pair_comp = frozenset({self.idx_to_id[x_idx], self.idx_to_id[y_idx]})
            if pair_comp != final_comp:
                changes.add((pair_comp, frozenset(final_comp)))

        # Process changes from old state
        if hasattr(self, '_initialized_shadow'):
            old_components = self.get_components(self._shadow_parent)
            
            # For each old component
            for old_comp in old_components:
                # Find which final component contains this old component
                sample_elem = next(iter(old_comp))
                sample_idx = self._get_idx(sample_elem)
                final_root = self.find_vec(np.array([sample_idx]))[0]
                final_comp = next(comp for comp in final_components 
                                if self.idx_to_id[sample_idx] in comp)
                
                if old_comp != final_comp:
                    changes.add((frozenset(old_comp), frozenset(final_comp)))
        
        # Update shadow state and clear pending
        self._shadow_parent[:] = self.parent
        self._shadow_rank[:] = self.rank
        self._initialized_shadow = True
        self._pending_x = np.array([], dtype=np.int32)
        self._pending_y = np.array([], dtype=np.int32)
        
        if not changes:
            return np.array([], dtype=object), np.array([], dtype=object)
        
        # Convert changes to numpy arrays
        changes_list = list(changes)
        old_comps = np.array([np.array(list(old_comp)) for old_comp, _ in changes_list], dtype=object)
        new_comps = np.array([np.array(list(new_comp)) for _, new_comp in changes_list], dtype=object)
        
        return old_comps, new_comps


def component_to_hierarchy_pa(table: pa.Table) -> pa.Table:
    """
    Convert pairwise probabilities into a hierarchical representation.
    Assumes data is pre-sorted by probability descending.
    
    Args:
        key: Group key (ignored in this implementation)
        table: Arrow Table with columns ['left', 'right', 'probability']
    
    Returns:
        Arrow Table with columns ['parent', 'child', 'probability'] representing hierarchical merges
    """
    hierarchy: list[tuple[int, int, float]] = []
    uf = NumpyUnionFindWithDiff[int]()

    # Get unique probabilities
    probs = pc.unique(table['probability'])

    for threshold in probs:
        # Get current probability rows
        mask = pc.equal(table['probability'], threshold)
        current_probs = table.filter(mask)
        threshold_float = float(threshold.as_py())
        
        # Process each row
        left = current_probs['left'].to_numpy()
        right = current_probs['right'].to_numpy()
        uf.union_vectorized(left, right)
        parent = vectorized_combine_integers([left, right])

        hierarchy.extend(
            list(zip(parent, left, [threshold_float] * len(parent))) +
            list(zip(parent, right, [threshold_float] * len(parent)))
        )
        
        # Process UnionFind diffs
        old_comps, new_comps = uf.diff()
        if len(old_comps) > 0:  # only if there are changes
            singles = np.array([len(comp) == 1 for comp in old_comps])
            multi_indices = ~singles
            single_indices = singles

            if multi_indices.any():
                parents = vectorized_combine_integers(new_comps[multi_indices])
                children = vectorized_combine_integers(old_comps[multi_indices])
                hierarchy.extend(list(zip(parents, children, [threshold_float] * len(parents))))

            if single_indices.any():
                parents = vectorized_combine_integers(new_comps[single_indices])
                children = old_comps[single_indices].ravel()  # automatically flattens single-element arrays
                hierarchy.extend(list(zip(parents, children, [threshold_float] * len(parents))))

    parents, children, probs = zip(*hierarchy)
    return pa.table({
        'parent': pa.array(parents, type=pa.int64()),
        'child': pa.array(children, type=pa.int64()),
        'probability': pa.array(probs, type=pa.float64())
    })

component_to_hierarchy_pa(cc_15)

TypeError: Got unexpected argument type <class 'polars.series.series.Series'> for compute function

In [129]:
import numpy as np
from typing import Set, List, Tuple

# Helper function to convert numpy arrays to sets for comparison
def numpy_components_to_sets(old_comps: np.ndarray, new_comps: np.ndarray) -> List[Tuple[Set[int], Set[int]]]:
    if len(old_comps) == 0:
        return []
    return [(set(old.tolist()), set(new.tolist())) for old, new in zip(old_comps, new_comps)]

def test_implementations():
    # Initialize both implementations
    dict_uf = UnionFindWithDiff()
    numpy_uf = NumpyUnionFindWithDiff()
    
    # Test Case 1: Simple chain merge
    print("Test Case 1: Chain merge")
    # Dict version
    dict_uf = UnionFindWithDiff()
    dict_uf.union(1, 2)
    dict_uf.union(2, 3)
    dict_changes1 = list(dict_uf.diff())
    
    # Numpy version
    numpy_uf = NumpyUnionFindWithDiff()
    numpy_uf.union_vectorized(np.array([1, 2]), np.array([2, 3]))
    old_comps, new_comps = numpy_uf.diff()
    numpy_changes1 = numpy_components_to_sets(old_comps, new_comps)
    
    print("Dict version changes:", dict_changes1)
    print("Numpy version changes:", numpy_changes1)
    print("Match:", set(map(str, dict_changes1)) == set(map(str, numpy_changes1)))
    print()
    
    # Test Case 2: Star merge
    print("Test Case 2: Star merge")
    # Dict version
    dict_uf = UnionFindWithDiff()
    dict_uf.union(1, 2)
    dict_uf.union(1, 3)
    dict_uf.union(1, 4)
    dict_changes2 = list(dict_uf.diff())
    
    # Numpy version
    numpy_uf = NumpyUnionFindWithDiff()
    numpy_uf.union_vectorized(np.array([1, 1, 1]), np.array([2, 3, 4]))
    old_comps, new_comps = numpy_uf.diff()
    numpy_changes2 = numpy_components_to_sets(old_comps, new_comps)
    
    print("Dict version changes:", dict_changes2)
    print("Numpy version changes:", numpy_changes2)
    print("Match:", set(map(str, dict_changes2)) == set(map(str, numpy_changes2)))
    print()
    
    # Test Case 3: Merging existing components
    print("Test Case 3: Merging existing components")
    # Dict version
    dict_uf = UnionFindWithDiff()
    dict_uf.union(1, 2)
    dict_uf.union(3, 4)
    list(dict_uf.diff())  # Clear initial changes
    dict_uf.union(2, 3)
    dict_changes3 = list(dict_uf.diff())
    
    # Numpy version
    numpy_uf = NumpyUnionFindWithDiff()
    numpy_uf.union_vectorized(np.array([1, 3]), np.array([2, 4]))
    numpy_uf.diff()  # Clear initial changes
    numpy_uf.union_vectorized(np.array([2]), np.array([3]))
    old_comps, new_comps = numpy_uf.diff()
    numpy_changes3 = numpy_components_to_sets(old_comps, new_comps)
    
    print("Dict version changes:", dict_changes3)
    print("Numpy version changes:", numpy_changes3)
    print("Match:", set(map(str, dict_changes3)) == set(map(str, numpy_changes3)))

test_implementations()

Test Case 1: Chain merge
Dict version changes: [({1, 2}, {1, 2, 3}), ({2, 3}, {1, 2, 3})]
Numpy version changes: [({2, 3}, {1, 2, 3}), ({1, 2}, {1, 2, 3})]
Match: True

Test Case 2: Star merge
Dict version changes: [({1, 2}, {1, 2, 3, 4}), ({1, 3}, {1, 2, 3, 4}), ({1, 4}, {1, 2, 3, 4})]
Numpy version changes: [({1, 2}, {1, 2, 3, 4}), ({1, 4}, {1, 2, 3, 4}), ({1, 3}, {1, 2, 3, 4})]
Match: True

Test Case 3: Merging existing components
Dict version changes: [({2, 3}, {1, 2, 3, 4}), ({1, 2}, {1, 2, 3, 4}), ({3, 4}, {1, 2, 3, 4})]
Numpy version changes: [({1, 2}, {1, 2, 3, 4}), ({3, 4}, {1, 2, 3, 4}), ({2, 3}, {1, 2, 3, 4})]
Match: True


In [130]:
import numpy as np
import time
from typing import Tuple, List
import pandas as pd
import matplotlib.pyplot as plt

def generate_random_pairs(size: int, num_pairs: int) -> Tuple[np.ndarray, np.ndarray]:
    """Generate random pairs of integers for testing"""
    x = np.random.randint(0, size, num_pairs)
    y = np.random.randint(0, size, num_pairs)
    return x, y

def benchmark_dict_version(x: np.ndarray, y: np.ndarray) -> float:
    """Benchmark dictionary-based implementation"""
    uf = UnionFindWithDiff()
    start = time.perf_counter()
    
    for xi, yi in zip(x, y):
        uf.union(int(xi), int(yi))
    list(uf.diff())  # Materialize the changes
    
    return time.perf_counter() - start

def benchmark_numpy_version(x: np.ndarray, y: np.ndarray) -> float:
    """Benchmark numpy-based implementation"""
    uf = NumpyUnionFindWithDiff()
    start = time.perf_counter()
    
    uf.union_vectorized(x, y)
    uf.diff()  # Get the changes
    
    return time.perf_counter() - start

def run_benchmarks() -> pd.DataFrame:
    """Run benchmarks with different input sizes"""
    # Test parameters
    sizes = [100, 1000, 10000]
    pairs_per_size = [10, 100, 1000]
    num_trials = 5
    
    results = []
    
    for size in sizes:
        for num_pairs in pairs_per_size:
            if num_pairs > size:
                continue
                
            dict_times = []
            numpy_times = []
            
            for _ in range(num_trials):
                # Generate random pairs
                x, y = generate_random_pairs(size, num_pairs)
                
                # Run benchmarks
                dict_time = benchmark_dict_version(x, y)
                numpy_time = benchmark_numpy_version(x, y)
                
                dict_times.append(dict_time)
                numpy_times.append(numpy_time)
            
            # Record average results
            results.append({
                'Size': size,
                'Num_Pairs': num_pairs,
                'Dict_Time': np.mean(dict_times),
                'Numpy_Time': np.mean(numpy_times),
                'Dict_Std': np.std(dict_times),
                'Numpy_Std': np.std(numpy_times),
                'Speedup': np.mean(dict_times) / np.mean(numpy_times)
            })
    
    return pd.DataFrame(results)

run_benchmarks()

Unnamed: 0,Size,Num_Pairs,Dict_Time,Numpy_Time,Dict_Std,Numpy_Std,Speedup
0,100,10,7.4e-05,0.00242,1e-05,0.000339,0.030523
1,100,100,0.000676,0.01862,0.000165,0.00074,0.036309
2,1000,10,7.4e-05,0.002496,4e-06,0.000256,0.029668
3,1000,100,0.000583,0.015933,5e-06,0.001204,0.036586
4,1000,1000,0.027857,0.232283,0.00323,0.016529,0.119926
5,10000,10,7.2e-05,0.002457,7e-06,0.000193,0.029138
6,10000,100,0.0006,0.014807,2.6e-05,0.00086,0.040527
7,10000,1000,0.01849,0.185388,0.001771,0.001559,0.099738


In [140]:
cc_15_pa = cc2.filter(pc.equal(cc2['component'], 15))

In [155]:
import cProfile
import pstats
from pstats import SortKey

from cluster import component_to_hierarchy_pa

# Profile the function
pr = cProfile.Profile()
pr.enable()
_ = component_to_hierarchy_pa(cc_15_pa)
pr.disable()

# Print stats sorted by cumulative time
ps = pstats.Stats(pr).sort_stats(SortKey.CUMULATIVE)
ps.print_stats(20)  # Show top 20 lines

         26701 function calls in 0.033 seconds

   Ordered by: cumulative time
   List reduced from 54 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000    0.033    0.016 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3541(run_code)
        2    0.000    0.000    0.033    0.016 {built-in method builtins.exec}
        1    0.006    0.006    0.033    0.033 /Users/willlangdale/DS/matchbox/notebooks/cluster.py:208(component_to_hierarchy_pa)
      298    0.003    0.000    0.021    0.000 /Users/willlangdale/DS/matchbox/notebooks/cluster.py:107(diff)
       61    0.007    0.000    0.017    0.000 /Users/willlangdale/DS/matchbox/notebooks/cluster.py:97(get_components)
    11218    0.009    0.000    0.009    0.000 /Users/willlangdale/DS/matchbox/notebooks/cluster.py:62(find)
      501    0.001    0.000    0.002    0.000 /Users/willlangdale/DS/matchbox/notebooks/cl

<pstats.Stats at 0x127c54310>

In [123]:
P = 2147483647

cc_15.with_columns([
    pl.min_horizontal(['left', 'right']).abs().alias('smaller'),
    pl.max_horizontal(['left', 'right']).abs().alias('larger'),
]).with_columns([
    pl.col('smaller').add(pl.col('larger')).mod(P).alias('total2'),
    pl.col('smaller').mul(pl.col('larger')).mod(P).alias('prod2'),
]).with_columns([
    pl.col('total2').mul(31).add(pl.col('prod2').mul(37)).mod(P).neg().alias('parent')
]).select(["left", "right", "parent"])

left,right,parent
i32,i32,i32
81762,389632,-1905962718
5063,289249,-507336890
84647,212390,-1629285124
175160,262073,-1975619425
178519,348267,-435576558
…,…,…
12792,227750,-428260364
71547,331214,-641775315
66430,329216,-1741711938
114817,358133,-1027235085


In [122]:
cc_15.map_rows(lambda t: combine_integers(t[0], t[1]))

map
i64
-1905963266
-507336916
-1629285434
-1975620253
-435577630
…
-428260452
-641775761
-1741712314
-1027235831


In [219]:
import polars as pl

def component_to_hierarchy_pl(df: pl.DataFrame) -> pl.DataFrame:
    """
    Convert pairwise probabilities into a hierarchical representation using Polars.
    Assumes data is pre-sorted by probability descending.
    
    Args:
        key: Group key (ignored in this implementation)
        df: Polars DataFrame with columns ['left', 'right', 'probability']
    
    Returns:
        Polars DataFrame with columns ['parent', 'child', 'probability'] representing hierarchical merges
    """
    hierarchy: list[tuple[int, int, float]] = []
    uf = UnionFindWithDiff[int]()

    thresholds = df["probability"].unique(maintain_order=True)

    for threshold in thresholds:
        # Filter current probability rows and convert to Python objects for processing
        current_probs = df.filter(pl.col('probability') == threshold).select(["left", "right", "probability"])
        
        # Process each row
        for row in current_probs.iter_rows():
            parent = combine_integers(row[0], row[1])
            uf.union(row[0], row[1])
            hierarchy.extend([
                (parent, row[1], threshold),
                (parent, row[0], threshold)
            ])

        # Process union-find differences
        for old_comp, new_comp in uf.diff():
            if len(old_comp) > 1:
                parent = combine_integers(*new_comp)
                child = combine_integers(*old_comp)
                hierarchy.extend([
                    (parent, child, threshold)
                ])
            else:
                parent = combine_integers(*new_comp)
                hierarchy.extend([
                    (parent, old_comp.pop(), threshold)
                ])

    # Convert the results to a Polars DataFrame
    return pl.DataFrame(
        hierarchy,
        schema={
            'parent': pl.Int32,
            'child': pl.Int32,
            'probability': pl.UInt8
        }
    )

In [158]:
import polars as pl

cc_15 = pl.from_arrow(cc2.filter(pc.equal(cc2['component'], 15)))
cc_15 = cc_15.with_columns(
    pl.col("probability").cast(pl.Float32).mul(100).cast(pl.UInt8).alias("probability"),
    pl.col("left").cast(pl.Int32),
    pl.col("right").cast(pl.Int32),
    pl.col("component").cast(pl.UInt32),
)
cc_15

left,right,probability,component
i32,i32,u8,u32
81762,389632,100,15
5063,289249,100,15
84647,212390,100,15
175160,262073,100,15
178519,348267,100,15
…,…,…,…
12792,227750,70,15
71547,331214,70,15
66430,329216,70,15
114817,358133,70,15


In [209]:
import cProfile
import pstats
from pstats import SortKey

# Profile the function
pr = cProfile.Profile()
pr.enable()
_ = component_to_hierarchy_pl(cc_15)
pr.disable()

# Print stats sorted by cumulative time
ps = pstats.Stats(pr).sort_stats(SortKey.CUMULATIVE)
ps.print_stats(20)  # Show top 20 lines


# Pure polars was 0.06, with iter_rows, 0.03-4

         32850 function calls in 0.040 seconds

   Ordered by: cumulative time
   List reduced from 196 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000    0.040    0.020 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3541(run_code)
        2    0.000    0.000    0.040    0.020 {built-in method builtins.exec}
        1    0.004    0.004    0.040    0.040 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_72030/483903679.py:3(component_to_hierarchy_pl)
      298    0.003    0.000    0.019    0.000 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_72030/807932433.py:65(diff)
       61    0.008    0.000    0.015    0.000 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_72030/807932433.py:55(get_components)
       31    0.000    0.000    0.007    0.000 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/polars/dataf

  return pl.DataFrame(


<pstats.Stats at 0x126f54a10>

In [218]:
from unittest.mock import patch

test_cases = [
    # Test case 1: Equal probabilities
    (
        {
            "left": ["a", "b", "c"],
            "right": ["b", "c", "d"],
            "probability": [1.0, 1.0, 1.0],
        },
        {
            ("ab", "a", 1.0),
            ("ab", "b", 1.0),
            ("bc", "b", 1.0),
            ("bc", "c", 1.0),
            ("cd", "c", 1.0),
            ("cd", "d", 1.0),
            ("abcd", "ab", 1.0),
            ("abcd", "bc", 1.0),
            ("abcd", "cd", 1.0),
        },
    ),
    # Test case 2: Asymmetric probabilities
    (
        {
            "left": ["w", "x", "y"],
            "right": ["x", "y", "z"],
            "probability": [0.9, 0.85, 0.8],
        },
        {
            ("wx", "w", 0.9),
            ("wx", "x", 0.9),
            ("xy", "x", 0.85),
            ("xy", "y", 0.85),
            ("wxy", "wx", 0.85),
            ("wxy", "xy", 0.85),
            ("yz", "y", 0.8),
            ("yz", "z", 0.8),
            ("wxyz", "wxy", 0.8),
            ("wxyz", "yz", 0.8),
        },
    ),
    # Test case 3: Single two-item component
    (
        {
            "left": ["x"],
            "right": ["y"],
            "probability": [0.9],
        },
        {
            ("xy", "x", 0.9),
            ("xy", "y", 0.9),
        },
    ),
]

for i, (prob_data, expected_relations) in enumerate(test_cases, 1):
    print(f"\nRunning test case {i}...")
    
    with patch('__main__.combine_integers', side_effect=combine_strings):
        # Pandas
        # probabilities = (
        #     pd.DataFrame.from_dict(prob_data)
        #     .assign(probability=lambda df: df['probability'].astype(float))
        #     .sort_values(by="probability", ascending=False)
        # )
        # hierarchy_true = (
        #         pd.DataFrame.from_records(
        #         list(expected_relations), 
        #         columns=["parent", "child", "probability"]
        #     )
        #     .sort_values(by=["probability", "parent", "child"], ascending=[False, True, True])
        #     .dropna(how='all')
        #     .reset_index(drop=True)
        # )

        # hierarchy = component_to_hierarchy(0, probabilities)

        # hierarchy = hierarchy.sort_values(
        #     by=["probability", "parent", "child"],
        #     ascending=[False, True, True]
        # ).reset_index(drop=True)

        # Arrow
        # probabilities = (
        #     pa.Table.from_pydict(prob_data)
        #     .cast(pa.schema([
        #         ('left', pa.string()),
        #         ('right', pa.string()),
        #         ('probability', pa.float64()),
        #     ]))
        #     .sort_by([('probability', 'descending')])
        # )

        # parents, children, probs = zip(*expected_relations)

        # hierarchy_true = (
        #     pa.table(
        #         [parents, children, probs], 
        #         names=['parent', 'child', 'probability']
        #     )
        #     .sort_by([
        #         ('probability', 'descending'),
        #         ('parent', 'ascending'),
        #         ('child', 'ascending')
        #     ])
        #     .filter(pc.is_valid(pc.field('parent')))
        # )

        # hierarchy = (
        #     component_to_hierarchy_pa(probabilities)
        #     .sort_by([
        #         ('probability', 'descending'),
        #         ('parent', 'ascending'),
        #         ('child', 'ascending')
        #     ])
        # )

        # Polars
        probabilities = (
            pl.DataFrame(prob_data)
            .with_columns(
                pl.col("probability").cast(pl.Float32).mul(100).cast(pl.UInt8).alias("probability"),
                pl.col("left").cast(pl.String),
                pl.col("right").cast(pl.String),
            )
            .sort("probability", descending=True)
        )

        # Convert hierarchy_true DataFrame
        hierarchy_true = (
                pd.DataFrame.from_records(
                list(expected_relations), 
                columns=["parent", "child", "probability"]
            )
            .sort_values(by=["probability", "parent", "child"], ascending=[False, True, True])
            .dropna(how='all')
            .reset_index(drop=True)
        )

        # Assuming component_to_hierarchy() function exists and works with polars
        hierarchy = component_to_hierarchy_pl(probabilities)

        # Sort hierarchy, turning back to pandas
        hierarchy = (
            hierarchy
            .with_columns(
                pl.col("probability").cast(pl.Float64).truediv(100).alias("probability")
            )
            .to_pandas()
            .sort_values(
                by=["probability", "parent", "child"],
                ascending=[False, True, True]
            )
            .reset_index(drop=True)
        )
        
        try:
            assert hierarchy.equals(hierarchy_true)
            print(f"✓ Test case {i} passed")
        except AssertionError:
            print(f"✗ Test case {i} failed")
            print("\nExpected DataFrame:")
            print(hierarchy_true)
            print("\nActual DataFrame:")
            print(hierarchy)
            print("\nDifferences:")
            if hierarchy.shape != hierarchy_true.shape:
                print(f"Shape mismatch: Expected {hierarchy_true.shape}, got {hierarchy.shape}")
            else:
                # Show where values differ
                differences = (hierarchy != hierarchy_true).any(axis=1)
                if differences.any():
                    print("\nMismatched rows:")
                    print("Expected:")
                    print(hierarchy_true[differences])
                    print("\nGot:")
                    print(hierarchy[differences])
            # Continue to next test case instead of raising
            continue



Running test case 1...
✓ Test case 1 passed

Running test case 2...
✓ Test case 2 passed

Running test case 3...
✓ Test case 3 passed


  return pl.DataFrame(
  return pl.DataFrame(
  return pl.DataFrame(


### Process all components in parallel

In [17]:
import pyarrow.parquet as pq

h2 = pq.read_table('hierarchical_cc2k.parquet')
cc2 = attach_independent_components(h2)

h200 = pq.read_table('hierarchical_cc200k.parquet')
cc200 = attach_independent_components(h200)

In [15]:
import pyarrow as pa
import dask.dataframe as dd
import pandas as pd
from multiprocessing import cpu_count
from dask.distributed import Client, LocalCluster

def setup_client(port: int = 8787) -> Client:
    """
    Creates a Dask client optimized for local computation.
    Returns the client object.
    """
    try:
        client = Client(f'tcp://localhost:{port}')
    except Exception:
        n_workers = cpu_count()
        cluster = LocalCluster(
            dashboard_address=f':{port}',
            n_workers=n_workers,
            threads_per_worker=1,  # One thread per worker for CPU-bound tasks
            memory_limit='auto'    # Automatically determine memory limits
        )
        client = Client(cluster)
    return client

In [5]:
from cluster import component_to_hierarchy

def to_hierarchical_clusters(client: Client, table: pa.Table) -> pd.DataFrame:
    """
    Perform hierarchical clustering on an Arrow table of pairwise probabilities.
    """
    pdf = table.to_pandas()
    
    ddf = dd.from_pandas(
        pdf, 
        npartitions=len(client.cluster.worker_spec)
    )
    
    result_ddf = ddf.groupby('component').apply(
        component_to_hierarchy, 
        meta={
            'parent': int, 
            'child': int, 
            'probability': float
        }
    )
    
    return result_ddf.compute()

NameError: name 'Client' is not defined

In [17]:
with setup_client(port=8787) as client:
    h_out = to_hierarchical_clusters(client, cc2)

h_out

KeyboardInterrupt: 

In [None]:
for i in cc2.partitioning("component"):
    print(i)
    break

AttributeError: 'pyarrow.lib.Table' object has no attribute 'partition_by'

In [29]:
import pyarrow as pa
import pyarrow.compute as pc
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
from cluster import component_to_hierarchy_pa
import logging
import time
from tqdm.notebook import tqdm
import logging

# def to_hierarchical_clusters_pa(table: pa.Table) -> pa.Table:
#     """
#     Parallel processing of components using Arrow with optimized submission.
#     """
#     logging.info(f"Starting processing with table size: {len(table)}")
#     components = pc.unique(table['component'])
#     logging.info(f"Found {len(components)} unique components")
#     n_cores = multiprocessing.cpu_count() + 4
    
#     # Pre-filter all components at once
#     component_tables = []
#     t_start = time.time()
#     for comp in components:
#         mask = pc.equal(table["component"], comp)
#         component_tables.append(table.filter(mask))
#     logging.info(f"Bulk filtering took {time.time() - t_start:.2f}s")
    
#     results = []
#     with ProcessPoolExecutor(max_workers=n_cores) as executor:
#         # Submit all tasks at once
#         t_start = time.time()
#         futures = [executor.submit(component_to_hierarchy_pa, component_table) 
#                   for component_table in component_tables]
#         logging.info(f"Bulk submission took {time.time() - t_start:.2f}s")
        
#         # Process results as they complete
#         for i, future in enumerate(futures):
#             t_start = time.time()
#             results.append(future.result())
#             logging.info(f"Component {i}: processing took {time.time() - t_start:.2f}s")

#     return pa.concat_tables(results) if results else pa.table({
#         'parent': pa.array([], type=pa.int64()),
#         'child': pa.array([], type=pa.int64()),
#         'probability': pa.array([], type=pa.float64())
#     })


# def to_hierarchical_clusters_pa(table: pa.Table) -> pa.Table:
#     """
#     Parallel processing of components using Arrow with optimized submission.
#     """
#     components = pc.unique(table['component'])
#     n_cores = multiprocessing.cpu_count() + 4
    
#     component_tables = []
#     for comp in components:
#         mask = pc.equal(table["component"], comp)
#         component_tables.append(table.filter(mask))
    
#     results = []
#     with ProcessPoolExecutor(max_workers=n_cores) as executor:
#         futures = [
#             executor.submit(component_to_hierarchy_pa, component_table)
#             for component_table in component_tables
#         ]
        
#         for future in futures:
#             results.append(future.result())

#     return pa.concat_tables(results) if results else pa.table({
#         'parent': pa.array([], type=pa.int64()),
#         'child': pa.array([], type=pa.int64()),
#         'probability': pa.array([], type=pa.float64())
#     })

# def split_table(table: pa.Table) -> list[pa.Table]:
#     """
#     Split a table into component chunks efficiently.
#     """
#     sorted_table = table.sort_by("component")

#     components = []
#     current_component = None
#     start_idx = 0

#     component_array = sorted_table.column("component").to_numpy()
    
#     for idx, component in enumerate(component_array):
#         if component != current_component:
#             if current_component is not None:
#                 components.append(sorted_table.slice(start_idx, idx - start_idx).combine_chunks())
#             current_component = component
#             start_idx = idx
    
#     if current_component is not None:
#         components.append(sorted_table.slice(start_idx, len(component_array) - start_idx).combine_chunks())
    
#     return components

def to_hierarchical_clusters_pa(table: pa.Table) -> pa.Table:
    """
    Parallel processing of components using Arrow with optimized submission and progress tracking.
    """
    # Set up logging
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
    
    table = table.sort_by([("component", "ascending")])
    components = pc.unique(table['component'])
    n_cores = multiprocessing.cpu_count() + 4
    n_components = len(components)
    
    logging.info(f"Processing {n_components} components using {n_cores} workers")

    # 1) Filtering (2k @ 11s) 
    component_tables = []
    for comp in tqdm(components, desc="Preparing components", leave=False):
        mask = pc.equal(table["component"], comp)
        component_tables.append(table.filter(mask))

    # 2) Sort slicing (2k in like 10 minutes)
    # component_tables = split_table(table)
    
    results = []
    with ProcessPoolExecutor(max_workers=n_cores) as executor:
        # Submit all tasks
        futures = [
            executor.submit(component_to_hierarchy_pa, component_table)
            for component_table in component_tables
        ]
        
        # Process results with progress bar
        for future in tqdm(futures, desc="Processing components", total=len(futures)):
            try:
                result = future.result()
                results.append(result)
            except Exception as e:
                logging.error(f"Error processing component: {str(e)}")
                continue

    logging.info(f"Completed processing {len(results)} components successfully")
    
    # Create empty table if no results
    if not results:
        logging.warning("No results to concatenate")
        return pa.table({
            'parent': pa.array([], type=pa.int64()),
            'child': pa.array([], type=pa.int64()),
            'probability': pa.array([], type=pa.float64())
        })

    return pa.concat_tables(results)

In [28]:
cc2

pyarrow.Table
left: int64
right: int64
probability: decimal128(3, 2)
component: int64
----
left: [[197413,116407,114551,160857,6412,...,39429,156175,197197,48177,121674]]
right: [[326505,384163,344025,248700,258884,...,311452,278755,204144,357956,378111]]
probability: [[1.00,1.00,1.00,1.00,1.00,...,0.70,0.70,0.70,0.70,0.70]]
component: [[0,0,0,0,0,...,2079,2079,2079,2079,2079]]

In [14]:
# Clear existing handlers
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

# Setup logging
logging.basicConfig(level=logging.INFO)

In [30]:
h_out = to_hierarchical_clusters_pa(cc2)

h_out

2024-12-16 14:01:07,543 - INFO - Processing 2080 components using 20 workers


Preparing components:   0%|          | 0/2080 [00:00<?, ?it/s]

Processing components:   0%|          | 0/2080 [00:00<?, ?it/s]

2024-12-16 14:01:19,547 - INFO - Completed processing 2080 components successfully


pyarrow.Table
parent: int64
child: int64
probability: double
----
parent: [[-1193661193,-1193661193,-1065816097,-1065816097,-2131390865,...,-1988605072,-384900360,-384900360,-1450541918,-1450541918],[-1998262778,-1998262778,-352384079,-352384079,-1971847649,...,-1780983856,-1515770754,-1515770754,-1282464668,-1282464668],...,[-549228011,-549228011,-785366439,-785366439,-1013741990,...,-1759852711,-1842716946,-1842716946,-1756929996,-1756929996],[-75333889,-75333889,-1183850366,-1183850366,-95546434,...,-1299895816,-284056808,-284056808,-1440724029,-1440724029]]
child: [[197413,326505,116407,384163,114551,...,212599,19131,294811,193341,224447],[64284,340314,170279,378400,152470,...,246051,14448,296071,85173,322723],...,[72846,345192,90867,360477,2449,...,324026,51072,220303,113552,257515],[168741,288248,29190,339106,179769,...,204144,48177,357956,121674,378111]]
probability: [[1,1,1,1,1,...,0.7,0.7,0.7,0.7,0.7],[1,1,1,1,1,...,0.7,0.7,0.7,0.7,0.7],...,[1,1,1,1,1,...,0.7,0.7,0.7,0.7,0.7],

In [31]:
h_out = to_hierarchical_clusters_pa(cc200)

h_out

2024-12-16 14:01:51,445 - INFO - Processing 206763 components using 20 workers


Preparing components:   0%|          | 0/206763 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [222]:
import polars as pl

def to_hierarchical_clusters_pl(table: pa.Table) -> pa.Table:
    """
    Parallel processing of components using Arrow.
    """
    table_pl = (
        pl.from_arrow(table)
        .with_columns(
            pl.col("probability").cast(pl.Float32).mul(100).cast(pl.UInt8).alias("probability"),
            pl.col("left").cast(pl.Int32),
            pl.col("right").cast(pl.Int32),
            pl.col("component").cast(pl.UInt32),
        )
    )

    return table_pl.group_by("component").map_groups(component_to_hierarchy_pl)

to_hierarchical_clusters_pl(cc2)

  return pl.DataFrame(


parent,child,probability
i32,i32,u8
-283593755,337091,100
-283593755,15518,100
-126938103,378438,100
-126938103,41264,100
-1702119885,234149,100
…,…,…
-1028611184,182587,70
-1183984842,323264,70
-1183984842,13923,70
-1046343640,268429,70
