In [1]:
import findspark
import numpy as np

In [2]:
findspark.init()
import pyspark

In [3]:
from cPickle import dumps, loads
import cPickle

In [None]:
def print_types(obj): 
    if isinstance(obj, (list, tuple)):
        return ' '.join(map(print_types, obj))
    else: 
        return str(type(obj))

import cPickle as pickle
import marshal 

def check_for_numpy(obj): 
    if isinstance(obj, np.ndarray): 
        return True
    elif isinstance(obj, (list,tuple)):
        return any(map(check_for_numpy, obj))
    else: 
        return False
    
class AutoSerializerNumpy(pyspark.serializers.FramedSerializer):

    """
    Choose marshal or pickle as serialization protocol automatically
    """

    def __init__(self):
        pyspark.serializers.FramedSerializer.__init__(self)
        self._type = None

    def dumps(self, obj):
        if self._type is not None or check_for_numpy(obj):
            return b'P' + pickle.dumps(obj, 2)
        try:
            return b'M' + marshal.dumps(obj, 2)
        except Exception as e:
            self._type = b'P'
            return b'P' + pickle.dumps(obj, 2)

    def loads(self, obj):
        _type = obj[0]
        if _type == b'M':
            return marshal.loads(obj[1:])
        elif _type == b'P':
            return pickle.loads(obj[1:])
        else:
            raise ValueError("invalid serialization type: %s" % _type)

In [195]:
import numpy as np
import ujson

class NumpySerializer(pyspark.serializers.PickleSerializer):
    def dumps(self, obj):
        if isinstance(obj,list):
            res = 'NPALIST'
            res += '>><<'.join([self.dumps(x) for x in obj])
            return res
        if isinstance(obj,np.ndarray):
            md = dict(
                dtype = str(obj.dtype),
                shape = obj.shape,
                )
            return 'NPA'+'<<>>'.join([ujson.dumps(md),obj.tostring()])
        else: 
            return super(NumpySerializer, self).dumps(obj)
    def loads(self, string): 
        loc = string.find('NPALIST')
        if loc >= 0: 
            return [self.loads(s) for s in string[loc+len('NPALIST'):].split('>><<')]
        elif string[:3] == 'NPA':
            dict_str, arr_str = string[3:].split('<<>>')
            md = ujson.loads(dict_str)
            return np.fromstring(arr_str,md['dtype']).reshape(md['shape'])
        else: 
            return super(NumpySerializer, self).loads(string)

In [106]:
import json, ujson

In [101]:
md = dict(dtype=str(a.dtype), shape=a.shape)

In [114]:
%timeit json.dumps(md)

The slowest run took 9.63 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 3.22 µs per loop


In [196]:
ns = NumpySerializer()

In [197]:
%load_ext line_profiler

In [201]:
%lprun -f NumpySerializer.dumps ns.dumps(arrays[0])

In [209]:
%prun x = ns.dumps(arrays)

 

In [210]:
%prun x=dumps(arrays, 2)

 

In [146]:
%timeit ns.dumps(arrays)

1 loop, best of 3: 1.44 s per loop


In [147]:
%timeit dumps(arrays,2)

1 loop, best of 3: 1.04 s per loop


In [155]:
%%timeit x=ns.dumps(arrays)
ns.loads(x)

1 loop, best of 3: 2.93 s per loop


In [156]:
%%timeit x = dumps(arrays,2)
loads(x)

1 loop, best of 3: 545 ms per loop


In [158]:
import re

In [189]:
x = ns.dumps(arrays)

In [190]:
get_arrays = re.compile('>><<(.*)>><<')

In [191]:
x = 'NPALIST>><<'+x[7:]

In [192]:
x[:100000]

