In [1]:
import pynbody
import numpy as np
import os, subprocess

In [2]:
import matplotlib.pylab as plt
plt.style.use('fivethirtyeight')
%matplotlib inline

In [3]:
import findspark
findspark.init()

import pyspark
from pyspark import SparkConf, SparkContext
from pyspark.sql import Row, SQLContext
from pynbody import util, units, family
from pyspark.sql.types import ArrayType, DoubleType, FloatType, NullType, StructType, StructField, StringType, ByteType

# set up config directory
os.environ['SPARK_CONF_DIR'] = os.path.realpath('./spark_config')

# set up the submit arguments
files = "--files %s/metrics.properties"%os.environ['SPARK_CONF_DIR']
packages = "--packages com.databricks:spark-csv_2.10:1.3.0"
shell = "pyspark-shell"

os.environ['PYSPARK_SUBMIT_ARGS'] = " ".join([files,packages,shell])

# how many cores do we have for the driver
ncores = int(os.environ.get('LSB_DJOB_NUMPROC', 1)) 

# here we set the memory we want spark to use for the driver JVM
#os.environ['SPARK_DRIVER_MEMORY'] = '%dG'%(ncores*2*0.7)
os.environ['SPARK_DRIVER_MEMORY'] = '2g'
# we have to tell spark which python executable we are using
os.environ['PYSPARK_PYTHON'] = subprocess.check_output('which python', shell=True).rstrip()


conf = SparkConf()
exec_cores = 4
num_execs = 100

conf.set('spark.executor.instances', str(num_execs))
conf.set('spark.executor.cores', str(exec_cores))

try: 
    sc.stop()
except: 
    pass
sc = SparkContext(master='yarn-client', conf=conf)
sqc = SQLContext(sc)

### Some ugly code for reading in the particles...

In [7]:
def make_batches(n_parts, n_batches): 
    batch_size = n_parts/n_batches
    curr_start = 0
    while curr_start < n_parts:
        yield((curr_start, min(curr_start+batch_size, n_parts+1)))
        curr_start += batch_size

In [8]:

class SparkTipsySnap(pynbody.tipsy.TipsySnap) :
    
    def _load_main_file(self):
        logger.info("Loading data from main file %s", self._filename)

        f = util.open_(self._filename, 'rb')
        f.seek(32)

        max_item_size = max(
            [q.itemsize for q in self._g_dtype, self._d_dtype, self._s_dtype])
        tbuf = bytearray(max_item_size * 10240)

        for fam, dtype in ((family.gas, self._g_dtype), (family.dm, self._d_dtype), (family.star, self._s_dtype)):
            self_fam = self[fam]
            st_len = dtype.itemsize
            for readlen, buf_index, mem_index in self._load_control.iterate([fam], [fam], multiskip=True):
                # Read in the block
                if mem_index is None:
                    f.seek(st_len * readlen, 1)
                    continue

                buf = np.fromstring(f.read(st_len * readlen), dtype=dtype)
                
                if self._byteswap:
                    buf = buf.byteswap()
                
                yield buf, buf_index, fam
                
    
fam_lookup = {family.dm:'d', family.gas: 'g', family.star:'s'}

def buf_to_row(buf, buf_index, names, fam) : 
    if type(buf_index) == slice:
        buf_index = xrange(buf_index.start, buf_index.stop)
    for i in buf_index :
        d = {name:float(buf[name][i]) for name in buf.dtype.names} # this should be an OrderedDict
        for name in names: 
            if name not in d: 
                d[name] = np.nan
                d['fam'] = fam_lookup[fam]
        yield Row(**d) 

In [9]:
def load_partition(filename, batch_iter, names) : 
    for batch in batch_iter :
        s = SparkTipsySnap(filename, take = xrange(*batch))
        loader = s._load_main_file()
        for buf, bi, fam in loader : 
            for row in buf_to_row(buf, bi, names, fam) : 
                yield row
            del(buf)
        del(s)

### Data read and conversion to `DataFrame`

In [10]:
filename = '/cluster/home03/sdid/roskarr/work/testing/cosmo25cmb.768g2_dm.001024'

In [11]:
s = SparkTipsySnap(filename)

