Skip to content

Commit

Permalink
timed cache manager (#68)
Browse files Browse the repository at this point in the history
* timed cache manager

* add CacheManager instance to TaskContext
  • Loading branch information
svenkreiss committed Jul 29, 2017
1 parent ddc0d7d commit e8ed676
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 110 deletions.
8 changes: 5 additions & 3 deletions pysparkling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""pysparkling module."""
"""pysparkling module"""
# flake8: noqa

__version__ = '0.4.3'
Expand All @@ -10,10 +10,12 @@
from .context import Context
from .broadcast import Broadcast
from .stat_counter import StatCounter
from .cache_manager import CacheManager
from .cache_manager import CacheManager, TimedCacheManager

from . import fileio
from . import streaming

__all__ = ['RDD', 'Context', 'Broadcast', 'StatCounter', 'CacheManager',
__all__ = ['FileAlreadyExistsException', 'ConnectionException',
'RDD', 'Context', 'Broadcast', 'StatCounter', 'CacheManager',
'TimedCacheManager',
'fileio', 'streaming']
72 changes: 54 additions & 18 deletions pysparkling/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import logging
import pickle
import time
import zlib

log = logging.getLogger(__name__)
Expand All @@ -21,21 +22,8 @@ class CacheManager(object):
:param deserializer: Use to deserialize cache objects.
:param checksum: Function returning a checksum.
"""
singleton__ = None

@staticmethod
def singleton(max_mem=1.0,
serializer=None, deserializer=None,
checksum=None):
if CacheManager.singleton__ is None:
CacheManager.singleton__ = CacheManager(max_mem,
serializer, deserializer,
checksum)
return CacheManager.singleton__

def __init__(self,
max_mem=1.0,
serializer=None, deserializer=None,
def __init__(self, max_mem=1.0, serializer=None, deserializer=None,
checksum=None):
self.max_mem = max_mem
self.serializer = serializer if serializer else pickle.dumps
Expand Down Expand Up @@ -103,14 +91,12 @@ def stored_idents(self):
v['disk_location'] is not None)]

def clone_contains(self, filter_id):
"""clone contains
"""Clone the cache manager and add a subset of the cache to it.
:param filter_id:
A function returning true for ids that should be returned.
:returns:
A new CacheManager with the entries that contain partial_ident
in the ident.
:rtype: CacheManager
"""
cm = CacheManager(self.max_mem,
self.serializer, self.deserializer,
Expand All @@ -133,3 +119,53 @@ def clear(self):
self.cache_cnt = 0
self.cache_mem_size = 0.0
self.cache_disk_size = 0.0


class TimedCacheManager(CacheManager):
def __init__(self,
max_mem=1.0,
serializer=None, deserializer=None,
checksum=None, timeout=600.0, min_gc_interval=60.0):
super(TimedCacheManager, self).__init__(
max_mem, serializer, deserializer, checksum)

self.timeout = timeout
self.last_gc = time.time()
self.min_gc_interval = min_gc_interval

def add(self, ident, obj, storageLevel=None):
super(TimedCacheManager, self).add(ident, obj, storageLevel)
self.cache_obj[ident]['utc_added_s'] = time.time()
self.gc()

def clone_contains(self, filter_id):
"""Clone the timed cache manager and add a subset of the cache to it.
:param filter_id:
A function returning true for ids that should be returned.
:rtype: TimedCacheManager
"""
cm = TimedCacheManager(self.max_mem,
self.serializer, self.deserializer,
self.checksum,
self.timeout, self.min_gc_interval)
cm.cache_obj = {i: c
for i, c in self.cache_obj.items()
if filter_id(i)}
return cm

def gc(self):
if time.time() - self.min_gc_interval < self.last_gc:
return

log.debug('Looking for timed out cache entries.')
threshold_time = time.time() - self.timeout
print(self.cache_obj.items())
timed_out_ids = {ident
for ident, cache_obj in self.cache_obj.items()
if cache_obj['utc_added_s'] < threshold_time}
log.debug('Timed out ids: {}'.format(timed_out_ids))
for id_ in timed_out_ids:
self.delete(id_)
log.debug('Clear done.')
66 changes: 21 additions & 45 deletions pysparkling/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,7 @@ def _run_task(task_context, rdd, func, partition):
def runJob_map(i):
(deserializer, data_serializer, data_deserializer,
serialized_func_rdd, serialized_task_context,
serialized_data, cache_manager) = i

t_start = time.clock()
if cache_manager:
if not CacheManager.singleton__:
CacheManager.singleton__ = data_deserializer(cache_manager)
else:
CacheManager.singleton().join(
data_deserializer(cache_manager).cache_obj
)
cm_state = CacheManager.singleton().stored_idents()
t_cache_init = time.clock() - t_start
serialized_data) = i

t_start = time.clock()
func, rdd = deserializer(serialized_func_rdd)
Expand All @@ -85,6 +74,7 @@ def runJob_map(i):

t_start = time.clock()
task_context = deserializer(serialized_task_context)
cm_state = task_context.cache_manager.stored_idents()
t_deserialize_task_context = time.clock() - t_start

t_start = time.clock()
Expand All @@ -93,9 +83,8 @@ def runJob_map(i):

return data_serializer((
result,
CacheManager.singleton().get_not_in(cm_state),
task_context.cache_manager.get_not_in(cm_state),
{
'map_cache_init': t_cache_init,
'map_deserialize_func': t_deserialize_func,
'map_deserialize_task_context': t_deserialize_task_context,
'map_deserialize_data': t_deserialize_data,
Expand All @@ -112,22 +101,23 @@ class Context(object):
:param pool: An instance with a ``map(func, iterable)`` method.
:param serializer:
Serializer for functions. Examples are ``pickle.dumps`` and
Serializer for functions. Examples are `pickle.dumps` and
``dill.dumps``.
:param deserializer:
Deserializer for functions. Examples are ``pickle.loads`` and
Deserializer for functions. Examples are `pickle.loads` and
``dill.loads``.
:param data_serializer: Serializer for the data.
:param data_deserializer: Deserializer for the data.
:param int max_retries: maximum number a partition is retried
:param float retry_wait: seconds to wait between retries
:param cache_manager: custom cache manager (like `TimedCacheManager`)
"""

