In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import glob
import os
import sys
import logging
from itertools import chain

import requests
from tqdm import tqdm, tqdm_notebook, tnrange
#tqdm = tqdm_notebook

import vigra
import numpy as np
import pandas as pd

from dvidutils import LabelMapper
from libdvid import DVIDNodeService

from neuclease.dvid import *
from neuclease.util import Timer
from neuclease.misc import find_best_plane

In [3]:
from DVIDSparkServices.spark_launch_scripts.janelia_lsf.lsf_utils import get_hostgraph_url

In [4]:
handler = logging.StreamHandler(sys.stdout)
root_logger = logging.getLogger()
root_logger.handlers = []
root_logger.addHandler(handler)
root_logger.setLevel(logging.INFO)
logging.getLogger('kafka').setLevel(logging.WARNING)

In [5]:
cd /nrs/flyem/bergs/complete-ffn-agglo/

/nrs/flyem/bergs/complete-ffn-agglo


In [6]:
!uname -n

h10u12.int.janelia.org


In [7]:
pwd

'/nrs/flyem/bergs/complete-ffn-agglo'

In [8]:
sc

In [9]:
def closest_approach(sv_vol, id_a, id_b):
    """
    Given a segmentation volume and two label IDs which it contains,
    Find the two coordinates within id_a and id_b, respectively,
    which mark the two objects' closest approach, i.e. where the objects
    come closest to touching, even if they don't actually touch.
    
    Returns (coord_a, coord_b)
    """
    # For all voxels, find the shortest vector toward id_b
    to_b_vectors = vigra.filters.vectorDistanceTransform((sv_vol == id_b).astype(np.uint32))

    # Magnitude of those vectors == distance to id_b
    to_b_distances = np.linalg.norm(to_b_vectors, axis=-1)

    # We're only interested in the voxels within id_a;
    # everything else is infinite distance
    to_b_distances[sv_vol != id_a] = np.inf

    # Find the point within id_a with the smallest vector
    point_a = np.unravel_index(np.argmin(to_b_distances), to_b_distances.shape)

    # Its closest point id_b is indicated by the corresponding vector
    point_b = (point_a + to_b_vectors[point_a]).astype(int)

    return (point_a, point_b)

In [10]:
def split_events_to_mapping(split_events, leaves_only=False):
    """
    Convert the given split_events,
    into a mapping, from all split fragment supervoxel IDs to their ROOT supervoxel ID,
    i.e. the supervoxel from which they came originally.

    Args:
        split_events:
            As produced by fetch_supervoxel_splits()

        leaves_only:
            If True, do not include intermediate supervoxels in the mapping;
            only include fragment IDs that have not been further split,
            i.e. they still exist in the volume.
    
    Returns:
        pd.Series, where index is fragment ID, data is root ID.
    """
    if len(split_events) == 0:
        return np.zeros((0,2), np.uint64)
    
    split_tables = list(map(lambda t: np.asarray(t, np.uint64), split_events.values()))
    split_table = np.concatenate(split_tables)

    old_svs = split_table[:, SplitEvent._fields.index('old')]
    remain_fragment_svs = split_table[:, SplitEvent._fields.index('remain')]
    split_fragment_svs = split_table[:, SplitEvent._fields.index('split')]

    if leaves_only:
        leaf_fragment_svs = (set(remain_fragment_svs) | set(split_fragment_svs)) - set(old_svs)
        fragment_svs = np.fromiter(leaf_fragment_svs, np.uint64)
    else:
        fragment_svs = np.concatenate((remain_fragment_svs, split_fragment_svs))
        
    g = split_events_to_graph(split_events)
    root_svs = np.fromiter(map(lambda sv: find_root(g, sv), fragment_svs), np.uint64, len(fragment_svs))

    mapping = pd.Series(index=fragment_svs, data=root_svs)
    mapping.index.name = 'fragment_sv'
    mapping.name = 'root_sv'
    return mapping


In [11]:
@sanitize_server
def expand_uuid(server, uuid, repo_uuid=None):
    repo_uuid = repo_uuid or uuid
    repo_info = fetch_repo_info(server, repo_uuid)
    full_uuids = repo_info["DAG"]["Nodes"].keys()
    
    matching_uuids = list(filter(lambda full_uuid: uuids_match(uuid, full_uuid), full_uuids))
    if len(matching_uuids) == 0:
        raise RuntimeError(f"No matching uuid for '{uuid}'")
    
    if len(matching_uuids) > 1:
        raise RuntimeError(f"Multiple ({len(matching_uuids)}) uuids match '{uuid}': {matching_uuids}")

    return matching_uuids[0]


