In [1]:
from process_cube.dat import read_dat_file, Header
import numpy as np
header = Header.read_header("self_test_rad.hdr")
img = read_dat_file("self_test_rad.img", header).astype(np.float32)
cache_path = "D:\\SingleLinkage\\diff.cache"


In [None]:
import numpy as np
from tqdm import tqdm


class Tree:
    count = 0

    def __init__(self, members: list):
        self.members = members
        self._contains = set.union(*map(lambda tree: tree._contains, members)) if len(members) > 0 else set()
        self.index = Tree.count
        Tree.count+=1
    
    def join(self, other):
        return Tree([self, other])
    
    def clear(self,):
#         print(self.index)
        self._contains = None

    def num_of_members(self,):
        return len(self.members)
    
    def __iter__(self,):
        yield from self._contains
    
    def __contains__(self, i):
        return i in self._contains
        
    def __len__(self,):
        return len(tuple(leaf for leaf in self))

class Leaf(Tree):
    def __init__(self, value) -> None:
        self.value = value
        super().__init__([])
        self._contains = {value}


def triangle_offset(i, n):
    return (2*i*n-i**2-i)//2

# Iterate through a strictly upper-triangular matrix line by line.
def iterate_triangle(cache_path, height, batch_size=100):
    batch_size = min(batch_size,height)
    partition = np.fromfile(cache_path, dtype=np.float16, offset=0, count=triangle_offset(batch_size,height))
    partition_offset = 0
    for i in tqdm(range(height-1)):
        start = triangle_offset(i, height)
        end = triangle_offset(i+1, height)
        if (i+1) % batch_size == 0:
            partition =  np.fromfile(cache_path, dtype=np.float16, offset=start*2, count=triangle_offset(i+batch_size,height) - start)
            partition_offset = start
        yield i, partition[start - partition_offset:end - partition_offset]

def get_kth_min(cache_path, vector_index, k, height):
    start = triangle_offset(vector_index, height)
    end = triangle_offset(vector_index+1, height)
    vector = np.fromfile(cache_path, dtype=np.float16, count=(end-start), offset=start*2)
    kth_argmin = np.argpartition(vector, k)[k]
    kth_min = vector[kth_argmin]
    return kth_argmin, kth_min

fimg = img.reshape(-1, img.shape[2])

trees = {}
for i in range(fimg.shape[0]):
    leaf = Leaf(i)
    trees[leaf.index] = leaf

_current_min = np.zeros(fimg.shape[0],)
_current_min_pairs = np.zeros(fimg.shape[0], dtype=int)

for i, vector in iterate_triangle(cache_path, fimg.shape[0], batch_size=10000):
    _current_min_pairs[i] = np.argmin(vector).astype(int)
    _current_min[i] = vector[_current_min_pairs[i]]

_current_min_pairs.astype(np.int32).tofile("initial_current_min_pairs.cache")
_current_min.astype(np.float64).tofile("initial_current_min.cache")

In [1]:
import numpy as np
from tqdm.notebook import tqdm
from process_cube.dat import read_dat_file, Header
from functools import lru_cache

header = Header.read_header("self_test_rad.hdr")
img = read_dat_file("self_test_rad.img", header).astype(np.float32)
fimg = img.reshape(-1, img.shape[2])
cache_path = "D:\\SingleLinkage\\diff.cache"

class Tree:
    count = 0

    def __init__(self, members: list):
        self.members = members
        self._contains = set.union(*map(lambda tree: tree._contains, members)) if len(members) > 0 else set()
        self.index = Tree.count
        Tree.count+=1
    
    def join(self, other):
        return Tree([self, other])
    
    def clear(self,):
#         print(self.index)
        self._contains = None

    def num_of_members(self,):
        return len(self.members)
    
    def __iter__(self,):
        yield from self._contains
    
    def __contains__(self, i):
        return i in self._contains
        
    def __len__(self,):
        return len(tuple(leaf for leaf in self))

class Leaf(Tree):
    def __init__(self, value) -> None:
        self.value = value
        super().__init__([])
        self._contains = {value}

