In [1]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Iterator, Tuple, Optional, List
from __future__ import annotations
import heapq
from dataclasses import dataclass
from tqdm.auto import tqdm

In [2]:
xcounts = np.load('x5-500000.npz')['counts']
zcounts = np.load('z5-400000.npz')['counts']

In [3]:
baseline = 5 * 4 # expected count per location

In [4]:
xmin = zmin = -1875000

In [5]:
def _value_counts_0_to_max(x: np.ndarray, max_val: int) -> np.ndarray:
    """
    Count occurrences of each integer value in [0, max_val] in a 1D numpy array.
    Returns int64 counts of shape (max_val+1,).
    """
    x = np.asarray(x)
    if x.ndim != 1:
        x = x.ravel()

    # bincount requires non-negative ints; you said values are in 0..100
    # Ensure integer dtype without copying if already int.
    if x.dtype.kind not in ("i", "u"):
        x = x.astype(np.int64, copy=False)

    counts = np.bincount(x, minlength=max_val + 1)
    if counts.shape[0] < max_val + 1:
        # Shouldn't happen with minlength, but keep it robust.
        counts = np.pad(counts, (0, max_val + 1 - counts.shape[0]), mode="constant")

    return counts.astype(np.int64, copy=False)


def _suffix_ge(counts: np.ndarray) -> np.ndarray:
    """
    Given counts[b] for b=0..max_val, return suffix_ge[k] = sum_{b>=k} counts[b]
    for k=0..max_val+1, where suffix_ge[max_val+1] = 0.
    """
    counts = np.asarray(counts, dtype=np.int64)
    max_val = counts.shape[0] - 1
    out = np.zeros(max_val + 2, dtype=np.int64)
    # cumulative sum from the end
    out[:max_val + 1] = np.cumsum(counts[::-1], dtype=np.int64)[::-1]
    out[max_val + 1] = 0
    return out


def _count_greater_and_equal_from_hists(
    countA: np.ndarray,
    countB: np.ndarray,
    p: int,
) -> Tuple[int, int]:
    """
    Given histograms countA[a], countB[b] (a,b >= 0) and integer target p,
    compute:
      G = #pairs (a,b) with a*b > p
      E = #pairs (a,b) with a*b == p
    using O(max_val) work.
    """
    countA = np.asarray(countA, dtype=np.int64)
    countB = np.asarray(countB, dtype=np.int64)
    max_val = countA.shape[0] - 1
    assert countB.shape[0] - 1 == max_val

    nB = int(countB.sum())
    suffixB = _suffix_ge(countB)

    G = np.int64(0)
    E = np.int64(0)

    # Handle each possible value of a (only 0..max_val)
    for a in range(max_val + 1):
        ca = countA[a]
        if ca == 0:
            continue

        if a == 0:
            # product is always 0
            if p == 0:
                E += ca * nB
            elif p < 0:
                # not expected in your setting, but included for completeness
                G += ca * nB
            continue

        # Count b such that a*b > p  <=>  b > floor(p/a)
        t = p // a  # floor
        if t < max_val:
            # number of b in [t+1..max_val]
            G += ca * suffixB[t + 1]
        # else: t >= max_val => no b in range can exceed, contribute 0

        # Count equality: a*b == p => p divisible by a and b=p/a within range
        if p % a == 0:
            b = p // a
            if 0 <= b <= max_val:
                E += ca * countB[b]

    return int(G), int(E)


def average_rank_for_product(
    A: np.ndarray,
    B: np.ndarray,
    p: int,
    *,
    max_val: int = 1000,
    require_present: bool = False,
    return_counts: bool = False,
) -> float | Tuple[float, int, int]:
    """
    Compute the average 1-indexed rank of product value p among all pairs (i,j),
    ordered by decreasing A[i]*B[j], with ties averaged.

    Assumes A and B are arrays of non-negative integers in [0, max_val].

    Returns:
      avg_rank (float), or (avg_rank, G, E) if return_counts=True

    Where:
      G = number of pairs with product > p
      E = number of pairs with product = p

    If E==0:
      - if require_present=True: raises ValueError
      - else returns np.nan (and counts if requested)
    """
    if not isinstance(p, (int, np.integer)):
        # You can relax this if you want to accept floats that are integral
        raise TypeError("p must be an integer product value")

    countA = _value_counts_0_to_max(A, max_val)
    countB = _value_counts_0_to_max(B, max_val)

    # Quick reject if p outside possible range: [0, max_val^2]
    if p < 0 or p > max_val * max_val:
        if require_present:
            raise ValueError(f"p={p} is outside achievable product range [0, {max_val*max_val}]")
        avg = np.nan
        return (avg, 0, 0) if return_counts else avg

    G, E = _count_greater_and_equal_from_hists(countA, countB, int(p))

    if E == 0:
        if require_present:
            raise ValueError(f"No pairs have product p={p} (E=0).")
        avg = np.nan
        return (avg, G, E) if return_counts else avg

    # Average rank of tied block [G+1, G+E] (1-indexed)
    avg_rank = G + (E + 1) / 2.0

    return (avg_rank, G, E) if return_counts else float(avg_rank)

