Skip to content

Commit

Permalink
Merge pull request #64 from svenkreiss/retries
Browse files Browse the repository at this point in the history
Retries
  • Loading branch information
svenkreiss committed May 18, 2017
2 parents 7a386ae + 64392f0 commit c094fad
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 72 deletions.
109 changes: 79 additions & 30 deletions pysparkling/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pickle
import struct
import time
import traceback

from . import __version__ as PYSPARKLING_VERSION
from .broadcast import Broadcast
Expand All @@ -26,11 +27,39 @@ def unit_fn(arg):
return arg


def _run_task(task_context, rdd, func, partition):
"""Run a task, aka compute a partition.
:param TaskContext task_context: this task context
:param RDD rdd: rdd this partition is a part of
:param func: a function
:param Partition partition: partition to process
"""
task_context.attempt_number += 1

try:
return func(task_context, rdd.compute(partition, task_context))
except Exception:
log.warn('Attempt {} failed for partition {} of {}.'
''.format(task_context.attempt_number, partition.index,
rdd.name()))
traceback.print_exc()
if task_context.attempt_number == task_context.max_retries:
log.error('Partition {} of {} failed.'
''.format(partition.index, rdd.name()))
return []

if task_context.retry_wait:
time.sleep(task_context.retry_wait)
return _run_task(task_context, rdd, func, partition)


def runJob_map(i):
t_start = time.clock()
(deserializer, data_serializer, data_deserializer,
serialized, serialized_data, cache_manager) = i
serialized_func_rdd, serialized_task_context,
serialized_data, cache_manager) = i

t_start = time.clock()
if cache_manager:
if not CacheManager.singleton__:
CacheManager.singleton__ = data_deserializer(cache_manager)
Expand All @@ -39,25 +68,33 @@ def runJob_map(i):
data_deserializer(cache_manager).cache_obj
)
cm_state = CacheManager.singleton().stored_idents()
t_cache_init = time.clock()
t_cache_init = time.clock() - t_start

t_start = time.clock()
func, rdd = deserializer(serialized_func_rdd)
t_deserialize_func = time.clock() - t_start

func, rdd = deserializer(serialized)
t_deserialize_func = time.clock()
t_start = time.clock()
partition = data_deserializer(serialized_data)
t_deserialize_data = time.clock()
t_deserialize_data = time.clock() - t_start

task_context = TaskContext(stage_id=0, partition_id=partition.index)
result = func(task_context, rdd.compute(partition, task_context))
t_exec = time.clock()
t_start = time.clock()
task_context = deserializer(serialized_task_context)
t_deserialize_task_context = time.clock() - t_start

t_start = time.clock()
result = _run_task(task_context, rdd, func, partition)
t_exec = time.clock() - t_start

return data_serializer((
result,
CacheManager.singleton().get_not_in(cm_state),
{
'map_cache_init': t_cache_init - t_start,
'map_deserialize_func': t_deserialize_func - t_cache_init,
'map_deserialize_data': t_deserialize_data - t_deserialize_func,
'map_exec': t_exec - t_deserialize_data,
'map_cache_init': t_cache_init,
'map_deserialize_func': t_deserialize_func,
'map_deserialize_task_context': t_deserialize_task_context,
'map_deserialize_data': t_deserialize_data,
'map_exec': t_exec,
}
))