def get_vector(cache_path, vector_index, height):
    start = triangle_offset(vector_index, height)
    end = triangle_offset(vector_index+1, height)
    count = end - start
    return np.fromfile(cache_path, dtype=np.float16, count=count, offset=start*2)
    return np.linalg.norm(fimg[vector_index]-fimg[vector_index+1:],axis=-1)

@lru_cache(20000)
def get_argsort(cache_path, vector_index, height):
    vector = get_vector(cache_path, vector_index, height)
    
    # argpartition is faster the difference is negligible for k > 1000, which is common - caching should outperform.
    return np.argsort(vector), vector

def get_kth_min(cache_path, vector_index, k, height):
    argsort, vector = get_argsort(cache_path, vector_index, height)
    if k >= len(vector):
        return k, np.inf
    
    kth_argmin = argsort[k] 
    kth_min = vector[kth_argmin]
    return kth_argmin, kth_min

def triangle_offset(i, n):
    return (2*i*n-i**2-i)//2


In [2]:
current_min_pairs = np.fromfile("initial_current_min_pairs.cache", dtype=np.int32)[:fimg.shape[0] - 1]
current_min = np.fromfile("initial_current_min.cache", dtype=np.float64)[:fimg.shape[0] - 1]
current_k = np.zeros(fimg.shape[0], dtype=int)

Tree.count = 0
Leaf.count = 0
trees = {}
for i in range(fimg.shape[0]):
    leaf = Leaf(i)
    trees[leaf.index] = leaf

out_trees = set()
    
max_len = 0

edges_count = 0
pb = tqdm(total=fimg.shape[0])

  0%|          | 0/224000 [00:00<?, ?it/s]

In [3]:
%%prun
def get_root_tree(tree_index, trees):
    tree = None
    original_index = tree_index
    while trees[tree_index] != tree:
        tree = trees[tree_index]
        tree_index = tree.index
    trees[original_index] = tree
    return tree

try:
    while edges_count < fimg.shape[0]-1:
        vector_with_min_pair = current_min.argmin()
        other_vector = current_min_pairs[vector_with_min_pair] + 1 + vector_with_min_pair
        tree1 = get_root_tree(vector_with_min_pair, trees)
        tree2 = get_root_tree(other_vector, trees)
        if other_vector in tree1 or vector_with_min_pair in tree2:
            if not (other_vector in tree1 and vector_with_min_pair in tree2):
                print(trees[other_vector], trees[vector_with_min_pair])
                raise
            
            while other_vector in tree1:
                kth_argmin = current_min_pairs[vector_with_min_pair]
                current_k[vector_with_min_pair] += 1
                kth_argmin, kth_min = get_kth_min(cache_path, vector_with_min_pair, current_k[vector_with_min_pair], fimg.shape[0])
                current_min[vector_with_min_pair] = kth_min
                current_min_pairs[vector_with_min_pair] = kth_argmin
                other_vector = current_min_pairs[vector_with_min_pair] + 1 + vector_with_min_pair
            continue
        
        
        joined = tree1.join(tree2)
#         print(tree2.index, other_vector, end=">")
        tree2.clear()
        tree1.clear()
        trees[tree2.index] = joined
        trees[tree1.index] = joined
        trees[joined.index] = joined
        edges_count += 1

        kth_argmin = current_min_pairs[vector_with_min_pair]
        current_k[vector_with_min_pair] += 1
        kth_argmin, kth_min = get_kth_min(cache_path, vector_with_min_pair, current_k[vector_with_min_pair], fimg.shape[0])
        current_min[vector_with_min_pair] = kth_min
        current_min_pairs[vector_with_min_pair] = kth_argmin
        pb.update()
except KeyboardInterrupt:
    pass
edges_count

 

         2245025924 function calls in 18081.568 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   674836 5827.382    0.009 5827.382    0.009 {method 'argsort' of 'numpy.ndarray' objects}
   674837 5548.974    0.008 5551.239    0.008 {built-in method numpy.fromfile}
        1 3281.885 3281.885 18081.566 18081.566 <string>:1(<module>)
  1566664 1553.007    0.001 1553.007    0.001 <string>:1(get_root_tree)
