# Hierarchical algorithm optimisation

## Data generation

In [3]:
import rustworkx as rx
from collections import Counter

def verify_components(table) -> dict:
    """
    Fast verification of connected components using rustworkx.
    
    Args:
        table: PyArrow table with 'left', 'right' columns
    
    Returns:
        dictionary containing basic component statistics
    """
    # Create graph directly from arrays
    graph = rx.PyDiGraph()
    
    # Add all unique nodes at once
    unique_nodes = set(table['left'].to_numpy()) | set(table['right'].to_numpy())
    graph.add_nodes_from(range(len(unique_nodes)))
    
    # Create node mapping and edges in one pass
    node_to_idx = {node: idx for idx, node in enumerate(unique_nodes)}
    edges = [(node_to_idx[left], node_to_idx[right], prob) 
            for left, right, prob in zip(table['left'].to_numpy(), 
                                       table['right'].to_numpy(),
                                       table['probability'].to_numpy())]
    
    # Add all edges at once
    graph.add_edges_from(edges)
    
    # Get components and their sizes
    components = rx.weakly_connected_components(graph)
    component_sizes = Counter(len(component) for component in components)
    
    return {
        'num_components': len(components),
        'total_nodes': len(unique_nodes),
        'total_edges': len(edges),
        'component_sizes': component_sizes,
        'min_component_size': min(component_sizes.keys()),
        'max_component_size': max(component_sizes.keys())
    }

In [None]:
def calculate_max_possible_edges(n_nodes: int, num_components: int) -> int:
    """
    Calculate the maximum possible number of edges given n nodes split into k components.
    
    Args:
        n_nodes: Total number of nodes
        num_components: Number of components to split into
        
    Returns:
        Maximum possible number of edges
    """
    nodes_per_component = n_nodes // num_components
    max_edges_per_component = nodes_per_component * nodes_per_component  # Complete bipartite graph
    return max_edges_per_component * num_components


In [None]:
import numpy as np
import pyarrow as pa
import rustworkx as rx
from typing import List, Tuple
from decimal import Decimal

def split_values_into_components(values: List[int], num_components: int) -> List[np.ndarray]:
    """
    Split values into non-overlapping groups for each component.
    
    Args:
        values: List of values to split
        num_components: Number of components to create
        
    Returns:
        List of arrays, one for each component
    """
    values = np.array(values)
    np.random.shuffle(values)
    return np.array_split(values, num_components)