'NPALIST>><<NPA{"dtype":"float64","shape":[10,1]}<<>>\xdf\xfb\xb7\xe8\xe9I\xe6?\xb8$\r\x8f\xf44\xc5?\xd8J\x11A\x0f\x8d\xef?\xee/~U\x81\x80\xeb?\x1eMy<(V\xd6?Pc\x8c\xdb\x0e\x9f\xc8?h1dj\x1b\xdb\xe3?\xe1duB\xfa\xdc\xeb?\x00c\xb5\x17l&\xe1?\xd3\xd7\xee\x1c\xf7\x98\xe3?>><<NPA{"dtype":"float64","shape":[10,1]}<<>>\x86\xf4A\x8fXY\xea?l\x98(J\xe6\'\xdc?]c96#(\xe9?\x80\x84Q\xa4\xd35\xa1?&\xef\xa5\xdd\x96\xd1\xed?\x0c\x92=x\x99\xd2\xc6?x\xec\xe5g\xfa5\xcc?\xf0\xaa\xd8\xb3\x81t\xe1?\x0f\x91wb\xea\xb7\xe4?@\x80E6\x95\xeb\x9f?>><<NPA{"dtype":"float64","shape":[10,1]}<<>>\xec\rw\xfaR6\xd8?f\xb4\x9ce"6\xd5?@M\x8f\xe1.\xcc\xa4?T\xde{\n\xe1\x03\xe1?\x16\x02\x92\xc21\x1c\xd5?\xd1!\xbf\xec?\xae\xed?\x80i\x85\xca\xedxp?\xf8\x16\xe6\xbf\xec\xff\xb8?\x16\xe1{\x9c\x96\xeb\xd3?\x88\xe62\x87\x18=\xcf?>><<NPA{"dtype":"float64","shape":[10,1]}<<>>P&\xb4\xbc4\x1b\xe3?w\xbaO\xcf\xde]\xe9?\xa8\'\xcb_\x1a#\xd1?$L\x8cS-\xc5\xdc?\xfe\x93\xe0\xed~a\xd9?\xf4(n\xc5\x1f-\xe0?\xcc\xd7\xb2\xc3\xbdn\xd3?p\x90\xe4\xc5\x86U\

In [194]:
get_arrays.findall(x)

