diff --git a/CHANGELOG.md b/CHANGELOG.md index f71ba12a3..503aeba55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## 2022.3 (TBD) + +### Features (some, not all…) + +New `zentral.core.stores.backends.snowflake` store backend. + ## 2022.2 (August 13, 2022) **IMPORTANT:** The License has changed! Most of the code stays under the Apache license, but some modules, like the SAML authentication, or the Splunk event store are licensed under a new source available license, and require a subscription when used in production. diff --git a/constraints.txt b/constraints.txt index 2a74fc4bb..68699a6cf 100644 --- a/constraints.txt +++ b/constraints.txt @@ -28,7 +28,7 @@ click-didyoumean==0.3.0 click-plugins==1.1.1 click-repl==0.2.0 click==8.1.3 -cryptography==37.0.4 +cryptography==36.0.2 decorator==5.1.1 defusedxml==0.7.1 django-bootstrap-form==3.4 @@ -88,6 +88,7 @@ pyOpenSSL==22.0.0 pyasn1-modules==0.2.8 pyasn1==0.4.8 pycparser==2.21 +pycryptodomex==3.15.0 pycurl==7.44.1 pydantic==1.9.1 pylibmc==1.6.1 @@ -98,6 +99,7 @@ python-dateutil==2.8.2 python-ldap==3.4.2 pytz==2022.1 opensearch-py==2.0.0 +oscrypto==1.3.0 redis==4.3.4 requests-oauthlib==1.3.1 requests==2.28.1 @@ -107,11 +109,12 @@ s3transfer==0.6.0 six==1.16.0 sniffio==1.2.0 sqlparse==0.4.2 +snowflake-connector-python==2.7.11 stack-data==0.3.0 tqdm==4.64.0 traitlets==5.3.0 typing_extensions==4.3.0 -urllib3==1.26.10 +urllib3==1.26.11 vine==5.0.0 wcwidth==0.2.5 webauthn==1.6.0 diff --git a/docs/configuration/stores.md b/docs/configuration/stores.md index bfc541a17..dcb0e3b5f 100644 --- a/docs/configuration/stores.md +++ b/docs/configuration/stores.md @@ -31,6 +31,7 @@ The python module implementing the store, as a string. Currently available: * `zentral.core.stores.backends.humio` * `zentral.core.stores.backends.kinesis` * `zentral.core.stores.backends.opensearch` +* `zentral.core.stores.backends.snowflake` * `zentral.core.stores.backends.splunk` * `zentral.core.stores.backends.sumo_logic` * `zentral.core.stores.backends.syslog` @@ -241,6 +242,71 @@ An integer between 1 and 20, 1 by default. The number of threads to use when pos } ``` +## Snowflake backend options + +The Snowflake backend is read-only. It can only be used as a `frontend` backend. To store the events in snowflake, you will have to setup a pipeline using the `Kinesis` backend, and `Kinesis Firehose` for example. + +### `account` + +**MANDATORY** + +The name of the Snowflake account + +### `user` + +**MANDATORY** + +The name of the Snowflake user + +### `password` + +**MANDATORY** + +The password of the Snowflake user + +### `database` + +**MANDATORY** + +The name of the Snowflake database + +### `schema` + +The name of the Snowflake schema. Defaults to `PUBLIC`. + +### `role` + +**MANDATORY** + +The name of the Snowflake role. + +### `warehouse` + +**MANDATORY** + +The name of the Snowflake warehouse. + +### `session_timeout` + +In seconds, the session timeout. After the current session has timed out, a new connection will be established if necessary. Defaults to 4 hours - 10 minutes. + +### Full example + +```json +{ + "backend": "zentral.core.stores.backends.snowflake", + "frontend": true, + "username": "Zentral", + "password": "{{ env:SNOWFLAKE_PASSWORD }}", + "database": "ZENTRAL", + "schema": "ZENTRAL", + "role": "ZENTRAL", + "warehouse": "DEFAULTWH", + "session_timeout": 14400 +} +``` + + ## Splunk backend options ### `hec_url` diff --git a/ee/zentral/core/stores/backends/snowflake.py b/ee/zentral/core/stores/backends/snowflake.py new file mode 100644 index 000000000..0bba06ab3 --- /dev/null +++ b/ee/zentral/core/stores/backends/snowflake.py @@ -0,0 +1,258 @@ +from datetime import timedelta +import json +import logging +import time +from django.utils import timezone +import snowflake.connector +from snowflake.connector import DictCursor +from zentral.core.events import event_from_event_d, event_types +from zentral.core.exceptions import ImproperlyConfigured +from zentral.core.stores.backends.base import BaseEventStore + + +logger = logging.getLogger("zentral.core.stores.backends.snowflake") + + +class EventStore(BaseEventStore): + read_only = True + last_machine_heartbeats = True + machine_events = True + object_events = True + probe_events = True + + def __init__(self, config_d): + super().__init__(config_d) + self._connect_kwargs = {} + # connection parameters + missing_params = [] + for k in ("account", "user", "password", "database", "schema", "role", "warehouse"): + v = config_d.get(k) + if not v: + if k == "schema": + v = "PUBLIC" + else: + missing_params.append(k) + continue + self._connect_kwargs[k] = v + if missing_params: + raise ImproperlyConfigured("Missing configuration parameters: {}".format(", ".join(missing_params))) + # connection + self._connection = None + self._last_active_at = time.monotonic() + self._session_timeout = config_d.get( + "session_timeout", + 4*3600-10*60 # 4 hours (Snowflake default) - 10 min + ) + + def _get_connection(self): + if self._connection is None or (time.monotonic() - self._last_active_at) > self._session_timeout: + account = self._connect_kwargs["account"] + if self._connection is None: + action = "Connect" + else: + logger.info("Close current connection to account %s", account) + self._connection.close() + action = "Re-connect" + logger.info("%s to account %s", action, account) + self._connection = snowflake.connector.connect(**self._connect_kwargs) + self._last_active_at = time.monotonic() + return self._connection + + def _deserialize_event(self, result): + metadata = json.loads(result['METADATA']) + metadata['type'] = result['TYPE'] + metadata['created_at'] = result['CREATED_AT'] + metadata['tags'] = json.loads(result['TAGS']) + metadata['objects'] = {} + for objref in json.loads(result['OBJECTS']): + k, v = objref.split(":", 1) + metadata['objects'].setdefault(k, []).append(v) + metadata['serial_number'] = result['SERIAL_NUMBER'] + event_d = json.loads(result.pop("PAYLOAD")) + event_d['_zentral'] = metadata + return event_from_event_d(event_d) + + def _prepare_query(self, query, args=None, **kwargs): + if args is None: + args = [] + first_filter = True + for attr, filter_tmpl in (("from_dt", "AND created_at >= %s"), + ("to_dt", "AND created_at <= %s"), + ("event_type", "AND type = %s"), + ("objref", "AND ARRAY_CONTAINS(%s::variant, objects)"), + ("probe", "AND ARRAY_CONTAINS(%s::variant, probes)"), + ("serial_number", "AND serial_number = %s"), + ("order_by", None), + ("limit", "LIMIT %s"), + ("offset", "OFFSET %s")): + val = kwargs.get(attr) + if val is not None: + if attr == "order_by": + query += f" ORDER BY {val}" + else: + if first_filter and filter_tmpl.startswith("AND "): + filter_tmpl = f"WHERE {filter_tmpl[4:]}" + query += f" {filter_tmpl}" + args.append(val) + first_filter = False + return query, args + + def _fetch_aggregated_event_counts(self, **kwargs): + query, args = self._prepare_query("SELECT TYPE, COUNT(*) AS COUNT FROM ZENTRALEVENTS", **kwargs) + query += " GROUP BY type" + cursor = self._get_connection().cursor(DictCursor) + cursor.execute(query, args) + event_counts = { + r['TYPE']: r['COUNT'] + for r in cursor.fetchall() + } + cursor.close() + return event_counts + + def _fetch_events(self, **kwargs): + kwargs["order_by"] = "CREATED_AT DESC" + offset = int(kwargs.pop("cursor", None) or 0) + if offset > 0: + kwargs["offset"] = offset + query, args = self._prepare_query("SELECT * FROM ZENTRALEVENTS", **kwargs) + cursor = self._get_connection().cursor(DictCursor) + cursor.execute(query, args) + events = [self._deserialize_event(t) for t in cursor.fetchall()] + cursor.close() + next_cursor = None + limit = kwargs.get("limit") + if limit and len(events) >= limit: + next_cursor = str(limit + kwargs.get("offset", 0)) + return events, next_cursor + + # machine events + + def fetch_machine_events(self, serial_number, from_dt, to_dt=None, event_type=None, limit=10, cursor=None): + return self._fetch_events( + serial_number=serial_number, + from_dt=from_dt, + to_dt=to_dt, + event_type=event_type, + limit=limit, + cursor=cursor + ) + + def get_aggregated_machine_event_counts(self, serial_number, from_dt, to_dt=None): + return self._fetch_aggregated_event_counts( + serial_number=serial_number, + from_dt=from_dt, + to_dt=to_dt + ) + + def get_last_machine_heartbeats(self, serial_number, from_dt): + heartbeats = {} + query = ( + "SELECT TYPE, MAX(CREATED_AT) LAST_SEEN," + "PAYLOAD:source.name::varchar SOURCE_NAME, NULL USER_AGENT " + "FROM ZENTRALEVENTS " + "WHERE CREATED_AT >= %s " + "AND TYPE = 'inventory_heartbeat' " + "AND SERIAL_NUMBER = %s " + "GROUP BY TYPE, SOURCE_NAME, USER_AGENT " + + "UNION " + + "SELECT TYPE, MAX(CREATED_AT) LAST_SEEN," + "NULL SOURCE_NAME, METADATA:request.user_agent::varchar USER_AGENT " + "FROM ZENTRALEVENTS " + "WHERE CREATED_AT >= %s " + "AND TYPE <> 'inventory_heartbeat' " + "AND ARRAY_CONTAINS('heartbeat'::variant, TAGS) " + "AND SERIAL_NUMBER = %s " + "GROUP BY TYPE, SOURCE_NAME, USER_AGENT" + ) + args = [from_dt, serial_number, from_dt, serial_number] + cursor = self._get_connection().cursor(DictCursor) + cursor.execute(query, args) + for t in cursor.fetchall(): + event_class = event_types.get(t['TYPE']) + if not event_class: + logger.error("Unknown event type %s", t['TYPE']) + continue + key = (event_class, t['SOURCE_NAME']) + heartbeats.setdefault(key, []).append((t['USER_AGENT'], t['LAST_SEEN'])) + cursor.close() + return [ + (event_class, source_name, ua_max_dates) + for (event_class, source_name), ua_max_dates in heartbeats.items() + ] + + # object events + + def fetch_object_events(self, key, val, from_dt, to_dt=None, event_type=None, limit=10, cursor=None): + return self._fetch_events( + objref=f"{key}:{val}", + from_dt=from_dt, + to_dt=to_dt, + event_type=event_type, + limit=limit, + cursor=cursor + ) + + def get_aggregated_object_event_counts(self, key, val, from_dt, to_dt=None): + return self._fetch_aggregated_event_counts( + objref=f"{key}:{val}", + from_dt=from_dt, + to_dt=to_dt + ) + + # probe events + + def fetch_probe_events(self, probe, from_dt, to_dt=None, event_type=None, limit=10, cursor=None): + return self._fetch_events( + probe=probe.pk, + from_dt=from_dt, + to_dt=to_dt, + event_type=event_type, + limit=limit, + cursor=cursor + ) + + def get_aggregated_probe_event_counts(self, probe, from_dt, to_dt=None): + return self._fetch_aggregated_event_counts( + probe=probe.pk, + from_dt=from_dt, + to_dt=to_dt + ) + + # zentral apps data + + def get_app_hist_data(self, interval, bucket_number, tag): + data = [] + query = ( + "SELECT COUNT(*) EVENT_COUNT, COUNT(DISTINCT SERIAL_NUMBER) MACHINE_COUNT," + "DATE_TRUNC(%s, CREATED_AT) BUCKET " + "FROM ZENTRALEVENTS " + "WHERE ARRAY_CONTAINS(%s::variant, TAGS) " + "GROUP BY BUCKET ORDER BY BUCKET DESC" + ) + if interval == "day": + args = ["DAY", tag] + last_value = timezone.now().replace(hour=0, minute=0, second=0, microsecond=0) + delta = timedelta(days=1) + elif interval == "hour": + args = ["HOUR", tag] + last_value = timezone.now().replace(minute=0, second=0, microsecond=0) + delta = timedelta(hours=1) + else: + logger.error("Unsupported interval %s", interval) + return data + cursor = self._get_connection().cursor(DictCursor) + cursor.execute(query, args) + results = { + t['BUCKET']: (t['EVENT_COUNT'], t['MACHINE_COUNT']) + for t in cursor.fetchall() + } + cursor.close() + for bucket in (last_value - i * delta for i in range(bucket_number - 1, -1, -1)): + try: + event_count, machine_count = results[bucket] + except KeyError: + event_count = machine_count = 0 + data.append((bucket, event_count, machine_count)) + return data diff --git a/requirements.txt b/requirements.txt index 9a6f2042f..1c13b0721 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,6 +31,7 @@ pyyaml requests requests_oauthlib # MDM DEP setuptools +snowflake-connector-python sqlparse # SQL syntax highlighting tqdm XlsxWriter diff --git a/tests/queues/test_workers.py b/tests/queues/test_workers.py new file mode 100644 index 000000000..437c7efab --- /dev/null +++ b/tests/queues/test_workers.py @@ -0,0 +1,15 @@ +from django.test import SimpleTestCase +from zentral.core.queues.workers import get_workers + + +class QueuesWorkersTestCase(SimpleTestCase): + maxDiff = None + + def test_workers(self): + worker_names = set(w.name for w in get_workers()) + self.assertEqual( + worker_names, + {"inventory worker dummy", + "preprocess worker", "enrich worker", "process worker", + "store worker elasticsearch"} + ) diff --git a/tests/stores/test_snowflake.py b/tests/stores/test_snowflake.py new file mode 100644 index 000000000..780392ba7 --- /dev/null +++ b/tests/stores/test_snowflake.py @@ -0,0 +1,444 @@ +from datetime import datetime +import uuid +from unittest.mock import patch, Mock +from django.test import SimpleTestCase +from django.utils.crypto import get_random_string +from accounts.events import EventMetadata, LoginEvent +from zentral.contrib.inventory.events import InventoryHeartbeat +from zentral.contrib.osquery.events import OsqueryRequestEvent +from zentral.core.exceptions import ImproperlyConfigured +from zentral.core.stores.backends.snowflake import EventStore + + +class SnowflakeStoreTestCase(SimpleTestCase): + maxDiff = None + + def get_store(self, **kwargs): + kwargs["store_name"] = get_random_string() + return EventStore(kwargs) + + def get_default_store(self): + return self.get_store( + account="account", + user="user", + password="password", + database="database", + role="role", + warehouse="warehouse" + ) + + def build_login_event(self, username=None): + if username is None: + username = get_random_string(12) + return LoginEvent(EventMetadata(), {"user": {"username": username}}) + + def test_default_store(self): + store = self.get_default_store() + self.assertEqual( + store._connect_kwargs, + {"account": "account", + "user": "user", + "password": "password", + "database": "database", + "schema": "PUBLIC", + "role": "role", + "warehouse": "warehouse"} + ) + self.assertEqual(store._session_timeout, 13800) + + def test_store_with_schema_and_session_timeout(self): + store = self.get_store( + account="account", + user="user", + password="password", + database="database", + schema="ZENTRAL", + role="role", + warehouse="warehouse", + session_timeout=123, + ) + self.assertEqual( + store._connect_kwargs, + {"account": "account", + "user": "user", + "password": "password", + "database": "database", + "schema": "ZENTRAL", + "role": "role", + "warehouse": "warehouse"} + ) + self.assertEqual(store._session_timeout, 123) + + def test_store_missing_parameters(self): + with self.assertRaises(ImproperlyConfigured) as cm: + self.get_store() + self.assertEqual( + cm.exception.args[0], + "Missing configuration parameters: account, user, password, database, role, warehouse" + ) + + def test_deserialize_event(self): + serialized_event = { + 'CREATED_AT': '2022-08-20T09:50:03.848542', + 'METADATA': '{"id": "d304f4f6-7a2f-4d1e-91f6-da673104748b", "index": 3, ' + '"namespace": "zentral", "request": {"user_agent": "user_agent", ' + '"ip": "203.0.113.10"}, "probes": [{"pk": 18, "name": ' + '"DfARpBxpYIBq"}]}', + 'OBJECTS': '["osquery_enrollment:19"]', + 'PAYLOAD': '{"user": {"username": "QeI99eAhCmWH"}}', + 'PROBES': '[18]', + 'SERIAL_NUMBER': None, + 'TAGS': '["yolo", "fomo", "zentral"]', + 'TYPE': 'zentral_login' + } + event = self.get_default_store()._deserialize_event(serialized_event) + self.assertIsInstance(event, LoginEvent) + metadata = event.metadata + self.assertEqual(set(metadata.tags), {"yolo", "fomo", "zentral"}) + self.assertEqual( + metadata.objects, + {"osquery_enrollment": [["19"]]} + ) + self.assertEqual( + metadata.probes, + [{"pk": 18, "name": "DfARpBxpYIBq"}] + ) + self.assertEqual( + event.payload, + {"user": {"username": "QeI99eAhCmWH"}} + ) + + def test_prepare_query(self): + query, args = self.get_default_store()._prepare_query( + "SELECT * FROM ZENTRALEVENTS", + from_dt=datetime(2022, 1, 1), + to_dt=datetime(2023, 1, 1), + event_type="zentral_login", + objref="osquery_enrollment:19", + probe=18, + serial_number="0123456789", + order_by="CREATED_AT DESC", + limit=10, + offset=20 + ) + self.assertEqual( + query, + "SELECT * FROM ZENTRALEVENTS " + "WHERE created_at >= %s " + "AND created_at <= %s " + "AND type = %s " + "AND ARRAY_CONTAINS(%s::variant, objects) " + "AND ARRAY_CONTAINS(%s::variant, probes) " + "AND serial_number = %s " + "ORDER BY CREATED_AT DESC " + "LIMIT %s " + "OFFSET %s" + ) + self.assertEqual( + args, + [datetime(2022, 1, 1), + datetime(2023, 1, 1), + "zentral_login", + "osquery_enrollment:19", + 18, + "0123456789", + 10, + 20] + ) + + @patch("zentral.core.stores.backends.snowflake.snowflake.connector.connect") + def test_new_connection(self, connect): + store = self.get_default_store() + store._get_connection() + connect.assert_called_once_with(**store._connect_kwargs) + + @patch("zentral.core.stores.backends.snowflake.snowflake.connector.connect") + def test_reuse_connection(self, connect): + store = self.get_default_store() + store._get_connection() + store._get_connection() + connect.assert_called_once_with(**store._connect_kwargs) + + @patch("zentral.core.stores.backends.snowflake.snowflake.connector.connect") + def test_reconnect_connection(self, connect): + connection = Mock() + connect.return_value = connection + store = self.get_default_store() + store._get_connection() + # fake expired session + store._last_active_at -= store._session_timeout + store._get_connection() + connection.close.assert_called_once_with() + self.assertEqual(connect.call_count, 2) + + @patch("zentral.core.stores.backends.snowflake.snowflake.connector.connect") + def test_fetch_machine_events(self, connect): + cursor = Mock() + cursor.fetchall.return_value = [ + {'CREATED_AT': '2022-08-20T09:50:03.848542', + 'METADATA': '{"id": "d304f4f6-7a2f-4d1e-91f6-da673104748b", "index": 3, ' + '"namespace": "zentral", "request": {"user_agent": "user_agent", ' + '"ip": "203.0.113.10"}, "probes": [{"pk": 18, "name": ' + '"DfARpBxpYIBq"}]}', + 'OBJECTS': '["osquery_enrollment:19"]', + 'PAYLOAD': '{"user": {"username": "QeI99eAhCmWH"}}', + 'PROBES': '[18]', + 'SERIAL_NUMBER': "0123456789", + 'TAGS': '["yolo", "fomo", "zentral"]', + 'TYPE': 'zentral_login'} + ] + connection = Mock() + connection.cursor.return_value = cursor + connect.return_value = connection + store = self.get_default_store() + events, next_cursor = store.fetch_machine_events("0123456789", datetime(2022, 1, 1), limit=1, cursor="2") + cursor.execute.assert_called_once_with( + "SELECT * FROM ZENTRALEVENTS WHERE created_at >= %s " + "AND serial_number = %s " + "ORDER BY CREATED_AT DESC LIMIT %s OFFSET %s", + [datetime(2022, 1, 1), "0123456789", 1, 2] + ) + self.assertEqual(len(events), 1) + event = events[0] + self.assertIsInstance(event, LoginEvent) + self.assertEqual(event.metadata.uuid, uuid.UUID("d304f4f6-7a2f-4d1e-91f6-da673104748b")) + self.assertEqual(event.metadata.index, 3) + self.assertEqual(next_cursor, "3") + + @patch("zentral.core.stores.backends.snowflake.snowflake.connector.connect") + def test_fetch_machine_events_no_next_cursor(self, connect): + cursor = Mock() + cursor.fetchall.return_value = [ + {'CREATED_AT': '2022-08-20T09:50:03.848542', + 'METADATA': '{"id": "d304f4f6-7a2f-4d1e-91f6-da673104748b", "index": 3, ' + '"namespace": "zentral", "request": {"user_agent": "user_agent", ' + '"ip": "203.0.113.10"}, "probes": [{"pk": 18, "name": ' + '"DfARpBxpYIBq"}]}', + 'OBJECTS': '["osquery_enrollment:19"]', + 'PAYLOAD': '{"user": {"username": "QeI99eAhCmWH"}}', + 'PROBES': '[18]', + 'SERIAL_NUMBER': "0123456789", + 'TAGS': '["yolo", "fomo", "zentral"]', + 'TYPE': 'zentral_login'} + ] + connection = Mock() + connection.cursor.return_value = cursor + connect.return_value = connection + store = self.get_default_store() + events, next_cursor = store.fetch_machine_events("0123456789", datetime(2022, 1, 1), limit=10, cursor="20") + connect.assert_called_once_with(**store._connect_kwargs) + cursor.execute.assert_called_once_with( + "SELECT * FROM ZENTRALEVENTS WHERE created_at >= %s " + "AND serial_number = %s " + "ORDER BY CREATED_AT DESC LIMIT %s OFFSET %s", + [datetime(2022, 1, 1), "0123456789", 10, 20] + ) + self.assertEqual(len(events), 1) + event = events[0] + self.assertIsInstance(event, LoginEvent) + self.assertEqual(event.metadata.uuid, uuid.UUID("d304f4f6-7a2f-4d1e-91f6-da673104748b")) + self.assertEqual(event.metadata.index, 3) + self.assertIsNone(next_cursor) + + @patch("zentral.core.stores.backends.snowflake.snowflake.connector.connect") + def test_get_aggregated_machine_event_counts(self, connect): + cursor = Mock() + cursor.fetchall.return_value = [ + {"TYPE": "osquery_request", "COUNT": 17}, + {"TYPE": "munki_enrollment", "COUNT": 16} + ] + connection = Mock() + connection.cursor.return_value = cursor + connect.return_value = connection + store = self.get_default_store() + self.assertEqual( + store.get_aggregated_machine_event_counts("0123456789", datetime(2022, 1, 1)), + {"osquery_request": 17, "munki_enrollment": 16} + ) + cursor.execute.assert_called_once_with( + "SELECT TYPE, COUNT(*) AS COUNT FROM ZENTRALEVENTS " + "WHERE created_at >= %s AND serial_number = %s " + "GROUP BY type", + [datetime(2022, 1, 1), "0123456789"] + ) + + @patch("zentral.core.stores.backends.snowflake.snowflake.connector.connect") + def test_get_last_machine_heartbeats(self, connect): + cursor = Mock() + cursor.fetchall.return_value = [ + {"TYPE": "osquery_request", "LAST_SEEN": datetime(2022, 8, 1), + "SOURCE_NAME": None, "USER_AGENT": "osquery/5.4.0"}, + {"TYPE": "osquery_request", "LAST_SEEN": datetime(2022, 7, 1), + "SOURCE_NAME": None, "USER_AGENT": "osquery/5.3.0"}, + {"TYPE": "inventory_heartbeat", "LAST_SEEN": datetime(2022, 8, 2), + "SOURCE_NAME": "Santa", "USER_AGENT": None} + ] + connection = Mock() + connection.cursor.return_value = cursor + connect.return_value = connection + store = self.get_default_store() + self.assertEqual( + store.get_last_machine_heartbeats("0123456789", datetime(2022, 1, 1)), + [(OsqueryRequestEvent, None, [("osquery/5.4.0", datetime(2022, 8, 1)), + ("osquery/5.3.0", datetime(2022, 7, 1))]), + (InventoryHeartbeat, "Santa", [(None, datetime(2022, 8, 2))])] + ) + + @patch("zentral.core.stores.backends.snowflake.snowflake.connector.connect") + def test_object_events(self, connect): + cursor = Mock() + cursor.fetchall.return_value = [ + {'CREATED_AT': '2022-08-20T09:50:03.848542', + 'METADATA': '{"id": "d304f4f6-7a2f-4d1e-91f6-da673104748b", "index": 3, ' + '"namespace": "zentral", "request": {"user_agent": "user_agent", ' + '"ip": "203.0.113.10"}, "probes": [{"pk": 18, "name": ' + '"DfARpBxpYIBq"}]}', + 'OBJECTS': '["osquery_enrollment:19"]', + 'PAYLOAD': '{"user": {"username": "QeI99eAhCmWH"}}', + 'PROBES': '[18]', + 'SERIAL_NUMBER': "0123456789", + 'TAGS': '["yolo", "fomo", "zentral"]', + 'TYPE': 'zentral_login'} + ] + connection = Mock() + connection.cursor.return_value = cursor + connect.return_value = connection + store = self.get_default_store() + events, next_cursor = store.fetch_object_events( + "osquery_enrollment", "19", + datetime(2022, 1, 1), limit=1, cursor="2" + ) + connect.assert_called_once_with(**store._connect_kwargs) + cursor.execute.assert_called_once_with( + "SELECT * FROM ZENTRALEVENTS WHERE created_at >= %s " + "AND ARRAY_CONTAINS(%s::variant, objects) " + "ORDER BY CREATED_AT DESC LIMIT %s OFFSET %s", + [datetime(2022, 1, 1), "osquery_enrollment:19", 1, 2] + ) + self.assertEqual(len(events), 1) + event = events[0] + self.assertIsInstance(event, LoginEvent) + self.assertEqual(event.metadata.uuid, uuid.UUID("d304f4f6-7a2f-4d1e-91f6-da673104748b")) + self.assertEqual(event.metadata.index, 3) + self.assertEqual(next_cursor, "3") + + @patch("zentral.core.stores.backends.snowflake.snowflake.connector.connect") + def test_get_aggregated_object_event_counts(self, connect): + cursor = Mock() + cursor.fetchall.return_value = [ + {"TYPE": "osquery_enrollment", "COUNT": 17}, + ] + connection = Mock() + connection.cursor.return_value = cursor + connect.return_value = connection + store = self.get_default_store() + self.assertEqual( + store.get_aggregated_object_event_counts("osquery_enrollment", "19", datetime(2022, 1, 1)), + {"osquery_enrollment": 17} + ) + cursor.execute.assert_called_once_with( + "SELECT TYPE, COUNT(*) AS COUNT FROM ZENTRALEVENTS " + "WHERE created_at >= %s AND ARRAY_CONTAINS(%s::variant, objects) " + "GROUP BY type", + [datetime(2022, 1, 1), "osquery_enrollment:19"] + ) + + @patch("zentral.core.stores.backends.snowflake.snowflake.connector.connect") + def test_probe_events(self, connect): + cursor = Mock() + cursor.fetchall.return_value = [ + {'CREATED_AT': '2022-08-20T09:50:03.848542', + 'METADATA': '{"id": "d304f4f6-7a2f-4d1e-91f6-da673104748b", "index": 3, ' + '"namespace": "zentral", "request": {"user_agent": "user_agent", ' + '"ip": "203.0.113.10"}, "probes": [{"pk": 18, "name": ' + '"DfARpBxpYIBq"}]}', + 'OBJECTS': '["osquery_enrollment:19"]', + 'PAYLOAD': '{"user": {"username": "QeI99eAhCmWH"}}', + 'PROBES': '[18]', + 'SERIAL_NUMBER': "0123456789", + 'TAGS': '["yolo", "fomo", "zentral"]', + 'TYPE': 'zentral_login'} + ] + connection = Mock() + connection.cursor.return_value = cursor + connect.return_value = connection + store = self.get_default_store() + probe = Mock(pk=18) + events, next_cursor = store.fetch_probe_events(probe, datetime(2022, 1, 1), limit=1, cursor="2") + connect.assert_called_once_with(**store._connect_kwargs) + cursor.execute.assert_called_once_with( + "SELECT * FROM ZENTRALEVENTS WHERE created_at >= %s " + "AND ARRAY_CONTAINS(%s::variant, probes) " + "ORDER BY CREATED_AT DESC LIMIT %s OFFSET %s", + [datetime(2022, 1, 1), 18, 1, 2] + ) + self.assertEqual(len(events), 1) + event = events[0] + self.assertIsInstance(event, LoginEvent) + self.assertEqual(event.metadata.uuid, uuid.UUID("d304f4f6-7a2f-4d1e-91f6-da673104748b")) + self.assertEqual(event.metadata.index, 3) + self.assertEqual(next_cursor, "3") + + @patch("zentral.core.stores.backends.snowflake.snowflake.connector.connect") + def test_get_aggregated_probe_event_counts(self, connect): + cursor = Mock() + cursor.fetchall.return_value = [ + {"TYPE": "osquery_enrollment", "COUNT": 17}, + ] + connection = Mock() + connection.cursor.return_value = cursor + connect.return_value = connection + store = self.get_default_store() + probe = Mock(pk=18) + self.assertEqual( + store.get_aggregated_probe_event_counts(probe, datetime(2022, 1, 1)), + {"osquery_enrollment": 17} + ) + cursor.execute.assert_called_once_with( + "SELECT TYPE, COUNT(*) AS COUNT FROM ZENTRALEVENTS " + "WHERE created_at >= %s AND ARRAY_CONTAINS(%s::variant, probes) " + "GROUP BY type", + [datetime(2022, 1, 1), 18] + ) + + def test_get_app_hist_data_unsupported_interval(self): + self.assertEqual(self.get_default_store().get_app_hist_data("yolo", 12, "fomo"), []) + + @patch("zentral.core.stores.backends.snowflake.timezone.now") + @patch("zentral.core.stores.backends.snowflake.snowflake.connector.connect") + def test_get_hourly_app_hist_data(self, connect, tznow): + tznow.return_value = datetime(2022, 9, 20, 11, 17) + cursor = Mock() + cursor.fetchall.return_value = [ + {'EVENT_COUNT': 323, 'MACHINE_COUNT': 5, 'BUCKET': datetime(2022, 9, 20, 10)} + ] + connection = Mock() + connection.cursor.return_value = cursor + connect.return_value = connection + store = self.get_default_store() + self.assertEqual( + store.get_app_hist_data("hour", 3, "osquery"), + [(datetime(2022, 9, 20, 9), 0, 0), + (datetime(2022, 9, 20, 10), 323, 5), + (datetime(2022, 9, 20, 11), 0, 0)] + ) + + @patch("zentral.core.stores.backends.snowflake.timezone.now") + @patch("zentral.core.stores.backends.snowflake.snowflake.connector.connect") + def test_get_daily_app_hist_data(self, connect, tznow): + tznow.return_value = datetime(2022, 9, 20, 11, 17) + cursor = Mock() + cursor.fetchall.return_value = [ + {'EVENT_COUNT': 322, 'MACHINE_COUNT': 4, 'BUCKET': datetime(2022, 9, 19)} + ] + connection = Mock() + connection.cursor.return_value = cursor + connect.return_value = connection + store = self.get_default_store() + self.assertEqual( + store.get_app_hist_data("day", 4, "osquery"), + [(datetime(2022, 9, 17), 0, 0), + (datetime(2022, 9, 18), 0, 0), + (datetime(2022, 9, 19), 322, 4), + (datetime(2022, 9, 20), 0, 0)] + ) diff --git a/zentral/core/queues/workers.py b/zentral/core/queues/workers.py index f301cf171..ea4f30118 100644 --- a/zentral/core/queues/workers.py +++ b/zentral/core/queues/workers.py @@ -9,7 +9,7 @@ def get_workers(): yield queues.get_preprocess_worker() yield queues.get_enrich_worker(enrich_event) yield queues.get_process_worker(process_event) - for store in stores: + for store in stores.iter_queue_worker_stores(): yield queues.get_store_worker(store) # extra apps workers for app in settings['apps']: diff --git a/zentral/core/stores/backends/base.py b/zentral/core/stores/backends/base.py index 6071bf313..9533ffacb 100644 --- a/zentral/core/stores/backends/base.py +++ b/zentral/core/stores/backends/base.py @@ -3,6 +3,7 @@ class BaseEventStore(object): + read_only = False # if read only, we do not need a store worker max_batch_size = 1 max_concurrency = 1 machine_events = False diff --git a/zentral/core/stores/conf.py b/zentral/core/stores/conf.py index 574def8c7..9b8036b7c 100644 --- a/zentral/core/stores/conf.py +++ b/zentral/core/stores/conf.py @@ -57,6 +57,11 @@ def iter_events_url_store_for_user(self, key, user): continue yield store + def iter_queue_worker_stores(self): + for store in self.stores.values(): + if not store.read_only: + yield store + stores = SimpleLazyObject(lambda: Stores(settings)) frontend_store = SimpleLazyObject(lambda: stores.frontend_store)