def generate_arrow_data(
    left_values: List[int],
    right_values: List[int],
    prob_range: Tuple[float, float],
    num_components: int,
    total_rows: int
) -> pa.Table:
    """
    Generate dummy arrow data with guaranteed isolated components.
    
    Args:
        left_values: List of integers to use for left column
        right_values: List of integers to use for right column
        prob_range: Tuple of (min_prob, max_prob) to constrain probabilities
        num_components: Number of distinct connected components to generate
        total_rows: Total number of rows to generate
    
    Returns:
        PyArrow Table with 'left', 'right', and 'probability' columns
    """
    if len(left_values) < 2 or len(right_values) < 2:
        raise ValueError("Need at least 2 possible values for both left and right")
    if num_components > min(len(left_values), len(right_values)):
        raise ValueError("Cannot have more components than minimum of left/right values")
    
    # Calculate maximum possible edges
    min_nodes = min(len(left_values), len(right_values))
    max_possible_edges = calculate_max_possible_edges(min_nodes, num_components)
    
    if total_rows > max_possible_edges:
        raise ValueError(
            f"Cannot generate {total_rows:,} edges with {num_components:,} components. "
            f"Maximum possible edges is {max_possible_edges:,} given {min_nodes:,} nodes. "
            "Either increase the number of nodes, decrease the number of components, "
            "or decrease the total edges requested."
        )
    
    # Convert probability range to integers (60-80 for 0.60-0.80)
    prob_min = int(prob_range[0] * 100)
    prob_max = int(prob_range[1] * 100)
    
    # Split values into completely separate groups for each component
    left_components = split_values_into_components(left_values, num_components)
    right_components = split_values_into_components(right_values, num_components)
    
    # Calculate base number of edges per component
    base_edges_per_component = total_rows // num_components
    remaining_edges = total_rows % num_components
    
    all_edges = []
    
    # Generate edges for each component
    for comp_idx in range(num_components):
        comp_left_values = left_components[comp_idx]
        comp_right_values = right_components[comp_idx]
        
        # Calculate edges for this component
        edges_in_component = base_edges_per_component
        if comp_idx < remaining_edges:  # Distribute remaining edges
            edges_in_component += 1
            
        # Ensure basic connectivity within the component
        base_edges = []
        
        # Create a spanning tree-like structure
        for i in range(len(comp_left_values)):
            base_edges.append((
                comp_left_values[i],
                comp_right_values[i % len(comp_right_values)],
                np.random.randint(prob_min, prob_max + 1)
            ))
        
        # Generate remaining random edges strictly within this component
        remaining_edges = edges_in_component - len(base_edges)
        if remaining_edges > 0:
            random_lefts = np.random.choice(comp_left_values, size=remaining_edges)
            random_rights = np.random.choice(comp_right_values, size=remaining_edges)
            random_probs = np.random.randint(prob_min, prob_max + 1, size=remaining_edges)
            
            component_edges = base_edges + list(zip(random_lefts, random_rights, random_probs))
        else:
            component_edges = base_edges
            
        all_edges.extend(component_edges)
    
    # Convert to arrays
    lefts, rights, probs = zip(*all_edges)
    
    # Create PyArrow arrays
    left_array = pa.array(lefts, type=pa.int64())
    right_array = pa.array(rights, type=pa.int64())
    decimal_probs = [Decimal(str(p/100)) for p in probs]
    prob_array = pa.array(decimal_probs, type=pa.decimal128(precision=3, scale=2))
    
    return pa.table([left_array, right_array, prob_array],
                   names=['left', 'right', 'probability'])



In [413]:
left_values = list(range(10_000))
right_values = list(range(10_000, 20_000))
prob_range = (0.6, 0.8)
num_components = 10
total_rows = 1_000_000

table = generate_arrow_data(
    left_values=left_values,
    right_values=right_values,
    prob_range=prob_range,
    num_components=num_components,
    total_rows=total_rows
)

results = verify_components(table)
print(f"Number of components found: {results['num_components']}")
print(f"Total nodes: {results['total_nodes']}")
print(f"Total edges: {results['total_edges']}")
print("\nComponent sizes:")
for size, count in sorted(results['component_sizes'].items()):
    print(f"Size {size}: {count} components")

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x10d274c90>>
Traceback (most recent call last):
  File "/Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


Number of components found: 10
Total nodes: 20000
Total edges: 1000009

Component sizes:
Size 2000: 10 components


In [None]:
left_values = list(range(int(2e7)))
right_values = list(range(int(2e7), int(4e7)))
prob_range = (0.7, 1.0)
num_components = 200_000
total_rows = int(1e8)

table = generate_arrow_data(
    left_values=left_values,
    right_values=right_values,
    prob_range=prob_range,
    num_components=num_components,
    total_rows=total_rows
)

In [None]:
results = verify_components(table)
print(f"Number of components found: {results['num_components']}")
print(f"Total nodes: {results['total_nodes']}")
print(f"Total edges: {results['total_edges']}")
print("\nComponent sizes:")
for size, count in sorted(results['component_sizes'].items()):
    print(f"Size {size}: {count} components")

Number of components found: 206763
Total nodes: 40000000
Total edges: 100000400

Component sizes:
Size 2: 6755 components
Size 4: 8 components
Size 194: 1 components
Size 196: 124 components
Size 198: 6520 components
Size 200: 193355 components


Number of components found: 206763
* Total nodes: 40000000
* Total edges: 100000400

Component sizes:
* Size 2: 6755 components
* Size 4: 8 components
* Size 194: 1 components
* Size 196: 124 components
* Size 198: 6520 components
* Size 200: 193355 components