### Hostgraph URLS

In [12]:
print("This notebook:")
print(get_hostgraph_url(os.environ["LSB_JOBID"]))
print("Cluster:")
print(get_hostgraph_url(os.environ["MASTER_BJOB_ID"]))

This notebook:
http://lsf-rtm/cacti/plugins/grid/grid_bjobs.php?action=viewjob&tab=hostgraph&clusterid=1&indexid=0&jobid=44106114&submit_time=1532486124
Cluster:
http://lsf-rtm/cacti/plugins/grid/grid_bjobs.php?action=viewjob&tab=hostgraph&clusterid=1&indexid=0&jobid=44106116&submit_time=1532486137


### UUIDs

In [13]:
# The starting agglo
initial_agglo = DvidInstanceInfo('emdata3:8900', 'ac901', 'segmentation')

# The uuid used when loading the neo4j instance (for 'important bodies')
neo4j_reference = DvidInstanceInfo('emdata3:8900', '52f9', 'segmentation')

# The last supervoxel splits: One past the neo4j node
analysis_node = DvidInstanceInfo('emdata3:8900', '662e', 'segmentation')

# We won't be using this...
current_master = DvidInstanceInfo('emdata3:8900', 'f545', 'segmentation')

### Load split SVs

In [14]:
leaf_fragment_svs, retired_svs = fetch_supervoxel_fragments(analysis_node, 'kafka')
retired_svs = set(retired_svs)
split_events = fetch_supervoxel_splits(analysis_node, 'kafka')
split_mapping = split_events_to_mapping(split_events)

Reading kafka messages from kafka.int.janelia.org:9092 for emdata3:8900 / 662e / segmentation
Reading 166737 kafka messages took 7.26911473274231 seconds
Reading kafka messages from kafka.int.janelia.org:9092 for emdata3:8900 / 662e / segmentation
Reading 166737 kafka messages took 6.93674635887146 seconds


### Load neo4j-defined important bodies; append final splits

In [16]:
# This list was generated from node 52f9
important_bodies_path = '/nrs/flyem/bergs/complete-ffn-agglo/bodies-0.5-including-psds-from-neuprint-52f9.csv'
important_bodies = pd.read_csv(important_bodies_path, header=0, usecols=['bodyid'], dtype=np.uint64)['bodyid']
important_bodies = set(important_bodies)

# Read last set of new bodies (from analysis node, after neo4j was loaded).
msgs = read_kafka_messages(analysis_node, 'split', 'leaf-only')
final_new_bodies = set(chain(*((msg['Target'], msg['NewLabel']) for msg in msgs)))

# Append final set
important_bodies |= final_new_bodies

Reading kafka messages from kafka.int.janelia.org:9092 for emdata3:8900 / 662e / segmentation
Reading 166737 kafka messages took 6.982581615447998 seconds


### CSV definitions

In [17]:
# label_a, label_b -- the two supervoxel IDs
# xa, ya, za -- point from which segmentation of 'a' was started, 8 nm coordinates
# xb, yb, zb -- point from which segmentation of 'b' was started, 8 nm coordinates
# caa, cab, cba, cbb -- cXY means: fraction of voxels from the original segment Y recovered when seeding from X
# iou -- Jaccard index of the two local segmentations
# da, db -- dX means: fraction of voxels that changed value from >0.8 to <0.5 when segmenting & seeding from X;
#                     the higher this value is, the more "internally inconsistent" the segmentation resolution
#                     potentially is; higher thresholds for iou, cXY might be warranted

csv_dtypes = { 'id_a': np.uint64, 'id_b': np.uint64, # Use'id_a', and 'id_b' for consistency with our other code.
               'xa': np.int32, 'ya': np.int32, 'za': np.int32,
               'xb': np.int32, 'yb': np.int32, 'zb': np.int32,
               'caa': np.float32, 'cab': np.float32, 'cba': np.float32, 'cbb': np.float32,
               'iou': np.float32,
               'da': np.float32, 'db': np.float32 }

TOTAL_EDGE_COUNT = 0

### Convert to numpy

In [54]:
csv_paths = {}