__last_rdd_id = 0

def __init__(self, pool=None, serializer=None, deserializer=None,
data_serializer=None, data_deserializer=None,
max_retries=3, retry_wait=0.0):
max_retries=3, retry_wait=0.0, cache_manager=None):
if not pool:
pool = DummyPool()
if not serializer:
Expand All @@ -141,6 +131,7 @@ def __init__(self, pool=None, serializer=None, deserializer=None,
self.max_retries = max_retries
self.retry_wait = retry_wait

self._cache_manager = cache_manager or CacheManager()
self._pool = pool
self._serializer = serializer
self._deserializer = deserializer
Expand All @@ -164,10 +155,9 @@ def parallelize(self, x, numPartitions=None):
:param x:
An iterable (e.g. a list) that represents the data.
:param int|None numPartitions: (optional)
:param int numPartitions:
The number of partitions the data should be split into.
A partition is a unit of data that is processed at a time.
Can be ``None``.
:rtype: RDD
"""
Expand Down Expand Up @@ -253,25 +243,15 @@ def runJob(self, rdd, func, partitions=None, allowLocal=False,
if you need everything to be executed, the resultHandler needs to be
at least ``lambda x: list(x)`` to trigger execution of the generators.
:param func:
Map function. The signature is
:param func: Map function with signature
func(TaskContext, Iterator over elements).
:param partitions: (optional)
List of partitions that are involved. Default is ``None``, meaning
:param partitions: List of partitions that are involved. `None` means
the map job is applied to all partitions.
:param allowLocal: (optional)
Allows for local execution. Default is False.
:param resultHandler: (optional)
Process the result from the maps.
:returns:
Result of resultHandler.
:param allowLocal: Allows local execution.
:param resultHandler: Process the result from the maps.
:returns: Result of resultHandler.
:rtype: list
"""

if not partitions:
partitions = rdd.partitions()

Expand All @@ -289,6 +269,7 @@ def runJob(self, rdd, func, partitions=None, allowLocal=False,
def _runJob_local(self, rdd, func, partitions):
for partition in partitions:
task_context = TaskContext(
cache_manager=self._cache_manager,
stage_id=0,
partition_id=partition.index,
max_retries=self.max_retries,
Expand All @@ -297,22 +278,18 @@ def _runJob_local(self, rdd, func, partitions):
yield _run_task(task_context, rdd, func, partition)

def _runJob_distributed(self, rdd, func, partitions):
cm = CacheManager.singleton()
serialized_func_rdd = self._serializer((func, rdd))

def prepare(partition):
t_start = time.clock()
cm_clone = cm.clone_contains(lambda i: i[1] == partition.index)
cm_clone = self._cache_manager.clone_contains(
lambda i: i[1] == partition.index)
self._stats['driver_cache_clone'] += (time.clock() -
t_start)

t_start = time.clock()
cm_serialized = self._data_deserializer(cm_clone)
self._stats['driver_cache_serialize'] += (time.clock() -
t_start)

t_start = time.clock()
task_context = TaskContext(
cache_manager=cm_clone,
stage_id=0,
partition_id=partition.index,
max_retries=self.max_retries,
Expand All @@ -334,7 +311,6 @@ def prepare(partition):
serialized_func_rdd,
serialized_task_context,
serialized_partition,
cm_serialized,
)

prepared_partitions = (prepare(p) for p in partitions)
Expand All @@ -346,7 +322,7 @@ def prepare(partition):

# join cache
t_start = time.clock()
cm.join(cache_result)
self._cache_manager.join(cache_result)
self._stats['driver_cache_join'] += time.clock() - t_start

# collect stats
Expand Down Expand Up @@ -412,7 +388,7 @@ def binaryRecords(self, path, recordLength=None):
and multiple expressions separated by ``,``.
:param recordLength:
If ``None`` every file is a record, ``int`` means fixed length
If `None` every file is a record, ``int`` means fixed length
records and a ``string`` is used as a format string to ``struct``
to read the length of variable length binary records.
Expand Down

0 comments on commit e8ed676

Please sign in to comment.