In [None]:
from pathlib import Path
import pyarrow.parquet as pq

pq.write_table(table, Path.cwd() / 'hierarchical_cc200k.parquet')

In [None]:
left_values = list(range(int(2e5)))
right_values = list(range(int(2e5), int(4e5)))
prob_range = (0.7, 1.0)
num_components = 2_000
total_rows = int(1e6)

table2 = generate_arrow_data(
    left_values=left_values,
    right_values=right_values,
    prob_range=prob_range,
    num_components=num_components,
    total_rows=total_rows
)

In [None]:
results2 = verify_components(table2)
print(f"Number of components found: {results2['num_components']:,}")
print(f"Total nodes: {results2['total_nodes']:,}")
print(f"Total edges: {results2['total_edges']:,}")
print("\nComponent sizes:")
for size, count in sorted(results2['component_sizes'].items()):
    print(f"Size {size:,}: {count:,} components")

Number of components found: 2,080
Total nodes: 400,000
Total edges: 1,000,400

Component sizes:
Size 2: 80 components
Size 196: 1 components
Size 198: 78 components
Size 200: 1,921 components


Number of components found: 2,080
* Total nodes: 400,000
* Total edges: 1,000,400

Component sizes:
* Size 2: 80 components
* Size 196: 1 components
* Size 198: 78 components
* Size 200: 1,921 components

In [None]:
from pathlib import Path
import pyarrow.parquet as pq

pq.write_table(table2, Path.cwd() / 'hierarchical_cc2k.parquet')

## Algorithm

In [1]:
import pyarrow.parquet as pq

h2 = pq.read_table('hierarchical_cc2k.parquet')

In [None]:
h2.schema

left: int64
right: int64
probability: decimal128(3, 2)

The plan.

* Find components and their sizes at lowest threshold (rustworkx)
* Use this to dask.groupby the data for parallel per-component processing
* Ensure we implement early stopping!

### Find components

In [4]:
import pyarrow as pa
import pyarrow.compute as pc
import rustworkx as rx
import numpy as np

def attach_independent_components(table: pa.Table) -> pa.Table:
    """
    Returns the original table with an additional 'component' column indicating
    which connected component each edge belongs to.
    """
    # Create dictionary array from sorted unique values
    unique = pc.unique(
        pa.concat_arrays([
            table['left'].combine_chunks(),
            table['right'].combine_chunks()
        ])
    )
    
    # Get indices into unique array for graph construction
    left_indices = pc.index_in(table['left'], unique)
    right_indices = pc.index_in(table['right'], unique)
    
    # Create and process graph
    n_nodes = len(unique)
    n_edges = len(table)
    
    graph = rx.PyGraph(
        node_count_hint=n_nodes,
        edge_count_hint=n_edges
    )
    graph.add_nodes_from(range(n_nodes))

    edges = tuple(zip(left_indices.to_numpy(), right_indices.to_numpy()))
    graph.add_edges_from_no_data(edges)
    
    # Get components and create mapping array
    components = rx.connected_components(graph)
    
    # Convert components to numpy arrays
    component_indices = np.concatenate([np.array(list(c)) for c in components])
    component_labels = np.repeat(np.arange(len(components)), [len(c) for c in components])
    
    # Create mapping array and fill with component labels
    node_to_component = np.zeros(len(unique), dtype=np.int64)
    node_to_component[component_indices] = component_labels
    
    # Use the indices we already have to map back to components  
    edge_components = pa.array(node_to_component[left_indices.to_numpy()])
    
    return table.append_column('component', edge_components).sort_by(
        [('component', 'ascending'), ('probability', 'descending')]
    )

cc2 = attach_independent_components(h2)
len(pc.unique(cc2.column('component'))), cc2

(2080,
 pyarrow.Table
 left: int64
 right: int64
 probability: decimal128(3, 2)
 component: int64
 ----
 left: [[197413,116407,114551,160857,6412,...,39429,156175,197197,48177,121674]]
 right: [[326505,384163,344025,248700,258884,...,311452,278755,204144,357956,378111]]
 probability: [[1.00,1.00,1.00,1.00,1.00,...,0.70,0.70,0.70,0.70,0.70]]
 component: [[0,0,0,0,0,...,2079,2079,2079,2079,2079]])