for res in [32, 16, 8]:
    csv_paths[res] = [os.path.abspath(f'{res}nm/data-000{i:02d}-of-00100.csv')
                          for i in range(100)]

def save_as_npy(resolution, csv_path):
    """
    Convert the given CSV edge table to .npy format,
    and append a column for 'resolution' in the process.
    """
    npy_path = os.path.splitext(csv_path)[0] + '.npy'
    df = pd.read_csv(csv_path, header=None, names=list(csv_dtypes.keys()), dtype=csv_dtypes)
    df['resolution'] = np.uint8(res)
    np.save(npy_path, df.to_records(index=False))
    return len(df)

TOTAL_EDGE_COUNT = 0
for res, paths in csv_paths.items():
    with Timer(f"Converting {res}nm files to npy"):
        counts = sc.parallelize(paths).map(lambda p: save_as_npy(res, p)).collect()
        TOTAL_EDGE_COUNT += sum(counts)

Converting 32nm files to npy...
Converting 32nm files to npy took 0:00:05.185719
Converting 16nm files to npy...
Converting 16nm files to npy took 0:00:20.241233
Converting 8nm files to npy...
Converting 8nm files to npy took 0:00:38.211370


In [17]:
orig_npy_paths = (  sorted(glob.glob('32nm/*.npy'))
                  + sorted(glob.glob('16nm/*.npy'))
                  + sorted(glob.glob('8nm/*.npy')))
orig_npy_paths = list(map(os.path.abspath, orig_npy_paths))

In [18]:
if TOTAL_EDGE_COUNT == 0:
    def npy_len(path):
        return len(np.load(path))
    paths = map(os.path.abspath, glob.glob(f'*nm/*.npy'))
    TOTAL_EDGE_COUNT = sc.parallelize(paths).map(npy_len).sum()
print(f"TOTAL_EDGE_COUNT: {TOTAL_EDGE_COUNT}")

KeyboardInterrupt: 

In [20]:
first = np.load('32nm/data-00000-of-00100.npy')
TABLE_DTYPE = first.dtype
EDGE_NBYTES = first[0].nbytes
TOTAL_NBYTES = EDGE_NBYTES * TOTAL_EDGE_COUNT
del first
print(f"Total GB: {TOTAL_NBYTES / 1e9}")

Total GB: 113.233191336


In [18]:
##%time combined_edge_table = np.fromiter(chain(*(np.load(p) for p in all_npy_files)), TABLE_DTYPE, TOTAL_EDGE_COUNT)

In [35]:
# Repair coordinates for split supervoxels
os.makedirs('split-coords-fixed/32nm', exist_ok=True)
os.makedirs('split-coords-fixed/16nm', exist_ok=True)
os.makedirs('split-coords-fixed/8nm', exist_ok=True)

def repair_coords_on_splits(orig_npy_path):
    """
    Read the given original npy path, repair coordinates for
    edges mentioning anything in the retired_svs set,
    and save the repaired file to a different directory.
    """
    df = pd.DataFrame(np.load(orig_npy_path))
    print("Selecting retired supervoxels")
    retired_svs # Reference this variable to ensure that it gets captured when pickling this function.
    rows_to_fix = df.eval('(id_a in @retired_svs) or (id_b in @retired_svs)')
    print(f"Found {rows_to_fix.sum()} rows")
    
    fixed_points = []
    df_to_fix = df[rows_to_fix]
    for row in tqdm(df_to_fix.itertuples(), total=len(df_to_fix)):
        new_points = None
        coord_a = np.array((row.za, row.ya, row.xa))
        coord_b = np.array((row.zb, row.yb, row.xb))
        avg_coord = (coord_a + coord_b) // 2

        for search_radius in [64, 128, 256]:
            box = np.array(( avg_coord - search_radius,
                             avg_coord + search_radius ))

            sv_vol = fetch_labelarray_voxels(initial_agglo, box, supervoxels=True)
            
            # Try finding a touch point
            touching_points = np.array(find_best_plane(sv_vol, row.id_a, row.id_b))
            if not (touching_points == -1).all():
                new_points = touching_points + box[0]
                break
            
            # Try finding "closest approach" instead.
            if (row.id_a in sv_vol.flat) and (row.id_b in sv_vol.flat):
                # both ids are present in the volume,
                # but they are not touching.
                # Find the points that minimally separate them.
                point_a, point_b = closest_approach(sv_vol, row.id_a, row.id_b)
                new_points = np.array((point_a, point_b)) + box[0]
                break


        if new_points is None:
            # The bodies are so far apart that we couldn't find a "closst approach"
            # If the original points are at least on the correct supervoxels,
            # settle for that.
            if ( fetch_label_for_coordinate(initial_agglo, coord_a, True) == row.id_a
             and fetch_label_for_coordinate(initial_agglo, coord_b, True) == row.id_b ):
                new_points = np.array((coord_a, coord_b))
            else:
                # Couldn't find good points via any method at any radius.
                # Indicate this by negating the coordinates.
                new_points = np.array((-coord_a, -coord_b))

        fixed_points.append( new_points )
        
    fixed_points = np.array(fixed_points)
    df.loc[rows_to_fix, ['za', 'ya', 'xa']] = fixed_points[:,0,:]
    df.loc[rows_to_fix, ['zb', 'yb', 'xb']] = fixed_points[:,1,:]
    
    parts = orig_npy_path.split('/')
    parts.insert(-2, 'split-coords-fixed')
    new_npy_path = '/'.join(parts)
    np.save(new_npy_path, df.to_records(index=False))