In [6]:
knownx = 51069
knownz = 419124
product_of_known_coord = xcounts[knownx-xmin] * zcounts[knownz-zmin]
print(f'x={knownx}', f'z={knownz}', f'p={product_of_known_coord/baseline:.1f}')

x=51069 z=419124 p=8.0


In [7]:
# quality of heuristic
# e.g. 0.009 means the known location scores in the top 0.9% -> heuristic offers over 100x improvement
average_rank_for_product(xcounts,zcounts,product_of_known_coord) / len(xcounts) / len(zcounts)

0.009524233374589526

In [8]:
def _group_indices_by_value_desc(
    x: np.ndarray,
    *,
    stable: bool = False,
    index_dtype: Optional[np.dtype] = None,
) -> Tuple[List[int], List[np.ndarray]]:
    """
    Group original indices of x by value, returning:
      - values_desc: list of distinct values in descending order (Python ints)
      - groups_desc: list of 1D numpy arrays of original indices for each value

    Notes:
      - Uses one sort of length len(x).
      - Within each group, index order is arbitrary unless stable=True (then stable sort).
    """
    x = np.asarray(x)
    n = int(x.size)
    if n == 0:
        return [], []

    if index_dtype is None:
        index_dtype = np.int32 if n <= np.iinfo(np.int32).max else np.int64

    kind = "mergesort" if stable else "quicksort"
    order = np.argsort(x, kind=kind).astype(index_dtype, copy=False)  # ascending indices by x
    xs = x[order]  # ascending values

    # boundaries where value changes
    # (bool array length n-1, very memory-light)
    change = xs[1:] != xs[:-1]
    boundaries = np.nonzero(change)[0] + 1  # split points

    # group start positions and corresponding values
    starts = np.empty(boundaries.size + 1, dtype=np.int64)
    starts[0] = 0
    starts[1:] = boundaries
    vals_asc = xs[starts]

    groups_asc = np.split(order, boundaries)  # views into 'order' when possible

    # reverse to descending by value
    vals_desc = [int(v) for v in vals_asc[::-1]]
    groups_desc = list(reversed(groups_asc))

    return vals_desc, groups_desc