In [None]:
import cProfile
import pstats
from pstats import SortKey

# Profile the function
pr = cProfile.Profile()
pr.enable()
components = attach_independent_components(table2)
pr.disable()

# Print stats sorted by cumulative time
ps = pstats.Stats(pr).sort_stats(SortKey.CUMULATIVE)
ps.print_stats(20)  # Show top 20 lines

         34170 function calls in 1.108 seconds

   Ordered by: cumulative time
   List reduced from 49 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000    1.108    0.554 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3541(run_code)
        2    0.000    0.000    1.108    0.554 {built-in method builtins.exec}
        1    0.058    0.058    1.108    1.108 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_25768/655470289.py:1(<module>)
        1    0.397    0.397    1.050    1.050 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_25768/2783446239.py:5(attach_independent_components)
        3    0.296    0.099    0.296    0.099 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/pyarrow/compute.py:249(wrapper)
        1    0.118    0.118    0.118    0.118 {connected_components}
        1    0.088    0.088    0.088    0.088 {

<pstats.Stats at 0x17d84cdd0>

In [None]:
import cProfile
import pstats
from pstats import SortKey

# Profile the function
pr = cProfile.Profile()
pr.enable()
_ = attach_independent_components(table)
pr.disable()

# Print stats sorted by cumulative time
ps = pstats.Stats(pr).sort_stats(SortKey.CUMULATIVE)
ps.print_stats(20)  # Show top 20 lines

         3411228 function calls in 121.198 seconds

   Ordered by: cumulative time
   List reduced from 48 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000  139.163   69.581 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3541(run_code)
        2    0.001    0.000  139.162   69.581 {built-in method builtins.exec}
        1   47.532   47.532  121.195  121.195 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_25768/2801194547.py:1(find_independent_components)
        1   24.123   24.123   24.123   24.123 {method 'add_edges_from_no_data' of 'rustworkx.PyGraph' objects}
        1   18.713   18.713   18.731   18.731 {connected_components}
        2   14.375    7.187   14.375    7.187 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/pyarrow/compute.py:249(wrapper)
        1    7.661    7.661    7.661    7.661 /Users/willlangdale/DS/m

<pstats.Stats at 0x16d1a8190>

### Process a single component

In [24]:
from functools import lru_cache

@lru_cache(maxsize=None)
def combine_integers(*n: int) -> int:
    """
    Combine n integers into a single negative integer.

    Used to create a symmetric deterministic hash of two integers that populates the
    range of integers efficiently and reduces the likelihood of collisions.

    Aims to vectorise amazingly when used in Arrow.

    Does this by:

    * Using a Mersenne prime as a modulus
    * Making negative integers positive with modulo, sped up with bitwise operations
    * Combining using symmetric operations with coprime multipliers

    Args:
        *args: Variable number of integers to combine

    Returns:
        A negative integer
    """
    P = 2147483647

    total = 0
    product = 1

    for x in sorted(n):
        x_pos = (x ^ (x >> 31)) - (x >> 31)
        total = (total + x_pos) % P
        product = (product * x_pos) % P

    result = (31 * total + 37 * product) % P

    return -result

In [25]:
from itertools import chain
from functools import lru_cache

@lru_cache(maxsize=None)
def combine_strings(*n: str) -> str:
    """
    Combine n strings into a single string.

    Args:
        *args: Variable number of strings to combine
        
    Returns:
        A single string
    """
    letters = set(chain.from_iterable(n))
    return "".join(sorted(list(letters)))

In [16]:
from collections import defaultdict
from typing import TypeVar, Generic, Hashable, Iterator
import pandas as pd

T = TypeVar('T', bound=Hashable)

class UnionFindWithDiff(Generic[T]):
    def __init__(self):
        self.parent: dict[T, T] = {}
        self.rank: dict[T, int] = {}
        self._shadow_parent: dict[T, T] = {}
        self._shadow_rank: dict[T, int] = {}
        self._pending_pairs: list[tuple[T, T]] = []
        
    def make_set(self, x: T) -> None:
        if x not in self.parent:
            self.parent[x] = x
            self.rank[x] = 0
    
    def find(self, x: T, parent_dict: dict[T, T] | None = None) -> T:
        if parent_dict is None:
            parent_dict = self.parent
            
        if x not in parent_dict:
            self.make_set(x)
            if parent_dict is self._shadow_parent:
                self._shadow_parent[x] = x
                self._shadow_rank[x] = 0
        
        while parent_dict[x] != x:
            parent_dict[x] = parent_dict[parent_dict[x]]
            x = parent_dict[x]
        return x
    
    def union(self, x: T, y: T) -> None:
        root_x = self.find(x)
        root_y = self.find(y)
        
        if root_x != root_y:
            self._pending_pairs.append((x, y))
            
            if self.rank[root_x] < self.rank[root_y]:
                root_x, root_y = root_y, root_x
            self.parent[root_y] = root_x
            if self.rank[root_x] == self.rank[root_y]:
                self.rank[root_x] += 1
    
    def get_component(self, x: T, parent_dict: dict[T, T] | None = None) -> set[T]:
        if parent_dict is None:
            parent_dict = self.parent
        
        root = self.find(x, parent_dict)
        return {y for y in parent_dict if self.find(y, parent_dict) == root}
    
    def get_components(self, parent_dict: dict[T, T] | None = None) -> list[set[T]]:
        if parent_dict is None:
            parent_dict = self.parent
            
        components = defaultdict(set)
        for x in parent_dict:
            root = self.find(x, parent_dict)
            components[root].add(x)
        return list(components.values())
    
    def diff(self) -> Iterator[tuple[set[T], set[T]]]:
        """
        Returns differences including all pairwise merges that occurred since last diff,
        excluding cases where old_comp == new_comp.
        """
        # Get current state before processing pairs
        current_components = self.get_components()
        reported_pairs = set()
        
        # Process pending pairs
        for x, y in self._pending_pairs:
            # Find the final component containing the pair
            final_component = next(comp for comp in current_components 
                                 if x in comp and y in comp)
            
            # Only report if the pair forms a proper subset of the final component
            pair_component = {x, y}
            if (pair_component != final_component and 
                frozenset((frozenset(pair_component), frozenset(final_component))) not in reported_pairs):
                reported_pairs.add(frozenset((frozenset(pair_component), frozenset(final_component))))
                yield (pair_component, final_component)
        
        self._pending_pairs.clear()
        
        # Handle initial state
        if not self._shadow_parent:
            self._shadow_parent = self.parent.copy()
            self._shadow_rank = self.rank.copy()
            return
        
        # Get old components
        old_components = self.get_components(self._shadow_parent)
        
        # Report changes between old and new states
        for old_comp in old_components:
            if len(old_comp) > 1:  # Only consider non-singleton old components
                sample_elem = next(iter(old_comp))
                new_comp = next(comp for comp in current_components if sample_elem in comp)
                
                # Only yield if the components are different and this pair hasn't been reported
                if (old_comp != new_comp and 
                    frozenset((frozenset(old_comp), frozenset(new_comp))) not in reported_pairs):
                    reported_pairs.add(frozenset((frozenset(old_comp), frozenset(new_comp))))
                    yield (old_comp, new_comp)
        
        # Update shadow copy
        self._shadow_parent = self.parent.copy()
        self._shadow_rank = self.rank.copy()


def component_to_hierarchy(key: str | int | tuple, df: pd.DataFrame) -> pd.DataFrame:
    """
    Convert pairwise probabilities into a hierarchical representation.
    Assumes data is pre-sorted by probability descending.
    
    Args:
        key: Group key (ignored in this implementation)
        df: DataFrame with columns ['left', 'right', 'probability']
    
    Returns:
        DataFrame with columns ['parent', 'child', 'probability'] representing hierarchical merges
    """
    hierarchy: list[tuple[int, int, float]] = []
    uf = UnionFindWithDiff[int]()

    for threshold in df["probability"].unique():
        current_probs = df[df["probability"] == threshold]

        for _, row in current_probs.iterrows():
            uf.union(row["left"], row["right"])
            parent = combine_integers(row["left"], row["right"])
            hierarchy.extend([
                (parent, row["left"], threshold),
                (parent, row["right"], threshold)
            ])

        for old_comp, new_comp in uf.diff():
            if len(old_comp) > 1:
                parent = combine_integers(*new_comp)
                child = combine_integers(*old_comp)
                hierarchy.extend([
                    (parent, child, threshold)
                ])
            else:
                parent = combine_integers(*new_comp)
                hierarchy.extend([
                    (parent, old_comp.pop(), threshold)
                ])

    return pd.DataFrame(hierarchy, columns=['parent', 'child', 'probability'])

In [41]:
df = cc2.to_pandas()
cc_15_pd = df[df['component'] == 15]

In [46]:
import cProfile
import pstats
from pstats import SortKey

# Profile the function
pr = cProfile.Profile()
pr.enable()
_ = component_to_hierarchy(15, cc_15_pd)
pr.disable()

# Print stats sorted by cumulative time
ps = pstats.Stats(pr).sort_stats(SortKey.CUMULATIVE)
ps.print_stats(20)  # Show top 20 lines

         142304 function calls (139675 primitive calls) in 0.136 seconds

   Ordered by: cumulative time
   List reduced from 276 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000    0.136    0.068 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3541(run_code)
        2    0.000    0.000    0.136    0.068 {built-in method builtins.exec}
        1    0.000    0.000    0.136    0.136 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_68435/3996032216.py:1(<module>)
        1    0.007    0.007    0.136    0.136 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_68435/807932433.py:115(component_to_hierarchy)
      532    0.002    0.000    0.059    0.000 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/pandas/core/frame.py:1505(iterrows)
      532    0.007    0.000    0.054    0.000 /Users/willlangdale/DS/matchbox/.venv/lib/py

<pstats.Stats at 0x12219f390>

In [81]:
import pyarrow as pa

import pyarrow.compute as pc

def component_to_hierarchy_pa(table: pa.Table) -> pa.Table:
    """
    Convert pairwise probabilities into a hierarchical representation.
    Assumes data is pre-sorted by probability descending.
    
    Args:
        key: Group key (ignored in this implementation)
        table: Arrow Table with columns ['left', 'right', 'probability']
    
    Returns:
        Arrow Table with columns ['parent', 'child', 'probability'] representing hierarchical merges
    """
    hierarchy: list[tuple[int, int, float]] = []
    uf = UnionFindWithDiff[int]()

    # Get unique probabilities
    probs = pc.unique(table['probability'])

    for threshold in probs:
        # Get current probability rows
        mask = pc.equal(table['probability'], threshold)
        current_probs = table.filter(mask)
        threshold_float = float(threshold.as_py())
        
        # Process each row
        for row in zip(
            current_probs['left'].to_numpy(),
            current_probs['right'].to_numpy()
        ):
            left, right = row
            uf.union(left, right)
            parent = combine_integers(left, right)
            hierarchy.extend([
                (parent, left, threshold_float),
                (parent, right, threshold_float)
            ])

        # Process UnionFind diffs - exact same logic as original
        for old_comp, new_comp in uf.diff():
            if len(old_comp) > 1:
                parent = combine_integers(*new_comp)
                child = combine_integers(*old_comp)
                hierarchy.extend([
                    (parent, child, threshold_float)
                ])
            else:
                parent = combine_integers(*new_comp)
                hierarchy.extend([
                    (parent, old_comp.pop(), threshold_float)
                ])

    parents, children, probs = zip(*hierarchy)
    return pa.table({
        'parent': pa.array(parents, type=pa.int64()),
        'child': pa.array(children, type=pa.int64()),
        'probability': pa.array(probs, type=pa.float64())
    })

In [9]:
cc_15 = filtered_table = cc2.filter(pc.equal(cc2['component'], 15))

In [40]:
import cProfile
import pstats
from pstats import SortKey

# Profile the function
pr = cProfile.Profile()
pr.enable()
_ = component_to_hierarchy_pa(cc_15)
pr.disable()

# Print stats sorted by cumulative time
ps = pstats.Stats(pr).sort_stats(SortKey.CUMULATIVE)
ps.print_stats(20)  # Show top 20 lines

         26701 function calls in 0.034 seconds

   Ordered by: cumulative time
   List reduced from 54 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000    0.034    0.017 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3541(run_code)
        2    0.000    0.000    0.034    0.017 {built-in method builtins.exec}
        1    0.007    0.007    0.034    0.034 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_68435/2661828505.py:5(component_to_hierarchy_pa)
      298    0.003    0.000    0.021    0.000 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_68435/807932433.py:65(diff)
       61    0.008    0.000    0.017    0.000 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_68435/807932433.py:55(get_components)
    11218    0.009    0.000    0.009    0.000 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_68435/807932433.py:2

<pstats.Stats at 0x11c183250>

In [80]:
from unittest.mock import patch

test_cases = [
    # Test case 1: Equal probabilities
    (
        {
            "left": ["a", "b", "c"],
            "right": ["b", "c", "d"],
            "probability": [1.0, 1.0, 1.0],
        },
        {
            ("ab", "a", 1.0),
            ("ab", "b", 1.0),
            ("bc", "b", 1.0),
            ("bc", "c", 1.0),
            ("cd", "c", 1.0),
            ("cd", "d", 1.0),
            ("abcd", "ab", 1.0),
            ("abcd", "bc", 1.0),
            ("abcd", "cd", 1.0),
        },
    ),
    # Test case 2: Asymmetric probabilities
    (
        {
            "left": ["w", "x", "y"],
            "right": ["x", "y", "z"],
            "probability": [0.9, 0.85, 0.8],
        },
        {
            ("wx", "w", 0.9),
            ("wx", "x", 0.9),
            ("xy", "x", 0.85),
            ("xy", "y", 0.85),
            ("wxy", "wx", 0.85),
            ("wxy", "xy", 0.85),
            ("yz", "y", 0.8),
            ("yz", "z", 0.8),
            ("wxyz", "wxy", 0.8),
            ("wxyz", "yz", 0.8),
        },
    ),
    # Test case 3: Single two-item component
    (
        {
            "left": ["x"],
            "right": ["y"],
            "probability": [0.9],
        },
        {
            ("xy", "x", 0.9),
            ("xy", "y", 0.9),
        },
    ),
]

for i, (prob_data, expected_relations) in enumerate(test_cases, 1):
    print(f"\nRunning test case {i}...")
    
    with patch('__main__.combine_integers', side_effect=combine_strings):
        # Pandas
        # probabilities = (
        #     pd.DataFrame.from_dict(prob_data)
        #     .assign(probability=lambda df: df['probability'].astype(float))
        #     .sort_values(by="probability", ascending=False)
        # )
        # hierarchy_true = (
        #         pd.DataFrame.from_records(
        #         list(expected_relations), 
        #         columns=["parent", "child", "probability"]
        #     )
        #     .sort_values(by=["probability", "parent", "child"], ascending=[False, True, True])
        #     .dropna(how='all')
        #     .reset_index(drop=True)
        # )

        # hierarchy = component_to_hierarchy(0, probabilities)

        # hierarchy = hierarchy.sort_values(
        #     by=["probability", "parent", "child"],
        #     ascending=[False, True, True]
        # ).reset_index(drop=True)

        # Arrow
        probabilities = (
            pa.Table.from_pydict(prob_data)
            .cast(pa.schema([
                ('left', pa.string()),
                ('right', pa.string()),
                ('probability', pa.float64()),
            ]))
            .sort_by([('probability', 'descending')])
        )

        parents, children, probs = zip(*expected_relations)

        hierarchy_true = (
            pa.table(
                [parents, children, probs], 
                names=['parent', 'child', 'probability']
            )
            .sort_by([
                ('probability', 'descending'),
                ('parent', 'ascending'),
                ('child', 'ascending')
            ])
            .filter(pc.is_valid(pc.field('parent')))
        )

        hierarchy = (
            component_to_hierarchy_pa(probabilities)
            .sort_by([
                ('probability', 'descending'),
                ('parent', 'ascending'),
                ('child', 'ascending')
            ])
        )
        
        try:
            assert hierarchy.equals(hierarchy_true)
            print(f"✓ Test case {i} passed")
        except AssertionError:
            print(f"✗ Test case {i} failed")
            print("\nExpected DataFrame:")
            print(hierarchy_true)
            print("\nActual DataFrame:")
            print(hierarchy)
            print("\nDifferences:")
            if hierarchy.shape != hierarchy_true.shape:
                print(f"Shape mismatch: Expected {hierarchy_true.shape}, got {hierarchy.shape}")
            else:
                # Show where values differ
                differences = (hierarchy != hierarchy_true).any(axis=1)
                if differences.any():
                    print("\nMismatched rows:")
                    print("Expected:")
                    print(hierarchy_true[differences])
                    print("\nGot:")
                    print(hierarchy[differences])
            # Continue to next test case instead of raising
            continue



Running test case 1...
✓ Test case 1 passed

Running test case 2...
✓ Test case 2 passed

Running test case 3...
✓ Test case 3 passed


### Process all components in parallel

In [2]:
import pyarrow.parquet as pq

h2 = pq.read_table('hierarchical_cc2k.parquet')
cc2 = attach_independent_components(h2)

# h200 = pq.read_table('hierarchical_cc200k.parquet')
# cc200 = attach_independent_components(h200)

NameError: name 'attach_independent_components' is not defined

In [15]:
import pyarrow as pa
import dask.dataframe as dd
import pandas as pd
from multiprocessing import cpu_count
from dask.distributed import Client, LocalCluster

def setup_client(port: int = 8787) -> Client:
    """
    Creates a Dask client optimized for local computation.
    Returns the client object.
    """
    try:
        client = Client(f'tcp://localhost:{port}')
    except Exception:
        n_workers = cpu_count()
        cluster = LocalCluster(
            dashboard_address=f':{port}',
            n_workers=n_workers,
            threads_per_worker=1,  # One thread per worker for CPU-bound tasks
            memory_limit='auto'    # Automatically determine memory limits
        )
        client = Client(cluster)
    return client

In [6]:
from cluster import component_to_hierarchy

def to_hierarchical_clusters(client: Client, table: pa.Table) -> pd.DataFrame:
    """
    Perform hierarchical clustering on an Arrow table of pairwise probabilities.
    """
    pdf = table.to_pandas()
    
    ddf = dd.from_pandas(
        pdf, 
        npartitions=len(client.cluster.worker_spec)
    )
    
    result_ddf = ddf.groupby('component').apply(
        component_to_hierarchy, 
        meta={
            'parent': int, 
            'child': int, 
            'probability': float
        }
    )
    
    return result_ddf.compute()

In [17]:
with setup_client(port=8787) as client:
    h_out = to_hierarchical_clusters(client, cc2)

h_out

KeyboardInterrupt: 

In [89]:
import pyarrow as pa
import pyarrow.compute as pc
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
from cluster import process_component

def to_hierarchical_clusters_pa(table: pa.Table) -> pa.Table:
    """
    Parallel processing of components using Arrow.
    """
    components = pc.unique(table['component'])
    n_cores = multiprocessing.cpu_count()
    results = []
    
    with ProcessPoolExecutor(max_workers=n_cores) as executor:
        futures = [
            executor.submit(process_component, comp, table) 
            for comp in components
        ]
        results = [future.result() for future in futures]
    
    return pa.concat_tables(results) if results else pa.table({
        'parent': pa.array([], type=pa.int64()),
        'child': pa.array([], type=pa.int64()),
        'probability': pa.array([], type=pa.float64())
    })

In [90]:
h_out = to_hierarchical_clusters_pa(cc2)

h_out

: 

: 