In [33]:
# df = pd.DataFrame(np.load('32nm/data-00000-of-00100.npy'))
# print("Selecting retired supervoxels")
# rows_to_fix = df.eval('(id_a in @retired_svs) or (id_b in @retired_svs)')
# print(f"Found {rows_to_fix.sum()} rows")
# df_to_fix = df[rows_to_fix]
# df_to_fix.iloc[37:38]

In [28]:
#repair_coords_on_splits('32nm/data-00000-of-00100.npy')

In [36]:
sc.parallelize(orig_npy_paths).foreach(repair_coords_on_splits)
print("Done.")

Done.


In [20]:
fixed_npy_paths = (  sorted(glob.glob('split-coords-fixed/32nm/*.npy'))
                   + sorted(glob.glob('split-coords-fixed/16nm/*.npy'))
                   + sorted(glob.glob('split-coords-fixed/8nm/*.npy')))
fixed_npy_paths = list(map(os.path.abspath, fixed_npy_paths))

In [21]:
%%time
def count_unfixable(npy_path):
    return (np.load(npy_path)['za'] < 0).sum()
unfixable_count = sc.parallelize(fixed_npy_paths).map(count_unfixable).sum()

CPU times: user 26.1 ms, sys: 10.6 ms, total: 36.7 ms
Wall time: 11 s


### Relabel table SVs from init agglo to current master
(and drop bad edges)

In [126]:
os.makedirs('updated-tables/32nm', exist_ok=True)
os.makedirs('updated-tables/16nm', exist_ok=True)
os.makedirs('updated-tables/8nm', exist_ok=True)

# Replace old SV ids with updated IDs by sampling from those coordinates.
def remap_split_svs(npy_path):
    df = pd.DataFrame(np.load(npy_path))
    assert df['id_a'].dtype == np.uint64
    assert df['id_b'].dtype == np.uint64

    retired_svs # Reference this variable to ensure that it gets captured when pickling this function.
    rows_to_fix = df.eval('(id_a in @retired_svs) or (id_b in @retired_svs)')

    fixed_ids = []
    df_to_fix = df[rows_to_fix]
    for row in tqdm(df_to_fix.itertuples(), total=len(df_to_fix)):
        id_a, id_b = row.id_a, row.id_b
        if id_a in retired_svs:
            id_a = fetch_label_for_coordinate(analysis_node, (row.za, row.ya, row.xa), supervoxels=True)
        if id_b in retired_svs:
            id_b = fetch_label_for_coordinate(analysis_node, (row.zb, row.yb, row.xb), supervoxels=True)
        fixed_ids.append( (id_a, id_b) )

    df.loc[rows_to_fix, ['id_a', 'id_b']] = np.array(fixed_ids, np.uint64)
    assert df['id_a'].dtype == np.uint64
    assert df['id_b'].dtype == np.uint64

    parts = npy_path.split('/')
    assert parts[-3] == 'split-coords-fixed'
    parts[-3] = 'updated-tables'
    new_npy_path = '/'.join(parts)
    np.save(new_npy_path, df.to_records(index=False))

    return rows_to_fix.sum()

In [28]:
#remap_split_svs(fixed_npy_paths[0])

