Skip to content

Commit

Permalink
Support unix sockets when connecting to a local postgres cluster
Browse files Browse the repository at this point in the history
This feature is enabled by default. If the `unix_socket_directories` is
defined and non empty, Patroni will use the first value from it to
connect to the local postgres cluster. Also unix socket will be used
when running `post_bootstrap` (`post_init`) script.

Set `postgresql.use_unix_socket: false` if you want to disable it.

Solves: #61

In addition to mentioned above, this commit solves couple of bugs:
* manual failover with pg_rewind in a pause state was broken
* psycopg2 (or libpq, I am not really sure what exactly) doesn't mark
  cusros connection as closed when we use unix socket and there is an
  OperationalError occurs. We will close such connection on our own.
  • Loading branch information
Alexander Kukushkin committed Jun 21, 2017
1 parent 3fee62c commit 2930ba4
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 30 deletions.
3 changes: 2 additions & 1 deletion docs/SETTINGS.rst
Expand Up @@ -36,7 +36,7 @@ Bootstrap configuration
- **options**: list of options for CREATE USER statement
- **- createrole**
- **- createdb**
- **post_init**: An additional script that will be executed after initializing the cluster. The script receives a connection string URL (with the cluster superuser as a user name). The PGPASSFILE variable is set to the location of pgpass file.
- **post\_bootstrap** or **post\_init**: An additional script that will be executed after initializing the cluster. The script receives a connection string URL (with the cluster superuser as a user name). The PGPASSFILE variable is set to the location of pgpass file.

Consul
------
Expand Down Expand Up @@ -84,6 +84,7 @@ PostgreSQL
- **data\_dir**: The location of the Postgres data directory, either existing or to be initialized by Patroni.
- **bin\_dir**: Path to PostgreSQL binaries. (pg_ctl, pg_rewind, pg_basebackup, postgres) The default value is an empty string meaning that PATH environment variable will be used to find the executables.
- **listen**: IP address + port that Postgres listens to; must be accessible from other nodes in the cluster, if you're using streaming replication. Multiple comma-separated addresses are permitted, as long as the port component is appended after to the last one with a colon, i.e. ``listen: 127.0.0.1,127.0.0.2:5432``. Patroni will use the first address from this list to establish local connections to the PostgreSQL node.
- **use\_unix\_socket**: specifies that Patroni should prefer to use unix sockets to connect to the cluster. Default value is ``true``. If ``unix_socket_directories`` is definded (and non empty), Patroni will use first value from it to connect to the cluster.
- **pgpass**: path to the `.pgpass <https://www.postgresql.org/docs/current/static/libpq-pgpass.html>`__ password file. Patroni creates this file before executing pg\_basebackup, the post_init script and under some other circumstances. The location must be writable by Patroni.
- **recovery\_conf**: additional configuration settings written to recovery.conf when configuring follower.
- **custom_conf** : path to an optional custom ``postgresql.conf`` file, that will be used in place of ``postgresql.base.conf``. The file must exist on all cluster nodes, be readable by PostgreSQL and will be included from its location on the real ``postgresql.conf``. Note that Patroni will not monitor this file for changes, nor backup it. However, its settings can still be overriden by Patroni's own configuration facilities - see `dynamic configuration <https://github.com/zalando/patroni/blob/master/docs/dynamic_configuration.rst>`__ for details.
Expand Down
9 changes: 6 additions & 3 deletions features/environment.py
Expand Up @@ -161,10 +161,13 @@ def _make_patroni_test_config(self, name, tags, custom_config):
config['postgresql']['data_dir'] = self._data_dir
config['postgresql']['parameters'].update({
'logging_collector': 'on', 'log_destination': 'csvlog', 'log_directory': self._output_dir,
'log_filename': name + '.log', 'log_statement': 'all', 'log_min_messages': 'debug1'})
'log_filename': name + '.log', 'log_statement': 'all', 'log_min_messages': 'debug1',
'unix_socket_directories': self._data_dir})

if 'bootstrap' in config and 'initdb' in config['bootstrap']:
config['bootstrap']['initdb'].extend([{'auth': 'md5'}, {'auth-host': 'md5'}])
if 'bootstrap' in config:
config['bootstrap']['post_bootstrap'] = 'psql -w -c "SELECT 1"'
if 'initdb' in config['bootstrap']:
config['bootstrap']['initdb'].extend([{'auth': 'md5'}, {'auth-host': 'md5'}])

if tags:
config['tags'] = tags
Expand Down
3 changes: 2 additions & 1 deletion patroni/ha.py
Expand Up @@ -1047,7 +1047,8 @@ def _run_cycle(self):
self.dcs.delete_leader()
self.dcs.reset_cluster()
return 'removed leader lock because postgres is not running'
elif not (self.state_handler.need_rewind and self.state_handler.can_rewind):
elif not (self.state_handler.rewind_executed or
self.state_handler.need_rewind and self.state_handler.can_rewind):
return 'postgres is not running'