In [12]:
names = set(s._g_dtype.names) | set(s._s_dtype.names) | set(s._d_dtype.names)

In [13]:
n_parts = len(s)

In [14]:
batches = make_batches(n_parts, 800-1)

In [15]:
batches_rdd = sc.parallelize(batches, 800)

In [16]:
#sc.addPyFile('/cluster/home03/sdid/roskarr/src/pynbody/dist/pynbody-0.31-py2.7-linux-x86_64.egg')

In [17]:
sim_rdd = batches_rdd.mapPartitions(lambda iterator: load_partition(filename, iterator, names))

In [18]:
schema = StructType(fields = [StructField('eps', FloatType(), True), 
                               StructField('fam', StringType(), True),
                               StructField('mass', FloatType(), True),
                               StructField('metals', FloatType(), True),
                               StructField('phi', FloatType(), True),
                              StructField('rho', FloatType(), True),
                              StructField('temp', FloatType(), True),
                              StructField('tform', FloatType(), True),
                              StructField('vx', FloatType(), True),
                              StructField('vy', FloatType(), True),
                              StructField('vz', FloatType(), True),
                              StructField('x', FloatType(), True),
                              StructField('y', FloatType(), True),
                              StructField('z', FloatType(), True),])

In [19]:
schema_new = StructType(fields = [StructField('mass', FloatType(), True), 
                                  StructField('x', FloatType(), True),
                                  StructField('y', FloatType(), True),
                                  StructField('z', FloatType(), True),
                                  StructField('vx', FloatType(), True),
                                  StructField('vy', FloatType(), True),
                                  StructField('vz', FloatType(), True),
                                  StructField('eps', FloatType(), True),
                                  StructField('phi', FloatType(), True)])

In [20]:
sim_rdd.first()

Row(eps=1.3799999578623101e-05, fam='g', mass=1.0154856816546598e-10, metals=0.0, phi=0.099265918135643, rho=0.0, temp=500.0, tform=nan, vx=0.14497043192386627, vy=-0.048239488154649734, vz=-0.08926118165254593, x=-0.41583994030952454, y=-0.4288076162338257, z=0.42203330993652344)

In [22]:
df = sim_rdd.toDF(schema)

In [23]:
df.select('mass', 'x', 'y', 'z').show()

+-------------+-----------+-----------+----------+
|         mass|          x|          y|         z|
+-------------+-----------+-----------+----------+
|1.0154857E-10|-0.41583994|-0.42880762| 0.4220333|
|1.0154857E-10| -0.4170401| -0.4275199|0.42231753|
|1.0154857E-10|-0.41655722|-0.42766565| 0.4223259|
|1.0154857E-10|-0.41671997|-0.42697328| 0.4228194|
|1.0154857E-10| -0.4170576|-0.42447266| 0.4226917|
|1.0154857E-10|-0.41692457| -0.4255475|0.42324886|
|1.0154857E-10|-0.41573152| -0.4264287|0.42298603|
|1.0154857E-10|-0.41570926|-0.42558408|0.42263356|
|1.0154857E-10|-0.41670665|-0.42773277| 0.4234212|
|1.0154857E-10| -0.4164624|-0.42720002| 0.4232937|
|1.0154857E-10|-0.41709548|-0.42287406|0.42332387|
|1.0154857E-10|-0.41609466|-0.42538744|0.42336327|
|1.0154857E-10|-0.41849965|-0.42489263|0.42386532|
|1.0154857E-10|-0.41691625|-0.42734033|0.42359728|
|1.0154857E-10|-0.41441837|-0.43146068|0.42128232|
|1.0154857E-10|-0.41666746|-0.42466545|0.42309728|
|1.0154857E-10|-0.41687754| -0.

In [24]:
df.cache().count()

1981808640

In [25]:
df.write.mode('overwrite').parquet('/user/roskarr/nbody/cosmo25cmb.768g2_dm/cosmo25cmb.768g2_dm.001024.parquet')

In [26]:
df.unpersist()
del(df)

### Reading in the full dataset from a distributed parquet file

Reading in from parquet is now much faster because we don't have to convert and format the data... 

In [27]:
%%time
df = sqc.read.parquet('/user/roskarr/nbody/cosmo25cmb.768g2_dm/cosmo25cmb.768g2_dm.001024.parquet')
df.cache().count()