In [26]:
%time updated_row_count = sc.parallelize(fixed_npy_paths).map(remap_split_svs).sum()

CPU times: user 204 ms, sys: 46.3 ms, total: 250 ms
Wall time: 13min 29s


In [27]:
updated_row_count

2081409

### Body mapping

In [None]:
mapping = fetch_mappings(analysis_node)

In [73]:
%time mapper = LabelMapper(mapping.index.values, mapping.values)

CPU times: user 24 s, sys: 471 ms, total: 24.4 s
Wall time: 24.3 s


### Filter

In [None]:
os.makedirs('filtered-tables/32nm', exist_ok=True)
os.makedirs('filtered-tables/16nm', exist_ok=True)
os.makedirs('filtered-tables/8nm', exist_ok=True)

# Replace old SV ids with updated IDs by sampling from those coordinates.
def apply_mapping_and_filter_to_partition(paths):
    # Must create mapper here since it cannot be pickled.
    mapper = LabelMapper(mapping.index.values, mapping.values)

    def apply_mapping_and_filter(npy_path):
        df = pd.DataFrame(np.load(npy_path))

        # A bug above caused the type to be int64. Fix that now.
        df['id_a'] = df['id_a'].astype(np.uint64)
        df['id_b'] = df['id_b'].astype(np.uint64)
        
        df['body_a'] = mapper.apply(df['id_a'].values, allow_unmapped=True)
        df['body_b'] = mapper.apply(df['id_b'].values, allow_unmapped=True)

        important_bodies # Referenced to ensure capture in this closure

        # Drop internal edges,
        # Filter for important bodies (on at least one end -- capture 1-hop and 2-hop)
        q = '(body_a != body_b) and ((body_a in @important_bodies) or (body_b in @important_bodies))'
        df.query(q, inplace=True)

        parts = npy_path.split('/')
        assert parts[-3] == 'updated-tables'
        parts[-3] = 'filtered-tables'
        new_npy_path = '/'.join(parts)
        np.save(new_npy_path, df.to_records(index=False))

        return len(df)
    
    return list(map(apply_mapping_and_filter, paths))

In [None]:
updated_npy_paths = (  sorted(glob.glob('updated-tables/32nm/*.npy'))
                     + sorted(glob.glob('updated-tables/16nm/*.npy'))
                     + sorted(glob.glob('updated-tables/8nm/*.npy')))
updated_npy_paths = list(map(os.path.abspath, updated_npy_paths))

In [135]:
%%time 
filtered_row_count = (sc.parallelize(updated_npy_paths)
                        .mapPartitions(apply_mapping_and_filter_to_partition)
                        .sum())

CPU times: user 23.6 s, sys: 4.18 s, total: 27.7 s
Wall time: 2min 45s


In [136]:
print(filtered_row_count)

755536110


In [137]:
filtered_npy_paths = (  sorted(glob.glob('filtered-tables/32nm/*.npy'))
                      + sorted(glob.glob('filtered-tables/16nm/*.npy'))
                      + sorted(glob.glob('filtered-tables/8nm/*.npy')))
filtered_npy_paths = list(map(os.path.abspath, filtered_npy_paths))

In [138]:
combined_table = np.concatenate(list(map(np.load, tqdm(filtered_npy_paths))))

100%|██████████| 300/300 [01:14<00:00,  4.02it/s]


In [147]:
print(combined_table.shape[0] / 1e6, "M")
print(combined_table.nbytes / 1e9, "GB")

755.53611 M
64.22056935 GB


In [None]:
combined_df = pd.DataFrame(combined_table)

In [154]:
%time np.save('combined-filtered-table.npy', combined_table)

CPU times: user 3.28 s, sys: 1min 7s, total: 1min 10s
Wall time: 2min 50s


In [151]:
ls

[0m[38;5;27m16nm[0m/                               [38;5;27mnotebook-cluster--20180723.094528[0m/
[38;5;27m32nm[0m/                               [38;5;27mnotebook-cluster--20180723.094651[0m/
[38;5;27m8nm[0m/                                [38;5;27mnotebook-cluster--20180723.180735[0m/
bodies-0.5-from-neuprint-52f9.csv   spark-focused.ipynb
[38;5;27mfiltered-tables[0m/                    [38;5;27msplit-coords-fixed[0m/
[38;5;27mnotebook-cluster--20180722.201030[0m/  [38;5;27mupdated-tables[0m/