In [9]:
class PairProductStreamer(Iterator[Tuple[int, int, int]]):
    """
    Streams (i, j, score=A[i]*B[j]) in non-increasing score order, without materializing pairs.

    Strategy:
      1) Compress A and B into value-groups (value -> indices).
      2) Enumerate value-pair buckets (a_val, b_val) in descending product order via a max-heap
         using lazy row activation (k-way merge over rows).
      3) For each bucket, emit the full Cartesian product of the two index groups.

    tqdm compatibility:
      - Implements __len__() as total_pairs = len(A)*len(B) (Python int).
      - Tracks self.emitted_pairs for custom progress reporting if desired.
    """

    def __init__(
        self,
        A: np.ndarray,
        B: np.ndarray,
        *,
        stable_group_order: bool = False,
    ):
        self.A = np.asarray(A)
        self.B = np.asarray(B)

        self.nA = int(self.A.size)
        self.nB = int(self.B.size)
        self.total_pairs: int = int(self.nA) * int(self.nB)

        # Group indices by value (descending values)
        self.a_vals, self.a_groups = _group_indices_by_value_desc(
            self.A, stable=stable_group_order
        )
        self.b_vals, self.b_groups = _group_indices_by_value_desc(
            self.B, stable=stable_group_order
        )

        self.ua = len(self.a_vals)
        self.ub = len(self.b_vals)

        # Heap for (negative_product, i_group, j_group)
        self._heap: List[Tuple[int, int, int]] = []
        self._max_row_activated: int = 0

        # Current bucket emission state
        self._curA: Optional[np.ndarray] = None
        self._curB: Optional[np.ndarray] = None
        self._score: Optional[int] = None
        self._pos_a: int = 0
        self._pos_b: int = 0

        self.emitted_pairs: int = 0

        if self.ua > 0 and self.ub > 0:
            p0 = self.a_vals[0] * self.b_vals[0]
            heapq.heappush(self._heap, (-p0, 0, 0))

    def __len__(self) -> int:
        return self.total_pairs

    def __iter__(self) -> "PairProductStreamer":
        return self

    def _push_next_in_row(self, i: int, j: int) -> None:
        j2 = j + 1
        if j2 < self.ub:
            p = self.a_vals[i] * self.b_vals[j2]
            heapq.heappush(self._heap, (-p, i, j2))

    def _activate_next_row_if_needed(self, i: int, j: int) -> None:
        # Lazy activation: when we pop (max_row, 0), we can activate (max_row+1, 0)
        if j == 0 and i == self._max_row_activated:
            nxt = i + 1
            if nxt < self.ua:
                self._max_row_activated = nxt
                p = self.a_vals[nxt] * self.b_vals[0]
                heapq.heappush(self._heap, (-p, nxt, 0))

    def _load_next_bucket(self) -> bool:
        if not self._heap:
            return False

        negp, i, j = heapq.heappop(self._heap)
        self._score = -negp
        self._curA = self.a_groups[i]
        self._curB = self.b_groups[j]
        self._pos_a = 0
        self._pos_b = 0

        # Advance the frontier in the value-pair matrix
        self._push_next_in_row(i, j)
        self._activate_next_row_if_needed(i, j)
        return True

    def __next__(self) -> Tuple[int, int, int]:
        if self.total_pairs == 0:
            raise StopIteration

        # Local bindings for speed in tight loops
        while True:
            if self._curA is None:
                if not self._load_next_bucket():
                    raise StopIteration

            curA = self._curA
            curB = self._curB
            score = self._score

            # Defensive: these should never be None here
            if curA is None or curB is None or score is None:
                raise StopIteration

            # Emit one pair from the current bucket's Cartesian product
            a_idx = int(curA[self._pos_a])
            b_idx = int(curB[self._pos_b])

            self._pos_b += 1
            if self._pos_b >= curB.size:
                self._pos_b = 0
                self._pos_a += 1

            # Bucket exhausted?
            if self._pos_a >= curA.size:
                self._curA = None
                self._curB = None
                self._score = None

            self.emitted_pairs += 1
            return a_idx, b_idx, score

    @property
    def progress(self) -> float:
        """Fraction of total pairs emitted (0..1)."""
        if self.total_pairs == 0:
            return 1.0
        return self.emitted_pairs / self.total_pairs


In [10]:
stream = PairProductStreamer(xcounts, zcounts)

In [11]:
for (i, j, s) in tqdm(stream, total=len(stream), mininterval=0.25):
    # break early in the demo
    print(f'x={i+xmin}', f'z={j+zmin}', f'p={s/baseline:.1f}')
    if stream.emitted_pairs >= 100:
        break

  0%|          | 0/14062477500009 [00:00<?, ?it/s]

x=1286212 z=1164393 p=449.5
x=-1259094 z=1164393 p=427.8
x=-98126 z=1164393 p=420.5
x=1228102 z=1164393 p=413.2
x=-127999 z=1164393 p=413.2
x=1286212 z=199417 p=399.9
x=-1005031 z=1164393 p=398.8
x=94064 z=1164393 p=398.8
x=449815 z=1164393 p=398.8
x=-1332712 z=1164393 p=398.8
x=765586 z=1164393 p=391.5
x=-981440 z=1164393 p=384.2
x=-683795 z=1164393 p=384.2
x=1826236 z=1164393 p=384.2
x=-1259094 z=199417 p=380.6
x=-127998 z=1164393 p=377.0
x=-98126 z=199417 p=374.1
x=777495 z=1164393 p=369.8
x=1286212 z=-642230 p=368.9
x=1228102 z=199417 p=367.6
x=-127999 z=199417 p=367.6
x=-1545320 z=1164393 p=362.5
x=-1339156 z=1164393 p=362.5
x=-1605782 z=1164393 p=362.5
x=-12927 z=1164393 p=362.5
x=-1036440 z=1164393 p=355.2
x=-1266683 z=1164393 p=355.2
x=-1833407 z=1164393 p=355.2
x=525888 z=1164393 p=355.2
x=-1005031 z=199417 p=354.8
x=94064 z=199417 p=354.8
x=449815 z=199417 p=354.8
x=-1332712 z=199417 p=354.8
x=-1259094 z=-642230 p=351.1
x=765586 z=199417 p=348.3
x=1527834 z=1164393 p=348.0
x=