In [1]:
import numpy as np
import treeswift
from flowtree import Quadtree, Distribution, QuadNode, QuadLeaf

d = 2
N = 16
s = 4

points = np.random.normal(size=(N, d)) 
quadtree = Quadtree(points)
mu = Distribution.uniform_distribution(N, s)
nu = Distribution.uniform_distribution(N, s)

  from .autonotebook import tqdm as notebook_tqdm


Normalizing distribution's mass from 1.790162537072419 to 1
Normalizing distribution's mass from 2.7773390393441915 to 1


In [2]:
from typing import Dict, Tuple, Union
from treeswift import Tree
from copy import deepcopy
from collections import defaultdict
from numpy import ndarray as Array

EPS = 1e-12

def compute_optimal_flow(
    qt: Quadtree, 
    mu: Distribution, 
    nu: Distribution,
) -> Dict[Tuple[int, int], float]:
    compute_flow_tree(qt, mu, nu)
    optimal_flow = resolve_dtree(qt)
    return optimal_flow

def compute_flow_tree(
    qt: Quadtree, 
    mu: Distribution, 
    nu: Distribution,
) -> None:
    for leaf in qt.traverse_leaves():
        i = leaf.label
        if i in mu:
            leaf.demand[0][i] = mu(i)
        if i in nu:
            leaf.demand[1][i] = nu(i)

def find_transaction(demand) -> Tuple:
    am0, m0 = -1, 2
    for i, amount in demand[0].items():
        if amount < m0:
            am0, m0 = i, amount
    am1, m1 = -1, 2
    for i, amount in demand[1].items():
        if amount < m1:
            am1, m1 = i, amount
    return (min(m0, m1), am0, am1)

def resolve_demand(
    of: Dict[Tuple[int, int], float],
    node: Union[QuadLeaf, QuadNode],
) -> Dict[Tuple[int, int], float]:
    print('Opening new node with demand\n', node.demand)
    if node.demand[0] and node.demand[1]:
        amount, i, j = find_transaction(node.demand)
        print('Transcation found:', amount, i, j)
        node.demand[0][i] -= amount
        if node.demand[0][i] < EPS:
            node.demand[0].pop(i)
        node.demand[1][j] -= amount
        if node.demand[1][j] < EPS:
            node.demand[1].pop(j)
        of[(i, j)] += amount
        return resolve_demand(of, node)
    else:
        for i, amount in node.demand[0].items():
            if i in node.parent.demand[0]:
                node.parent.demand[0][i] += amount
            else:
                node.parent.demand[0][i] = amount
        for j, amount in node.demand[1].items():
            if j in node.parent.demand[1]:
                node.parent.demand[1][j] += amount
            else:
                node.parent.demand[1][j] = amount
        return of

def resolve_dtree(
    dtree: Quadtree, # dtree = demand tree
) -> Dict[Tuple[int, int], float]:
    of: Dict[Tuple[int, int], float] = defaultdict(lambda : 0) # optimal flow
    for node in dtree.traverse_postorder():
        of = resolve_demand(of, node)
    return of

In [3]:
compute_optimal_flow(quadtree, mu, nu)

Opening new node with demand
 ({}, {0: 0.36004598382753605})
Opening new node with demand
 ({}, {})
Opening new node with demand
 ({}, {0: 0.36004598382753605})
Opening new node with demand
 ({3: 0.3261998299426699}, {})
Opening new node with demand
 ({}, {6: 0.30477345706114845})
Opening new node with demand
 ({}, {})
Opening new node with demand
 ({}, {6: 0.30477345706114845})
Opening new node with demand
 ({11: 0.37849961466465515}, {})
Opening new node with demand
 ({3: 0.3261998299426699, 11: 0.37849961466465515}, {0: 0.36004598382753605, 6: 0.30477345706114845})
Transcation found: 0.30477345706114845 3 6
Opening new node with demand
 ({3: 0.021426372881521438, 11: 0.37849961466465515}, {0: 0.36004598382753605})
Transcation found: 0.021426372881521438 3 0
Opening new node with demand
 ({11: 0.37849961466465515}, {0: 0.3386196109460146})
Transcation found: 0.3386196109460146 11 0
Opening new node with demand
 ({11: 0.03988000371864053}, {})
Opening new node with demand
 ({}, {1: 0.

defaultdict(<function __main__.resolve_dtree.<locals>.<lambda>()>,
            {(3, 6): 0.30477345706114845,
             (3, 0): 0.021426372881521438,
             (11, 0): 0.3386196109460146,
             (8, 4): 0.01021637063086835,
             (11, 4): 0.03988000371864053,
             (5, 4): 0.09891468730928846,
             (5, 1): 0.18616949745251807})