# try to start dead postgres
Expand Down
61 changes: 42 additions & 19 deletions patroni/postgresql.py
Expand Up @@ -18,6 +18,7 @@
from patroni.exceptions import PostgresConnectionException, PostgresException
from patroni.utils import compare_values, parse_bool, parse_int, Retry, RetryFailedError, polling_loop, null_context
from six import string_types
from six.moves.urllib.parse import quote_plus
from threading import current_thread, Lock, Event

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -220,9 +221,14 @@ def get_server_parameters(self, config):
self._major_version >= self.CMDLINE_OPTIONS.get(k, (0, 1, 90100))[2]}

def resolve_connection_addresses(self):
self._local_address = self.get_local_address()
port = self._server_parameters['port']
tcp_local_address = self._get_tcp_local_address()

self._local_address = {'host': self._get_unix_local_address() or tcp_local_address, 'port': port}
self._local_replication_address = {'host': tcp_local_address, 'port': port}

self.connection_string = 'postgres://{0}/{1}'.format(
self._connect_address or self._local_address['host'] + ':' + self._local_address['port'], self._database)
self._connect_address or tcp_local_address + ':' + port, self._database)

def _pgcommand(self, cmd):
"""Returns path to the specified PostgreSQL command"""
Expand Down Expand Up @@ -260,7 +266,7 @@ def reload_config(self, config):
self._superuser = config['authentication'].get('superuser', {})
server_parameters = self.get_server_parameters(config)

listen_address_changed = pending_reload = pending_restart = False
local_connection_address_changed = pending_reload = pending_restart = False
if self.state == 'running':
changes = {p: v for p, v in server_parameters.items() if '.' not in p}
changes.update({p: None for p, v in self._server_parameters.items() if not ('.' in p or p in changes)})
Expand All @@ -284,8 +290,9 @@ def reload_config(self, config):
if new_value is None or not compare_values(r[3], unit, r[1], new_value):
if r[4] == 'postmaster':
pending_restart = True
if r[0] in ('listen_addresses', 'port'):
listen_address_changed = True
if r[0] in ('listen_addresses', 'port') or\
config.get('use_unix_socket', True) and r[0] == 'unix_socket_directories':
local_connection_address_changed = True
else:
pending_reload = True
for param in changes:
Expand All @@ -310,7 +317,7 @@ def reload_config(self, config):
self._server_parameters = server_parameters
self._connect_address = config.get('connect_address')

if not listen_address_changed:
if not local_connection_address_changed:
self.resolve_connection_addresses()

if pending_reload:
Expand Down Expand Up @@ -352,15 +359,19 @@ def sysid(self):
self._sysid = data.get('Database system identifier', "")
return self._sysid

def get_local_address(self):
def _get_unix_local_address(self):
for d in self._server_parameters.get('unix_socket_directories', '').split(','):
d = d.strip()
if d.startswith('/'): # Only absolute path can be used to connect via unix-socket
return d

def _get_tcp_local_address(self):
listen_addresses = self._server_parameters['listen_addresses'].split(',')
local_address = listen_addresses[0].strip() # take first address from listen_addresses

for la in listen_addresses:
if la.strip().lower() in ('*', '0.0.0.0', '127.0.0.1', 'localhost'): # we are listening on '*' or localhost
local_address = 'localhost' # connection via localhost is preferred
break
return {'host': local_address, 'port': self._server_parameters['port']}
return 'localhost' # connection via localhost is preferred
return listen_addresses[0].strip() # can't use localhost, take first address from listen_addresses

def get_postgres_role_from_data_directory(self):
if self.data_directory_empty():
Expand Down Expand Up @@ -414,7 +425,14 @@ def _query(self, sql, *params):
return cursor
except psycopg2.Error as e:
if cursor and cursor.connection.closed == 0:
raise e
# When connected via unix socket, psycopg2 can't recoginze 'connection lost'
# and leaves `_cursor_holder.connection.closed == 0`, but psycopg2.OperationalError
# is still raised (what is correct). It doesn't make sense to continiue with existing
# connection and we will close it, to avoid its reuse by the `_cursor` method.
if isinstance(e, psycopg2.OperationalError):
self.close_connection()
else:
raise e
if self.state == 'restarting':
raise RetryFailedError('cluster is being restarted')
raise PostgresConnectionException('connection problems')
Expand Down Expand Up @@ -476,21 +494,26 @@ def _initialize(self, config):

def run_bootstrap_post_init(self, config):
"""
runs a script after initdb is called and waits until completion.
passed: cluster name, parameters
runs a script after initdb or custom bootstrap script is called and waits until completion.
"""
if 'post_init' in config:
cmd = config['post_init']
cmd = config.get('post_bootstrap') or config.get('post_init')
if cmd:
r = self._local_connect_kwargs

