Skip to content

Commit

Permalink
Fix InfluxDB count.
Browse files Browse the repository at this point in the history
  • Loading branch information
kostko committed Jul 11, 2017
1 parent 600f5c6 commit 9579d33
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 20 deletions.
161 changes: 142 additions & 19 deletions datastream/backends/influxdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,107 @@ def update(self, src_stream, datapoint):
self.dst_stream.derive_state = datapoint


class Query(object):
"""
Simple InfluxDB query wrapper.
"""

def __init__(self, select, from_, where=None, group_by=None, order_by=None):
self.select = select
self.from_ = from_
self.where = where
self.group_by = group_by
self.order_by = order_by
self.offset = None
self.limit = None

def clone(self):
"""
Clone this query.
"""

return self.__class__(
self.select,
self.from_,
self.where,
self.group_by,
self.order_by
)

def __copy__(self):
return self.clone()

def count(self):
"""
Return a counting query.
"""

return self.__class__(
['COUNT(*)'],
self
)

def slice(self, limit, offset):
"""
Return a slicing query.
"""

clone = self.clone()
clone.limit = limit
clone.offset = offset
return clone

def _from_clause(self):
"""
Return from clause of the query.
"""

if isinstance(self.from_, Query):
from_clause = '( %s )' % self.from_.compile()
else:
from_clause = '"%s"' % self.from_

return from_clause

def compile(self):
"""
Compile query into a string representation.
"""

query = 'SELECT %(select)s FROM %(from_)s %(where)s %(group_by)s %(order_by)s' % {
'select': ', '.join(self.select),
'from_': self._from_clause(),
'where': ('WHERE %s' % self.where) if self.where else '',
'group_by': self.group_by if self.group_by else '',
'order_by': self.order_by if self.order_by else '',
}

if self.limit is not None:
query += ' LIMIT %d' % self.limit
if self.offset is not None:
query += ' OFFSET %d' % self.offset

return query.strip()


class PostgresQuery(Query):
"""
Simple PostgreSQL query wrapper.
"""

def _from_clause(self):
"""
Return from clause of the query.
"""

if isinstance(self.from_, Query):
from_clause = '( %s ) AS tmp' % self.from_.compile()
else:
from_clause = self.from_

return from_clause


class ResultSetIteratorMixin(object):
def __init__(self, backend, query):
self._backend = backend
Expand All @@ -457,11 +558,32 @@ def _evaluate(self):
if self._query is None:
self._results = []
else:
self._results = list(self._influxdb.query(self._query.strip()).get_points())
self._results = list(self._influxdb.query(self._query.compile()).get_points())

def count(self):
self._evaluate()
return len(self._results)
if self._query is None:
return 0

# If results have already been fetched, just count them.
if self._results is not None:
return len(self._results)

clone = self._clone()
clone._results = None
clone._query = clone._query.count()
clone._evaluate()
if not clone._results:
return 0

results = clone._results[0]
if isinstance(results, dict):
result_keys = results.keys()
if 'count_value' in result_keys:
return results.get('count_value', 0) + results.get('count_value_null', 0)
else:
return max([results[key] for key in results.keys() if key.startswith('count')])
elif isinstance(results, tuple):
return results[0]

def batch(self, size):
clone = self._clone()
Expand Down Expand Up @@ -501,24 +623,28 @@ def __getitem__(self, key):

if isinstance(key, slice):
if clone._query is not None:
limit = None
offset = None

if key.stop is not None:
limit = key.stop
if key.start is not None:
limit -= key.start

assert limit >= 0
clone._query += ' LIMIT %d' % limit

if key.start is not None:
assert key.start >= 0
clone._query += ' OFFSET %d' % key.start
offset = key.start

clone._query = clone._query.slice(limit, offset)

clone._evaluate()
return clone
elif isinstance(key, (int, long)):
if clone._query is not None:
assert key >= 0
clone._query += ' LIMIT 1 OFFSET %d' % key
clone._query = clone._query.slice(1, key)

