# Hierarchical algorithm optimisation

## Data generation

In [20]:
import rustworkx as rx
from collections import Counter

def verify_components(table) -> dict:
    """
    Fast verification of connected components using rustworkx.
    
    Args:
        table: PyArrow table with 'left', 'right' columns
    
    Returns:
        Dictionary containing basic component statistics
    """
    # Create graph directly from arrays
    graph = rx.PyDiGraph()
    
    # Add all unique nodes at once
    unique_nodes = set(table['left'].to_numpy()) | set(table['right'].to_numpy())
    graph.add_nodes_from(range(len(unique_nodes)))
    
    # Create node mapping and edges in one pass
    node_to_idx = {node: idx for idx, node in enumerate(unique_nodes)}
    edges = [(node_to_idx[left], node_to_idx[right], prob) 
            for left, right, prob in zip(table['left'].to_numpy(), 
                                       table['right'].to_numpy(),
                                       table['probability'].to_numpy())]
    
    # Add all edges at once
    graph.add_edges_from(edges)
    
    # Get components and their sizes
    components = rx.weakly_connected_components(graph)
    component_sizes = Counter(len(component) for component in components)
    
    return {
        'num_components': len(components),
        'total_nodes': len(unique_nodes),
        'total_edges': len(edges),
        'component_sizes': component_sizes,
        'min_component_size': min(component_sizes.keys()),
        'max_component_size': max(component_sizes.keys())
    }

In [22]:
def calculate_max_possible_edges(n_nodes: int, num_components: int) -> int:
    """
    Calculate the maximum possible number of edges given n nodes split into k components.
    
    Args:
        n_nodes: Total number of nodes
        num_components: Number of components to split into
        
    Returns:
        Maximum possible number of edges
    """
    nodes_per_component = n_nodes // num_components
    max_edges_per_component = nodes_per_component * nodes_per_component  # Complete bipartite graph
    return max_edges_per_component * num_components


In [29]:
import numpy as np
import pyarrow as pa
import rustworkx as rx
from typing import List, Tuple
from decimal import Decimal

def split_values_into_components(values: List[int], num_components: int) -> List[np.ndarray]:
    """
    Split values into non-overlapping groups for each component.
    
    Args:
        values: List of values to split
        num_components: Number of components to create
        
    Returns:
        List of arrays, one for each component
    """
    values = np.array(values)
    np.random.shuffle(values)
    return np.array_split(values, num_components)


def generate_arrow_data(
    left_values: List[int],
    right_values: List[int],
    prob_range: Tuple[float, float],
    num_components: int,
    total_rows: int
) -> pa.Table:
    """
    Generate dummy arrow data with guaranteed isolated components.
    
    Args:
        left_values: List of integers to use for left column
        right_values: List of integers to use for right column
        prob_range: Tuple of (min_prob, max_prob) to constrain probabilities
        num_components: Number of distinct connected components to generate
        total_rows: Total number of rows to generate
    
    Returns:
        PyArrow Table with 'left', 'right', and 'probability' columns
    """
    if len(left_values) < 2 or len(right_values) < 2:
        raise ValueError("Need at least 2 possible values for both left and right")
    if num_components > min(len(left_values), len(right_values)):
        raise ValueError("Cannot have more components than minimum of left/right values")
    
    # Calculate maximum possible edges
    min_nodes = min(len(left_values), len(right_values))
    max_possible_edges = calculate_max_possible_edges(min_nodes, num_components)
    
    if total_rows > max_possible_edges:
        raise ValueError(
            f"Cannot generate {total_rows:,} edges with {num_components:,} components. "
            f"Maximum possible edges is {max_possible_edges:,} given {min_nodes:,} nodes. "
            "Either increase the number of nodes, decrease the number of components, "
            "or decrease the total edges requested."
        )
    
    # Convert probability range to integers (60-80 for 0.60-0.80)
    prob_min = int(prob_range[0] * 100)
    prob_max = int(prob_range[1] * 100)
    
    # Split values into completely separate groups for each component
    left_components = split_values_into_components(left_values, num_components)
    right_components = split_values_into_components(right_values, num_components)
    
    # Calculate base number of edges per component
    base_edges_per_component = total_rows // num_components
    remaining_edges = total_rows % num_components
    
    all_edges = []
    
    # Generate edges for each component
    for comp_idx in range(num_components):
        comp_left_values = left_components[comp_idx]
        comp_right_values = right_components[comp_idx]
        
        # Calculate edges for this component
        edges_in_component = base_edges_per_component
        if comp_idx < remaining_edges:  # Distribute remaining edges
            edges_in_component += 1
            
        # Ensure basic connectivity within the component
        base_edges = []
        
        # Create a spanning tree-like structure
        for i in range(len(comp_left_values)):
            base_edges.append((
                comp_left_values[i],
                comp_right_values[i % len(comp_right_values)],
                np.random.randint(prob_min, prob_max + 1)
            ))
        
        # Generate remaining random edges strictly within this component
        remaining_edges = edges_in_component - len(base_edges)
        if remaining_edges > 0:
            random_lefts = np.random.choice(comp_left_values, size=remaining_edges)
            random_rights = np.random.choice(comp_right_values, size=remaining_edges)
            random_probs = np.random.randint(prob_min, prob_max + 1, size=remaining_edges)
            
            component_edges = base_edges + list(zip(random_lefts, random_rights, random_probs))
        else:
            component_edges = base_edges
            
        all_edges.extend(component_edges)
    
    # Convert to arrays
    lefts, rights, probs = zip(*all_edges)
    
    # Create PyArrow arrays
    left_array = pa.array(lefts, type=pa.int64())
    right_array = pa.array(rights, type=pa.int64())
    decimal_probs = [Decimal(str(p/100)) for p in probs]
    prob_array = pa.array(decimal_probs, type=pa.decimal128(precision=3, scale=2))
    
    return pa.table([left_array, right_array, prob_array],
                   names=['left', 'right', 'probability'])