730457591  849.731    0.000 12313.941    0.000 1520836299.py:59(get_kth_min)
732942918  379.563    0.000  379.563    0.000 1520836299.py:33(__contains__)
   216000  232.110    0.001  232.110    0.001 {method 'union' of 'set' objects}
   432000  133.316    0.000  133.316    0.000 1520836299.py:23(clear)
   783332  127.750    0.000  127.750    0.000 {method 'argmin' of 'numpy.ndarray' objects}
731943366   65.607    0.000   65.607    0.000 {built-in method builtins.len}
   141086    9.210    0.000    9.210    0.000 socket.py:5

In [74]:
from dash import Dash, html
import dash_cytoscape as cyto

def create_node(i):
    return {"data": {"id": i, "label": i}}

def create_edge(i,j):
    return {"data": {"source":i, "target": j}}

cyto.load_extra_layouts()
app = Dash(__name__)

app.layout = html.Div([
    html.P("Dash Cytoscape:"),
    cyto.Cytoscape(
        id='cytoscape',
        elements=[
            *(create_node(i) for i in trees),
            *(create_edge(i,tree.index) for i,tree in trees.items() if i != tree.index )
        ],
        layout={'name': 'dagre',"rankDir": 'LR',},
        style={'width': '100%', 'height': '1000px', 'border': '3px solid black'}
    )
])


app.run_server(debug=True)

In [None]:
min_value = np.zeros(fimg.shape[0]-1)
batch_size = 10
diff_cache_part = diff_cache[0: triangle_offset(batch_size,fimg.shape[0])]
partition_offset = 0
for i in tqdm(range(fimg.shape[0]-1)):
    start = triangle_offset(i, fimg.shape[0])
    end = triangle_offset(i+1, fimg.shape[0])
    if (i+1) % batch_size == 0:
        diff_cache_part = diff_cache[start: triangle_offset(i+batch_size, fimg.shape[0])]
        partition_offset = start
    min_value[i] = diff_cache_part[start - partition_offset + current_min_pairs[i]]
min_value

In [None]:
for i in range(fimg.shape[0]-1):
    start = triangle_offset(i, fimg.shape[0])
    end = triangle_offset(i+1, fimg.shape[0])
    if end <= start:
        print(i,start,end)
        break

In [None]:
import os
os.stat(cache_path).st_size/2

In [None]:
diff_cache[25087664554:25087665222], diff_cache.shape

In [None]:
%%prun


def get_distance(p1,p2,img, dist=np.linalg.norm):
    print(img[p2].shape, img[p1].shape )
    return dist(img[p2]-img[p1], axis=-1)

def get_min_dist_pair(img):
    original_shape = img.shape
    img = img.reshape(-1,img.shape[-1])
    
    min_distance = np.inf
    min_pair = None
    for vector in range(img.shape[0]):
        if vector == 1000:
            break
        dist = get_distance(vector, slice(1+vector,None), img)
        arg_min = np.argmin(dist) 
        
        if dist[arg_min] >= min_distance:
            continue
        min_pair = (np.unravel_index(vector, original_shape[:2]), np.unravel_index(arg_min + vector, original_shape[:2]))
        min_distance = dist[arg_min]
    return min_pair

def find_closest_clusters(trees, img):
    m1, m2 = get_min_dist_pair(img)
    t1, t2 = None, None
    for tree in trees:
        if m1 in tree:
            t1 = tree
        if m2 in tree:
            t2 = tree
    return t1,t2

def single_linkage_heir(img):
    trees = {}
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            tree = Leaf((i,j))
            trees[tree.index] = tree
    memory = np.ones(img.shape[:2])
    while len(trees) > 1:
        tree1, tree2  = find_closest_clusters(trees.values(), img)
        print(tree1.index, tree2.index)
        tree_joineed = tree1.join(tree2)
        trees.pop(tree1.index)
        trees.pop(tree2.index)
        break

single_linkage_heir(img)

In [None]:
A = np.arange(16).reshape(4,4)

Au = np.triu(A)
Au

In [None]:
n-1 + n-2+ ... + 1

In [None]:
280*800

#### Note:
indexing the distance would be practically impossible in RAM
~(200*800)^2=50176000000 which is about 46 GB of RAM for single byte precision.