diff --git a/README.rst b/README.rst index cb06c1764..0798ad00a 100644 --- a/README.rst +++ b/README.rst @@ -77,7 +77,7 @@ settings: etcd - *session\_timeout*: the TTL to acquire the leader lock. Think of it as the length of time before automatic failover process is initiated. -- *reconnects\_timeout*: how long we should try to reconnect to +- *reconnect\_timeout*: how long we should try to reconnect to ZooKeeper after connection loss. After this timeout we assume that we don't have lock anymore and will restart in read-only mode. - *hosts*: list of ZooKeeper cluster members in format: [ @@ -101,6 +101,7 @@ settings: accessible from other nodes and applications. - *data\_dir*: file path to initialize and store Postgres data files - *maximum\_lag\_on\_failover*: the maximum bytes a follower may lag +- *use\_slots*: whether or not to use replication_slots. Must be False for PostgreSQL 9.3, and you should comment out max_replication_slots. before it is not eligible become leader - *pg\_hba*: list of lines which should be added to pg\_hba.conf @@ -138,9 +139,8 @@ settings: - *password*: admin password, user will be created during initialization. -- *recovery\_conf*: configuration settings written to recovery.conf - when configuring follower -- *parameters*: list of configuration settings for Postgres +- *recovery\_conf*: additional configuration settings written to recovery.conf when configuring follower +- *parameters*: list of configuration settings for Postgres. Many of these are required for replication to work. Replication choices ------------------- diff --git a/patroni/__init__.py b/patroni/__init__.py index b970b5137..2da277292 100644 --- a/patroni/__init__.py +++ b/patroni/__init__.py @@ -1 +1,121 @@ -__version__ = '0.22' +import logging +import os +import sys +import time +import yaml + +from patroni.api import RestApiServer +from patroni.etcd import Etcd +from patroni.ha import Ha +from patroni.postgresql import Postgresql +from patroni.utils import setup_signal_handlers, sleep, reap_children +from patroni.zookeeper import ZooKeeper + +logger = logging.getLogger(__name__) + + +class Patroni: + + def __init__(self, config): + self.nap_time = config['loop_wait'] + self.postgresql = Postgresql(config['postgresql']) + self.ha = Ha(self.postgresql, self.get_dcs(self.postgresql.name, config)) + host, port = config['restapi']['listen'].split(':') + self.api = RestApiServer(self, config['restapi']) + self.next_run = time.time() + self.shutdown_member_ttl = 300 + + @staticmethod + def get_dcs(name, config): + if 'etcd' in config: + return Etcd(name, config['etcd']) + if 'zookeeper' in config: + return ZooKeeper(name, config['zookeeper']) + raise Exception('Can not find sutable configuration of distributed configuration store') + + def touch_member(self, ttl=None): + connection_string = self.postgresql.connection_string + '?application_name=' + self.api.connection_string + if self.ha.cluster: + for m in self.ha.cluster.members: + # Do not update member TTL when it is far from being expired + if m.name == self.postgresql.name and m.real_ttl() > self.shutdown_member_ttl: + return True + return self.ha.dcs.touch_member(connection_string, ttl) + + def initialize(self): + # wait for etcd to be available + while not self.touch_member(): + logger.info('waiting on DCS') + sleep(5) + + # is data directory empty? + if self.postgresql.data_directory_empty(): + # racing to initialize + if self.ha.dcs.race('/initialize'): + self.postgresql.initialize() + self.ha.dcs.take_leader() + self.postgresql.start() + self.postgresql.create_replication_user() + self.postgresql.create_connection_users() + else: + while True: + leader = self.ha.dcs.current_leader() + if leader and self.postgresql.sync_from_leader(leader): + self.postgresql.write_recovery_conf(leader) + self.postgresql.start() + break + sleep(5) + elif self.postgresql.is_running(): + self.postgresql.load_replication_slots() + + def schedule_next_run(self): + if self.postgresql.is_promoted: + self.next_run = time.time() + self.next_run += self.nap_time + current_time = time.time() + nap_time = self.next_run - current_time + if nap_time <= 0: + self.next_run = current_time + else: + self.ha.dcs.watch(nap_time) + + def run(self): + self.api.start() + self.next_run = time.time() + + while True: + self.touch_member() + logger.info(self.ha.run_cycle()) + try: + if self.ha.state_handler.is_leader(): + self.ha.cluster and self.ha.state_handler.create_replication_slots(self.ha.cluster) + else: + self.ha.state_handler.drop_replication_slots() + except: + logger.exception('Exception when changing replication slots') + reap_children() + self.schedule_next_run() + + +def main(): + logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO) + logging.getLogger('requests').setLevel(logging.WARNING) + setup_signal_handlers() + + if len(sys.argv) < 2 or not os.path.isfile(sys.argv[1]): + print('Usage: {} config.yml'.format(sys.argv[0])) + return + + with open(sys.argv[1], 'r') as f: + config = yaml.load(f) + + patroni = Patroni(config) + try: + patroni.initialize() + patroni.run() + except KeyboardInterrupt: + pass + finally: + patroni.touch_member(patroni.shutdown_member_ttl) # schedule member removal + patroni.postgresql.stop() + patroni.ha.dcs.delete_leader() diff --git a/patroni/__main__.py b/patroni/__main__.py index 7390c1f1a..3abcbfc37 100644 --- a/patroni/__main__.py +++ b/patroni/__main__.py @@ -1,4 +1,5 @@ -import patroni +from patroni import main + if __name__ == '__main__': - patroni.main() + main() diff --git a/patroni/helpers/api.py b/patroni/api.py similarity index 100% rename from patroni/helpers/api.py rename to patroni/api.py diff --git a/patroni/helpers/dcs.py b/patroni/dcs.py similarity index 86% rename from patroni/helpers/dcs.py rename to patroni/dcs.py index 69ccdebd2..6fb7aea8a 100644 --- a/patroni/helpers/dcs.py +++ b/patroni/dcs.py @@ -1,7 +1,8 @@ import abc from collections import namedtuple -from patroni.helpers.utils import calculate_ttl, sleep +from patroni.exceptions import DCSError +from patroni.utils import calculate_ttl, sleep from six.moves.urllib_parse import urlparse, urlunparse, parse_qsl @@ -22,24 +23,11 @@ def parse_connection_string(value): return conn_url, api_url -class DCSError(Exception): - """Parent class for all kind of exceptions related to selected distributed configuration store""" - - def __init__(self, value): - self.value = value - - def __str__(self): - """ - >>> str(DCSError('foo')) - "'foo'" - """ - return repr(self.value) - - class Member(namedtuple('Member', 'index,name,conn_url,api_url,expiration,ttl')): + """Immutable object (namedtuple) which represents single member of PostgreSQL cluster. Consists of the following fields: - :param index: modification index of a given member key in DCS + :param index: modification index of a given member key in a Configuration Store :param name: name of PostgreSQL cluster member :param conn_url: connection string containing host, user and password which could be used to access this member. :param api_url: REST API url of patroni instance @@ -50,11 +38,30 @@ def real_ttl(self): return calculate_ttl(self.expiration) or -1 +class Leader(namedtuple('Leader', 'index,expiration,ttl,member')): + + """Immutable object (namedtuple) which represents leader key. + Consists of the following fields: + :param index: modification index of a leader key in a Configuration Store + :param expiration: expiration time of the leader key + :param ttl: ttl of the leader key + :param member: reference to a `Member` object which represents current leader (see `Cluster.members`)""" + + @property + def name(self): + return self.member.name + + @property + def conn_url(self): + return self.member.conn_url + + class Cluster(namedtuple('Cluster', 'initialize,leader,last_leader_operation,members')): + """Immutable object (namedtuple) which represents PostgreSQL cluster. Consists of the following fields: :param initialize: boolean, shows whether this cluster has initialization key stored in DC or not. - :param leader: `Member` object which represents current leader of the cluster + :param leader: `Leader` object which represents current leader of the cluster :param last_leader_operation: int or long object containing position of last known leader operation. This value is stored in `/optime/leader` key :param members: list of Member object, all PostgreSQL cluster members including leader""" @@ -74,7 +81,8 @@ def __init__(self, name, config): i.e.: `zookeeper` for zookeeper, `etcd` for etcd, etc... """ self._name = name - self._base_path = '/service/' + config['scope'] + self._scope = config['scope'] + self._base_path = '/service/' + self._scope def client_path(self, path): return self._base_path + path @@ -144,5 +152,5 @@ def delete_leader(self): """Voluntarily remove leader key from DCS This method should remove leader key if current instance is the leader""" - def sleep(self, timeout): + def watch(self, timeout): sleep(timeout) diff --git a/patroni/helpers/etcd.py b/patroni/etcd.py similarity index 66% rename from patroni/helpers/etcd.py rename to patroni/etcd.py index c58c1a789..97bd2cf7d 100644 --- a/patroni/helpers/etcd.py +++ b/patroni/etcd.py @@ -5,11 +5,13 @@ import random import requests import socket +import time +import urllib3 from dns.exception import DNSException from dns import resolver -from patroni.helpers.dcs import AbstractDCS, Cluster, DCSError, Member, parse_connection_string -from patroni.helpers.utils import sleep +from patroni.dcs import AbstractDCS, Cluster, DCSError, Leader, Member, parse_connection_string +from patroni.utils import Retry, RetryFailedError, sleep from requests.exceptions import RequestException logger = logging.getLogger(__name__) @@ -59,6 +61,16 @@ def get_srv_record(host): logger.exception('Can not resolve SRV for %s', host) return [] + # try to workarond bug in python-etcd: https://github.com/jplana/python-etcd/issues/81 + def _result_from_response(self, response): + try: + response.data.decode('utf-8') + except urllib3.exceptions.TimeoutError: + raise + except Exception as e: + raise etcd.EtcdException('Unable to decode server response: %s' % e) + return super(Client, self)._result_from_response(response) + def _get_machines_cache_from_srv(self, discovery_srv): """Fetch list of etcd-cluster member by resolving _etcd-server._tcp. SRV record. This record should contain list of host and peer ports which could be used to run @@ -124,7 +136,7 @@ def catch_etcd_errors(func): def wrapper(*args, **kwargs): try: return not func(*args, **kwargs) is None - except etcd.EtcdException: + except (RetryFailedError, etcd.EtcdException): return False return wrapper @@ -135,7 +147,16 @@ def __init__(self, name, config): super(Etcd, self).__init__(name, config) self.ttl = config['ttl'] self.member_ttl = config.get('member_ttl', 3600) + self._retry = Retry(deadline=10, max_delay=1, max_tries=-1, + retry_exceptions=(etcd.EtcdConnectionFailed, + etcd.EtcdLeaderElectionInProgress, + etcd.EtcdWatcherCleared, + etcd.EtcdEventIndexCleared)) self.client = self.get_etcd_client(config) + self.cluster = None + + def retry(self, *args, **kwargs): + return self._retry.copy()(*args, **kwargs) def get_etcd_client(self, config): client = None @@ -154,7 +175,7 @@ def member(node): def get_cluster(self): try: - result = self.client.read(self.client_path(''), recursive=True) + result = self.retry(self.client.read, self.client_path(''), recursive=True) nodes = {os.path.relpath(node.key, result.key): node for node in result.leaves} # get initialize flag @@ -170,30 +191,37 @@ def get_cluster(self): # get leader leader = nodes.get('leader', None) if leader: - leader = Member(-1, leader.value, None, None, None, None) - leader = ([m for m in members if m.name == leader.name] or [leader])[0] + member = Member(-1, leader.value, None, None, None, None) + member = ([m for m in members if m.name == leader.value] or [member])[0] + leader = Leader(leader.modifiedIndex, leader.expiration, leader.ttl, member) - return Cluster(initialize, leader, last_leader_operation, members) + self.cluster = Cluster(initialize, leader, last_leader_operation, members) except etcd.EtcdKeyNotFound: - return Cluster(False, None, None, []) + self.cluster = Cluster(False, None, None, []) except: + self.cluster = None logger.exception('get_cluster') - - raise EtcdError('Etcd is not responding properly') + raise EtcdError('Etcd is not responding properly') + return self.cluster @catch_etcd_errors def touch_member(self, connection_string, ttl=None): - return self.client.set(self.client_path('/members/' + self._name), connection_string, ttl or self.member_ttl) + return self.retry(self.client.set, self.client_path('/members/' + self._name), + connection_string, ttl or self.member_ttl) @catch_etcd_errors def take_leader(self): - return self.client.set(self.client_path('/leader'), self._name, self.ttl) + return self.retry(self.client.set, self.client_path('/leader'), self._name, self.ttl) - @catch_etcd_errors def attempt_to_acquire_leader(self): - ret = self.client.write(self.client_path('/leader'), self._name, ttl=self.ttl, prevExist=False) - ret or logger.info('Could not take out TTL lock') - return ret + try: + return not self.retry(self.client.write, self.client_path('/leader'), + self._name, ttl=self.ttl, prevExist=False) is None + except etcd.EtcdAlreadyExist: + logger.info('Could not take out TTL lock') + except (RetryFailedError, etcd.EtcdException): + pass + return False @catch_etcd_errors def write_leader_optime(self, state_handler): @@ -201,14 +229,36 @@ def write_leader_optime(self, state_handler): @catch_etcd_errors def update_leader(self, state_handler): - ret = self.client.test_and_set(self.client_path('/leader'), self._name, self._name, self.ttl) + ret = self.retry(self.client.test_and_set, self.client_path('/leader'), self._name, self._name, self.ttl) ret and self.write_leader_optime(state_handler) return ret @catch_etcd_errors def race(self, path): - return self.client.write(self.client_path(path), self._name, prevExist=False) + return self.retry(self.client.write, self.client_path(path), self._name, prevExist=False) @catch_etcd_errors def delete_leader(self): return self.client.delete(self.client_path('/leader'), prevValue=self._name) + + def watch(self, timeout): + # watch on leader key changes if it is defined and current node is not lock owner + if self.cluster and self.cluster.leader and self.cluster.leader.name != self._name: + end_time = time.time() + timeout + index = self.cluster.leader.index + + while index and timeout >= 1: # when timeout is too small urllib3 doesn't have enough time to connect + try: + res = self.client.watch(self.client_path('/leader'), index=index + 1, timeout=timeout) + if res.action not in ['set', 'compareAndSwap'] or res.value != self.cluster.leader.name: + return + index = res.modifiedIndex + except urllib3.exceptions.TimeoutError: + self.client.http.clear() + return + except etcd.EtcdException: + index = None + + timeout = end_time - time.time() + + timeout > 0 and super(Etcd, self).watch(timeout) diff --git a/patroni/exceptions.py b/patroni/exceptions.py new file mode 100644 index 000000000..5c13f3e0b --- /dev/null +++ b/patroni/exceptions.py @@ -0,0 +1,17 @@ +class PatroniException(Exception): + pass + + +class DCSError(PatroniException): + + """Parent class for all kind of exceptions related to selected distributed configuration store""" + + def __init__(self, value): + self.value = value + + def __str__(self): + """ + >>> str(DCSError('foo')) + "'foo'" + """ + return repr(self.value) diff --git a/patroni/helpers/ha.py b/patroni/ha.py similarity index 99% rename from patroni/helpers/ha.py rename to patroni/ha.py index 0496a094a..f31b2b262 100644 --- a/patroni/helpers/ha.py +++ b/patroni/ha.py @@ -1,6 +1,6 @@ import logging -from patroni.helpers.dcs import DCSError +from patroni.dcs import DCSError from psycopg2 import InterfaceError, OperationalError logger = logging.getLogger(__name__) diff --git a/patroni/helpers/__init__.py b/patroni/helpers/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/patroni/helpers/utils.py b/patroni/helpers/utils.py deleted file mode 100644 index 5f09490e8..000000000 --- a/patroni/helpers/utils.py +++ /dev/null @@ -1,101 +0,0 @@ -import datetime -import os -import re -import signal -import sys -import time - -received_sigchld = False - -_DATE_TIME_RE = re.compile(r'''^ -(?P\d{4})\-(?P\d{2})\-(?P\d{2}) # date -T -(?P\d{2}):(?P\d{2}):(?P\d{2})\.(?P\d{6}) # time -\d*Z$''', re.X) - - -def parse_datetime(time_str): - """ - >>> parse_datetime('2015-06-10T12:56:30.552539016Z') - datetime.datetime(2015, 6, 10, 12, 56, 30, 552539) - >>> parse_datetime('2015-06-10 12:56:30.552539016Z') - """ - m = _DATE_TIME_RE.match(time_str) - if not m: - return None - p = dict((n, int(m.group(n))) for n in 'year month day hour minute second microsecond'.split(' ')) - return datetime.datetime(**p) - - -def calculate_ttl(expiration): - """ - >>> calculate_ttl(None) - >>> calculate_ttl('2015-06-10 12:56:30.552539016Z') - """ - if not expiration: - return None - expiration = parse_datetime(expiration) - if not expiration: - return None - now = datetime.datetime.utcnow() - return int((expiration - now).total_seconds()) - - -def lsn_to_bytes(value): - """ - >>> lsn_to_bytes('1/66000060') - 6006243424 - >>> lsn_to_bytes('j/66000060') - 0 - """ - try: - e = value.split('/') - if len(e) == 2 and len(e[0]) > 0 and len(e[1]) > 0: - return (int(e[0], 16) << 32) | int(e[1], 16) - except ValueError: - pass - return 0 - - -def bytes_to_lsn(value): - """ - >>> bytes_to_lsn(6006243424) - '1/66000060' - """ - id = value >> 32 - off = value & 0xffffffff - return '%x/%x' % (id, off) - - -def sigterm_handler(signo, stack_frame): - sys.exit() - - -def sigchld_handler(signo, stack_frame): - global received_sigchld - received_sigchld = True - try: - while True: - ret = os.waitpid(-1, os.WNOHANG) - if ret == (0, 0): - break - except OSError: - pass - - -def sleep(interval): - global received_sigchld - current_time = time.time() - end_time = current_time + interval - while current_time < end_time: - received_sigchld = False - time.sleep(end_time - current_time) - if not received_sigchld: # we will ignore only sigchld - break - current_time = time.time() - received_sigchld = False - - -def setup_signal_handlers(): - signal.signal(signal.SIGTERM, sigterm_handler) - signal.signal(signal.SIGCHLD, sigchld_handler) diff --git a/patroni/patroni.py b/patroni/patroni.py deleted file mode 100755 index 2ee322119..000000000 --- a/patroni/patroni.py +++ /dev/null @@ -1,119 +0,0 @@ -#!/usr/bin/env python -import logging -import os -import sys -import time -import yaml - -from .helpers.api import RestApiServer -from .helpers.etcd import Etcd -from .helpers.ha import Ha -from .helpers.postgresql import Postgresql -from .helpers.utils import setup_signal_handlers, sleep -from .helpers.zookeeper import ZooKeeper - -logger = logging.getLogger(__name__) - - -class Patroni: - - def __init__(self, config): - self.nap_time = config['loop_wait'] - self.postgresql = Postgresql(config['postgresql']) - self.ha = Ha(self.postgresql, self.get_dcs(self.postgresql.name, config)) - host, port = config['restapi']['listen'].split(':') - self.api = RestApiServer(self, config['restapi']) - self.next_run = time.time() - self.shutdown_member_ttl = 300 - - @staticmethod - def get_dcs(name, config): - if 'etcd' in config: - return Etcd(name, config['etcd']) - if 'zookeeper' in config: - return ZooKeeper(name, config['zookeeper']) - raise Exception('Can not find sutable configuration of distributed configuration store') - - def touch_member(self, ttl=None): - connection_string = self.postgresql.connection_string + '?application_name=' + self.api.connection_string - if self.ha.cluster: - for m in self.ha.cluster.members: - # Do not update member TTL when it is far from being expired - if m.name == self.postgresql.name and m.real_ttl() > self.shutdown_member_ttl: - return True - return self.ha.dcs.touch_member(connection_string, ttl) - - def initialize(self): - # wait for etcd to be available - while not self.touch_member(): - logger.info('waiting on DCS') - sleep(5) - - # is data directory empty? - if self.postgresql.data_directory_empty(): - # racing to initialize - if self.ha.dcs.race('/initialize'): - self.postgresql.initialize() - self.ha.dcs.take_leader() - self.postgresql.start() - self.postgresql.create_replication_user() - self.postgresql.create_connection_users() - else: - while True: - leader = self.ha.dcs.current_leader() - if leader and self.postgresql.sync_from_leader(leader): - self.postgresql.write_recovery_conf(leader) - self.postgresql.start() - break - sleep(5) - elif self.postgresql.is_running(): - self.postgresql.load_replication_slots() - - def schedule_next_run(self): - self.next_run += self.nap_time - current_time = time.time() - nap_time = self.next_run - current_time - if nap_time <= 0: - self.next_run = current_time - else: - self.ha.dcs.sleep(nap_time) - - def run(self): - self.api.start() - self.next_run = time.time() - - while True: - self.touch_member() - logger.info(self.ha.run_cycle()) - try: - if self.ha.state_handler.is_leader(): - self.ha.cluster and self.ha.state_handler.create_replication_slots(self.ha.cluster) - else: - self.ha.state_handler.drop_replication_slots() - except: - logger.exception('Exception when changing replication slots') - self.schedule_next_run() - - -def main(): - logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO) - logging.getLogger('requests').setLevel(logging.WARNING) - setup_signal_handlers() - - if len(sys.argv) < 2 or not os.path.isfile(sys.argv[1]): - print('Usage: {} config.yml'.format(sys.argv[0])) - return - - with open(sys.argv[1], 'r') as f: - config = yaml.load(f) - - patroni = Patroni(config) - try: - patroni.initialize() - patroni.run() - except KeyboardInterrupt: - pass - finally: - patroni.touch_member(patroni.shutdown_member_ttl) # schedule member removal - patroni.postgresql.stop() - patroni.ha.dcs.delete_leader() diff --git a/patroni/helpers/postgresql.py b/patroni/postgresql.py similarity index 87% rename from patroni/helpers/postgresql.py rename to patroni/postgresql.py index 8593fb52f..d729942cb 100644 --- a/patroni/helpers/postgresql.py +++ b/patroni/postgresql.py @@ -4,14 +4,10 @@ import shlex import shutil import subprocess -import six -from patroni.helpers.utils import sleep +from patroni.utils import sleep from six.moves.urllib_parse import urlparse -if six.PY3: - long = int - logger = logging.getLogger(__name__) ACTION_ON_START = "on_start" @@ -50,6 +46,7 @@ def __init__(self, config): self.superuser = config['superuser'] self.admin = config['admin'] self.callback = config.get('callbacks', {}) + self.use_slots = config.get('use_slots', True) self.recovery_conf = os.path.join(self.data_dir, 'recovery.conf') self.configuration_to_save = (os.path.join(self.data_dir, 'pg_hba.conf'), os.path.join(self.data_dir, 'postgresql.conf')) @@ -256,8 +253,8 @@ def is_healthiest_node(self, cluster): member_conn.autocommit = True member_cursor = member_conn.cursor() member_cursor.execute( - "SELECT pg_is_in_recovery(), %s - (pg_last_xlog_replay_location() - '0/0000000'::pg_lsn)", - (self.xlog_position(), )) + "SELECT pg_is_in_recovery(), %s - pg_xlog_location_diff(pg_last_xlog_replay_location(), '0/0')", + (self.xlog_position(),)) row = member_cursor.fetchone() member_cursor.close() member_conn.close() @@ -305,10 +302,9 @@ def write_recovery_conf(self, leader): recovery_target_timeline = 'latest' """) if leader and leader.conn_url: - f.write(""" -primary_slot_name = '{}' -primary_conninfo = '{}' -""".format(self.name, self.primary_conninfo(leader.conn_url))) + f.write("""primary_conninfo = '{}'\n""".format(self.primary_conninfo(leader.conn_url))) + if self.use_slots: + f.write("""primary_slot_name = '{}'\n""".format(self.name)) for name, value in self.config.get('recovery_conf', {}).items(): f.write("{} = '{}'\n".format(name, value)) @@ -354,32 +350,37 @@ def create_connection_users(self): self.query('CREATE ROLE "{0}" WITH LOGIN SUPERUSER PASSWORD %s'.format( self.superuser['username']), self.superuser['password']) else: - self.query('ALTER ROLE postgres WITH PASSWORD %s', self.superuser['password']) + rolsuper = self.query("""SELECT rolname FROM pg_authid WHERE rolsuper = 't'""").fetchone()[0] + self.query('ALTER ROLE "{0}" WITH PASSWORD %s'.format(rolsuper), self.superuser['password']) if self.admin: self.query('CREATE ROLE "{0}" WITH LOGIN CREATEDB CREATEROLE PASSWORD %s'.format( self.admin['username']), self.admin['password']) def xlog_position(self): - return self.query("""SELECT CASE WHEN pg_is_in_recovery() - THEN pg_last_xlog_replay_location() - '0/0000000'::pg_lsn - ELSE pg_current_xlog_location() - '0/00000'::pg_lsn END""").fetchone()[0] + return self.query("""SELECT pg_xlog_location_diff(CASE WHEN pg_is_in_recovery() + THEN pg_last_xlog_replay_location() + ELSE pg_current_xlog_location() + END, '0/0')""").fetchone()[0] def load_replication_slots(self): - cursor = self.query("SELECT slot_name FROM pg_replication_slots WHERE slot_type='physical'") - self.members = [r[0] for r in cursor] + if self.use_slots: + cursor = self.query("SELECT slot_name FROM pg_replication_slots WHERE slot_type='physical'") + self.members = [r[0] for r in cursor] def sync_replication_slots(self, members): - # drop unused slots - for slot in set(self.members) - set(members): - self.query("""SELECT pg_drop_replication_slot(%s) - WHERE EXISTS(SELECT 1 FROM pg_replication_slots - WHERE slot_name = %s)""", slot, slot) - - # create new slots - for slot in set(members) - set(self.members): - self.query("""SELECT pg_create_physical_replication_slot(%s) - WHERE NOT EXISTS (SELECT 1 FROM pg_replication_slots - WHERE slot_name = %s)""", slot, slot) + if self.use_slots: + # drop unused slots + for slot in set(self.members) - set(members): + self.query("""SELECT pg_drop_replication_slot(%s) + WHERE EXISTS(SELECT 1 FROM pg_replication_slots + WHERE slot_name = %s)""", slot, slot) + + # create new slots + for slot in set(members) - set(self.members): + self.query("""SELECT pg_create_physical_replication_slot(%s) + WHERE NOT EXISTS (SELECT 1 FROM pg_replication_slots + WHERE slot_name = %s)""", slot, slot) + self.members = members def create_replication_slots(self, cluster): diff --git a/patroni/utils.py b/patroni/utils.py new file mode 100644 index 000000000..45ada8640 --- /dev/null +++ b/patroni/utils.py @@ -0,0 +1,162 @@ +import datetime +import os +import random +import re +import signal +import sys +import time + +from patroni.exceptions import DCSError + +interrupted_sleep = False +reap_children = False + +_DATE_TIME_RE = re.compile(r'''^ +(?P\d{4})\-(?P\d{2})\-(?P\d{2}) # date +T +(?P\d{2}):(?P\d{2}):(?P\d{2})\.(?P\d{6}) # time +\d*Z$''', re.X) + + +def parse_datetime(time_str): + """ + >>> parse_datetime('2015-06-10T12:56:30.552539016Z') + datetime.datetime(2015, 6, 10, 12, 56, 30, 552539) + >>> parse_datetime('2015-06-10 12:56:30.552539016Z') + """ + m = _DATE_TIME_RE.match(time_str) + if not m: + return None + p = dict((n, int(m.group(n))) for n in 'year month day hour minute second microsecond'.split(' ')) + return datetime.datetime(**p) + + +def calculate_ttl(expiration): + """ + >>> calculate_ttl(None) + >>> calculate_ttl('2015-06-10 12:56:30.552539016Z') + """ + if not expiration: + return None + expiration = parse_datetime(expiration) + if not expiration: + return None + now = datetime.datetime.utcnow() + return int((expiration - now).total_seconds()) + + +def sigterm_handler(signo, stack_frame): + sys.exit() + + +def sigchld_handler(signo, stack_frame): + global interrupted_sleep, reap_children + reap_children = interrupted_sleep = True + + +def sleep(interval): + global interrupted_sleep + current_time = time.time() + end_time = current_time + interval + while current_time < end_time: + interrupted_sleep = False + time.sleep(end_time - current_time) + if not interrupted_sleep: # we will ignore only sigchld + break + current_time = time.time() + interrupted_sleep = False + + +def setup_signal_handlers(): + signal.signal(signal.SIGTERM, sigterm_handler) + signal.signal(signal.SIGCHLD, sigchld_handler) + + +def reap_children(): + global reap_children + if reap_children: + try: + while True: + ret = os.waitpid(-1, os.WNOHANG) + if ret == (0, 0): + break + except OSError: + pass + finally: + reap_children = False + + +class RetryFailedError(DCSError): + + """Raised when retrying an operation ultimately failed, after retrying the maximum number of attempts.""" + + +class Retry: + + """Helper for retrying a method in the face of retry-able exceptions""" + + def __init__(self, max_tries=1, delay=0.1, backoff=2, max_jitter=0.8, max_delay=3600, + sleep_func=time.sleep, deadline=None, retry_exceptions=DCSError): + """Create a :class:`Retry` instance for retrying function calls + + :param max_tries: How many times to retry the command. -1 means infinite tries. + :param delay: Initial delay between retry attempts. + :param backoff: Backoff multiplier between retry attempts. Defaults to 2 for exponential backoff. + :param max_jitter: Additional max jitter period to wait between retry attempts to avoid slamming the server. + :param max_delay: Maximum delay in seconds, regardless of other backoff settings. Defaults to one hour. + :param retry_exceptions: single exception or tuple""" + + self.max_tries = max_tries + self.delay = delay + self.backoff = backoff + self.max_jitter = int(max_jitter * 100) + self.max_delay = float(max_delay) + self._attempts = 0 + self._cur_delay = delay + self.deadline = deadline + self._cur_stoptime = None + self.sleep_func = sleep_func + self.retry_exceptions = retry_exceptions + + def reset(self): + """Reset the attempt counter""" + self._attempts = 0 + self._cur_delay = self.delay + self._cur_stoptime = None + + def copy(self): + """Return a clone of this retry manager""" + return Retry(max_tries=self.max_tries, delay=self.delay, backoff=self.backoff, + max_jitter=self.max_jitter / 100.0, max_delay=self.max_delay, sleep_func=self.sleep_func, + deadline=self.deadline, retry_exceptions=self.retry_exceptions) + + def __call__(self, func, *args, **kwargs): + """Call a function with arguments until it completes without throwing a `retry_exceptions` + + :param func: Function to call + :param args: Positional arguments to call the function with + :params kwargs: Keyword arguments to call the function with + + The function will be called until it doesn't throw one of the retryable exceptions""" + self.reset() + + while True: + try: + if self.deadline is not None and self._cur_stoptime is None: + self._cur_stoptime = time.time() + self.deadline + return func(*args, **kwargs) + except self.retry_exceptions: + # Note: max_tries == -1 means infinite tries. + if self._attempts == self.max_tries: + raise RetryFailedError("Too many retry attempts") + self._attempts += 1 + sleeptime = self._cur_delay + ( + random.randint(0, self.max_jitter) / 100.0) + + if self._cur_stoptime is not None and \ + time.time() + sleeptime >= self._cur_stoptime: + raise RetryFailedError("Exceeded retry deadline") + else: + self.sleep_func(sleeptime) + self._cur_delay = min(self._cur_delay * self.backoff, + self.max_delay) diff --git a/patroni/version.py b/patroni/version.py new file mode 100644 index 000000000..11d27f8c7 --- /dev/null +++ b/patroni/version.py @@ -0,0 +1 @@ +__version__ = '0.1' diff --git a/patroni/helpers/zookeeper.py b/patroni/zookeeper.py similarity index 89% rename from patroni/helpers/zookeeper.py rename to patroni/zookeeper.py index fd5f4f5bb..29b8c7f1a 100644 --- a/patroni/helpers/zookeeper.py +++ b/patroni/zookeeper.py @@ -3,10 +3,10 @@ import requests import time -from patroni.helpers.dcs import AbstractDCS, Cluster, DCSError, Member, parse_connection_string -from patroni.helpers.utils import sleep from kazoo.client import KazooClient, KazooState from kazoo.exceptions import NoNodeError, NodeExistsError +from patroni.dcs import AbstractDCS, Cluster, DCSError, Leader, Member, parse_connection_string +from patroni.utils import sleep from requests.exceptions import RequestException logger = logging.getLogger(__name__) @@ -134,21 +134,18 @@ def _inner_load_cluster(self): leader = self.get_node('/leader', self.cluster_watcher) self.members = self.load_members() if leader: - if leader[0] == self._name: - client_id = self.client.client_id - if client_id is not None and client_id[0] != leader[1].ephemeralOwner: - logger.info('I am leader but not owner of the session. Removing leader node') - self.client.delete(self.client_path('/leader')) - leader = None + client_id = self.client.client_id + if leader[0] == self._name and client_id is not None and client_id[0] != leader[1].ephemeralOwner: + logger.info('I am leader but not owner of the session. Removing leader node') + self.client.delete(self.client_path('/leader')) + leader = None if leader: - for member in self.members: - if member.name == leader[0]: - leader = member - self.fetch_cluster = False - break - if not isinstance(leader, Member): - leader = Member(-1, leader, None, None, None, None) + member = Member(-1, leader[0], None, None, None, None) + member = ([m for m in self.members if m.name == leader[0]] or [member])[0] + leader = Leader(leader[1].mzxid, None, None, member) + self.fetch_cluster = member.index == -1 + self.leader = leader if self.fetch_cluster: last_leader_operation = self.get_node('/optime/leader') @@ -220,10 +217,10 @@ def update_leader(self, state_handler): return True def delete_leader(self): - if isinstance(self.leader, Member) and self.leader.name == self._name: + if isinstance(self.leader, Leader) and self.leader.name == self._name: self.client.delete(self.client_path('/leader')) - def sleep(self, timeout): + def watch(self, timeout): self.cluster_event.wait(timeout) if self.cluster_event.isSet(): self.fetch_cluster = True diff --git a/postgres0.yml b/postgres0.yml index a2a7ce443..ce90da148 100644 --- a/postgres0.yml +++ b/postgres0.yml @@ -30,6 +30,7 @@ postgresql: connect_address: 127.0.0.1:5432 data_dir: data/postgresql0 maximum_lag_on_failover: 1048576 # 1 megabyte in bytes + use_slots: True pg_hba: - host all all 0.0.0.0/0 md5 - hostssl all all 0.0.0.0/0 md5 @@ -46,7 +47,7 @@ postgresql: env_dir: /home/postgres/etc/wal-e.d/env threshold_megabytes: 10240 threshold_backup_size_percentage: 30 - restore: "true" + restore: scripts/restore.py #recovery_conf: #restore_command: cp ../wal_archive/%f %p parameters: diff --git a/postgres1.yml b/postgres1.yml index 6ef6b1c9f..763444e85 100644 --- a/postgres1.yml +++ b/postgres1.yml @@ -30,6 +30,7 @@ postgresql: connect_address: 127.0.0.1:5433 data_dir: data/postgresql1 maximum_lag_on_failover: 1048576 # 1 megabyte in bytes + use_slots: True pg_hba: - host all all 0.0.0.0/0 md5 - hostssl all all 0.0.0.0/0 md5 @@ -48,6 +49,7 @@ postgresql: env_dir: /home/postgres/etc/wal-e.d/env threshold_megabytes: 10240 threshold_backup_size_percentage: 30 + restore: scripts/restore.py parameters: archive_mode: "on" wal_level: hot_standby diff --git a/release.sh b/release.sh index 5f97ba083..316ca9086 100755 --- a/release.sh +++ b/release.sh @@ -12,7 +12,7 @@ git --version version=$1 -sed -i "s/__version__ = .*/__version__ = '${version}'/" __init__.py +sed -i "s/__version__ = .*/__version__ = '${version}'/" version.py python3 setup.py clean python3 setup.py test python3 setup.py flake8 diff --git a/setup.py b/setup.py index b90d10848..3def98b35 100644 --- a/setup.py +++ b/setup.py @@ -21,13 +21,12 @@ def read_version(package): data = {} - with open(os.path.join(package, '__init__.py'), 'r') as fd: + with open(os.path.join(package, 'version.py'), 'r') as fd: exec(fd.read(), data) return data['__version__'] NAME = 'patroni' -MAIN_PACKAGE = 'patroni' -HELPERS = 'helpers' +MAIN_PACKAGE = NAME SCRIPTS = 'scripts' VERSION = read_version(MAIN_PACKAGE) DESCRIPTION = 'PostgreSQL High-Available orchestrator and CLI' @@ -57,7 +56,7 @@ def read_version(package): 'Programming Language :: Python :: Implementation :: CPython', ] -CONSOLE_SCRIPTS = ['patroni = patroni.patroni:main'] +CONSOLE_SCRIPTS = ['patroni = patroni:main'] class PyTest(TestCommand): @@ -74,8 +73,7 @@ def initialize_options(self): def finalize_options(self): TestCommand.finalize_options(self) if self.cov_xml or self.cov_html: - self.cov = ['--cov', MAIN_PACKAGE, '--cov', HELPERS, '--cov', SCRIPTS, '--cov-report', - 'term-missing'] + self.cov = ['--cov', MAIN_PACKAGE, '--cov', MAIN_PACKAGE, '--cov-report', 'term-missing'] if self.cov_xml: self.cov.extend(['--cov-report', 'xml']) if self.cov_html: diff --git a/tests/test_api.py b/tests/test_api.py index 12d7798b3..aecf3df66 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,7 +1,7 @@ import psycopg2 import unittest -from patroni.helpers.api import RestApiHandler, RestApiServer +from patroni.api import RestApiHandler, RestApiServer from six import BytesIO as IO from test_postgresql import psycopg2_connect diff --git a/tests/test_etcd.py b/tests/test_etcd.py index 3f09caa56..ab5717ca9 100644 --- a/tests/test_etcd.py +++ b/tests/test_etcd.py @@ -3,14 +3,15 @@ import etcd import json import requests +import urllib3 import socket import time import unittest from dns.exception import DNSException -from patroni.helpers.dcs import Cluster, DCSError, Member -from patroni.helpers.etcd import Client, Etcd from mock import Mock, patch +from patroni.dcs import Cluster, DCSError, Leader, Member +from patroni.etcd import Client, Etcd class MockResponse: @@ -25,6 +26,10 @@ def json(self): @property def data(self): + if self.content == 'TimeoutError': + raise urllib3.exceptions.TimeoutError + if self.content == 'Exception': + raise Exception return self.content @property @@ -61,7 +66,22 @@ def requests_get(url, **kwargs): return response +def etcd_watch(key, index=None, timeout=None, recursive=None): + if timeout == 1: + raise urllib3.exceptions.TimeoutError + elif timeout == 5: + return etcd.EtcdResult('delete', {}) + elif timeout == 10: + raise etcd.EtcdException + elif index == 20729: + return etcd.EtcdResult('set', {'value': 'postgresql1', 'modifiedIndex': index + 1}) + elif index == 20731: + return etcd.EtcdResult('set', {'value': 'postgresql2', 'modifiedIndex': index + 1}) + + def etcd_write(key, value, **kwargs): + if key == '/service/exists/leader': + raise etcd.EtcdAlreadyExist if key == '/service/test/leader': if kwargs.get('prevValue', None) == 'foo' or not kwargs.get('prevExist', True): return True @@ -107,8 +127,12 @@ def time_sleep(_): pass +class SleepException(Exception): + pass + + def time_sleep_exception(_): - raise Exception() + raise SleepException() class MockSRV: @@ -172,6 +196,15 @@ def test_get_srv_record(self): self.assertEquals(self.client.get_srv_record('blabla'), []) self.assertEquals(self.client.get_srv_record('exception'), []) + def test__result_from_response(self): + response = MockResponse() + response.content = 'TimeoutError' + self.assertRaises(urllib3.exceptions.TimeoutError, self.client._result_from_response, response) + response.content = 'Exception' + self.assertRaises(etcd.EtcdException, self.client._result_from_response, response) + response.content = b'{}' + self.assertRaises(etcd.EtcdException, self.client._result_from_response, response) + def test__get_machines_cache_from_srv(self): self.client.get_srv_record = lambda e: [('localhost', 2380)] self.client._get_machines_cache_from_srv('blabla') @@ -204,7 +237,7 @@ def test_get_etcd_client(self): time.sleep = time_sleep_exception with patch.object(etcd.Client, 'machines') as mock_machines: mock_machines.__get__ = Mock(side_effect=etcd.EtcdException) - self.assertRaises(Exception, self.etcd.get_etcd_client, {'discovery_srv': 'test'}) + self.assertRaises(SleepException, self.etcd.get_etcd_client, {'discovery_srv': 'test'}) def test_get_cluster(self): self.assertIsInstance(self.etcd.get_cluster(), Cluster) @@ -214,7 +247,7 @@ def test_get_cluster(self): self.assertIsNone(cluster.leader) def test_current_leader(self): - self.assertIsInstance(self.etcd.current_leader(), Member) + self.assertIsInstance(self.etcd.current_leader(), Leader) self.etcd._base_path = '/service/noleader' self.assertIsNone(self.etcd.current_leader()) @@ -224,6 +257,12 @@ def test_touch_member(self): def test_take_leader(self): self.assertFalse(self.etcd.take_leader()) + def testattempt_to_acquire_leader(self): + self.etcd._base_path = '/service/exists' + self.assertFalse(self.etcd.attempt_to_acquire_leader()) + self.etcd._base_path = '/service/failed' + self.assertFalse(self.etcd.attempt_to_acquire_leader()) + def test_update_leader(self): self.assertTrue(self.etcd.update_leader(MockPostgresql())) @@ -233,3 +272,12 @@ def test_race(self): def test_delete_leader(self): self.etcd.client.delete = etcd_delete self.assertFalse(self.etcd.delete_leader()) + + def test_watch(self): + self.etcd.client.watch = etcd_watch + self.etcd.watch(100) + self.etcd.get_cluster() + self.etcd.watch(1) + self.etcd.watch(5) + self.etcd.watch(10) + self.etcd.watch(100) diff --git a/tests/test_ha.py b/tests/test_ha.py index 8f9a9c490..46aa9a0aa 100644 --- a/tests/test_ha.py +++ b/tests/test_ha.py @@ -1,9 +1,9 @@ import unittest -from patroni.helpers.dcs import Cluster, DCSError -from patroni.helpers.etcd import Client, Etcd -from patroni.helpers.ha import Ha from mock import Mock, patch +from patroni.dcs import Cluster, DCSError +from patroni.etcd import Client, Etcd +from patroni.ha import Ha from test_etcd import etcd_read, etcd_write diff --git a/tests/test_patroni.py b/tests/test_patroni.py index fd2e605db..a45ebf624 100644 --- a/tests/test_patroni.py +++ b/tests/test_patroni.py @@ -1,5 +1,5 @@ import datetime -import patroni.helpers.zookeeper +import patroni.zookeeper import psycopg2 import subprocess import sys @@ -7,12 +7,12 @@ import unittest import yaml -from patroni.helpers.api import RestApiServer -from patroni.helpers.dcs import Cluster, Member -from patroni.helpers.etcd import Etcd -from patroni.helpers.zookeeper import ZooKeeper from mock import Mock, patch -from patroni.patroni import Patroni, main +from patroni.api import RestApiServer +from patroni.dcs import Cluster, Member +from patroni.etcd import Etcd +from patroni import Patroni, main +from patroni.zookeeper import ZooKeeper from six.moves import BaseHTTPServer from test_etcd import Client, etcd_read, etcd_write from test_ha import true, false @@ -24,8 +24,12 @@ def nop(*args, **kwargs): pass +class SleepException(Exception): + pass + + def time_sleep(*args): - raise Exception() + raise SleepException() class Mock_BaseServer__is_shut_down: @@ -70,7 +74,7 @@ def tear_down(self): Postgresql.write_recovery_conf = self.write_recovery_conf def test_get_dcs(self): - patroni.helpers.zookeeper.KazooClient = MockKazooClient + patroni.zookeeper.KazooClient = MockKazooClient self.assertIsInstance(self.p.get_dcs('', {'zookeeper': {'scope': '', 'hosts': ''}}), ZooKeeper) self.assertRaises(Exception, self.p.get_dcs, '', {}) @@ -90,7 +94,7 @@ def test_patroni_main(self): Etcd.delete_leader = nop - self.assertRaises(Exception, main) + self.assertRaises(SleepException, main) Patroni.run = run Patroni.touch_member = touch_member @@ -100,10 +104,11 @@ def test_patroni_run(self): self.p.touch_member = self.touch_member self.p.ha.state_handler.sync_replication_slots = time_sleep self.p.ha.dcs.client.read = etcd_read - self.assertRaises(Exception, self.p.run) + self.p.ha.dcs.watch = time_sleep + self.assertRaises(SleepException, self.p.run) self.p.ha.state_handler.is_leader = lambda: False self.p.api.start = nop - self.assertRaises(Exception, self.p.run) + self.assertRaises(SleepException, self.p.run) def touch_member(self, ttl=None): if not self.touched: diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index f79a1438b..273d077d8 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -4,8 +4,8 @@ import subprocess import unittest -from patroni.helpers.dcs import Cluster, Member -from patroni.helpers.postgresql import Postgresql +from patroni.dcs import Cluster, Leader, Member +from patroni.postgresql import Postgresql def nop(*args, **kwargs): @@ -24,7 +24,6 @@ class MockCursor: def __init__(self): self.closed = False - self.current = 0 self.results = [] def execute(self, sql, *params): @@ -43,7 +42,7 @@ def execute(self, sql, *params): self.results = [(True, -1)] else: self.results = [(False, 0)] - elif sql.startswith('SELECT CASE WHEN pg_is_in_recovery()'): + elif sql.startswith('SELECT pg_xlog_location_diff'): self.results = [(0,)] elif sql.startswith('SELECT pg_is_in_recovery()'): self.results = [(False, )] @@ -119,11 +118,12 @@ def set_up(self): 'on_restart': 'true', 'on_role_change': 'true', 'on_reload': 'true' }, - 'restore': '/usr/bin/true'}) + 'restore': 'true'}) psycopg2.connect = psycopg2_connect if not os.path.exists(self.p.data_dir): os.makedirs(self.p.data_dir) - self.leader = Member(0, 'leader', 'postgres://replicator:rep-pass@127.0.0.1:5435/postgres', None, None, 28) + self.leadermem = Member(0, 'leader', 'postgres://replicator:rep-pass@127.0.0.1:5435/postgres', None, None, 28) + self.leader = Leader(-1, None, 28, self.leadermem) self.other = Member(0, 'test1', 'postgres://replicator:rep-pass@127.0.0.1:5433/postgres', None, None, 28) self.me = Member(0, 'test0', 'postgres://replicator:rep-pass@127.0.0.1:5434/postgres', None, None, 28) @@ -156,7 +156,7 @@ def test_follow_the_leader(self): self.p.follow_the_leader(None) self.p.demote(self.leader) self.p.follow_the_leader(self.leader) - self.p.follow_the_leader(self.other) + self.p.follow_the_leader(Leader(-1, None, 28, self.other)) def test_create_connection_users(self): cfg = self.p.config @@ -166,7 +166,7 @@ def test_create_connection_users(self): def test_create_replication_slots(self): self.p.start() - cluster = Cluster(True, self.leader, 0, [self.me, self.other, self.leader]) + cluster = Cluster(True, self.leader, 0, [self.me, self.other, self.leadermem]) self.p.create_replication_slots(cluster) def test_query(self): @@ -180,7 +180,7 @@ def test_query(self): self.assertRaises(psycopg2.OperationalError, self.p.query, 'blabla') def test_is_healthiest_node(self): - cluster = Cluster(True, self.leader, 0, [self.me, self.other, self.leader]) + cluster = Cluster(True, self.leader, 0, [self.me, self.other, self.leadermem]) self.assertTrue(self.p.is_healthiest_node(cluster)) self.p.is_leader = false self.assertFalse(self.p.is_healthiest_node(cluster)) @@ -188,7 +188,7 @@ def test_is_healthiest_node(self): self.assertTrue(self.p.is_healthiest_node(cluster)) self.p.xlog_position = lambda: 2 self.assertFalse(self.p.is_healthiest_node(cluster)) - self.p.config['maximum_lag_on_failover'] = -2 + self.p.config['maximum_lag_on_failover'] = -3 self.assertFalse(self.p.is_healthiest_node(cluster)) def test_is_leader(self): diff --git a/tests/test_utils.py b/tests/test_utils.py index 3bb238b03..9a814f78e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,8 @@ import time import unittest -from patroni.helpers.utils import sigchld_handler, sigterm_handler, sleep +from patroni.exceptions import DCSError +from patroni.utils import Retry, RetryFailedError, reap_children, sigchld_handler, sigterm_handler, sleep def nop(*args, **kwargs): @@ -34,11 +35,67 @@ def tear_down(self): def test_sigterm_handler(self): self.assertRaises(SystemExit, sigterm_handler, None, None) - def test_sigchld_handler(self): - sigchld_handler(None, None) + def test_reap_children(self): + reap_children() os.waitpid = os_waitpid sigchld_handler(None, None) + reap_children() def test_sleep(self): time.sleep = time_sleep sleep(0.01) + + +class TestRetrySleeper(unittest.TestCase): + + def _pass(self): + pass + + def _fail(self, times=1): + scope = dict(times=0) + + def inner(): + if scope['times'] >= times: + pass + else: + scope['times'] += 1 + raise DCSError('Failed!') + return inner + + def _makeOne(self, *args, **kwargs): + return Retry(*args, **kwargs) + + def test_reset(self): + retry = self._makeOne(delay=0, max_tries=2) + retry(self._fail()) + self.assertEquals(retry._attempts, 1) + retry.reset() + self.assertEquals(retry._attempts, 0) + + def test_too_many_tries(self): + retry = self._makeOne(delay=0) + self.assertRaises(RetryFailedError, retry, self._fail(times=999)) + self.assertEquals(retry._attempts, 1) + + def test_maximum_delay(self): + def sleep_func(_time): + pass + + retry = self._makeOne(delay=10, max_tries=100, sleep_func=sleep_func) + retry(self._fail(times=10)) + self.assertTrue(retry._cur_delay < 4000, retry._cur_delay) + # gevent's sleep function is picky about the type + self.assertEquals(type(retry._cur_delay), float) + + def test_deadline(self): + def sleep_func(_time): + pass + + retry = self._makeOne(deadline=0.0001, sleep_func=sleep_func) + self.assertRaises(RetryFailedError, retry, self._fail(times=100)) + + def test_copy(self): + _sleep = lambda t: None + retry = self._makeOne(sleep_func=_sleep) + rcopy = retry.copy() + self.assertTrue(rcopy.sleep_func is _sleep) diff --git a/tests/test_zookeeper.py b/tests/test_zookeeper.py index 1333897ab..af2115cac 100644 --- a/tests/test_zookeeper.py +++ b/tests/test_zookeeper.py @@ -1,8 +1,9 @@ -import patroni.helpers.zookeeper +import patroni.zookeeper import requests import unittest -from patroni.helpers.zookeeper import ExhibitorEnsembleProvider, ZooKeeper, ZooKeeperError +from patroni.dcs import Leader +from patroni.zookeeper import ExhibitorEnsembleProvider, ZooKeeper, ZooKeeperError from kazoo.client import KazooState from kazoo.exceptions import NoNodeError, NodeExistsError from kazoo.protocol.states import ZnodeStat @@ -30,6 +31,10 @@ def event_object(self): return MockEvent() +class SleepException(Exception): + pass + + class MockKazooClient: def __init__(self, **kwargs): @@ -94,7 +99,7 @@ def set_hosts(self, hosts, randomize_hosts=None): def exhibitor_sleep(_): - raise Exception + raise SleepException class TestExhibitorEnsembleProvider(unittest.TestCase): @@ -105,10 +110,10 @@ def __init__(self, method_name='runTest'): def set_up(self): requests.get = requests_get - patroni.helpers.zookeeper.sleep = exhibitor_sleep + patroni.zookeeper.sleep = exhibitor_sleep def test_init(self): - self.assertRaises(Exception, ExhibitorEnsembleProvider, ['localhost'], 8181) + self.assertRaises(SleepException, ExhibitorEnsembleProvider, ['localhost'], 8181) class TestZooKeeper(unittest.TestCase): @@ -119,7 +124,7 @@ def __init__(self, method_name='runTest'): def set_up(self): requests.get = requests_get - patroni.helpers.zookeeper.KazooClient = MockKazooClient + patroni.zookeeper.KazooClient = MockKazooClient self.zk = ZooKeeper('foo', {'exhibitor': {'hosts': ['localhost', 'exhibitor'], 'port': 8181}, 'scope': 'test'}) def test_session_listener(self): @@ -136,7 +141,8 @@ def test__inner_load_cluster(self): def test_get_cluster(self): self.assertRaises(ZooKeeperError, self.zk.get_cluster) self.zk.exhibitor.poll = lambda: True - self.zk.get_cluster() + cluster = self.zk.get_cluster() + self.assertIsInstance(cluster.leader, Leader) self.zk.touch_member('foo') self.zk.delete_leader() @@ -158,5 +164,5 @@ def test_update_leader(self): self.zk.last_leader_operation = -1 self.assertTrue(self.zk.update_leader(MockPostgresql())) - def test_sleep(self): - self.zk.sleep(0) + def test_watch(self): + self.zk.watch(0)