CPU times: user 4 ms, sys: 6 ms, total: 10 ms
Wall time: 34.7 s


In [28]:
df.printSchema()

root
 |-- eps: float (nullable = true)
 |-- fam: string (nullable = true)
 |-- mass: float (nullable = true)
 |-- metals: float (nullable = true)
 |-- phi: float (nullable = true)
 |-- rho: float (nullable = true)
 |-- temp: float (nullable = true)
 |-- tform: float (nullable = true)
 |-- vx: float (nullable = true)
 |-- vy: float (nullable = true)
 |-- vz: float (nullable = true)
 |-- x: float (nullable = true)
 |-- y: float (nullable = true)
 |-- z: float (nullable = true)



In [29]:
%time df.count()

CPU times: user 2 ms, sys: 0 ns, total: 2 ms
Wall time: 2.62 s


1981808640

### Sub-sampling for visualization

To visualize the data, it must be brought to the driver and therefore sub-sampled. 

In [None]:
sampled = df.select('x','y','z').sample(False, 0.0001, 1).toPandas()

In [None]:
plt.figure(figsize=(10,10))
plt.plot(sampled['x'], sampled['y'], '.', alpha = .02)
plt.xlim(-.5,.5); plt.ylim(-.5,.5);

### Creating a filtered dataset 

Lets say we want to focus on the big blob around `x = 0` and `y = 0.3`:

In [None]:
filt_string = 'abs(x)< cast(.1 as float) and y < cast(-.2 as float) and y > cast(-.4 as float)'

In [None]:
%%time
filtered = df.filter(filt_string)
print '%e'%filtered.count()

In [None]:
sampled = filtered.select('x','y').sample(False, 0.001, 1).toPandas()

In [None]:
plt.figure(figsize=(10,10))
plt.plot(sampled['x'], sampled['y'], '.', alpha = .01)

In [None]:
from pyspark.mllib.clustering import KMeans

In [None]:
points = filtered.select('x', 'y', 'z', 'vx', 'vy', 'vz').rdd.map(lambda r: np.array(r))

In [None]:
points.cache().count()

### Data read using filtering at the source

Because each parquet file keeps metadata about data ranges of each column, many files can be completely skipped in some cases, resulting in much more efficient data-ingestion. 

In [None]:
df2 = sqc.read.parquet('/user/roskarr/nbody/cosmo25cmb.768g2_dm/cosmo25cmb.768g2_dm.001024.parquet')
filtered2 = df2.filter(filt_string)

In [None]:
filtered2.explain()

In [None]:
%%time
filtered2.count()

Note that this is faster than what was done above, i.e. filtering the full cached dataset. In the previous case, every element of the dataset still needed to be touched, whereas in this case large chunks of the data were automatically excluded completely. 

In [None]:
points = filtered2.select('x', 'y', 'z', 'vx', 'vy', 'vz').rdd.map(lambda r: np.array(r))
points.cache().count()

### Training a model on the data

For fun, lets train a K-Means model on the position and velocity data in the filtered area.

In [None]:
model = KMeans.train(points, 100)

In [None]:
cluster_rdd = (points.sample(False, 0.1, 1)
                     .map(lambda vec: (model.predict(vec),vec))
                     .filter(lambda (cluster, vec): cluster < 10))

In [None]:
cluster_subsample = cluster_rdd.takeSample(False, 100000, 1)

In [None]:
colors = [u'#a6cee3', u'#1f78b4', u'#b2df8a', u'#33a02c', u'#fb9a99', u'#e31a1c', u'#fdbf6f', u'#ff7f00', u'#cab2d6', u'#6a3d9a']

In [None]:
from matplotlib.colors import ColorConverter

cc = ColorConverter()

colors_k100 = [cc.to_rgba(colors[cluster]) for (cluster, vec) in cluster_subsample]

In [None]:
vecs = np.vstack((vec for (cluster,vec) in cluster_subsample))

In [None]:
plt.rcParams['font.size'] = 18

In [None]:
plt.figure(figsize=(10,10))
plt.scatter(vecs[:,0], vecs[:,1], c= colors_k100, alpha=0.2)
plt.xlim(-.1,.1); plt.ylim(-.4,-.2);