## test hierarchical merge in FOF algorithm

#### the idea is this: 

* after the local FOF stage, each partition reports the particles it holds in the overlap region
* do a reduceByKey or treeAggregate of some sort to collect the groups belonging to the same particles
* produce a mapping of $G -> G_1$ and distribute to all hosts in form of broadcast lookup table

In [1]:
import numpy as np
import sys
sys.setrecursionlimit(sys.getrecursionlimit()*10)


# import matplotlib.pylab as plt
# %matplotlib inline
# import matplotlib.patches as patches
# plt.style.use('bmh')

In [2]:
%load_ext line_profiler
import line_profiler

from Cython.Compiler.Options import directive_defaults

directive_defaults['linetrace'] = True
directive_defaults['binding'] = True

In [3]:
import spark_fof
import spark_fof_c
from fof import fof
%load_ext Cython

In [4]:
def plot_rectangle(rec, ax=None):
    if ax is None: 
        ax = plt.subplot(aspect='equal')
    
    if isinstance(rec, (list, tuple)):
        for r in rec: 
            plot_rectangle(r,ax)
    
    else:
        size = (rec.maxes-rec.mins)
        ax.add_patch(patches.Rectangle(rec.mins, size[0], size[1], fill=False, zorder=-1))

## Set up data

In [5]:
# create the arrays
from spark_fof_c import pdt
pdt_tipsy = np.dtype([('mass', 'f4'),('pos', 'f4', 3),('vel', 'f4', 3), ('eps', 'f4'), ('phi', 'f4')])
# nps = 1000000
# ngs = 1
# particles = np.zeros(nps, dtype=pdt)
# done_ps = 0
# #centers = np.random.rand(ngs,3)*1.7 - 0.85
# centers = np.array([0,0,0]).reshape(1,3)
# for group, center in zip(range(ngs), centers): 
#     print group, center
#     group_ps = nps/ngs
#     if nps - (done_ps + group_ps) < group_ps:
#         group_ps = nps - done_ps 
#     particles['pos'][done_ps:done_ps+group_ps] = \
#         np.random.multivariate_normal(center, [[.5,0,0],[0,.5,0],[0,0,.5]], group_ps)
#     done_ps += group_ps
   
# particles['iOrder'] = range(nps)

In [6]:
from spark_fof_c import pdt

## Start Spark

In [7]:
import findspark
findspark.init()

In [8]:
import os
os.environ['SPARK_CONF_DIR'] = './conf'
os.environ['SPARK_DRIVER_MEMORY'] = '4G'

In [9]:
import pyspark
from pyspark import SparkContext, SparkConf
import pynbody



In [10]:
conf = SparkConf()

conf.set('spark.python.profile', 'true')
conf.set('spark.executor.memory', '3G')
conf.set('spark.driver.memory', '4G')

<pyspark.conf.SparkConf at 0x1172c1c50>

In [11]:
sc = SparkContext(master='local[4]', conf=conf)

In [12]:
sc.addPyFile('spark_fof.py')
sc.addPyFile('spark_fof_c.pyx')
sc.addPyFile('spark_fof_c.c')
sc.addPyFile('spark_fof_c.so')
sc.addPyFile('fof.so')

## Set up the domains

In [13]:
N = 2
tau = 7.8125e-4
mins = np.array([-.5,-.5,-.5])
maxs= np.array([.5,.5,.5])
domain_containers = spark_fof.setup_domain(N,tau,maxs,mins)

In [14]:
# f, ax = plt.subplots(subplot_kw={'aspect':'equal'}, figsize=(15,15))
# pynbody.plot.image(s.d, width=1, units = 'Msol Mpc^-2', cmap=plt.cm.Greys, show_cbar=False, subplot=ax)
# #plot_rectangle(domain_containers[0].bufferRectangle, ax=ax)
# for p in particles[::1000000]: 
#     plot_rectangle(domain_containers[spark_fof.get_bin_cython(p['pos'],2**N, np.array(mins), np.array(maxs))], ax=ax)
#     plot_rectangle(domain_containers[spark_fof.get_bin_cython(p['pos'], 2**N, np.array(mins),np.array(maxs))].bufferRectangle, ax=ax)
#     ax.plot(p['pos'][0], p['pos'][1], '.')
# plt.draw()
# ax.set_xlim(-.5,.5)
# ax.set_ylim(-.5,.5)

### Make the base RDD

In [15]:
from pyspark.accumulators import AccumulatorParam

class dictAdd(AccumulatorParam):
    def zero(self, value):
        return {i:0 for i in range(len(value))}
    def addInPlace(self, val1, val2): 
        for k, v in val2.iteritems(): 
            val1[k] += v
        return val1
    