Expand All @@ -68,29 +105,25 @@ class Context(object):
The variable `_stats` contains measured timing information about data and
function (de)serialization and workload execution to benchmark your jobs.
:param pool:
An instance with a ``map(func, iterable)`` method.
:param pool: An instance with a ``map(func, iterable)`` method.
:param serializer:
Serializer for functions. Examples are ``pickle.dumps`` and
``dill.dumps``.
:param deserializer:
Deserializer for functions. Examples are ``pickle.loads`` and
``dill.loads``.
:param data_serializer:
Serializer for the data.
:param data_deserializer:
Deserializer for the data.
:param data_serializer: Serializer for the data.
:param data_deserializer: Deserializer for the data.
:param int max_retries: maximum number a partition is retried
:param float retry_wait: seconds to wait between retries
"""

__last_rdd_id = 0

def __init__(self, pool=None, serializer=None, deserializer=None,
data_serializer=None, data_deserializer=None):
data_serializer=None, data_deserializer=None,
max_retries=3, retry_wait=0.0):
if not pool:
pool = DummyPool()
if not serializer:
Expand All @@ -101,6 +134,8 @@ def __init__(self, pool=None, serializer=None, deserializer=None,
data_serializer = unit_fn
if not data_deserializer:
data_deserializer = unit_fn
self.max_retries = max_retries
self.retry_wait = retry_wait

self._pool = pool
self._serializer = serializer
Expand Down Expand Up @@ -241,27 +276,29 @@ def runJob(self, rdd, func, partitions=None, allowLocal=False,
map_result = self._runJob_local(rdd, func, partitions)
else:
map_result = self._runJob_distributed(rdd, func, partitions)
log.debug('Map jobs generated.')

if resultHandler is not None:
return resultHandler(map_result)

return list(map_result) # convert to list to execute on all partitions

def _runJob_local(self, rdd, func, partitions):
for partition in partitions:
task_context = TaskContext(
stage_id=0,
partition_id=partition.index,
max_retries=self.max_retries,
retry_wait=self.retry_wait,
)
yield func(task_context, rdd.compute(partition, task_context))
yield _run_task(task_context, rdd, func, partition)

def _runJob_distributed(self, rdd, func, partitions):
cm = CacheManager.singleton()
serialized_func_rdd = self._serializer((func, rdd))

def prepare(p):
def prepare(partition):
t_start = time.clock()
cm_clone = cm.clone_contains(lambda i: i[1] == p.index)
cm_clone = cm.clone_contains(lambda i: i[1] == partition.index)
self._stats['driver_cache_clone'] += (time.clock() -
t_start)

Expand All @@ -271,7 +308,18 @@ def prepare(p):
t_start)

t_start = time.clock()
serialized_p = self._data_deserializer(p)
task_context = TaskContext(
stage_id=0,
partition_id=partition.index,
max_retries=self.max_retries,
retry_wait=self.retry_wait,
)
serialized_task_context = self._serializer(task_context)
self._stats['driver_serialize_task_context'] += (time.clock() -
t_start)

t_start = time.clock()
serialized_partition = self._data_deserializer(partition)
self._stats['driver_serialize_data'] += (time.clock() -
t_start)

Expand All @@ -280,7 +328,8 @@ def prepare(p):
self._data_serializer,
self._data_deserializer,
serialized_func_rdd,
serialized_p,
serialized_task_context,
serialized_partition,
cm_serialized,
)

Expand Down
6 changes: 5 additions & 1 deletion pysparkling/task_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@


class TaskContext(object):
def __init__(self, stage_id=0, partition_id=0):
def __init__(self, stage_id=0, partition_id=0,
max_retries=3, retry_wait=0):
log.debug('Running stage {0} for partition {1}'
''.format(stage_id, partition_id))

self.stage_id = stage_id
self.partition_id = partition_id
self.max_retries = max_retries
self.retry_wait = retry_wait

self.attempt_number = 0
self.is_completed = False
self.is_running_locally = True
Expand Down
53 changes: 53 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import print_function

import pysparkling
import unittest


class Context(unittest.TestCase):
def test_broadcast(self):
b = pysparkling.Context().broadcast([1, 2, 3])
self.assertEqual(b.value[0], 1)

def test_parallelize_single_element(self):
my_rdd = pysparkling.Context().parallelize([7], 100)
self.assertEqual(my_rdd.collect(), [7])

def test_parallelize_matched_elements(self):
my_rdd = pysparkling.Context().parallelize([1, 2, 3, 4, 5], 5)
self.assertEqual(my_rdd.collect(), [1, 2, 3, 4, 5])

def test_parallelize_empty_partitions_at_end(self):
my_rdd = pysparkling.Context().parallelize(range(3529), 500)
print(my_rdd.getNumPartitions())
my_rdd.foreachPartition(lambda p: print(sum(1 for _ in p)))
self.assertEqual(my_rdd.getNumPartitions(), 500)
self.assertEqual(my_rdd.count(), 3529)

def test_retry(self):

class EverySecondCallFails(object):
def __init__(self):
self.attempt = 0

def __call__(self, value):
self.attempt += 1
if self.attempt % 2 == 1:
raise Exception
return value

data = list(range(6))
rdd = pysparkling.Context().parallelize(data, 3)
result = rdd.mapPartitions(EverySecondCallFails()).collect()
self.assertEqual(result, data)

def test_union(self):
sc = pysparkling.Context()
rdd1 = sc.parallelize(['Hello'])
rdd2 = sc.parallelize(['World'])
union = sc.union([rdd1, rdd2]).collect()
print(union)
self.assertEqual(union, ['Hello', 'World'])

def test_version(self):
self.assertTrue(isinstance(pysparkling.Context().version, str))
41 changes: 0 additions & 41 deletions tests/test_context_unit.py

This file was deleted.

0 comments on commit c094fad

Please sign in to comment.