Skip to content

Commit

Permalink
add serialization benchmark & forking pickler
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Nov 10, 2019
1 parent 23ab700 commit 6166340
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 1 deletion.
2 changes: 1 addition & 1 deletion examples/FasterRCNN/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def training_roidbs(self):
boxes: numpy array of kx4 floats, each row is [x1, y1, x2, y2]
class: numpy array of k integers, in the range of [1, #categories], NOT [0, #categories)
is_crowd: k booleans. Use k False if you don't know what it means.
segmentation: k lists of numpy arrays (one for each instance).
segmentation: k lists of numpy arrays.
Each list of numpy arrays corresponds to the mask for one instance.
Each numpy array in the list is a polygon of shape Nx2,
because one mask can be represented by N polygons.
Expand Down
17 changes: 17 additions & 0 deletions tensorpack/utils/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os

import pickle
from multiprocessing.reduction import ForkingPickler
import msgpack
import msgpack_numpy

Expand Down Expand Up @@ -92,6 +93,7 @@ def loads(buf):
return pickle.loads(buf)


# Define the default serializer to be used that dumps data to bytes
_DEFAULT_S = os.environ.get('TENSORPACK_SERIALIZE', 'msgpack')

if _DEFAULT_S == "pyarrow":
Expand All @@ -103,3 +105,18 @@ def loads(buf):
else:
dumps = MsgpackSerializer.dumps
loads = MsgpackSerializer.loads

# Define the default serializer to be used for passing data
# among a pair of peers. In this case the deserialization is
# known to happen only once
_DEFAULT_S = os.environ.get('TENSORPACK_ONCE_SERIALIZE', 'pickle')

if _DEFAULT_S == "pyarrow":
dumps_once = PyarrowSerializer.dumps
loads_once = PyarrowSerializer.loads
elif _DEFAULT_S == "pickle":
dumps_once = ForkingPickler.dumps
loads_once = ForkingPickler.loads
else:
dumps_once = MsgpackSerializer.dumps
loads_once = MsgpackSerializer.loads
98 changes: 98 additions & 0 deletions tests/benchmark-serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#!/usr/bin/env python3

import numpy as np
import argparse
import pyarrow as pa
from tabulate import tabulate
import operator
from tensorpack.utils import logger
from tensorpack.utils.serialize import (
MsgpackSerializer,
PyarrowSerializer,
PickleSerializer,
ForkingPickler,
)
from tensorpack.utils.timer import Timer


def benchmark_serializer(dumps, loads, data, num):
buf = dumps(data)

enc_timer = Timer()
dec_timer = Timer()
enc_timer.pause()
dec_timer.pause()

for k in range(num):
enc_timer.resume()
buf = dumps(data)
enc_timer.pause()

dec_timer.resume()
loads(buf)
dec_timer.pause()

dumps_time = enc_timer.seconds() / num
loads_time = dec_timer.seconds() / num
return dumps_time, loads_time


def display_results(name, results):
logger.info("Encoding benchmark for {}:".format(name))
data = sorted([(x, y[0]) for x, y in results], key=operator.itemgetter(1))
print(tabulate(data, floatfmt='.5f'))

logger.info("Decoding benchmark for {}:".format(name))
data = sorted([(x, y[1]) for x, y in results], key=operator.itemgetter(1))
print(tabulate(data, floatfmt='.5f'))


def benchmark_all(name, serializers, data, num=30):
logger.info("Benchmarking {} ...".format(name))
results = []
for serializer_name, dumps, loads in serializers:
results.append((serializer_name, benchmark_serializer(dumps, loads, data, num=num)))
display_results(name, results)


def fake_json_data():
return {
'words': """
Lorem ipsum dolor sit amet, consectetur adipiscing
elit. Mauris adipiscing adipiscing placerat.
Vestibulum augue augue,
pellentesque quis sollicitudin id, adipiscing.
""" * 100,
'list': list(range(100)) * 500,
'dict': dict((str(i), 'a') for i in range(50000)),
'dict2': dict((i, 'a') for i in range(50000)),
'int': 3000,
'float': 100.123456
}


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("task")
args = parser.parse_args()

serializers = [
("msgpack", MsgpackSerializer.dumps, MsgpackSerializer.loads),
("pyarrow-buf", PyarrowSerializer.dumps, PyarrowSerializer.loads),
("pyarrow-bytes", PyarrowSerializer.dumps_bytes, PyarrowSerializer.loads),
("pickle", PickleSerializer.dumps, PickleSerializer.loads),
("forking-pickle", ForkingPickler.dumps, ForkingPickler.loads),
]

if args.task == "numpy":
numpy_data = [np.random.rand(64, 224, 224, 3).astype("float32"), np.random.rand(64).astype('int32')]
benchmark_all("numpy data", serializers, numpy_data)
elif args.task == "json":
benchmark_all("json data", serializers, fake_json_data(), num=50)
elif args.task == "torch":
import torch
from pyarrow.lib import _default_serialization_context

pa.register_torch_serialization_handlers(_default_serialization_context)
torch_data = [torch.rand(64, 224, 224, 3), torch.rand(64).to(dtype=torch.int32)]
benchmark_all("torch data", serializers[1:], torch_data)

0 comments on commit 6166340

Please sign in to comment.