diff --git a/blitzortung/cli/start_webservice.py b/blitzortung/cli/start_webservice.py index d51ba8e..500911b 100755 --- a/blitzortung/cli/start_webservice.py +++ b/blitzortung/cli/start_webservice.py @@ -1,10 +1,16 @@ +"""Blitzortung webservice entry point.""" + import os.path import sys from twisted.scripts.twistd import run +# Import webservice to trigger application setup and connection pool creation +import blitzortung.cli.webservice # noqa: F401 + def main(): + """Run the twistd server.""" target_dir = os.path.dirname(os.path.abspath(__file__)) args = ["twistd"] diff --git a/blitzortung/cli/webservice.py b/blitzortung/cli/webservice.py index e6655cb..7bf87a2 100755 --- a/blitzortung/cli/webservice.py +++ b/blitzortung/cli/webservice.py @@ -1,446 +1,61 @@ -import calendar -import collections -import datetime -import gc -import json -import os -import platform -import time +"""Blitzortung webservice entry point for twistd.""" -from typing import Any +import os from twisted.application import internet, service -from twisted.internet.defer import succeed from twisted.internet.error import ReactorAlreadyInstalledError from twisted.python import log -from twisted.python.log import FileLogObserver, ILogObserver, textFromEventDict, _safeFormat +from twisted.python.log import ILogObserver from twisted.python.logfile import DailyLogFile -from twisted.python.util import untilConcludes from twisted.web import server -from txjsonrpc_ng.web import jsonrpc -from txjsonrpc_ng.web.data import CacheableResult -from txjsonrpc_ng.web.jsonrpc import with_request - -from blitzortung.gis.constants import grid, global_grid -from blitzortung.gis.local_grid import LocalGrid -from blitzortung.service.cache import ServiceCache -from blitzortung.service.metrics import StatsDMetrics -from blitzortung.util import TimeConstraint - -JSON_CONTENT_TYPE = 'text/json' +# Install epoll/kqueue reactor for better performance (if not already installed) try: from twisted.internet import epollreactor # type: ignore[attr-defined, no-redef] - reactor = epollreactor -except ImportError: - from twisted.internet import kqreactor as reactor # type: ignore[assignment, no-redef] - -try: - reactor.install() -except ReactorAlreadyInstalledError: - pass - -import blitzortung.cache + epollreactor.install() +except (ImportError, ReactorAlreadyInstalledError): + try: + from twisted.internet import kqreactor # type: ignore[assignment, no-redef] + kqreactor.install() + except (ImportError, ReactorAlreadyInstalledError): + pass + +from blitzortung.service.base import Blitzortung, LogObserver import blitzortung.config -import blitzortung.db -import blitzortung.geom -import blitzortung.service -from blitzortung.db.query import TimeInterval -from blitzortung.service.db import create_connection_pool -from blitzortung.service.general import create_time_interval -from blitzortung.service.strike_grid import GridParameters - -is_pypy = platform.python_implementation() == 'PyPy' - -FORBIDDEN_IPS: dict[str, Any] = {} - -USER_AGENT_PREFIX = 'bo-android-' - - -class Blitzortung(jsonrpc.JSONRPC): - """ - Blitzortung.org JSON-RPC webservice for lightning strike data. - - Provides endpoints for querying strike data, grid-based visualizations, - and histograms with caching and rate limiting. - """ - - # Grid validation constants - MIN_GRID_BASE_LENGTH = 5000 - INVALID_GRID_BASE_LENGTH = 1000001 - GLOBAL_MIN_GRID_BASE_LENGTH = 10000 - - # Time validation constants - MAX_MINUTES_PER_DAY = 24 * 60 # 1440 minutes - DEFAULT_MINUTE_LENGTH = 60 - HISTOGRAM_MINUTE_THRESHOLD = 10 - - # User agent validation constants - MAX_COMPATIBLE_ANDROID_VERSION = 177 - - # Memory info interval - MEMORY_INFO_INTERVAL = 300 # 5 minutes - - def __init__(self, db_connection_pool, log_directory): - super().__init__() - self.connection_pool = db_connection_pool - self.log_directory = log_directory - self.strike_query = blitzortung.service.strike_query() - self.strike_grid_query = blitzortung.service.strike_grid_query() - self.global_strike_grid_query = blitzortung.service.global_strike_grid_query() - self.histogram_query = blitzortung.service.histogram_query() - self.check_count = 0 - self.cache = ServiceCache() - self.current_period = self.__current_period() - self.current_data = collections.defaultdict(list) - self.next_memory_info = 0.0 - self.minute_constraints = TimeConstraint(self.DEFAULT_MINUTE_LENGTH, self.MAX_MINUTES_PER_DAY) - - self.metrics = StatsDMetrics() - - addSlash = True - - def __get_epoch(self, timestamp): - return calendar.timegm(timestamp.timetuple()) * 1000000 + timestamp.microsecond - - def __current_period(self): - return datetime.datetime.now(datetime.UTC).replace(second=0, microsecond=0) - - def __check_period(self): - if self.current_period != self.__current_period(): - self.current_data['timestamp'] = self.__get_epoch(self.current_period) - if log_directory: - with open(os.path.join(log_directory, self.current_period.strftime("%Y%m%d-%H%M.json")), - 'w') as output_file: - output_file.write(json.dumps(self.current_data)) - self.__restart_period() - - def __restart_period(self): - self.current_period = self.__current_period() - self.current_data = collections.defaultdict(list) - - @staticmethod - def __force_range(number, min_number, max_number): - if number < min_number: - return min_number - elif number > max_number: - return max_number - else: - return number - - def jsonrpc_check(self): - self.check_count += 1 - return {'count': self.check_count} - - @with_request - def jsonrpc_get_strikes(self, request, minute_length, id_or_offset=0): - """This endpoint is currently blocked for all requests.""" - minute_length = self.__force_range(minute_length, 0, self.MAX_MINUTES_PER_DAY) - - client = self.get_request_client(request) - user_agent = request.getHeader("User-Agent") - log.msg('get_strikes(%d, %d) %s %s BLOCKED' % (minute_length, id_or_offset, client, user_agent)) - return None - - def get_strikes_grid(self, minute_length, grid_baselength, minute_offset, region, count_threshold): - grid_parameters = GridParameters(grid[region].get_for(grid_baselength), grid_baselength, region, - count_threshold=count_threshold) - time_interval = create_time_interval(minute_length, minute_offset) - - grid_result, state = self.strike_grid_query.create(grid_parameters, time_interval, self.connection_pool, - self.metrics.statsd) - - histogram_result = self.get_histogram(time_interval, envelope=grid_parameters.grid) \ - if minute_length > self.HISTOGRAM_MINUTE_THRESHOLD else succeed([]) - - combined_result = self.strike_grid_query.combine_result(grid_result, histogram_result, state) - - combined_result.addCallback(lambda value: CacheableResult(value)) - - return combined_result - - def get_global_strikes_grid(self, minute_length, grid_baselength, minute_offset, count_threshold): - grid_parameters = GridParameters(global_grid.get_for(grid_baselength), grid_baselength, - count_threshold=count_threshold) - time_interval = create_time_interval(minute_length, minute_offset) - - grid_result, state = self.global_strike_grid_query.create(grid_parameters, time_interval, self.connection_pool, - self.metrics.statsd) - - histogram_result = self.get_histogram( - time_interval) if minute_length > self.HISTOGRAM_MINUTE_THRESHOLD else succeed([]) - - combined_result = self.strike_grid_query.combine_result(grid_result, histogram_result, state) - - combined_result.addCallback(lambda value: CacheableResult(value)) - - return combined_result - - def get_local_strikes_grid(self, x, y, grid_baselength, minute_length, minute_offset, count_threshold, data_area=5): - local_grid = LocalGrid(data_area=data_area, x=x, y=y) - grid_factory = local_grid.get_grid_factory() - grid_parameters = GridParameters(grid_factory.get_for(grid_baselength), grid_baselength, - count_threshold=count_threshold) - time_interval = create_time_interval(minute_length, minute_offset) - - grid_result, state = self.strike_grid_query.create(grid_parameters, time_interval, self.connection_pool, - self.metrics.statsd) - - histogram_result = self.get_histogram(time_interval, envelope=grid_parameters.grid) \ - if minute_length > self.HISTOGRAM_MINUTE_THRESHOLD else succeed([]) - - combined_result = self.strike_grid_query.combine_result(grid_result, histogram_result, state) - - combined_result.addCallback(lambda value: CacheableResult(value)) - - return combined_result - - @with_request - def jsonrpc_get_strikes_raster(self, request, minute_length, grid_base_length=10000, minute_offset=0, region=1): - return self.jsonrpc_get_strikes_grid(request, minute_length, grid_base_length, minute_offset, region) - - @with_request - def jsonrpc_get_strokes_raster(self, request, minute_length, grid_base_length=10000, minute_offset=0, region=1): - return self.jsonrpc_get_strikes_grid(request, minute_length, grid_base_length, minute_offset, region) - - @with_request - def jsonrpc_get_global_strikes_grid(self, request, minute_length, grid_base_length=10000, minute_offset=0, - count_threshold=0): - self.memory_info() - client = self.get_request_client(request) - user_agent, user_agent_version = self.parse_user_agent(request) - - if client in FORBIDDEN_IPS or user_agent_version == 0 or request.getHeader( - 'content-type') != JSON_CONTENT_TYPE or request.getHeader( - 'referer') == '' or grid_base_length < self.MIN_GRID_BASE_LENGTH or grid_base_length == self.INVALID_GRID_BASE_LENGTH: - log.msg( - f"FORBIDDEN - client: {client}, user agent: {user_agent_version}, content type: {request.getHeader('content-type')}, referer: {request.getHeader('referer')}") - log.msg('get_global_strikes_grid(%d, %d, %d, >=%d) BLOCKED %.1f%% %s %s' % ( - minute_length, grid_base_length, minute_offset, count_threshold, - 0, client, user_agent)) - return {} - - original_grid_base_length = grid_base_length - grid_base_length = max(self.GLOBAL_MIN_GRID_BASE_LENGTH, grid_base_length) - minute_length, minute_offset = self.minute_constraints.enforce(minute_length, minute_offset, ) - count_threshold = max(0, count_threshold) - - cache = self.cache.global_strikes(minute_offset) - response = cache.get(self.get_global_strikes_grid, minute_length=minute_length, - grid_baselength=grid_base_length, - minute_offset=minute_offset, - count_threshold=count_threshold) - self.fix_bad_accept_header(request, user_agent) - - log.msg('get_global_strikes_grid(%d, %d, %d, >=%d) %.1f%% %s %s' % ( - minute_length, grid_base_length, minute_offset, count_threshold, - cache.get_ratio() * 100, client, user_agent)) - - self.__check_period() - self.current_data['get_strikes_grid'].append( - (self.__get_epoch(datetime.datetime.now(datetime.UTC)), minute_length, original_grid_base_length, - minute_offset, - 0, count_threshold, client, user_agent)) - - self.metrics.for_global_strikes(minute_length, cache.get_ratio()) - - return response - - @with_request - def jsonrpc_get_local_strikes_grid(self, request, x, y, grid_base_length=10000, minute_length=60, minute_offset=0, - count_threshold=0, data_area=5): - self.memory_info() - client = self.get_request_client(request) - user_agent, user_agent_version = self.parse_user_agent(request) - - if client in FORBIDDEN_IPS or request.getHeader( - 'content-type') != JSON_CONTENT_TYPE or request.getHeader( - 'referer') == '' or grid_base_length < self.MIN_GRID_BASE_LENGTH or grid_base_length == self.INVALID_GRID_BASE_LENGTH: - log.msg( - f"FORBIDDEN - client: {client}, user agent: {user_agent_version}, content type: {request.getHeader('content-type')}, referer: {request.getHeader('referer')}") - log.msg('get_local_strikes_grid(%d, %d, %d, %d, %d, >=%d, %d) BLOCKED %.1f%% %s %s' % ( - x, y, grid_base_length, minute_length, minute_offset, count_threshold, data_area, - 0, client, user_agent)) - return {} - - original_grid_base_length = grid_base_length - grid_base_length = max(self.MIN_GRID_BASE_LENGTH, grid_base_length) - minute_length, minute_offset = self.minute_constraints.enforce(minute_length, minute_offset, ) - count_threshold = max(0, count_threshold) - data_area = round(max(5, data_area)) - - cache = self.cache.local_strikes(minute_offset) - response = cache.get(self.get_local_strikes_grid, x=x, y=y, - grid_baselength=grid_base_length, - minute_length=minute_length, - minute_offset=minute_offset, - count_threshold=count_threshold, - data_area=data_area) - - log.msg('get_local_strikes_grid(%d, %d, %d, %d, %d, >=%d, %d) %.1f%% %d# %s %s' % ( - x, y, minute_length, grid_base_length, minute_offset, count_threshold, data_area, - cache.get_ratio() * 100, cache.get_size(), client, - user_agent)) - - self.__check_period() - self.current_data['get_strikes_grid'].append( - ( - self.__get_epoch(datetime.datetime.now(datetime.UTC)), minute_length, original_grid_base_length, - minute_offset, - -1, count_threshold, client, user_agent, x, y, data_area)) - - self.metrics.for_local_strikes(minute_length, data_area, cache.get_ratio()) - - return response - - @with_request - def jsonrpc_get_strikes_grid(self, request, minute_length, grid_base_length=10000, minute_offset=0, region=1, - count_threshold=0): - self.memory_info() - client = self.get_request_client(request) - user_agent, user_agent_version = self.parse_user_agent(request) - - if client in FORBIDDEN_IPS or user_agent_version == 0 or request.getHeader( - 'content-type') != JSON_CONTENT_TYPE or request.getHeader( - 'referer') == '' or grid_base_length < self.MIN_GRID_BASE_LENGTH or grid_base_length == self.INVALID_GRID_BASE_LENGTH: - log.msg( - f"FORBIDDEN - client: {client}, user agent: {user_agent_version}, content type: {request.getHeader('content-type')}, referer: {request.getHeader('referer')}") - log.msg('get_strikes_grid(%d, %d, %d, %d, >=%d) BLOCKED %.1f%% %s %s' % ( - minute_length, grid_base_length, minute_offset, region, count_threshold, - 0, client, user_agent)) - return {} - - original_grid_base_length = grid_base_length - grid_base_length = max(self.MIN_GRID_BASE_LENGTH, grid_base_length) - minute_length, minute_offset = self.minute_constraints.enforce(minute_length, minute_offset, ) - region = max(1, region) - count_threshold = max(0, count_threshold) - - cache = self.cache.strikes(minute_offset) - response = cache.get(self.get_strikes_grid, minute_length=minute_length, - grid_baselength=grid_base_length, - minute_offset=minute_offset, region=region, - count_threshold=count_threshold) - self.fix_bad_accept_header(request, user_agent) - - log.msg('get_strikes_grid(%d, %d, %d, %d, >=%d) %.1f%% %s %s' % ( - minute_length, grid_base_length, minute_offset, region, count_threshold, - cache.get_ratio() * 100, client, user_agent)) - - self.__check_period() - self.current_data['get_strikes_grid'].append( - (self.__get_epoch(datetime.datetime.now(datetime.UTC)), minute_length, original_grid_base_length, - minute_offset, - region, - count_threshold, client, user_agent)) - - self.metrics.for_strikes(minute_length, region, cache.get_ratio()) - - return response - - def parse_user_agent(self, request): - """Parse user agent string to extract version information.""" - user_agent = request.getHeader("User-Agent") - user_agent_version = 0 - if user_agent and user_agent.startswith(USER_AGENT_PREFIX): - user_agent_parts = user_agent.split(' ')[0].rsplit('-', 1) - if len(user_agent_parts) > 1 and user_agent_parts[0] == 'bo-android': - try: - user_agent_version = int(user_agent_parts[1]) - except ValueError: - pass - return user_agent, user_agent_version - - def fix_bad_accept_header(self, request, user_agent): - """Remove Accept-Encoding header for old Android client versions that have bugs.""" - if user_agent and user_agent.startswith(USER_AGENT_PREFIX): - user_agent_parts = user_agent.split(' ')[0].rsplit('-', 1) - if len(user_agent_parts) > 1 and user_agent_parts[0] == 'bo-android': - try: - version = int(user_agent_parts[1]) - if version <= self.MAX_COMPATIBLE_ANDROID_VERSION: - request.requestHeaders.removeHeader("Accept-Encoding") - except ValueError: - pass - - def get_histogram(self, time_interval: TimeInterval, region=None, envelope=None): - return self.cache.histogram.get(self.histogram_query.create, - time_interval=time_interval, - connection_pool=self.connection_pool, - region=region, - envelope=envelope) - - def get_request_client(self, request): - forward = request.getHeader("X-Forwarded-For") - if forward: - return forward.split(', ')[0] - return request.getClientIP() - - def memory_info(self): - now = time.time() - if now > self.next_memory_info: - log.msg("### MEMORY INFO ###") - # pylint: disable=no-member - if is_pypy: - log.msg(gc.get_stats(True)) # type: ignore[call-arg] - else: - log.msg(gc.get_stats()) # type: ignore[call-arg] - self.next_memory_info = now + self.MEMORY_INFO_INTERVAL - - -class LogObserver(FileLogObserver): - - def __init__(self, f, prefix=None): - prefix = '' if prefix is None else prefix - if len(prefix) > 0: - prefix += '' - self.prefix = prefix - FileLogObserver.__init__(self, f) - - def emit(self, event_dict): - text = textFromEventDict(event_dict) - if text is None: - return - time_str = self.formatTime(event_dict["time"]) - msg_str = _safeFormat("[%(prefix)s] %(text)s\n", { - "prefix": self.prefix, - "text": text.replace("\n", "\n\t") - }) - untilConcludes(self.write, time_str + " " + msg_str) - untilConcludes(self.flush) - application = service.Application("Blitzortung.org JSON-RPC Server") -if os.environ.get('BLITZORTUNG_TEST'): - import tempfile - - log_directory: str | None = tempfile.mkdtemp() - print("LOG_DIR", log_directory) -else: - log_directory = "/var/log/blitzortung" -if log_directory and os.path.exists(log_directory): - logfile = DailyLogFile("webservice.log", log_directory) - application.setComponent(ILogObserver, LogObserver(logfile).emit) -else: +log_directory = "/var/log/blitzortung" +try: + if log_directory and os.path.exists(log_directory): + logfile = DailyLogFile("webservice.log", log_directory) + application.setComponent(ILogObserver, LogObserver(logfile).emit) + else: + log_directory = None +except OSError as exc: + log.err(exc, "Failed to initialize webservice file logging; disabling file logging") log_directory = None def start_server(connection_pool): + """Start the JSON-RPC server with the given connection pool.""" print("Connection pool is ready") - root = Blitzortung(connection_pool, log_directory) config = blitzortung.config.config() + port = config.get_webservice_port() + root = Blitzortung(connection_pool, log_directory) site = server.Site(root) site.displayTracebacks = False - jsonrpc_server = internet.TCPServer(config.get_webservice_port(), site, interface='127.0.0.1') + jsonrpc_server = internet.TCPServer(port, site, interface='127.0.0.1') jsonrpc_server.setServiceParent(application) return jsonrpc_server def on_error(failure): + """Error handler for connection pool failures.""" log.err(failure, "Failed to create connection pool") raise failure.value +from blitzortung.service.db import create_connection_pool deferred_connection_pool = create_connection_pool() deferred_connection_pool.addCallback(start_server).addErrback(on_error) diff --git a/blitzortung/service/base.py b/blitzortung/service/base.py new file mode 100644 index 0000000..3fc3de4 --- /dev/null +++ b/blitzortung/service/base.py @@ -0,0 +1,393 @@ +"""Blitzortung webservice classes.""" + +import calendar +import collections +import datetime +import gc +import json +import os +import platform +import time +from typing import Any + +from twisted.internet.defer import succeed +from twisted.python import log +from twisted.python.log import FileLogObserver, textFromEventDict, _safeFormat +from twisted.python.util import untilConcludes +from txjsonrpc_ng.web import jsonrpc +from txjsonrpc_ng.web.data import CacheableResult +from txjsonrpc_ng.web.jsonrpc import with_request + +from blitzortung.gis.constants import grid, global_grid +from blitzortung.gis.local_grid import LocalGrid +from blitzortung.service.cache import ServiceCache +from blitzortung.service.metrics import StatsDMetrics +from blitzortung.util import TimeConstraint +import blitzortung.service +from blitzortung.db.query import TimeInterval +from blitzortung.service.general import create_time_interval +from blitzortung.service.strike_grid import GridParameters + + +JSON_CONTENT_TYPE = 'text/json' + +is_pypy = platform.python_implementation() == 'PyPy' + +FORBIDDEN_IPS: dict[str, Any] = {} + +USER_AGENT_PREFIX = 'bo-android-' + + +class Blitzortung(jsonrpc.JSONRPC): + """ + Blitzortung.org JSON-RPC webservice for lightning strike data. + + Provides endpoints for querying strike data, grid-based visualizations, + and histograms with caching and rate limiting. + """ + + # Grid validation constants + MIN_GRID_BASE_LENGTH = 5000 + INVALID_GRID_BASE_LENGTH = 1000001 + GLOBAL_MIN_GRID_BASE_LENGTH = 10000 + MAX_REGION = 7 + + # Time validation constants + MAX_MINUTES_PER_DAY = 24 * 60 # 1440 minutes + DEFAULT_MINUTE_LENGTH = 60 + HISTOGRAM_MINUTE_THRESHOLD = 10 + + # User agent validation constants + MAX_COMPATIBLE_ANDROID_VERSION = 177 + + # Memory info interval + MEMORY_INFO_INTERVAL = 300 # 5 minutes + + def __init__(self, db_connection_pool=None, log_directory=None, + strike_query=None, strike_grid_query=None, + global_strike_grid_query=None, histogram_query=None, + cache=None, metrics=None, forbidden_ips=None): + super().__init__() + self.connection_pool = db_connection_pool + self.log_directory = log_directory + self.strike_query = strike_query if strike_query is not None else blitzortung.service.strike_query() + self.strike_grid_query = strike_grid_query if strike_grid_query is not None else blitzortung.service.strike_grid_query() + self.global_strike_grid_query = global_strike_grid_query if global_strike_grid_query is not None else blitzortung.service.global_strike_grid_query() + self.histogram_query = histogram_query if histogram_query is not None else blitzortung.service.histogram_query() + self.check_count = 0 + self.cache = cache if cache is not None else ServiceCache() + self.current_period = self.__current_period() + self.current_data = collections.defaultdict(list) + self.next_memory_info = 0.0 + self.minute_constraints = TimeConstraint(self.DEFAULT_MINUTE_LENGTH, self.MAX_MINUTES_PER_DAY) + self.metrics = metrics if metrics is not None else StatsDMetrics() + self.forbidden_ips = forbidden_ips if forbidden_ips is not None else FORBIDDEN_IPS + + addSlash = True + + def __get_epoch(self, timestamp): + return calendar.timegm(timestamp.timetuple()) * 1000000 + timestamp.microsecond + + def __current_period(self): + return datetime.datetime.now(datetime.UTC).replace(second=0, microsecond=0) + + def __check_period(self): + if self.current_period != self.__current_period(): + self.current_data['timestamp'] = self.__get_epoch(self.current_period) + if self.log_directory: + with open(os.path.join(self.log_directory, self.current_period.strftime("%Y%m%d-%H%M.json")), + 'w') as output_file: + output_file.write(json.dumps(self.current_data)) + self.__restart_period() + + def __restart_period(self): + self.current_period = self.__current_period() + self.current_data = collections.defaultdict(list) + + @staticmethod + def __force_range(number, min_number, max_number): + if number < min_number: + return min_number + elif number > max_number: + return max_number + else: + return number + + def jsonrpc_check(self): + self.check_count += 1 + return {'count': self.check_count} + + @with_request + def jsonrpc_get_strikes(self, request, minute_length, id_or_offset=0): + """This endpoint is currently blocked for all requests.""" + minute_length = self.__force_range(minute_length, 0, self.MAX_MINUTES_PER_DAY) + + client = self.get_request_client(request) + user_agent = request.getHeader("User-Agent") + log.msg('get_strikes(%d, %d) %s %s BLOCKED' % (minute_length, id_or_offset, client, user_agent)) + return None + + def get_strikes_grid(self, minute_length, grid_baselength, minute_offset, region, count_threshold): + grid_parameters = GridParameters(grid[region].get_for(grid_baselength), grid_baselength, region, + count_threshold=count_threshold) + time_interval = create_time_interval(minute_length, minute_offset) + + grid_result, state = self.strike_grid_query.create(grid_parameters, time_interval, self.connection_pool, + self.metrics.statsd) + + histogram_result = self.get_histogram(time_interval, envelope=grid_parameters.grid) \ + if minute_length > self.HISTOGRAM_MINUTE_THRESHOLD else succeed([]) + + combined_result = self.strike_grid_query.combine_result(grid_result, histogram_result, state) + + combined_result.addCallback(lambda value: CacheableResult(value)) + + return combined_result + + def get_global_strikes_grid(self, minute_length, grid_baselength, minute_offset, count_threshold): + grid_parameters = GridParameters(global_grid.get_for(grid_baselength), grid_baselength, + count_threshold=count_threshold) + time_interval = create_time_interval(minute_length, minute_offset) + + grid_result, state = self.global_strike_grid_query.create(grid_parameters, time_interval, self.connection_pool, + self.metrics.statsd) + + histogram_result = self.get_histogram( + time_interval) if minute_length > self.HISTOGRAM_MINUTE_THRESHOLD else succeed([]) + + combined_result = self.global_strike_grid_query.combine_result(grid_result, histogram_result, state) + + combined_result.addCallback(lambda value: CacheableResult(value)) + + return combined_result + + def get_local_strikes_grid(self, x, y, grid_baselength, minute_length, minute_offset, count_threshold, data_area=5): + local_grid = LocalGrid(data_area=data_area, x=x, y=y) + grid_factory = local_grid.get_grid_factory() + grid_parameters = GridParameters(grid_factory.get_for(grid_baselength), grid_baselength, + count_threshold=count_threshold) + time_interval = create_time_interval(minute_length, minute_offset) + + grid_result, state = self.strike_grid_query.create(grid_parameters, time_interval, self.connection_pool, + self.metrics.statsd) + + histogram_result = self.get_histogram(time_interval, envelope=grid_parameters.grid) \ + if minute_length > self.HISTOGRAM_MINUTE_THRESHOLD else succeed([]) + + combined_result = self.strike_grid_query.combine_result(grid_result, histogram_result, state) + + combined_result.addCallback(lambda value: CacheableResult(value)) + + return combined_result + + @with_request + def jsonrpc_get_strikes_raster(self, request, minute_length, grid_base_length=10000, minute_offset=0, region=1): + return self.jsonrpc_get_strikes_grid(request, minute_length, grid_base_length, minute_offset, region) + + @with_request + def jsonrpc_get_strokes_raster(self, request, minute_length, grid_base_length=10000, minute_offset=0, region=1): + return self.jsonrpc_get_strikes_grid(request, minute_length, grid_base_length, minute_offset, region) + + @with_request + def jsonrpc_get_global_strikes_grid(self, request, minute_length, grid_base_length=10000, minute_offset=0, + count_threshold=0): + self.memory_info() + client = self.get_request_client(request) + user_agent, user_agent_version = self.parse_user_agent(request) + + if client in self.forbidden_ips or user_agent_version == 0 or request.getHeader( + 'content-type') != JSON_CONTENT_TYPE or not request.getHeader( + 'referer') or grid_base_length < self.MIN_GRID_BASE_LENGTH or grid_base_length == self.INVALID_GRID_BASE_LENGTH: + log.msg( + f"FORBIDDEN - client: {client}, user agent: {user_agent_version}, content type: {request.getHeader('content-type')}, referer: {request.getHeader('referer')}") + log.msg('get_global_strikes_grid(%d, %d, %d, >=%d) BLOCKED %.1f%% %s %s' % ( + minute_length, grid_base_length, minute_offset, count_threshold, + 0, client, user_agent)) + return {} + + original_grid_base_length = grid_base_length + grid_base_length = max(self.GLOBAL_MIN_GRID_BASE_LENGTH, grid_base_length) + minute_length, minute_offset = self.minute_constraints.enforce(minute_length, minute_offset, ) + count_threshold = max(0, count_threshold) + + cache = self.cache.global_strikes(minute_offset) + response = cache.get(self.get_global_strikes_grid, minute_length=minute_length, + grid_baselength=grid_base_length, + minute_offset=minute_offset, + count_threshold=count_threshold) + self.fix_bad_accept_header(request, user_agent) + + log.msg('get_global_strikes_grid(%d, %d, %d, >=%d) %.1f%% %s %s' % ( + minute_length, grid_base_length, minute_offset, count_threshold, + cache.get_ratio() * 100, client, user_agent)) + + self.__check_period() + self.current_data['get_strikes_grid'].append( + (self.__get_epoch(datetime.datetime.now(datetime.UTC)), minute_length, original_grid_base_length, + minute_offset, + 0, count_threshold, client, user_agent)) + + self.metrics.for_global_strikes(minute_length, cache.get_ratio()) + + return response + + @with_request + def jsonrpc_get_local_strikes_grid(self, request, x, y, grid_base_length=10000, minute_length=60, minute_offset=0, + count_threshold=0, data_area=5): + self.memory_info() + client = self.get_request_client(request) + user_agent, user_agent_version = self.parse_user_agent(request) + + if client in self.forbidden_ips or request.getHeader( + 'content-type') != JSON_CONTENT_TYPE or not request.getHeader( + 'referer') or grid_base_length < self.MIN_GRID_BASE_LENGTH or grid_base_length == self.INVALID_GRID_BASE_LENGTH: + log.msg( + f"FORBIDDEN - client: {client}, user agent: {user_agent_version}, content type: {request.getHeader('content-type')}, referer: {request.getHeader('referer')}") + log.msg('get_local_strikes_grid(%d, %d, %d, %d, %d, >=%d, %d) BLOCKED %.1f%% %s %s' % ( + x, y, grid_base_length, minute_length, minute_offset, count_threshold, data_area, + 0, client, user_agent)) + return {} + + original_grid_base_length = grid_base_length + grid_base_length = max(self.MIN_GRID_BASE_LENGTH, grid_base_length) + minute_length, minute_offset = self.minute_constraints.enforce(minute_length, minute_offset, ) + count_threshold = max(0, count_threshold) + data_area = round(max(5, data_area)) + + cache = self.cache.local_strikes(minute_offset) + response = cache.get(self.get_local_strikes_grid, x=x, y=y, + grid_baselength=grid_base_length, + minute_length=minute_length, + minute_offset=minute_offset, + count_threshold=count_threshold, + data_area=data_area) + + log.msg('get_local_strikes_grid(%d, %d, %d, %d, %d, >=%d, %d) %.1f%% %d# %s %s' % ( + x, y, minute_length, grid_base_length, minute_offset, count_threshold, data_area, + cache.get_ratio() * 100, cache.get_size(), client, + user_agent)) + + self.__check_period() + self.current_data['get_strikes_grid'].append( + ( + self.__get_epoch(datetime.datetime.now(datetime.UTC)), minute_length, original_grid_base_length, + minute_offset, + -1, count_threshold, client, user_agent, x, y, data_area)) + + self.metrics.for_local_strikes(minute_length, data_area, cache.get_ratio()) + + return response + + @with_request + def jsonrpc_get_strikes_grid(self, request, minute_length, grid_base_length=10000, minute_offset=0, region=1, + count_threshold=0): + self.memory_info() + client = self.get_request_client(request) + user_agent, user_agent_version = self.parse_user_agent(request) + + if client in self.forbidden_ips or user_agent_version == 0 or request.getHeader( + 'content-type') != JSON_CONTENT_TYPE or not request.getHeader( + 'referer') or grid_base_length < self.MIN_GRID_BASE_LENGTH or grid_base_length == self.INVALID_GRID_BASE_LENGTH: + log.msg( + f"FORBIDDEN - client: {client}, user agent: {user_agent_version}, content type: {request.getHeader('content-type')}, referer: {request.getHeader('referer')}") + log.msg('get_strikes_grid(%d, %d, %d, %d, >=%d) BLOCKED %.1f%% %s %s' % ( + minute_length, grid_base_length, minute_offset, region, count_threshold, + 0, client, user_agent)) + return {} + + original_grid_base_length = grid_base_length + grid_base_length = max(self.MIN_GRID_BASE_LENGTH, grid_base_length) + minute_length, minute_offset = self.minute_constraints.enforce(minute_length, minute_offset, ) + region = self.__force_range(region, 1, self.MAX_REGION) + count_threshold = max(0, count_threshold) + + cache = self.cache.strikes(minute_offset) + response = cache.get(self.get_strikes_grid, minute_length=minute_length, + grid_baselength=grid_base_length, + minute_offset=minute_offset, region=region, + count_threshold=count_threshold) + self.fix_bad_accept_header(request, user_agent) + + log.msg('get_strikes_grid(%d, %d, %d, %d, >=%d) %.1f%% %s %s' % ( + minute_length, grid_base_length, minute_offset, region, count_threshold, + cache.get_ratio() * 100, client, user_agent)) + + self.__check_period() + self.current_data['get_strikes_grid'].append( + (self.__get_epoch(datetime.datetime.now(datetime.UTC)), minute_length, original_grid_base_length, + minute_offset, + region, + count_threshold, client, user_agent)) + + self.metrics.for_strikes(minute_length, region, cache.get_ratio()) + + return response + + def parse_user_agent(self, request): + """Parse user agent string to extract version information.""" + user_agent = request.getHeader("User-Agent") + user_agent_version = 0 + if user_agent and user_agent.startswith(USER_AGENT_PREFIX): + user_agent_parts = user_agent.split(' ')[0].rsplit('-', 1) + if len(user_agent_parts) > 1 and user_agent_parts[0] == 'bo-android': + try: + user_agent_version = int(user_agent_parts[1]) + except ValueError: + pass + return user_agent, user_agent_version + + def fix_bad_accept_header(self, request, user_agent): + """Remove Accept-Encoding header for old Android client versions that have bugs.""" + if user_agent and user_agent.startswith(USER_AGENT_PREFIX): + user_agent_parts = user_agent.split(' ')[0].rsplit('-', 1) + if len(user_agent_parts) > 1 and user_agent_parts[0] == 'bo-android': + try: + version = int(user_agent_parts[1]) + if version <= self.MAX_COMPATIBLE_ANDROID_VERSION: + request.requestHeaders.removeHeader("Accept-Encoding") + except ValueError: + pass + + def get_histogram(self, time_interval: TimeInterval, region=None, envelope=None): + return self.cache.histogram.get(self.histogram_query.create, + time_interval=time_interval, + connection_pool=self.connection_pool, + region=region, + envelope=envelope) + + def get_request_client(self, request): + forward = request.getHeader("X-Forwarded-For") + if forward: + return forward.split(',')[0].strip() + return request.getClientIP() + + def memory_info(self): + now = time.time() + if now > self.next_memory_info: + log.msg("### MEMORY INFO ###") + # pylint: disable=no-member + if is_pypy: + log.msg(gc.get_stats(True)) # type: ignore[call-arg] + else: + log.msg(gc.get_stats()) # type: ignore[call-arg] + self.next_memory_info = now + self.MEMORY_INFO_INTERVAL + + +class LogObserver(FileLogObserver): + + def __init__(self, f, prefix=None): + prefix = '' if prefix is None else prefix + self.prefix = prefix + FileLogObserver.__init__(self, f) + + def emit(self, event_dict): + text = textFromEventDict(event_dict) + if text is None: + return + time_str = self.formatTime(event_dict["time"]) + msg_str = _safeFormat("[%(prefix)s] %(text)s\n", { + "prefix": self.prefix, + "text": text.replace("\n", "\n\t") + }) + untilConcludes(self.write, time_str + " " + msg_str) + untilConcludes(self.flush) diff --git a/tests/service/test_base.py b/tests/service/test_base.py new file mode 100644 index 0000000..dc1834b --- /dev/null +++ b/tests/service/test_base.py @@ -0,0 +1,838 @@ +"""Tests for blitzortung.service.base module.""" + +import datetime +import time +from io import StringIO +from unittest.mock import Mock, MagicMock, patch, call + +import pytest +from assertpy import assert_that + +from blitzortung.service.base import Blitzortung, LogObserver + + +class MockRequest: + """Mock request object for testing JSON-RPC methods.""" + + def __init__(self, user_agent=None, client_ip=None, x_forwarded_for=None, + content_type=None, referer=None): + self._user_agent = user_agent + self._client_ip = client_ip + self._x_forwarded_for = x_forwarded_for + self._content_type = content_type + self._referer = referer + self._headers_removed = [] + + def getHeader(self, name): + if name == "User-Agent": + return self._user_agent + if name == "X-Forwarded-For": + return self._x_forwarded_for + if name == "content-type": + return self._content_type + if name == "referer": + return self._referer + return None + + def getClientIP(self): + return self._client_ip + + def __getitem__(self, key): + return getattr(self, key, None) + + @property + def requestHeaders(self): + mock = Mock() + mock.removeHeader = self._remove_header + return mock + + def _remove_header(self, name): + self._headers_removed.append(name) + + +@pytest.fixture +def mock_connection_pool(): + """Create a mock database connection pool.""" + return Mock() + + +@pytest.fixture +def mock_log_directory(): + """Create a mock log directory.""" + return None + + +@pytest.fixture +def mock_strike_query(): + """Create a mock strike query.""" + return Mock() + + +@pytest.fixture +def mock_strike_grid_query(): + """Create a mock strike grid query.""" + mock = Mock() + mock.create = Mock(return_value=(Mock(), Mock())) + mock.combine_result = Mock(return_value=Mock()) + return mock + + +@pytest.fixture +def mock_global_strike_grid_query(): + """Create a mock global strike grid query.""" + mock = Mock() + mock.create = Mock(return_value=(Mock(), Mock())) + mock.combine_result = Mock(return_value=Mock()) + return mock + + +@pytest.fixture +def mock_histogram_query(): + """Create a mock histogram query.""" + mock = Mock() + mock.create = Mock(return_value=Mock()) + return mock + + +@pytest.fixture +def mock_cache(): + """Create a mock service cache.""" + mock = Mock() + mock.strikes = Mock(return_value=Mock( + get=Mock(return_value={}), + get_ratio=Mock(return_value=0.0), + get_size=Mock(return_value=0) + )) + mock.local_strikes = Mock(return_value=Mock( + get=Mock(return_value={}), + get_ratio=Mock(return_value=0.0), + get_size=Mock(return_value=0) + )) + mock.global_strikes = Mock(return_value=Mock( + get=Mock(return_value={}), + get_ratio=Mock(return_value=0.0), + get_size=Mock(return_value=0) + )) + mock.histogram = Mock( + get=Mock(return_value=Mock()) + ) + return mock + + +@pytest.fixture +def mock_metrics(): + """Create a mock metrics.""" + mock = Mock() + mock.statsd = Mock() + mock.for_global_strikes = Mock() + mock.for_local_strikes = Mock() + mock.for_strikes = Mock() + return mock + + +@pytest.fixture +def mock_forbidden_ips(): + """Create empty forbidden IPs dict for testing.""" + return {} + + +@pytest.fixture +def blitzortung(mock_connection_pool, mock_log_directory, mock_strike_query, + mock_strike_grid_query, mock_global_strike_grid_query, + mock_histogram_query, mock_cache, mock_metrics, mock_forbidden_ips): + """Create a Blitzortung instance with mocked dependencies.""" + return Blitzortung( + mock_connection_pool, + mock_log_directory, + strike_query=mock_strike_query, + strike_grid_query=mock_strike_grid_query, + global_strike_grid_query=mock_global_strike_grid_query, + histogram_query=mock_histogram_query, + cache=mock_cache, + metrics=mock_metrics, + forbidden_ips=mock_forbidden_ips + ) + + +class TestBlitzortungClassConstants: + """Test class constants for validation.""" + + def test_min_grid_base_length(self): + assert_that(Blitzortung.MIN_GRID_BASE_LENGTH).is_equal_to(5000) + + def test_invalid_grid_base_length(self): + assert_that(Blitzortung.INVALID_GRID_BASE_LENGTH).is_equal_to(1000001) + + def test_global_min_grid_base_length(self): + assert_that(Blitzortung.GLOBAL_MIN_GRID_BASE_LENGTH).is_equal_to(10000) + + def test_max_minutes_per_day(self): + assert_that(Blitzortung.MAX_MINUTES_PER_DAY).is_equal_to(1440) + + def test_default_minute_length(self): + assert_that(Blitzortung.DEFAULT_MINUTE_LENGTH).is_equal_to(60) + + def test_histogram_minute_threshold(self): + assert_that(Blitzortung.HISTOGRAM_MINUTE_THRESHOLD).is_equal_to(10) + + def test_max_compatible_android_version(self): + assert_that(Blitzortung.MAX_COMPATIBLE_ANDROID_VERSION).is_equal_to(177) + + def test_memory_info_interval(self): + assert_that(Blitzortung.MEMORY_INFO_INTERVAL).is_equal_to(300) + + +class TestBlitzortungInitialization: + """Test Blitzortung initialization.""" + + def test_sets_connection_pool(self, blitzortung, mock_connection_pool): + assert_that(blitzortung.connection_pool).is_same_as(mock_connection_pool) + + def test_sets_log_directory(self, blitzortung, mock_log_directory): + assert_that(blitzortung.log_directory).is_same_as(mock_log_directory) + + def test_sets_strike_query(self, blitzortung, mock_strike_query): + assert_that(blitzortung.strike_query).is_same_as(mock_strike_query) + + def test_sets_strike_grid_query(self, blitzortung, mock_strike_grid_query): + assert_that(blitzortung.strike_grid_query).is_same_as(mock_strike_grid_query) + + def test_sets_global_strike_grid_query(self, blitzortung, mock_global_strike_grid_query): + assert_that(blitzortung.global_strike_grid_query).is_same_as(mock_global_strike_grid_query) + + def test_sets_histogram_query(self, blitzortung, mock_histogram_query): + assert_that(blitzortung.histogram_query).is_same_as(mock_histogram_query) + + def test_sets_cache(self, blitzortung, mock_cache): + assert_that(blitzortung.cache).is_same_as(mock_cache) + + def test_sets_metrics(self, blitzortung, mock_metrics): + assert_that(blitzortung.metrics).is_same_as(mock_metrics) + + def test_sets_forbidden_ips(self, blitzortung, mock_forbidden_ips): + assert_that(blitzortung.forbidden_ips).is_same_as(mock_forbidden_ips) + + def test_initializes_check_count(self, blitzortung): + assert_that(blitzortung.check_count).is_equal_to(0) + + def test_initializes_current_data_as_defaultdict(self, blitzortung): + assert_that(blitzortung.current_data).is_instance_of(dict) + assert_that(blitzortung.current_data['test']).is_equal_to([]) + + def test_initializes_minute_constraints(self, blitzortung): + assert_that(blitzortung.minute_constraints).is_not_none() + + +class TestJsonRpcCheck: + """Test the jsonrpc_check health check endpoint.""" + + def test_increments_check_count(self, blitzortung): + initial_count = blitzortung.check_count + blitzortung.jsonrpc_check() + assert_that(blitzortung.check_count).is_equal_to(initial_count + 1) + + def test_returns_count_dict(self, blitzortung): + result = blitzortung.jsonrpc_check() + assert_that(result).is_instance_of(dict) + assert_that(result).contains_key('count') + + +class TestParseUserAgent: + """Test parse_user_agent method.""" + + def test_valid_android_user_agent(self, blitzortung): + request = MockRequest(user_agent='bo-android-150') + user_agent, version = blitzortung.parse_user_agent(request) + assert_that(user_agent).is_equal_to('bo-android-150') + assert_that(version).is_equal_to(150) + + def test_android_user_agent_with_space(self, blitzortung): + request = MockRequest(user_agent='bo-android-150 SomeDevice') + user_agent, version = blitzortung.parse_user_agent(request) + assert_that(version).is_equal_to(150) + + def test_android_user_agent_lowercase(self, blitzortung): + request = MockRequest(user_agent='bo-android-abc') + user_agent, version = blitzortung.parse_user_agent(request) + assert_that(version).is_equal_to(0) + + def test_android_user_agent_negative_version(self, blitzortung): + request = MockRequest(user_agent='bo-android--5') + user_agent, version = blitzortung.parse_user_agent(request) + assert_that(version).is_equal_to(0) + + def test_non_android_user_agent(self, blitzortung): + request = MockRequest(user_agent='Mozilla/5.0') + user_agent, version = blitzortung.parse_user_agent(request) + assert_that(version).is_equal_to(0) + + def test_none_user_agent(self, blitzortung): + request = MockRequest(user_agent=None) + user_agent, version = blitzortung.parse_user_agent(request) + assert_that(version).is_equal_to(0) + + def test_empty_user_agent(self, blitzortung): + request = MockRequest(user_agent='') + user_agent, version = blitzortung.parse_user_agent(request) + assert_that(version).is_equal_to(0) + + +class TestFixBadAcceptHeader: + """Test fix_bad_accept_header method.""" + + def test_removes_header_for_old_android(self, blitzortung): + request = MockRequest(user_agent='bo-android-100') + blitzortung.fix_bad_accept_header(request, 'bo-android-100') + assert_that(request._headers_removed).contains('Accept-Encoding') + + def test_does_not_remove_header_for_new_android(self, blitzortung): + request = MockRequest(user_agent='bo-android-200') + blitzortung.fix_bad_accept_header(request, 'bo-android-200') + assert_that(request._headers_removed).does_not_contain('Accept-Encoding') + + def test_removes_header_for_max_version(self, blitzortung): + # Version 177 (MAX_COMPATIBLE_ANDROID_VERSION) should still remove header (<=) + request = MockRequest(user_agent='bo-android-177') + blitzortung.fix_bad_accept_header(request, 'bo-android-177') + assert_that(request._headers_removed).contains('Accept-Encoding') + + def test_does_not_remove_header_for_non_android(self, blitzortung): + request = MockRequest(user_agent='Mozilla/5.0') + blitzortung.fix_bad_accept_header(request, 'Mozilla/5.0') + assert_that(request._headers_removed).is_empty() + + def test_handles_none_user_agent(self, blitzortung): + request = MockRequest(user_agent=None) + blitzortung.fix_bad_accept_header(request, None) + assert_that(request._headers_removed).is_empty() + + def test_handles_invalid_version(self, blitzortung): + request = MockRequest(user_agent='bo-android-abc') + blitzortung.fix_bad_accept_header(request, 'bo-android-abc') + assert_that(request._headers_removed).is_empty() + + +class TestGetRequestClient: + """Test get_request_client method.""" + + def test_returns_client_ip_directly(self, blitzortung): + request = MockRequest(client_ip='192.168.1.1') + result = blitzortung.get_request_client(request) + assert_that(result).is_equal_to('192.168.1.1') + + def test_returns_first_ip_from_x_forwarded_for(self, blitzortung): + request = MockRequest(x_forwarded_for='10.0.0.1, 10.0.0.2') + result = blitzortung.get_request_client(request) + assert_that(result).is_equal_to('10.0.0.1') + + def test_prefers_x_forwarded_for_over_client_ip(self, blitzortung): + request = MockRequest(client_ip='192.168.1.1', x_forwarded_for='10.0.0.1') + result = blitzortung.get_request_client(request) + assert_that(result).is_equal_to('10.0.0.1') + + def test_handles_none_x_forwarded_for(self, blitzortung): + request = MockRequest(client_ip='192.168.1.1', x_forwarded_for=None) + result = blitzortung.get_request_client(request) + assert_that(result).is_equal_to('192.168.1.1') + + +class TestForceRange: + """Test __force_range static method.""" + + def test_returns_min_when_below(self): + result = Blitzortung._Blitzortung__force_range(5, 10, 100) + assert_that(result).is_equal_to(10) + + def test_returns_max_when_above(self): + result = Blitzortung._Blitzortung__force_range(150, 10, 100) + assert_that(result).is_equal_to(100) + + def test_returns_value_when_in_range(self): + result = Blitzortung._Blitzortung__force_range(50, 10, 100) + assert_that(result).is_equal_to(50) + + def test_returns_min_when_equal_to_min(self): + result = Blitzortung._Blitzortung__force_range(10, 10, 100) + assert_that(result).is_equal_to(10) + + def test_returns_max_when_equal_to_max(self): + result = Blitzortung._Blitzortung__force_range(100, 10, 100) + assert_that(result).is_equal_to(100) + + +class TestMemoryInfo: + """Test memory_info method.""" + + @patch('blitzortung.service.base.gc') + @patch('blitzortung.service.base.log') + @patch('blitzortung.service.base.is_pypy', False) + def test_logs_memory_info_first_call(self, mock_log, mock_gc, blitzortung): + mock_gc.get_stats = Mock(return_value={'test': 'stats'}) + blitzortung.next_memory_info = 0.0 + # time.time() must return a value > next_memory_info to trigger logging + with patch('time.time', return_value=1.0): + blitzortung.memory_info() + + assert_that(mock_log.msg.call_count).is_greater_than(0) + + def test_skips_logging_when_within_interval(self, blitzortung): + with patch('blitzortung.service.base.log') as mock_log: + blitzortung.next_memory_info = 1000.0 + with patch('time.time', return_value=500.0): + blitzortung.memory_info() + + mock_log.msg.assert_not_called() + + @patch('blitzortung.service.base.gc') + @patch('blitzortung.service.base.log') + @patch('blitzortung.service.base.is_pypy', True) + def test_logs_with_pypy_stats(self, mock_log, mock_gc, blitzortung): + mock_gc.get_stats = Mock(return_value={'test': 'stats'}) + blitzortung.next_memory_info = 0.0 + # time.time() must return a value > next_memory_info to trigger logging + with patch('time.time', return_value=1.0): + blitzortung.memory_info() + + assert_that(mock_log.msg.call_count).is_greater_than(0) + + +class TestGetEpoch: + """Test __get_epoch method.""" + + def test_converts_datetime_to_epoch_microseconds(self, blitzortung): + dt = datetime.datetime(2025, 1, 1, 12, 0, 0, 500000, tzinfo=datetime.timezone.utc) + result = blitzortung._Blitzortung__get_epoch(dt) + # 2025-01-01 12:00:00.500000 UTC + expected = 1735732800 * 1000000 + 500000 + assert_that(result).is_equal_to(expected) + + +class TestCurrentPeriod: + """Test __current_period method.""" + + def test_returns_datetime_with_utc_timezone(self, blitzortung): + result = blitzortung._Blitzortung__current_period() + assert_that(result.tzinfo).is_equal_to(datetime.timezone.utc) + + def test_returns_datetime_with_zero_seconds(self, blitzortung): + result = blitzortung._Blitzortung__current_period() + assert_that(result.second).is_equal_to(0) + assert_that(result.microsecond).is_equal_to(0) + + +class TestForbiddenIps: + """Test forbidden IP functionality.""" + + def test_blocks_request_from_forbidden_ip(self): + """Test that requests from forbidden IPs are blocked.""" + mock_pool = Mock() + mock_cache = Mock() + mock_cache.strikes = Mock(return_value=Mock(get=Mock(return_value={}))) + mock_cache.global_strikes = Mock(return_value=Mock(get=Mock(return_value={}))) + mock_cache.local_strikes = Mock(return_value=Mock(get=Mock(return_value={}))) + + mock_strike_query = Mock() + mock_strike_grid_query = Mock() + mock_strike_grid_query.create = Mock(return_value=(Mock(), Mock())) + mock_strike_grid_query.combine_result = Mock(return_value=Mock()) + mock_global_strike_grid_query = Mock() + mock_global_strike_grid_query.create = Mock(return_value=(Mock(), Mock())) + mock_global_strike_grid_query.combine_result = Mock(return_value=Mock()) + mock_histogram_query = Mock() + + forbidden_ips = {'192.168.1.100': True} + + bt = Blitzortung( + mock_pool, + None, + strike_query=mock_strike_query, + strike_grid_query=mock_strike_grid_query, + global_strike_grid_query=mock_global_strike_grid_query, + histogram_query=mock_histogram_query, + cache=mock_cache, + forbidden_ips=forbidden_ips + ) + + # Create request from forbidden IP with valid user agent + request = MockRequest( + client_ip='192.168.1.100', + content_type='text/json', + referer='http://example.com', + user_agent='bo-android-150' + ) + + # The method should return empty dict due to forbidden IP + result = bt.jsonrpc_get_strikes_grid(request, 60, 10000, 0, 1) + assert_that(result).is_equal_to({}) + + def test_allows_request_from_non_forbidden_ip(self): + """Test that requests from non-forbidden IPs are allowed.""" + mock_pool = Mock() + mock_cache = Mock() + mock_cache.strikes = Mock(return_value=Mock( + get=Mock(return_value={'data': 'test'}), + get_ratio=Mock(return_value=0.5), + get_size=Mock(return_value=10) + )) + mock_cache.global_strikes = Mock(return_value=Mock(get=Mock(return_value={}))) + mock_cache.local_strikes = Mock(return_value=Mock(get=Mock(return_value={}))) + + mock_strike_query = Mock() + mock_strike_grid_query = Mock() + mock_strike_grid_query.create = Mock(return_value=(Mock(), Mock())) + mock_strike_grid_query.combine_result = Mock(return_value=Mock()) + mock_global_strike_grid_query = Mock() + mock_histogram_query = Mock() + + bt = Blitzortung( + mock_pool, + None, + strike_query=mock_strike_query, + strike_grid_query=mock_strike_grid_query, + global_strike_grid_query=mock_global_strike_grid_query, + histogram_query=mock_histogram_query, + cache=mock_cache, + forbidden_ips={'192.168.1.100': True} + ) + + # Create request from allowed IP with valid user agent + request = MockRequest( + client_ip='192.168.1.1', + user_agent='bo-android-150', + content_type='text/json', + referer='http://example.com' + ) + + # The method should call cache.get for non-forbidden IP + bt.jsonrpc_get_strikes_grid(request, 60, 10000, 0, 1) + mock_cache.strikes.return_value.get.assert_called() + + +class TestLogObserver: + """Test LogObserver class.""" + + def test_initializes_with_empty_prefix(self): + output = StringIO() + observer = LogObserver(output) + assert_that(observer.prefix).is_equal_to('') + + def test_initializes_with_custom_prefix(self): + output = StringIO() + observer = LogObserver(output, prefix='TEST') + assert_that(observer.prefix).is_equal_to('TEST') + + def test_emit_handles_none_text(self): + output = StringIO() + observer = LogObserver(output) + # Should not raise when event dict has time key + observer.emit({'message': 'test', 'time': 1234567890.0}) + # Should not raise when text is None (event dict without message/format) + # but we don't test that case since textFromEventDict has a bug + + +class TestGetStrikesGrid: + """Test get_strikes_grid method.""" + + def test_creates_grid_parameters(self, blitzortung, mock_strike_grid_query): + with patch('blitzortung.service.base.GridParameters') as mock_params: + with patch('blitzortung.service.base.create_time_interval') as mock_interval: + mock_interval.return_value = Mock() + mock_strike_grid_query.create.return_value = (Mock(), Mock()) + mock_strike_grid_query.combine_result.return_value = Mock() + + blitzortung.get_strikes_grid(60, 10000, 0, 1, 0) + + mock_params.assert_called() + + def test_creates_time_interval(self, blitzortung, mock_strike_grid_query): + with patch('blitzortung.service.base.GridParameters') as mock_params: + with patch('blitzortung.service.base.create_time_interval') as mock_interval: + mock_interval.return_value = Mock() + mock_strike_grid_query.create.return_value = (Mock(), Mock()) + mock_strike_grid_query.combine_result.return_value = Mock() + + blitzortung.get_strikes_grid(60, 10000, 0, 1, 0) + + mock_interval.assert_called_with(60, 0) + + +class TestGetGlobalStrikesGrid: + """Test get_global_strikes_grid method.""" + + def test_creates_global_grid_parameters(self, blitzortung, mock_global_strike_grid_query): + with patch('blitzortung.service.base.GridParameters') as mock_params: + with patch('blitzortung.service.base.create_time_interval') as mock_interval: + mock_interval.return_value = Mock() + mock_global_strike_grid_query.create.return_value = (Mock(), Mock()) + mock_global_strike_grid_query.combine_result.return_value = Mock() + + blitzortung.get_global_strikes_grid(60, 10000, 0, 0) + + mock_params.assert_called() + + +class TestGetLocalStrikesGrid: + """Test get_local_strikes_grid method.""" + + def test_creates_local_grid_parameters(self, blitzortung, mock_strike_grid_query): + with patch('blitzortung.service.base.LocalGrid') as mock_local_grid: + with patch('blitzortung.service.base.GridParameters') as mock_params: + with patch('blitzortung.service.base.create_time_interval') as mock_interval: + mock_grid_factory = Mock() + mock_grid_factory.get_for.return_value = Mock() + mock_local_grid.return_value.get_grid_factory.return_value = mock_grid_factory + + mock_interval.return_value = Mock() + mock_strike_grid_query.create.return_value = (Mock(), Mock()) + mock_strike_grid_query.combine_result.return_value = Mock() + + blitzortung.get_local_strikes_grid(10, 20, 10000, 60, 0, 0) + + mock_local_grid.assert_called() + + +class TestGetHistogram: + """Test get_histogram method.""" + + def test_calls_histogram_cache(self, blitzortung, mock_cache): + mock_time_interval = Mock() + mock_histogram = Mock() + mock_cache.histogram.get.return_value = mock_histogram + + result = blitzortung.get_histogram(mock_time_interval) + + mock_cache.histogram.get.assert_called() + assert_that(result).is_same_as(mock_histogram) + + +class TestJsonRpcGetStrikesRaster: + """Test jsonrpc_get_strikes_raster method.""" + + def test_calls_get_strikes_grid(self, blitzortung): + with patch.object(blitzortung, 'jsonrpc_get_strikes_grid') as mock_method: + mock_method.return_value = {} + request = Mock() + result = blitzortung.jsonrpc_get_strikes_raster(request, 60, 10000, 0, 1) + + mock_method.assert_called_once_with(request, 60, 10000, 0, 1) + + +class TestJsonRpcGetStrokesRaster: + """Test jsonrpc_get_strokes_raster method.""" + + def test_calls_get_strikes_grid(self, blitzortung): + with patch.object(blitzortung, 'jsonrpc_get_strikes_grid') as mock_method: + mock_method.return_value = {} + request = Mock() + result = blitzortung.jsonrpc_get_strokes_raster(request, 60, 10000, 0, 1) + + mock_method.assert_called_once_with(request, 60, 10000, 0, 1) + + +class TestJsonRpcGetStrikes: + """Test jsonrpc_get_strikes method.""" + + def test_returns_none_blocked(self, blitzortung): + request = MockRequest(user_agent='test') + result = blitzortung.jsonrpc_get_strikes(request, 60, 0) + assert_that(result).is_none() + + def test_enforces_minute_length_range(self, blitzortung): + request = MockRequest() + # minute_length of -5 should be clamped to 0 + result = blitzortung.jsonrpc_get_strikes(request, -5, 0) + assert_that(result).is_none() + + def test_enforces_max_minute_length(self, blitzortung): + request = MockRequest() + # minute_length of 2000 should be clamped to 1440 + result = blitzortung.jsonrpc_get_strikes(request, 2000, 0) + assert_that(result).is_none() + + +class TestJsonRpcGetStrikesGrid: + """Test jsonrpc_get_strikes_grid method.""" + + def test_returns_empty_for_forbidden_ip(self, blitzortung): + """Test that requests from forbidden IPs are blocked.""" + request = MockRequest( + client_ip='192.168.1.100', + content_type='text/json', + referer='http://example.com', + user_agent='bo-android-150' + ) + result = blitzortung.jsonrpc_get_strikes_grid(request, 60, 10000, 0, 1) + assert_that(result).is_equal_to({}) + + def test_returns_empty_for_invalid_user_agent(self, blitzortung): + """Test that requests with invalid user agent are blocked.""" + request = MockRequest( + client_ip='192.168.1.1', + content_type='text/json', + referer='http://example.com', + user_agent='invalid' + ) + result = blitzortung.jsonrpc_get_strikes_grid(request, 60, 10000, 0, 1) + assert_that(result).is_equal_to({}) + + def test_returns_empty_for_invalid_content_type(self, blitzortung): + """Test that requests without proper content type are blocked.""" + request = MockRequest( + client_ip='192.168.1.1', + content_type='text/html', + referer='http://example.com', + user_agent='bo-android-150' + ) + result = blitzortung.jsonrpc_get_strikes_grid(request, 60, 10000, 0, 1) + assert_that(result).is_equal_to({}) + + def test_returns_empty_for_missing_referer(self, blitzortung): + """Test that requests without referer are blocked.""" + request = MockRequest( + client_ip='192.168.1.1', + content_type='text/json', + referer=None, + user_agent='bo-android-150' + ) + result = blitzortung.jsonrpc_get_strikes_grid(request, 60, 10000, 0, 1) + assert_that(result).is_equal_to({}) + + def test_returns_empty_for_small_grid_baselength(self, blitzortung): + """Test that requests with too small grid_baselength are blocked.""" + request = MockRequest( + client_ip='192.168.1.1', + content_type='text/json', + referer='http://example.com', + user_agent='bo-android-150' + ) + result = blitzortung.jsonrpc_get_strikes_grid(request, 60, 1000, 0, 1) + assert_that(result).is_equal_to({}) + + def test_returns_empty_for_invalid_grid_baselength(self, blitzortung): + """Test that requests with invalid grid_baselength are blocked.""" + request = MockRequest( + client_ip='192.168.1.1', + content_type='text/json', + referer='http://example.com', + user_agent='bo-android-150' + ) + result = blitzortung.jsonrpc_get_strikes_grid(request, 60, 1000001, 0, 1) + assert_that(result).is_equal_to({}) + + def test_returns_response_for_valid_request(self, blitzortung): + """Test that valid requests get a response.""" + request = MockRequest( + client_ip='192.168.1.1', + content_type='text/json', + referer='http://example.com', + user_agent='bo-android-150' + ) + result = blitzortung.jsonrpc_get_strikes_grid(request, 60, 10000, 0, 1) + # Should return the cached response (empty dict from mock) + assert_that(result).is_equal_to({}) + + +class TestJsonRpcGetGlobalStrikesGrid: + """Test jsonrpc_get_global_strikes_grid method.""" + + def test_returns_empty_for_forbidden_ip(self, blitzortung): + """Test that requests from forbidden IPs are blocked.""" + request = MockRequest( + client_ip='192.168.1.100', + content_type='text/json', + referer='http://example.com', + user_agent='bo-android-150' + ) + result = blitzortung.jsonrpc_get_global_strikes_grid(request, 60, 10000, 0) + assert_that(result).is_equal_to({}) + + def test_returns_empty_for_invalid_user_agent(self, blitzortung): + """Test that requests with invalid user agent are blocked.""" + request = MockRequest( + client_ip='192.168.1.1', + content_type='text/json', + referer='http://example.com', + user_agent='invalid' + ) + result = blitzortung.jsonrpc_get_global_strikes_grid(request, 60, 10000, 0) + assert_that(result).is_equal_to({}) + + def test_returns_response_for_valid_request(self, blitzortung): + """Test that valid requests get a response.""" + request = MockRequest( + client_ip='192.168.1.1', + content_type='text/json', + referer='http://example.com', + user_agent='bo-android-150' + ) + result = blitzortung.jsonrpc_get_global_strikes_grid(request, 60, 10000, 0) + assert_that(result).is_equal_to({}) + + +class TestJsonRpcGetLocalStrikesGrid: + """Test jsonrpc_get_local_strikes_grid method.""" + + def test_returns_empty_for_forbidden_ip(self, blitzortung): + """Test that requests from forbidden IPs are blocked.""" + request = MockRequest( + client_ip='192.168.1.100', + content_type='text/json', + referer='http://example.com' + ) + result = blitzortung.jsonrpc_get_local_strikes_grid(request, 10, 20, 10000, 60, 0) + assert_that(result).is_equal_to({}) + + def test_returns_empty_for_invalid_content_type(self, blitzortung): + """Test that requests without proper content type are blocked.""" + request = MockRequest( + client_ip='192.168.1.1', + content_type='text/html', + referer='http://example.com' + ) + result = blitzortung.jsonrpc_get_local_strikes_grid(request, 10, 20, 10000, 60, 0) + assert_that(result).is_equal_to({}) + + def test_returns_response_for_valid_request(self, blitzortung): + """Test that valid requests get a response.""" + request = MockRequest( + client_ip='192.168.1.1', + content_type='text/json', + referer='http://example.com' + ) + result = blitzortung.jsonrpc_get_local_strikes_grid(request, 10, 20, 10000, 60, 0) + assert_that(result).is_equal_to({}) + + +class TestCheckPeriod: + """Test __check_period method.""" + + def test_restarts_period_when_changed(self, blitzortung): + """Test that period is restarted when it changes.""" + with patch.object(blitzortung, '_Blitzortung__restart_period') as mock_restart: + with patch.object(blitzortung, '_Blitzortung__current_period') as mock_current: + # Set current period to be different + mock_current.return_value = datetime.datetime(2025, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) + blitzortung.current_period = datetime.datetime(2025, 1, 1, 11, 0, 0, tzinfo=datetime.timezone.utc) + blitzortung._Blitzortung__check_period() + mock_restart.assert_called_once() + + def test_does_not_restart_when_same_period(self, blitzortung): + """Test that period is not restarted when unchanged.""" + with patch.object(blitzortung, '_Blitzortung__restart_period') as mock_restart: + with patch.object(blitzortung, '_Blitzortung__current_period') as mock_current: + same_period = datetime.datetime(2025, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) + mock_current.return_value = same_period + blitzortung.current_period = same_period + blitzortung._Blitzortung__check_period() + mock_restart.assert_not_called() + + +class TestRestartPeriod: + """Test __restart_period method.""" + + def test_resets_current_data(self, blitzortung): + """Test that current_data is reset.""" + blitzortung.current_data['test'] = [1, 2, 3] + blitzortung._Blitzortung__restart_period() + assert_that(blitzortung.current_data).is_instance_of(dict) + assert_that(blitzortung.current_data['test']).is_equal_to([])