# Hierarchical algorithm optimisation

## Data generation

In [20]:
import rustworkx as rx
from collections import Counter

def verify_components(table) -> dict:
    """
    Fast verification of connected components using rustworkx.
    
    Args:
        table: PyArrow table with 'left', 'right' columns
    
    Returns:
        Dictionary containing basic component statistics
    """
    # Create graph directly from arrays
    graph = rx.PyDiGraph()
    
    # Add all unique nodes at once
    unique_nodes = set(table['left'].to_numpy()) | set(table['right'].to_numpy())
    graph.add_nodes_from(range(len(unique_nodes)))
    
    # Create node mapping and edges in one pass
    node_to_idx = {node: idx for idx, node in enumerate(unique_nodes)}
    edges = [(node_to_idx[left], node_to_idx[right], prob) 
            for left, right, prob in zip(table['left'].to_numpy(), 
                                       table['right'].to_numpy(),
                                       table['probability'].to_numpy())]
    
    # Add all edges at once
    graph.add_edges_from(edges)
    
    # Get components and their sizes
    components = rx.weakly_connected_components(graph)
    component_sizes = Counter(len(component) for component in components)
    
    return {
        'num_components': len(components),
        'total_nodes': len(unique_nodes),
        'total_edges': len(edges),
        'component_sizes': component_sizes,
        'min_component_size': min(component_sizes.keys()),
        'max_component_size': max(component_sizes.keys())
    }

In [22]:
def calculate_max_possible_edges(n_nodes: int, num_components: int) -> int:
    """
    Calculate the maximum possible number of edges given n nodes split into k components.
    
    Args:
        n_nodes: Total number of nodes
        num_components: Number of components to split into
        
    Returns:
        Maximum possible number of edges
    """
    nodes_per_component = n_nodes // num_components
    max_edges_per_component = nodes_per_component * nodes_per_component  # Complete bipartite graph
    return max_edges_per_component * num_components


In [29]:
import numpy as np
import pyarrow as pa
import rustworkx as rx
from typing import List, Tuple
from decimal import Decimal

def split_values_into_components(values: List[int], num_components: int) -> List[np.ndarray]:
    """
    Split values into non-overlapping groups for each component.
    
    Args:
        values: List of values to split
        num_components: Number of components to create
        
    Returns:
        List of arrays, one for each component
    """
    values = np.array(values)
    np.random.shuffle(values)
    return np.array_split(values, num_components)


def generate_arrow_data(
    left_values: List[int],
    right_values: List[int],
    prob_range: Tuple[float, float],
    num_components: int,
    total_rows: int
) -> pa.Table:
    """
    Generate dummy arrow data with guaranteed isolated components.
    
    Args:
        left_values: List of integers to use for left column
        right_values: List of integers to use for right column
        prob_range: Tuple of (min_prob, max_prob) to constrain probabilities
        num_components: Number of distinct connected components to generate
        total_rows: Total number of rows to generate
    
    Returns:
        PyArrow Table with 'left', 'right', and 'probability' columns
    """
    if len(left_values) < 2 or len(right_values) < 2:
        raise ValueError("Need at least 2 possible values for both left and right")
    if num_components > min(len(left_values), len(right_values)):
        raise ValueError("Cannot have more components than minimum of left/right values")
    
    # Calculate maximum possible edges
    min_nodes = min(len(left_values), len(right_values))
    max_possible_edges = calculate_max_possible_edges(min_nodes, num_components)
    
    if total_rows > max_possible_edges:
        raise ValueError(
            f"Cannot generate {total_rows:,} edges with {num_components:,} components. "
            f"Maximum possible edges is {max_possible_edges:,} given {min_nodes:,} nodes. "
            "Either increase the number of nodes, decrease the number of components, "
            "or decrease the total edges requested."
        )
    
    # Convert probability range to integers (60-80 for 0.60-0.80)
    prob_min = int(prob_range[0] * 100)
    prob_max = int(prob_range[1] * 100)
    
    # Split values into completely separate groups for each component
    left_components = split_values_into_components(left_values, num_components)
    right_components = split_values_into_components(right_values, num_components)
    
    # Calculate base number of edges per component
    base_edges_per_component = total_rows // num_components
    remaining_edges = total_rows % num_components
    
    all_edges = []
    
    # Generate edges for each component
    for comp_idx in range(num_components):
        comp_left_values = left_components[comp_idx]
        comp_right_values = right_components[comp_idx]
        
        # Calculate edges for this component
        edges_in_component = base_edges_per_component
        if comp_idx < remaining_edges:  # Distribute remaining edges
            edges_in_component += 1
            
        # Ensure basic connectivity within the component
        base_edges = []
        
        # Create a spanning tree-like structure
        for i in range(len(comp_left_values)):
            base_edges.append((
                comp_left_values[i],
                comp_right_values[i % len(comp_right_values)],
                np.random.randint(prob_min, prob_max + 1)
            ))
        
        # Generate remaining random edges strictly within this component
        remaining_edges = edges_in_component - len(base_edges)
        if remaining_edges > 0:
            random_lefts = np.random.choice(comp_left_values, size=remaining_edges)
            random_rights = np.random.choice(comp_right_values, size=remaining_edges)
            random_probs = np.random.randint(prob_min, prob_max + 1, size=remaining_edges)
            
            component_edges = base_edges + list(zip(random_lefts, random_rights, random_probs))
        else:
            component_edges = base_edges
            
        all_edges.extend(component_edges)
    
    # Convert to arrays
    lefts, rights, probs = zip(*all_edges)
    
    # Create PyArrow arrays
    left_array = pa.array(lefts, type=pa.int64())
    right_array = pa.array(rights, type=pa.int64())
    decimal_probs = [Decimal(str(p/100)) for p in probs]
    prob_array = pa.array(decimal_probs, type=pa.decimal128(precision=3, scale=2))
    
    return pa.table([left_array, right_array, prob_array],
                   names=['left', 'right', 'probability'])