# '/tmp' => '%2Ftmp' for unix socket path
host = quote_plus(r['host']) if r['host'].startswith('/') else r['host']

if 'user' in r:
connstring = 'postgres://{user}@{host}:{port}/{database}'.format(**r)
user = r['user'] + '@'
else:
connstring = 'postgres://{host}:{port}/{database}'.format(**r)
user = ''
if 'password' in r:
import getpass
r.setdefault('user', os.environ.get('PGUSER', getpass.getuser()))

connstring = 'postgres://{0}{1}:{2}/{3}'.format(user, host, r['port'], r['database'])
env = self.write_pgpass(r) if 'password' in r else None

try:
ret = subprocess.call(shlex.split(cmd) + [connstring], env=env)
except OSError:
Expand Down Expand Up @@ -1115,7 +1138,7 @@ def _get_local_timeline_lsn(self):
timeline = lsn = None
if self.is_running(): # if postgres is running - get timeline and lsn from replication connection
try:
with self._get_replication_connection_cursor(**self._local_address) as cur:
with self._get_replication_connection_cursor(**self._local_replication_address) as cur:
cur.execute('IDENTIFY_SYSTEM')
timeline, lsn = cur.fetchone()[1:3]
except Exception:
Expand Down
14 changes: 8 additions & 6 deletions tests/test_postgresql.py
Expand Up @@ -23,7 +23,9 @@ def __init__(self, connection):
self.results = []

def execute(self, sql, *params):
if sql.startswith('blabla') or sql == 'CHECKPOINT':
if sql.startswith('blabla'):
raise psycopg2.ProgrammingError()
elif sql == 'CHECKPOINT':
raise psycopg2.OperationalError()
elif sql.startswith('RetryFailedError'):
raise RetryFailedError('retry')
Expand Down Expand Up @@ -156,7 +158,7 @@ class TestPostgresql(unittest.TestCase):
'search_path': 'public', 'hot_standby': 'on', 'max_wal_senders': 5,
'wal_keep_segments': 8, 'wal_log_hints': 'on', 'max_locks_per_transaction': 64,
'max_worker_processes': 8, 'max_connections': 100, 'max_prepared_transactions': 0,
'track_commit_timestamp': 'off'}
'track_commit_timestamp': 'off', 'unix_socket_directories': '/tmp'}

@patch('subprocess.call', Mock(return_value=0))
@patch('psycopg2.connect', psycopg2_connect)
Expand All @@ -168,7 +170,7 @@ def setUp(self):
if not os.path.exists(self.data_dir):
os.makedirs(self.data_dir)
self.p = Postgresql({'name': 'test0', 'scope': 'batman', 'data_dir': self.data_dir, 'retry_timeout': 10,
'listen': '127.0.0.1, *:5432', 'connect_address': '127.0.0.2:5432',
'listen': '127.0.0.2, 127.0.0.3:5432', 'connect_address': '127.0.0.2:5432',
'authentication': {'superuser': {'username': 'test', 'password': 'test'},
'replication': {'username': 'replicator', 'password': 'rep-pass'}},
'remove_data_directory_on_rewind_failure': True,
Expand Down Expand Up @@ -420,7 +422,7 @@ def test_sync_replication_slots(self):
assert "test-3" in errorlog_mock.call_args[0][1]
assert "test.3" in errorlog_mock.call_args[0][1]

@patch.object(MockConnect, 'closed', 2)
@patch.object(MockCursor, 'execute', Mock(side_effect=psycopg2.OperationalError))
def test__query(self):
self.assertRaises(PostgresConnectionException, self.p._query, 'blabla')
self.p._state = 'restarting'
Expand All @@ -429,7 +431,7 @@ def test__query(self):
def test_query(self):
self.p.query('select 1')
self.assertRaises(PostgresConnectionException, self.p.query, 'RetryFailedError')
self.assertRaises(psycopg2.OperationalError, self.p.query, 'blabla')
self.assertRaises(psycopg2.ProgrammingError, self.p.query, 'blabla')

@patch.object(Postgresql, 'pg_isready', Mock(return_value=STATE_REJECT))
def test_is_leader(self):
Expand Down Expand Up @@ -519,7 +521,7 @@ def test_run_bootstrap_post_init(self):
mock_method.assert_called()
args, kwargs = mock_method.call_args
assert 'PGPASSFILE' in kwargs['env'].keys()
self.assertEquals(args[0], ['/bin/false', 'postgres://localhost:5432/postgres'])
self.assertEquals(args[0], ['/bin/false', 'postgres://%2Ftmp:5432/postgres'])

@patch('patroni.postgresql.Postgresql.create_replica', Mock(return_value=0))
def test_clone(self):
Expand Down

0 comments on commit 2930ba4

Please sign in to comment.