clone._evaluate()
return list(clone)[0]
Expand All @@ -530,6 +656,12 @@ class Streams(ResultSetIteratorMixin, api.Streams):
def __init__(self, backend, query_tags, raw=False):
super(Streams, self).__init__(backend, '')

if query_tags:
where = 'tags @> %s'
else:
where = None

self._query = PostgresQuery('*', 'datastream.streams', where)
self._query_tags = query_tags
self._raw = raw

Expand All @@ -540,11 +672,9 @@ def _evaluate(self):
with self._backend._metadata:
with self._backend._metadata.cursor() as cursor:
if self._query_tags is None:
cursor.execute('SELECT * FROM datastream.streams' + self._query)
cursor.execute(self._query.compile())
else:
cursor.execute('SELECT * FROM datastream.streams WHERE tags @> %s' + self._query, (
PostgresJson(self._query_tags),
))
cursor.execute(self._query.compile(), (PostgresJson(self._query_tags),))

self._results = cursor.fetchall()

Expand Down Expand Up @@ -1642,16 +1772,9 @@ def add_condition(field, operator, value):
order_by = 'ORDER BY time DESC'

if where:
where = 'WHERE %s' % (' AND '.join(where))
where = ' AND '.join(where)
else:
where = ''

query = 'SELECT %(select)s FROM "%(measurement)s" %(where)s %(group_by)s %(order_by)s' % {
'select': ', '.join(select),
'measurement': stream.uuid,
'where': where,
'group_by': group_by,
'order_by': order_by,
}

query = Query(select, stream.uuid, where, group_by, order_by)
return Datapoints(self, stream, query, value_downsamplers, time_downsamplers)
53 changes: 52 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,57 @@ def test_basic(self):
with self.assertRaises(exceptions.StreamNotFound):
self.datastream.get_data(stream_id, datastream.Granularity.Minutes, datetime.datetime.utcfromtimestamp(0))

def test_count_limit_offset(self):
stream_id = self.datastream.ensure_stream(
{'name': 'foo'},
{},
self.value_downsamplers,
datastream.Granularity.Seconds
)

points = []
ts = datetime.datetime(2000, 1, 1, 12, 0, 0)
start_timestamp = ts
for i in xrange(200):
points.append({
'stream_id': stream_id,
'value': i,
'timestamp': ts,
})
ts += datetime.timedelta(seconds=1)

self.datastream.append_multiple(points)

data = self.datastream.get_data(
stream_id,
datastream.Granularity.Seconds,
start_timestamp,
ts
)

self.assertEqual(len(data), 200)
self.assertEqual(len(data[:20]), 20)
self.assertEqual(len(data[50:70]), 20)
self.assertEqual(len(data[190:210]), 10)

with self.time_offset():
self.datastream.downsample_streams()

data = self.datastream.get_data(
stream_id,
datastream.Granularity.Seconds10,
start_timestamp,
ts
)

if self.datastream.backend.requires_downsampling:
self.assertEqual(len(data), 19)
else:
self.assertEqual(len(data), 20)

self.assertEqual(len(data[:10]), 10)
self.assertEqual(len(data[5:15]), 10)

def test_granularities(self):
query_tags = {
'name': 'foodata',
Expand Down Expand Up @@ -987,7 +1038,7 @@ def test_downsample_freeze(self):
end_timestamp = None

data = self.datastream.get_data(stream_id, self.datastream.Granularity.Seconds10, start=start_timestamp, end_exclusive=end_timestamp)
self.assertEqual(len(data), 17280)
self.assertEqual(len(list(data)), 17280)
self._test_data_types(data)

self.datastream.append(stream_id, 1, datetime.datetime(2000, 1, 3, 12, 0, 0))
Expand Down

0 comments on commit 9579d33

Please sign in to comment.