diff --git a/datastream/backends/mongodb.py b/datastream/backends/mongodb.py index 5739c09..9c9b27d 100644 --- a/datastream/backends/mongodb.py +++ b/datastream/backends/mongodb.py @@ -32,6 +32,8 @@ DECIMAL_PRECISION = 64 +MAINTENANCE_LOCK_DURATION = 120 + def deserialize_numeric_value(value): """ @@ -946,6 +948,7 @@ class Stream(mongoengine.Document): choices=[downsampler.name for downsampler in ValueDownsamplers.values], )) downsample_state = mongoengine.MapField(mongoengine.EmbeddedDocumentField(DownsampleState)) + downsample_count = mongoengine.IntField(default=0) highest_granularity = GranularityField() derived_from = mongoengine.EmbeddedDocumentField(DerivedStreamDescriptor) derive_state = mongoengine.DynamicField() @@ -955,6 +958,9 @@ class Stream(mongoengine.Document): latest_datapoint = mongoengine.DateTimeField() tags = mongoengine.DictField() + # Maintenance operations lock + _lock_mt = mongoengine.DateTimeField(default=datetime.datetime.min) + meta = dict( db_alias=DATABASE_ALIAS, collection='streams', @@ -1746,22 +1752,42 @@ def _downsample_check(self, stream, until_timestamp, return_datapoints): new_datapoints = [] - # Last datapoint timestamp of one higher granularity - higher_last_ts = stream.latest_datapoint or self._min_timestamp + # Lock the stream for downsampling + now = datetime.datetime.utcnow() + locked_until = now + datetime.timedelta(seconds=MAINTENANCE_LOCK_DURATION) + locked_stream = Stream._get_collection().find_and_modify( + {"_id": stream.pk, "_lock_mt": {"$lt": now}, "downsample_count": stream.downsample_count}, + {"$set": {"_lock_mt": locked_until}, "$inc": {"downsample_count": 1}} + ) + if not locked_stream: + # Skip downsampling of this stream as we have failed to acquire the lock + return new_datapoints - for granularity in api.Granularity.values[api.Granularity.values.index(stream.highest_granularity) + 1:]: - state = stream.downsample_state.get(granularity.name, None) - rounded_timestamp = granularity.round_timestamp(min(until_timestamp, higher_last_ts)) - # TODO: Why "can't compare offset-naive and offset-aware datetimes" is sometimes thrown here? - if state is None or state.timestamp is None or rounded_timestamp > state.timestamp: - try: - result = self._downsample(stream, granularity, rounded_timestamp, return_datapoints) - if return_datapoints: - new_datapoints += result - except exceptions.InvalidTimestamp: - break - - higher_last_ts = stream.downsample_state[granularity.name].timestamp or higher_last_ts + try: + # Last datapoint timestamp of one higher granularity + higher_last_ts = stream.latest_datapoint or self._min_timestamp + + for granularity in api.Granularity.values[api.Granularity.values.index(stream.highest_granularity) + 1:]: + state = stream.downsample_state.get(granularity.name, None) + rounded_timestamp = granularity.round_timestamp(min(until_timestamp, higher_last_ts)) + # TODO: Why "can't compare offset-naive and offset-aware datetimes" is sometimes thrown here? + if state is None or state.timestamp is None or rounded_timestamp > state.timestamp: + try: + result, locked_until = self._downsample(stream, granularity, rounded_timestamp, return_datapoints, locked_until) + if return_datapoints: + new_datapoints += result + except exceptions.InvalidTimestamp: + break + + higher_last_ts = stream.downsample_state[granularity.name].timestamp or higher_last_ts + except: + # Only unlock the stream but do not save the descriptor as it might be corrupted + Stream.objects(pk=stream.pk).update(set___lock_mt=datetime.datetime.min) + raise + else: + # Ensure that the stream is unlocked and all changes are saved + stream._lock_mt = datetime.datetime.min + stream.save() return new_datapoints @@ -1811,7 +1837,7 @@ def _generate_timed_stream_object_id(self, timestamp, stream_id): oid += stream_id return objectid.ObjectId(oid) - def _downsample(self, stream, granularity, until_timestamp, return_datapoints): + def _downsample(self, stream, granularity, until_timestamp, return_datapoints, locked_until): """ Performs downsampling on the given stream and granularity. @@ -1820,6 +1846,7 @@ def _downsample(self, stream, granularity, until_timestamp, return_datapoints): :param until_timestamp: Timestamp until which to downsample, not including datapoints at a timestamp, rounded to the specified granularity :param return_datapoints: Should the added datapoints be stored + :param locked_until: Timestamp when the maintenance lock on this datastream expires """ assert granularity.round_timestamp(until_timestamp) == until_timestamp @@ -1873,7 +1900,16 @@ def _downsample(self, stream, granularity, until_timestamp, return_datapoints): new_datapoints = [] - def store_downsampled_datapoint(timestamp): + def store_downsampled_datapoint(timestamp, locked_until): + # Check if we need to lengthen the lock + now = datetime.datetime.utcnow() + if locked_until < now: + # Lock has expired while we were processing; abort immediately + raise exceptions.LockExpiredMidMaintenance + elif (locked_until - now).total_seconds() <= MAINTENANCE_LOCK_DURATION // 2: + locked_until = now + datetime.timedelta(seconds=MAINTENANCE_LOCK_DURATION) + Stream.objects(pk=stream.pk).update(set___lock_mt=locked_until) + value = {} time = {} for x in value_downsamplers: @@ -1892,9 +1928,8 @@ def store_downsampled_datapoint(timestamp): point_id = self._generate_timed_stream_object_id(timestamp, stream_id) datapoint = {'_id': point_id, 'm': stream.id, 'v': value, 't': time} - # We want to process each granularity period only once, we want it to fail if there is an error in this # TODO: We should probably create some API function which reprocesses everything and fixes any inconsistencies - downsampled_points.insert(datapoint, w=1) + downsampled_points.update({'_id': point_id}, datapoint, upsert=True, w=1) # Process contributions to other streams self._process_contributes_to(stream, datapoint['t'], value, granularity) @@ -1906,6 +1941,8 @@ def store_downsampled_datapoint(timestamp): 'datapoint': self._format_datapoint(datapoint), }) + return locked_until + # TODO: Use generator here, not concatenation for x in value_downsamplers + time_downsamplers: x.initialize() @@ -1927,7 +1964,7 @@ def store_downsampled_datapoint(timestamp): current_granularity_period_timestamp != new_granularity_period_timestamp: # All datapoints for current granularity period have been processed, we store the new datapoint # This happens when interval ("datapoints" query) contains multiple not-yet-downsampled granularity periods - store_downsampled_datapoint(current_granularity_period_timestamp) + locked_until = store_downsampled_datapoint(current_granularity_period_timestamp, locked_until) current_granularity_period_timestamp = new_granularity_period_timestamp @@ -1938,7 +1975,7 @@ def store_downsampled_datapoint(timestamp): for x in time_downsamplers: x.update(middle_timestamp(current_null_bucket, granularity), None) - store_downsampled_datapoint(current_null_bucket) + locked_until = store_downsampled_datapoint(current_null_bucket, locked_until) # Move to next bucket current_null_bucket = granularity.round_timestamp( @@ -1958,7 +1995,7 @@ def store_downsampled_datapoint(timestamp): x.update(ts, datapoint['v']) if current_granularity_period_timestamp is not None: - store_downsampled_datapoint(current_granularity_period_timestamp) + locked_until = store_downsampled_datapoint(current_granularity_period_timestamp, locked_until) # Insert NULL values into empty buckets while current_null_bucket < until_timestamp: @@ -1967,7 +2004,7 @@ def store_downsampled_datapoint(timestamp): for x in time_downsamplers: x.update(middle_timestamp(current_null_bucket, granularity), None) - store_downsampled_datapoint(current_null_bucket) + locked_until = store_downsampled_datapoint(current_null_bucket, locked_until) # Move to next bucket current_null_bucket = granularity.round_timestamp( @@ -1976,14 +2013,13 @@ def store_downsampled_datapoint(timestamp): # At the end, update the timestamp until which we have processed everything state.timestamp = until_timestamp - stream.save() # And call test callback for all new datapoints if self._test_callback is not None: for kwargs in new_datapoints: self._test_callback(**kwargs) - return new_datapoints + return new_datapoints, locked_until def _backprocess_stream(self, stream): """ diff --git a/datastream/exceptions.py b/datastream/exceptions.py index a0c40da..6bc1865 100644 --- a/datastream/exceptions.py +++ b/datastream/exceptions.py @@ -106,6 +106,13 @@ class InvalidOperatorArguments(DatastreamException, ValueError): pass +class LockExpiredMidMaintenance(DatastreamException): + """ + Raised when a maintenance lock expires inside a maintenance operation. + """ + pass + + class DatastreamWarning(RuntimeWarning): """ The base class for all datastream API runtime warnings. diff --git a/tests/test_basic.py b/tests/test_basic.py index 015c304..a3947e7 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,9 +1,12 @@ +import collections import datetime import decimal import random import time import unittest import warnings +import sys +import threading import pytz @@ -1035,6 +1038,41 @@ def test_downsample_freeze(self): self.datastream.append(stream_id, 1, datetime.datetime(2000, 1, 10, 12, 0, 0)) self.datastream.append(stream_id, 1, datetime.datetime(2000, 1, 10, 12, 0, 1)) + def test_concurrent(self): + for i in xrange(10): + stream_id = self.datastream.ensure_stream({'name': i}, {}, self.value_downsamplers, datastream.Granularity.Seconds) + ts = datetime.datetime(2000, 1, 1, 0, 0, 0, tzinfo=pytz.utc) + for j in xrange(1000): + self.datastream.append(stream_id, 1, ts) + ts += datetime.timedelta(seconds=4) + + def worker(results): + try: + datapoints = self.datastream.downsample_streams(return_datapoints=True) + results.append(len(datapoints)) + except: + results.append(sys.exc_info()) + + threads = [] + results = collections.deque() + for i in xrange(5): + t = threading.Thread(target=worker, args=(results,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + if results: + all_datapoints = 0 + for result in results: + if isinstance(result, int): + all_datapoints += result + else: + raise result[1], None, result[2] + + self.assertEqual(all_datapoints, 4720) + def test_downsamplers(self): # Test with floats that have issues with exact representation stream_id = self.datastream.ensure_stream({'name': 'small'}, {}, self.value_downsamplers, datastream.Granularity.Seconds)