['NPA{"dtype":"float64","shape":[10,1]}<<>>\xdf\xfb\xb7\xe8\xe9I\xe6?\xb8$\r\x8f\xf44\xc5?\xd8J\x11A\x0f\x8d\xef?\xee/~U\x81\x80\xeb?\x1eMy<(V\xd6?Pc\x8c\xdb\x0e\x9f\xc8?h1dj\x1b\xdb\xe3?\xe1duB\xfa\xdc\xeb?\x00c\xb5\x17l&\xe1?\xd3\xd7\xee\x1c\xf7\x98\xe3?>><<NPA{"dtype":"float64","shape":[10,1]}<<>>\x86\xf4A\x8fXY\xea?l\x98(J\xe6\'\xdc?]c96#(\xe9?\x80\x84Q\xa4\xd35\xa1?&\xef\xa5\xdd\x96\xd1\xed?\x0c\x92=x\x99\xd2\xc6?x\xec\xe5g\xfa5\xcc?\xf0\xaa\xd8\xb3\x81t\xe1?\x0f\x91wb\xea\xb7\xe4?@\x80E6\x95\xeb\x9f?',
 'NPA{"dtype":"float64","shape":[10,1]}<<>>P&\xb4\xbc4\x1b\xe3?w\xbaO\xcf\xde]\xe9?\xa8\'\xcb_\x1a#\xd1?$L\x8cS-\xc5\xdc?\xfe\x93\xe0\xed~a\xd9?\xf4(n\xc5\x1f-\xe0?\xcc\xd7\xb2\xc3\xbdn\xd3?p\x90\xe4\xc5\x86U\xe6?\xba\xa1\xca;\xc8J\xdf?V0\x1a\x19\xad\xcc\xea?>><<NPA{"dtype":"float64","shape":[10,1]}<<>>\x12\xe7\xdc\xf89\x8d\xe4?x\x9e#p\xd1\x8d\xd6?xE\xb2\x033\xe5\xde?[\xbe)^Q\xf8\xea?\x9ci\xdf\x19\\\x12\xca?\x1b\xb6\xb3\xe8\x1b\x05\xe7?\xd8<h.F`\xd1?\xe7\x92+\xf8\x8f\xc8\xeb?#\x87\

In [125]:
sc.stop()

In [126]:
from pyspark import SparkContext, SparkConf

conf = SparkConf()
conf.set('spark.python.profile', 'true')

sc = SparkContext('local[2]', conf=conf, serializer=NumpySerializer())

In [199]:
def generate_arrays(l, s, n): 
    for i in range(n): 
        yield(np.random.rand(l, s))

arrays = list(generate_arrays(100000,1,100))

In [7]:
ns = NumpySerializer()

In [8]:
res = ns.dumps(arrays)

In [9]:
res[:100]

'NPALISTNPA<<>>(I1000\nI3\nt.<<>><f8<<>>\xd0\x05=\x1e\xb5c\xde?\xdf\xaa\x1f*\x90\xa8\xe0?\x99\x16\x8e\x8c1g\xe3?\xf4B\x90K\xb7\xd3\xef?&AO\xce)*\xd2?\xe4w\xf6\xac\x1c\x01\xe6?\x08:\x1e!\xa7x\xc4?m\x1b\xf9\xeb\xa8c\xed'

In [10]:
x = res[len('NPALIST'):].split('>><<')

In [11]:
x[0][:200]

'NPA<<>>(I1000\nI3\nt.<<>><f8<<>>\xd0\x05=\x1e\xb5c\xde?\xdf\xaa\x1f*\x90\xa8\xe0?\x99\x16\x8e\x8c1g\xe3?\xf4B\x90K\xb7\xd3\xef?&AO\xce)*\xd2?\xe4w\xf6\xac\x1c\x01\xe6?\x08:\x1e!\xa7x\xc4?m\x1b\xf9\xeb\xa8c\xed?t\xb0^\x867\x8c\xc6?/\x98I\x1fR\x1c\xef?T\x13\xf3\xdf\x07\x98\xc8?\x93I3\xcbZ\x00\xe2?\x82l\xf4\x14\xbfB\xef?\xf2D\x85\xa6\xee\xc6\xd9?\xc0\x11\xe4\xdd\x82\x1d\x99?>\xb4\x08\xdb@\xf4\xe9?\xf2\xe3\xc0\xf3\xf8Q\xea?\xbaH\xdd<Z\x89\xdc?8\xe3`SCR\xdb?t\xc9\\\xd03\xd4\xcd?\xb0z\xf9v\xaeq\xcd?\x80\x05'

In [12]:
ns.loads(x[0])

array([[ 0.47483566,  0.52057656,  0.60634687],
       [ 0.99459424,  0.28382344,  0.68763574],
       [ 0.15993203,  0.91841551,  0.17615408],
       ..., 
       [ 0.39369234,  0.49611302,  0.33502906],
       [ 0.29454165,  0.97429573,  0.86025263],
       [ 0.69032917,  0.12008566,  0.34762298]])

In [15]:
%timeit x = ns.dumps(arrays)

1000 loops, best of 3: 964 µs per loop


In [16]:
%%timeit 
for a in arrays: a.tostring()

10000 loops, best of 3: 113 µs per loop


In [17]:
pickle_ser = pyspark.serializers.AutoBatchedSerializer(pyspark.serializers.PickleSerializer(), 1)
numpy_ser = pyspark.serializers.AutoBatchedSerializer(NumpySerializer())

In [18]:
import cStringIO

In [19]:
import itertools

In [20]:
class AutoBatchedSerializer2(pyspark.serializers.AutoBatchedSerializer): 
    def dump_stream(self, iterator, stream):
        batch, best = 1, self.bestSize
        iterator = iter(iterator)
        while True:
            vs = list(itertools.islice(iterator, batch))
            if not vs:
                break

            bytes = self.serializer.dumps(vs)
            pyspark.serializers.write_int(len(bytes), stream)
            stream.write(bytes)

            size = len(bytes)
            #print 'batch: ', batch, 'size: ', size
            if size < best:
                batch *= 2
            elif size > best * 10 and batch > 1:
                batch //= 2

In [21]:
sio = cStringIO.StringIO()

In [23]:
%%timeit 
numpy_ser.dump_stream(arrays, sio)
sio.reset()

1000 loops, best of 3: 954 µs per loop


In [24]:
res = sio.getvalue()

In [26]:
%%timeit 
sio.reset()
x = list(numpy_ser.load_stream(sio))

100 loops, best of 3: 5.65 ms per loop


In [27]:
sio2 = cStringIO.StringIO()

In [28]:
%%timeit
pickle_ser.dump_stream(arrays,sio2)
sio2.reset()

100 loops, best of 3: 2.13 ms per loop


In [30]:
%%timeit
sio2.reset()
x = list(pickle_ser.load_stream(sio2))

1000 loops, best of 3: 1.06 ms per loop


In [218]:
%lprun -f pyspark.serializers.AutoBatchedSerializer.dump_stream pickle_ser.dump_stream(arrays, sio)

In [221]:
x = ns.dumps(arrays)

In [222]:
assert(np.all(ns.loads(x) == arr))

  if __name__ == '__main__':


AssertionError: 

In [223]:
%load_ext line_profiler

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [226]:
%lprun -f pyspark.serializers.PickleSerializer.dumps ns.dumps(arrays)

In [60]:
rdd = sc.parallelize(arrays, 100)

In [61]:
rdd._jrdd_deserializer

BatchedSerializer(NumpySerializer(), 10)

In [62]:
rdd.count()

1000

In [63]:
sc.show_profiles()

Profile of RDD<id=1>
         12000 function calls (10800 primitive calls) in 0.079 seconds

   Ordered by: internal time, cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1100    0.040    0.000    0.040    0.000 {method 'split' of 'str' objects}
     1100    0.011    0.000    0.011    0.000 {method 'find' of 'str' objects}
 1100/100    0.009    0.000    0.064    0.001 <ipython-input-34-eec8db71e4e3>:17(loads)
      300    0.006    0.000    0.006    0.000 {method 'read' of 'file' objects}
     1000    0.002    0.000    0.002    0.000 {numpy.core.multiarray.fromstring}
     1000    0.002    0.000    0.002    0.000 {cPickle.loads}
      100    0.001    0.000    0.004    0.000 serializers.py:259(dump_stream)
     1100    0.001    0.000    0.073    0.000 rdd.py:1004(<genexpr>)
     1000    0.001    0.000    0.001    0.000 {method 'reshape' of 'numpy.ndarray' objects}
      200    0.001    0.000    0.072    0.000 serializers.py:136(load_stream)
 

In [57]:
sc.show_profiles()

Profile of RDD<id=2>
         5200 function calls (5100 primitive calls) in 0.019 seconds

   Ordered by: internal time, cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      100    0.007    0.000    0.007    0.000 {cPickle.loads}
      300    0.004    0.000    0.004    0.000 {method 'read' of 'file' objects}
      100    0.001    0.000    0.003    0.000 serializers.py:259(dump_stream)
      200    0.001    0.000    0.014    0.000 serializers.py:136(load_stream)
     1100    0.001    0.000    0.014    0.000 rdd.py:1004(<genexpr>)
      200    0.001    0.000    0.013    0.000 serializers.py:155(_read_with_length)
      200    0.000    0.000    0.001    0.000 serializers.py:542(read_int)
      100    0.000    0.000    0.019    0.000 worker.py:104(process)
      100    0.000    0.000    0.000    0.000 serializers.py:217(load_stream)
      100    0.000    0.000    0.000    0.000 {cPickle.dumps}
      200    0.000    0.000    0.015    0.000 {sum}
   

In [30]:
arrs = list(generate_arrays(100000,3,100))

In [32]:
%%timeit
for a in arrs:
    a.tostring()

10 loops, best of 3: 22.7 ms per loop


In [44]:
rdd2 = rdd._reserialize(NumpySerializer())

In [45]:
rdd.count()

100

In [59]:
rdd._jrdd_deserializer = pyspark.serializers.BatchedSerializer(pyspark.serializers.PickleSerializer())

In [60]:
rdd.count()

100

In [62]:
rdd._jrdd_deserializer

BatchedSerializer(PickleSerializer(), -1)

In [61]:
sc.show_profiles()

Profile of RDD<id=4>
         4400 function calls (4300 primitive calls) in 0.013 seconds

   Ordered by: internal time, cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      100    0.004    0.000    0.004    0.000 {cPickle.loads}
      100    0.002    0.000    0.004    0.000 serializers.py:259(dump_stream)
      200    0.001    0.000    0.006    0.000 serializers.py:155(_read_with_length)
      300    0.001    0.000    0.001    0.000 {method 'read' of 'file' objects}
      200    0.001    0.000    0.006    0.000 serializers.py:136(load_stream)
      200    0.000    0.000    0.001    0.000 serializers.py:542(read_int)
      100    0.000    0.000    0.000    0.000 {cPickle.dumps}
      100    0.000    0.000    0.013    0.000 worker.py:104(process)
      100    0.000    0.000    0.001    0.000 serializers.py:217(load_stream)
      200    0.000    0.000    0.007    0.000 {sum}
      100    0.000    0.000    0.001    0.000 <ipython-input-45-020d4d54

In [47]:
rdd2.count()
sc.show_profiles()

Profile of RDD<id=12>
         288 function calls (284 primitive calls) in 0.496 seconds

   Ordered by: internal time, cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.166    0.083    0.166    0.083 {cPickle.loads}
        6    0.159    0.027    0.159    0.027 {method 'read' of 'file' objects}
      102    0.119    0.001    0.495    0.005 rdd.py:1004(<genexpr>)
        4    0.050    0.013    0.376    0.094 serializers.py:136(load_stream)
        4    0.001    0.000    0.496    0.124 {sum}
        2    0.000    0.000    0.000    0.000 serializers.py:259(dump_stream)
        4    0.000    0.000    0.325    0.081 serializers.py:155(_read_with_length)
        2    0.000    0.000    0.000    0.000 serializers.py:549(write_int)
      6/2    0.000    0.000    0.496    0.248 rdd.py:2345(pipeline_func)
        2    0.000    0.000    0.496    0.248 rdd.py:1004(<lambda>)
        2    0.000    0.000    0.496    0.248 worker.py:104(process)
  

In [64]:
import hickle

In [66]:
sco = cStringIO.StringIO()

In [75]:
%%timeit 
hickle.dump(arrays, sco)
sco.reset()

<type 'cStringIO.StringO'>


FileError: Cannot open file. Please pass either a filename string, a file object, or a h5py.File

In [76]:
ipdb.pm()

> [0;32m/Users/rok/miniconda/lib/python2.7/site-packages/hickle.py[0m(163)[0;36mfile_opener[0;34m()[0m
[0;32m    161 [0;31m    [0;32melse[0m[0;34m:[0m[0;34m[0m[0m
[0m[0;32m    162 [0;31m        [0;32mprint[0m[0;34m([0m[0mtype[0m[0;34m([0m[0mf[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m--> 163 [0;31m        [0;32mraise[0m [0mFileError[0m[0;34m[0m[0m
[0m[0;32m    164 [0;31m[0;34m[0m[0m
[0m[0;32m    165 [0;31m    [0mh5f[0m[0;34m.[0m[0m__class__[0m [0;34m=[0m [0mH5FileWrapper[0m[0;34m[0m[0m
[0m
ipdb> u
> [0;32m/Users/rok/miniconda/lib/python2.7/site-packages/hickle.py[0m(308)[0;36mdump[0;34m()[0m
[0;32m    306 [0;31m    [0;32mtry[0m[0;34m:[0m[0;34m[0m[0m
[0m[0;32m    307 [0;31m        [0;31m# Open the file[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 308 [0;31m        [0mh5f[0m [0;34m=[0m [0mfile_opener[0m[0;34m([0m[0mfile_obj[0m[0;34m,[0m [0mmode[0m[0;34m,[0m [0mtrack_times[0m[0;34m)

In [80]:
import marshal

In [81]:
%timeit x = marshal.dumps(arrays)

10 loops, best of 3: 47.7 ms per loop


In [82]:
%timeit x = cPickle.dumps(arrays, 2)

10 loops, best of 3: 28.2 ms per loop


In [84]:
%timeit marshal.dumps(arrays[0])

10000 loops, best of 3: 38.8 µs per loop


In [86]:
%timeit cPickle.dumps(arrays[0], 2)

100000 loops, best of 3: 14 µs per loop


In [87]:
a = arrays[0]

In [89]:
buff = np.getbuffer(a)

In [99]:
%timeit a.tobytes()

The slowest run took 5.07 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 988 ns per loop