In [32]:
left_values = list(range(10_000))
right_values = list(range(10_000, 20_000))
prob_range = (0.6, 0.8)
num_components = 10
total_rows = 1_000_000

table = generate_arrow_data(
    left_values=left_values,
    right_values=right_values,
    prob_range=prob_range,
    num_components=num_components,
    total_rows=total_rows
)

results = verify_components(table)
print(f"Number of components found: {results['num_components']}")
print(f"Total nodes: {results['total_nodes']}")
print(f"Total edges: {results['total_edges']}")
print("\nComponent sizes:")
for size, count in sorted(results['component_sizes'].items()):
    print(f"Size {size}: {count} components")

Number of components found: 10
Total nodes: 20000
Total edges: 1000009

Component sizes:
Size 2000: 10 components


In [36]:
left_values = list(range(int(2e7)))
right_values = list(range(int(2e7), int(4e7)))
prob_range = (0.7, 1.0)
num_components = 200_000
total_rows = int(1e8)

table = generate_arrow_data(
    left_values=left_values,
    right_values=right_values,
    prob_range=prob_range,
    num_components=num_components,
    total_rows=total_rows
)

In [37]:
results = verify_components(table)
print(f"Number of components found: {results['num_components']}")
print(f"Total nodes: {results['total_nodes']}")
print(f"Total edges: {results['total_edges']}")
print("\nComponent sizes:")
for size, count in sorted(results['component_sizes'].items()):
    print(f"Size {size}: {count} components")

Number of components found: 206763
Total nodes: 40000000
Total edges: 100000400

Component sizes:
Size 2: 6755 components
Size 4: 8 components
Size 194: 1 components
Size 196: 124 components
Size 198: 6520 components
Size 200: 193355 components


Number of components found: 206763
* Total nodes: 40000000
* Total edges: 100000400

Component sizes:
* Size 2: 6755 components
* Size 4: 8 components
* Size 194: 1 components
* Size 196: 124 components
* Size 198: 6520 components
* Size 200: 193355 components

In [38]:
from pathlib import Path
import pyarrow.parquet as pq

pq.write_table(table, Path.cwd() / 'hierarchical_cc200k.parquet')

In [41]:
left_values = list(range(int(2e5)))
right_values = list(range(int(2e5), int(4e5)))
prob_range = (0.7, 1.0)
num_components = 2_000
total_rows = int(1e6)

table2 = generate_arrow_data(
    left_values=left_values,
    right_values=right_values,
    prob_range=prob_range,
    num_components=num_components,
    total_rows=total_rows
)

In [44]:
results2 = verify_components(table2)
print(f"Number of components found: {results2['num_components']:,}")
print(f"Total nodes: {results2['total_nodes']:,}")
print(f"Total edges: {results2['total_edges']:,}")
print("\nComponent sizes:")
for size, count in sorted(results2['component_sizes'].items()):
    print(f"Size {size:,}: {count:,} components")

Number of components found: 2,080
Total nodes: 400,000
Total edges: 1,000,400

Component sizes:
Size 2: 80 components
Size 196: 1 components
Size 198: 78 components
Size 200: 1,921 components


Number of components found: 2,080
* Total nodes: 400,000
* Total edges: 1,000,400

Component sizes:
* Size 2: 80 components
* Size 196: 1 components
* Size 198: 78 components
* Size 200: 1,921 components

In [46]:
from pathlib import Path
import pyarrow.parquet as pq