In [32]:
left_values = list(range(10_000))
right_values = list(range(10_000, 20_000))
prob_range = (0.6, 0.8)
num_components = 10
total_rows = 1_000_000

table = generate_arrow_data(
    left_values=left_values,
    right_values=right_values,
    prob_range=prob_range,
    num_components=num_components,
    total_rows=total_rows
)

results = verify_components(table)
print(f"Number of components found: {results['num_components']}")
print(f"Total nodes: {results['total_nodes']}")
print(f"Total edges: {results['total_edges']}")
print("\nComponent sizes:")
for size, count in sorted(results['component_sizes'].items()):
    print(f"Size {size}: {count} components")

Number of components found: 10
Total nodes: 20000
Total edges: 1000009

Component sizes:
Size 2000: 10 components


In [36]:
left_values = list(range(int(2e7)))
right_values = list(range(int(2e7), int(4e7)))
prob_range = (0.7, 1.0)
num_components = 200_000
total_rows = int(1e8)

table = generate_arrow_data(
    left_values=left_values,
    right_values=right_values,
    prob_range=prob_range,
    num_components=num_components,
    total_rows=total_rows
)

In [37]:
results = verify_components(table)
print(f"Number of components found: {results['num_components']}")
print(f"Total nodes: {results['total_nodes']}")
print(f"Total edges: {results['total_edges']}")
print("\nComponent sizes:")
for size, count in sorted(results['component_sizes'].items()):
    print(f"Size {size}: {count} components")

Number of components found: 206763
Total nodes: 40000000
Total edges: 100000400

Component sizes:
Size 2: 6755 components
Size 4: 8 components
Size 194: 1 components
Size 196: 124 components
Size 198: 6520 components
Size 200: 193355 components


Number of components found: 206763
* Total nodes: 40000000
* Total edges: 100000400

Component sizes:
* Size 2: 6755 components
* Size 4: 8 components
* Size 194: 1 components
* Size 196: 124 components
* Size 198: 6520 components
* Size 200: 193355 components

In [38]:
from pathlib import Path
import pyarrow.parquet as pq

pq.write_table(table, Path.cwd() / 'hierarchical_cc200k.parquet')

In [41]:
left_values = list(range(int(2e5)))
right_values = list(range(int(2e5), int(4e5)))
prob_range = (0.7, 1.0)
num_components = 2_000
total_rows = int(1e6)

table2 = generate_arrow_data(
    left_values=left_values,
    right_values=right_values,
    prob_range=prob_range,
    num_components=num_components,
    total_rows=total_rows
)

In [44]:
results2 = verify_components(table2)
print(f"Number of components found: {results2['num_components']:,}")
print(f"Total nodes: {results2['total_nodes']:,}")
print(f"Total edges: {results2['total_edges']:,}")
print("\nComponent sizes:")
for size, count in sorted(results2['component_sizes'].items()):
    print(f"Size {size:,}: {count:,} components")

Number of components found: 2,080
Total nodes: 400,000
Total edges: 1,000,400

Component sizes:
Size 2: 80 components
Size 196: 1 components
Size 198: 78 components
Size 200: 1,921 components


Number of components found: 2,080
* Total nodes: 400,000
* Total edges: 1,000,400

Component sizes:
* Size 2: 80 components
* Size 196: 1 components
* Size 198: 78 components
* Size 200: 1,921 components

In [46]:
from pathlib import Path
import pyarrow.parquet as pq

pq.write_table(table2, Path.cwd() / 'hierarchical_cc2k.parquet')

## Algorithm

In [47]:
import pyarrow.parquet as pq

h2 = pq.read_table('hierarchical_cc2k.parquet')

In [48]:
h2.schema

left: int64
right: int64
probability: decimal128(3, 2)

The plan.

* Find components and their sizes at lowest threshold (rustworkx)
* Use this to dask.groupby the data for parallel per-component processing
* Ensure we implement early stopping!