def read_tipsy_output(filename, chunksize = 2048): 
    """
    Read a tipsy file and set the sequential particle IDs
    
    This scans through the data twice -- first to get partition particle counts
    and a second time to actually set the particle IDs.
    """
    
    # helper functions
    def convert_to_fof_particle(s): 
        p_arr = np.frombuffer(s, pdt_tipsy)

        new_arr = np.zeros(len(p_arr), dtype=pdt)
        new_arr['pos'] = p_arr['pos']    
        return new_arr

    def convert_to_fof_particle_partition(index, iterator): 
        for s in iterator: 
            a = convert_to_fof_particle(s)
            if count: 
                npart_acc.add({index: len(a)})
            yield a

    def set_particle_IDs_partition(index, iterator): 
        p_counts = partition_counts.value
        local_index = 0
        start_index = sum([p_counts[i] for i in range(index)])
        for arr in iterator:
            arr['iOrder'] = range(start_index + local_index, start_index + local_index + len(arr))
            local_index += len(arr)
            yield arr
    
    rec_rdd = sc.binaryRecords(filename, pdt_tipsy.itemsize*chunksize)
    nPartitions = rec_rdd.getNumPartitions()
    # set the partition count accumulator
    npart_acc = sc.accumulator({i:0 for i in range(nPartitions)}, dictAdd())
    count=True
    # read the data and count the particles per partition
    rec_rdd = rec_rdd.mapPartitionsWithIndex(convert_to_fof_particle_partition)
    rec_rdd.count()
    count=False

    partition_counts = sc.broadcast(npart_acc.value)

    return rec_rdd.mapPartitionsWithIndex(set_particle_IDs_partition)

In [16]:
import spark_tipsy
reload(spark_tipsy)
p_rdd = (spark_tipsy.read_tipsy_output(sc, '/Users/rok/polybox/euclid256.nat_no_header', chunksize=1024*4)
         .map(lambda x: spark_fof_c.ghost_mask(x, domain_containers, tau, mins, maxs)))

In [17]:
ps = np.concatenate(p_rdd.collect())

In [18]:
assert(len(ps) == len(pynbody.load('/Users/rok/polybox/euclid256.nat')))


nMinMembers = 8
n_groups = fof.run(ps, tau, nMinMembers)
print 'number of groups to %d particle = %d'%(nMinMembers, n_groups)



number of groups to 8 particle = 105761


### Partition particles into domains and set the partition part of local group ID

In [19]:
def partition_wrapper(particle_iterator): 
    for particle_array in particle_iterator: 
        res = spark_fof_c.new_partitioning_cython(particle_array, domain_containers, tau, mins, maxs)
        for r in res: 
            yield r

In [20]:
part_rdd = (p_rdd.mapPartitions(partition_wrapper)).partitionBy(len(domain_containers)).values()

In [21]:
# ps = part_rdd.glom().collect()

In [22]:
# part_rdd2 = (p_rdd.mapPartitions(lambda particles: spark_fof_c.partition_particles_cython(particles, domain_containers, tau, mins, maxs))
#                  .partitionBy(len(domain_containers))
#                  .values())

In [23]:
# part_rdd2.cache().count()

In [24]:
#part_rdd.map(lambda x: len(x)).reduce(lambda a,b:a+b)

### Run the local FOF

In [25]:
from fof import fof

In [26]:
def run_local_fof(partition_index, particle_iter, tau, nMinMembers, batch_size=1024*256): 
    part_arr = np.hstack(particle_iter)
    if len(part_arr)>0:
        # run fof
        fof.run(part_arr, tau, nMinMembers)
        
        # encode the groupID  
        spark_fof_c.encode_gid(part_arr, partition_index)
        
    for arr in np.split(part_arr, range(batch_size,len(part_arr),batch_size)):
        yield arr

In [27]:
fof_rdd = part_rdd.mapPartitionsWithIndex(lambda index, particles: run_local_fof(index, particles, tau, 1)).cache()

### Group Merging stage

In [28]:
fof_analyzer = spark_fof.FOFAnalyzer(sc, N, tau, fof_rdd, [-.5,-.5,-.5], [.5,.5,.5])

merged_rdd = fof_analyzer.merge_groups(0)

merged = merged_rdd.collect()

merged_arr = np.concatenate(merged)

groups = np.unique(merged_arr['iGroup'])

In [30]:
n_groups

105761

In [31]:
len(groups)

7251094

In [32]:
fof_analyzer.finalize_groups()

PythonRDD[14] at collect at <ipython-input-28-6803381c4c70>:5

In [36]:
fof_analyzer.sorted_groups

[(1, 70256),
 (2, 60327),
 (3, 54174),
 (4, 50349),
 (5, 48172),
 (6, 45140),
 (7, 41820),
 (8, 40506),
 (9, 31893),
 (10, 31465),
 (11, 24960),
 (12, 24958),
 (13, 23763),
 (14, 23482),
 (15, 21812),
 (16, 20961),
 (17, 20737),
 (18, 20587),
 (19, 20005),
 (20, 19828),
 (21, 19727),
 (22, 19058),
 (23, 18450),
 (24, 17418),
 (25, 17199),
 (26, 17109),
 (27, 16406),
 (28, 16093),
 (29, 15820),
 (30, 15690),
 (31, 15406),
 (32, 15265),
 (33, 15057),
 (34, 14796),
 (35, 14468),
 (36, 14361),
 (37, 14336),
 (38, 13756),
 (39, 13500),
 (40, 13437),
 (41, 13350),
 (42, 13334),
 (43, 13172),
 (44, 13039),
 (45, 12987),
 (46, 12346),
 (47, 12285),
 (48, 12245),
 (49, 12175),
 (50, 12163),
 (51, 12158),
 (52, 12079),
 (53, 12049),
 (54, 11951),
 (55, 11935),
 (56, 11896),
 (57, 11810),
 (58, 11643),
 (59, 11461),
 (60, 11333),
 (61, 11210),
 (62, 10849),
 (63, 10659),
 (64, 10516),
 (65, 10472),
 (66, 10452),
 (67, 10268),
 (68, 10245),
 (69, 10176),
 (70, 10023),
 (71, 9907),
 (72, 9624),
 (7