pq.write_table(table2, Path.cwd() / 'hierarchical_cc2k.parquet')

## Algorithm

In [47]:
import pyarrow.parquet as pq

h2 = pq.read_table('hierarchical_cc2k.parquet')

In [48]:
h2.schema

left: int64
right: int64
probability: decimal128(3, 2)

The plan.

* Find components and their sizes at lowest threshold (rustworkx)
* Use this to dask.groupby the data for parallel per-component processing
* Ensure we implement early stopping!

### Find components

In [49]:
def combine_integers(*n: int) -> int:
    """
    Combine n integers into a single negative integer.

    Used to create a symmetric deterministic hash of two integers that populates the
    range of integers efficiently and reduces the likelihood of collisions.

    Aims to vectorise amazingly when used in Arrow.

    Does this by:

    * Using a Mersenne prime as a modulus
    * Making negative integers positive with modulo, sped up with bitwise operations
    * Combining using symmetric operations with coprime multipliers

    Args:
        *args: Variable number of integers to combine

    Returns:
        A negative integer
    """
    P = 2147483647

    total = 0
    product = 1

    for x in sorted(n):
        x_pos = (x ^ (x >> 31)) - (x >> 31)
        total = (total + x_pos) % P
        product = (product * x_pos) % P

    result = (31 * total + 37 * product) % P

    return -result

In [143]:
import pyarrow as pa
import pyarrow.compute as pc
import rustworkx as rx

def attach_independent_components(table: pa.Table) -> pa.Table:
    """
    Returns the original table with an additional 'component' column indicating
    which connected component each edge belongs to.
    """
    # Create dictionary array from sorted unique values
    unique = pc.unique(
        pa.concat_arrays([
            table['left'].combine_chunks(),
            table['right'].combine_chunks()
        ])
    )
    
    # Get indices into unique array for graph construction
    left_indices = pc.index_in(table['left'], unique)
    right_indices = pc.index_in(table['right'], unique)
    
    # Create and process graph
    n_nodes = len(unique)
    n_edges = len(table)
    
    graph = rx.PyGraph(
        node_count_hint=n_nodes,
        edge_count_hint=n_edges
    )
    graph.add_nodes_from(range(n_nodes))

    edges = tuple(zip(left_indices.to_numpy(), right_indices.to_numpy()))
    graph.add_edges_from_no_data(edges)
    
    # Get components and create mapping array
    components = rx.connected_components(graph)
    
    # Convert components to numpy arrays
    component_indices = np.concatenate([np.array(list(c)) for c in components])
    component_labels = np.repeat(np.arange(len(components)), [len(c) for c in components])
    
    # Create mapping array and fill with component labels
    node_to_component = np.zeros(len(unique), dtype=np.int64)
    node_to_component[component_indices] = component_labels
    
    # Use the indices we already have to map back to components  
    edge_components = pa.array(node_to_component[left_indices.to_numpy()])
    
    return table.append_column('component', edge_components).sort_by(
        [('component', 'ascending'), ('probability', 'descending')]
    )

cc = attach_independent_components(table2)
len(pc.unique(cc.column('component'))), cc

(2080,
 pyarrow.Table
 left: int64
 right: int64
 probability: decimal128(3, 2)
 component: int64
 ----
 left: [[197413,116407,114551,160857,6412,...,39429,156175,197197,48177,121674]]
 right: [[326505,384163,344025,248700,258884,...,311452,278755,204144,357956,378111]]
 probability: [[1.00,1.00,1.00,1.00,1.00,...,0.70,0.70,0.70,0.70,0.70]]
 component: [[0,0,0,0,0,...,2079,2079,2079,2079,2079]])

In [None]:
import cProfile
import pstats
from pstats import SortKey

# Profile the function
pr = cProfile.Profile()
pr.enable()
components = attach_independent_components(table2)
pr.disable()

# Print stats sorted by cumulative time
ps = pstats.Stats(pr).sort_stats(SortKey.CUMULATIVE)
ps.print_stats(20)  # Show top 20 lines

         29986 function calls in 0.775 seconds

   Ordered by: cumulative time
   List reduced from 35 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000    0.874    0.437 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3541(run_code)
        2    0.000    0.000    0.874    0.437 {built-in method builtins.exec}
        1    0.373    0.373    0.775    0.775 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_25768/4265471472.py:6(find_independent_components)
        1    0.102    0.102    0.103    0.103 {connected_components}
        1    0.097    0.097    0.097    0.097 {method 'add_edges_from_no_data' of 'rustworkx.PyGraph' objects}
        2    0.090    0.045    0.090    0.045 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/pyarrow/compute.py:249(wrapper)
        1    0.074    0.074    0.074    0.074 /Users/willlangdale/DS/match