In [49]:
def combine_integers(*n: int) -> int:
    """
    Combine n integers into a single negative integer.

    Used to create a symmetric deterministic hash of two integers that populates the
    range of integers efficiently and reduces the likelihood of collisions.

    Aims to vectorise amazingly when used in Arrow.

    Does this by:

    * Using a Mersenne prime as a modulus
    * Making negative integers positive with modulo, sped up with bitwise operations
    * Combining using symmetric operations with coprime multipliers

    Args:
        *args: Variable number of integers to combine

    Returns:
        A negative integer
    """
    P = 2147483647

    total = 0
    product = 1

    for x in sorted(n):
        x_pos = (x ^ (x >> 31)) - (x >> 31)
        total = (total + x_pos) % P
        product = (product * x_pos) % P

    result = (31 * total + 37 * product) % P

    return -result

In [96]:
import pyarrow as pa
import pyarrow.compute as pc
import rustworkx as rx
from typing import Dict, Set

def find_independent_components(
    table: pa.Table
) -> Dict[int, Set[int]]:
    """
    Find completely independent components at the lowest probability threshold.
    These components will never interact with each other at any threshold.
    
    Args:
        table: PyArrow table with columns (left: int64, right: int64, probability: decimal128(3,2))
    
    Returns:
        Dictionary mapping component IDs to the Arrow indices of the nodes in that component
    """
    unique = pc.unique(
        pa.concat_arrays([
            table['left'].combine_chunks(),
            table['right'].combine_chunks()
        ])
    )
    
    left_indices = pc.index_in(table['left'], unique).to_numpy()
    right_indices = pc.index_in(table['right'], unique).to_numpy()
    
    # Initialize graph with exact size
    n_nodes = len(unique)
    n_edges = len(table)
    
    graph = rx.PyDiGraph(
        node_count_hint=n_nodes,
        edge_count_hint=n_edges
    )
    graph.add_nodes_from(range(n_nodes))
    
    # Convert to sequence for rustworkx
    edges = tuple(zip(
        left_indices,
        right_indices,
        table['probability'].to_numpy()
    ))
    
    # Add all edges at once
    graph.add_edges_from(edges)
    
    return rx.weakly_connected_components(graph)

cc = find_independent_components(table2)
cc

[{0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99,
  200000,
  200001,
  200002,
  200003,
  200004,
  200005,
  200006,
  200007,
  200008,
  200009,
  200010,
  200011,
  200012,
  200013,
  200014,
  200015,
  200016,
  200017,
  200018,
  200019,
  200020,
  200021,
  200022,
  200023,
  200024,
  200025,
  200026,
  200027,
  200028,
  200029,
  200030,
  200031,
  200032,
  200033,
  200034,
  200035,
  200036,
  200037,
  200038,
  200039,
  200040,


In [97]:
import cProfile
import pstats
from pstats import SortKey

# Profile the function
pr = cProfile.Profile()
pr.enable()
components = find_independent_components(table2)
pr.disable()

# Print stats sorted by cumulative time
ps = pstats.Stats(pr).sort_stats(SortKey.CUMULATIVE)
ps.print_stats(20)  # Show top 20 lines

         29986 function calls in 1.419 seconds

   Ordered by: cumulative time
   List reduced from 35 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000    1.598    0.799 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3541(run_code)
        2    0.000    0.000    1.598    0.799 {built-in method builtins.exec}
        1    0.780    0.780    1.419    1.419 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_25768/2771747583.py:6(find_independent_components)
        1    0.289    0.289    0.289    0.289 {weakly_connected_components}
        1    0.147    0.147    0.147    0.147 {method 'add_edges_from' of 'rustworkx.PyDiGraph' objects}
        2    0.092    0.046    0.092    0.046 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/pyarrow/compute.py:249(wrapper)
        1    0.075    0.075    0.075    0.075 /Users/willlangdale/DS/matc

<pstats.Stats at 0x51e36b010>

In [99]:
import cProfile
import pstats
from pstats import SortKey

# Profile the function
pr = cProfile.Profile()
pr.enable()
_ = find_independent_components(table)
pr.disable()

# Print stats sorted by cumulative time
ps = pstats.Stats(pr).sort_stats(SortKey.CUMULATIVE)
ps.print_stats(20)  # Show top 20 lines

         2997685 function calls in 228.208 seconds

   Ordered by: cumulative time
   List reduced from 35 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.003    0.001  271.928  135.964 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3541(run_code)
        2    0.001    0.000  271.925  135.963 {built-in method builtins.exec}
        1  124.187  124.187  228.200  228.200 /var/folders/14/6nvsrw1n2ls1xncz_bvy2x8m0000gq/T/ipykernel_25768/2771747583.py:6(find_independent_components)
        1   57.866   57.866   57.866   57.866 {method 'add_edges_from' of 'rustworkx.PyDiGraph' objects}
        1   17.807   17.807   17.824   17.824 {weakly_connected_components}
        2   14.311    7.155   14.311    7.156 /Users/willlangdale/DS/matchbox/.venv/lib/python3.11/site-packages/pyarrow/compute.py:249(wrapper)
        1    8.583    8.583    8.583    8.583 /Users/willlangdale/DS/

<pstats.Stats at 0x1665dee10>