<pstats.Stats at 0x15e22c1d0>

In [135]:
import cProfile
import pstats
from pstats import SortKey

# Profile the function
pr = cProfile.Profile()
pr.enable()
_ = attach_independent_components(table)
pr.disable()

# Print stats sorted by cumulative time
ps = pstats.Stats(pr).sort_stats(SortKey.CUMULATIVE)
ps.print_stats(20)  # Show top 20 lines

         3411228 function calls in 121.198 seconds

   Ordered by: cumulative time
   List reduced from 48 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000  139.163   69.581 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3541(run_code)
        2    0.001    0.000  139.162   69.581 {built-in method builtins.exec}
        1   47.532   47.532  121.195  121.195 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_25768/2801194547.py:1(find_independent_components)
        1   24.123   24.123   24.123   24.123 {method 'add_edges_from_no_data' of 'rustworkx.PyGraph' objects}
        1   18.713   18.713   18.731   18.731 {connected_components}
        2   14.375    7.187   14.375    7.187 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/pyarrow/compute.py:249(wrapper)
        1    7.661    7.661    7.661    7.661 /Users/willlangdale/DS/m

<pstats.Stats at 0x16d1a8190>

### Process a single component

In [144]:
import pandas as pd
import numpy as np
from collections import defaultdict

def component_to_hierarchy(key: str | int | tuple, df: pd.DataFrame) -> pd.DataFrame:
    """
    Convert pairwise probabilities into a hierarchical representation.
    Assumes data is pre-sorted by probability descending.
    
    Args:
        key: Group key (ignored in this implementation)
        df: DataFrame with columns ['left', 'right', 'probability']
    
    Returns:
        DataFrame with columns ['parent', 'child', 'probability'] representing hierarchical merges
    """
    hierarchy: list[tuple[int, int, float]] = []
    ultimate_parents: dict[int, set[int]] = defaultdict(set)
    
    # Process each probability threshold
    for prob in df['probability'].unique():
        # Process pairs at this threshold
        current_pairs = df[df['probability'] == prob]
        
        for _, row in current_pairs.iterrows():
            left, right = row['left'], row['right']
            
            # Skip if already in same component
            if ultimate_parents[left] & ultimate_parents[right]:
                continue
                
            # Create merged node
            merged = combine_integers(left, right)
            
            # Add relationships to hierarchy
            hierarchy.extend([
                (merged, left, prob),
                (merged, right, prob)
            ])
            
            # Update parent tracking
            ultimate_parents[left].add(merged)
            ultimate_parents[right].add(merged)
            
            # Find all related nodes through shared parents
            parents_to_check = ultimate_parents[left] | ultimate_parents[right]
            related_nodes = {
                node 
                for node, parents in ultimate_parents.items() 
                if parents & parents_to_check and node not in (left, right)
            }
            
            if related_nodes:
                # Create new merged node for all related components
                super_merged = combine_integers(merged, *related_nodes)
                
                # Add relationships
                hierarchy.extend(
                    (super_merged, child, prob)
                    for child in (merged, *related_nodes)
                )
                
                # Update parent tracking for all children
                new_parents = {super_merged}
                for child in (left, right, *related_nodes):
                    ultimate_parents[child] = new_parents
            else:
                # Just update the two merged nodes
                new_parents = {merged}
                ultimate_parents[left] = new_parents
                ultimate_parents[right] = new_parents
        
        # Early stopping - check if everything is merged
        if len(set.union(*ultimate_parents.values())) == 1:
            break
    
    # Convert to DataFrame, already sorted by probability descending
    return pd.DataFrame(hierarchy, columns=['parent', 'child', 'probability'])

In [145]:
cc
# Convert the Arrow table to a Pandas DataFrame
df = cc.to_pandas()

# Filter the DataFrame where component is 0
cc_0 = df[df['component'] == 0]
cc_0

Unnamed: 0,left,right,probability,component
0,197413,326505,1.00,0
1,116407,384163,1.00,0
2,114551,344025,1.00,0
3,160857,248700,1.00,0
4,6412,258884,1.00,0
...,...,...,...,...
495,109907,399129,0.70,0
496,194530,385858,0.70,0
497,5985,212599,0.70,0
498,19131,294811,0.70,0


In [146]:
component_to_hierarchy(0, cc_0)

Unnamed: 0,parent,child,probability
0,-1193661193,197413,1.00
1,-1193661193,326505,1.00
2,-1065816097,116407,1.00
3,-1065816097,384163,1.00
4,-2131390865,114551,1.00
...,...,...,...
9341,-28774449,119799,0.82
9342,-28774449,290808,0.82
9343,-28774449,375290,0.82
9344,-28774449,154622,0.82
