Permalink
Browse files

feat: ability to customize socket factory (e.g. for unix domain sockets)

This adds socketFactory and socketFactoryArg connection parameters that can be used to customize socket factory

closes #457
  • Loading branch information...
mtran authored and vlsi committed Dec 26, 2015
1 parent 8fca8b4 commit dc1844c21efbb4a840347d5aaa991384e8883b69
@@ -205,6 +205,16 @@
*/
SOCKET_TIMEOUT("socketTimeout", "0", "The timeout value used for socket read operations."),
/**
* Socket factory used to create socket. A null value, which is the default, means system default.
*/
SOCKET_FACTORY("socketFactory", null, "Specify a socket factory for socket creation"),
/**
* The String argument to give to the constructor of the Socket Factory
*/
SOCKET_FACTORY_ARG("socketFactoryArg", null, "Argument forwarded to constructor of SocketFactory class."),
/**
* Socket read buffer size (SO_RECVBUF). A value of {@code -1}, which is the default, means system default.
*/
@@ -17,7 +17,7 @@
import java.net.InetSocketAddress;
import java.net.Socket;
import java.sql.SQLException;
import javax.net.SocketFactory;
import org.postgresql.util.GT;
import org.postgresql.util.HostSpec;
import org.postgresql.util.PSQLState;
@@ -32,6 +32,7 @@
*/
public class PGStream
{
private final SocketFactory socketFactory;
private final HostSpec hostSpec;
private final byte[] _int4buf;
@@ -49,15 +50,17 @@
* Constructor: Connect to the PostgreSQL back end and return
* a stream connection.
*
* @param socketFactory socket factory to use when creating sockets
* @param hostSpec the host and port to connect to
* @param timeout timeout in milliseconds, or 0 if no timeout set
* @exception IOException if an IOException occurs below it.
*/
public PGStream(HostSpec hostSpec, int timeout) throws IOException
public PGStream(SocketFactory socketFactory, HostSpec hostSpec, int timeout) throws IOException
{
this.socketFactory = socketFactory;
this.hostSpec = hostSpec;
Socket socket = new Socket();
Socket socket = socketFactory.createSocket();
socket.connect(new InetSocketAddress(hostSpec.getHost(), hostSpec.getPort()), timeout);
changeSocket(socket);
setEncoding(Encoding.getJVMEncoding("UTF-8"));
@@ -72,10 +75,10 @@ public PGStream(HostSpec hostSpec, int timeout) throws IOException
*
* @param hostSpec the host and port to connect to
* @throws IOException if an IOException occurs below it.
* @deprecated use {@link #PGStream(org.postgresql.util.HostSpec, int)}
* @deprecated use {@link #PGStream(SocketFactory, org.postgresql.util.HostSpec, int)}
*/
public PGStream(HostSpec hostSpec) throws IOException {
this(hostSpec, 0);
public PGStream(SocketFactory socketFactory, HostSpec hostSpec) throws IOException {
this(socketFactory, hostSpec, 0);
}
public HostSpec getHostSpec() {
@@ -86,6 +89,10 @@ public Socket getSocket() {
return connection;
}
public SocketFactory getSocketFactory() {
return socketFactory;
}
/**
* Check for pending backend messages without blocking.
* Might return false when there actually are messages
@@ -11,6 +11,7 @@
import java.util.Iterator;
import java.util.Properties;
import java.util.StringTokenizer;
import javax.net.SocketFactory;
import java.sql.SQLException;
import java.sql.SQLWarning;
import java.io.IOException;
@@ -27,6 +28,7 @@
import org.postgresql.util.PSQLState;
import org.postgresql.util.UnixCrypt;
import org.postgresql.util.MD5Digest;
import org.postgresql.util.ObjectFactory;
import org.postgresql.util.GT;
import org.postgresql.util.HostSpec;
@@ -84,6 +86,9 @@ else if ("require".equals(sslmode) || "verify-ca".equals(sslmode) || "verify-ful
throw new PSQLException (GT.tr("Invalid targetServerType value: {0}", info.getProperty("targetServerType")), PSQLState.CONNECTION_UNABLE_TO_CONNECT);
}
// Socket factory
SocketFactory socketFactory = SocketFactoryFactory.getSocketFactory(info);
HostChooser hostChooser = HostChooserFactory.createHostChooser(hostSpecs, targetServerType, info);
for (Iterator<HostSpec> hostIter = hostChooser.iterator(); hostIter.hasNext(); ) {
HostSpec hostSpec = hostIter.next();
@@ -99,7 +104,7 @@ else if ("require".equals(sslmode) || "verify-ca".equals(sslmode) || "verify-ful
PGStream newStream = null;
try
{
newStream = new PGStream(hostSpec, connectTimeout);
newStream = new PGStream(socketFactory, hostSpec, connectTimeout);
// Construct and send an ssl startup packet if requested.
if (trySSL)
@@ -205,7 +210,7 @@ private PGStream enableSSL(PGStream pgStream, boolean requireSSL, Properties inf
// We have to reconnect to continue.
pgStream.close();
return new PGStream(pgStream.getHostSpec(), connectTimeout);
return new PGStream(pgStream.getSocketFactory(), pgStream.getHostSpec(), connectTimeout);
case 'N':
if (logger.logDebug())
@@ -94,7 +94,7 @@ public void sendQueryCancel() throws SQLException {
if (logger.logDebug())
logger.debug(" FE=> CancelRequest(pid=" + cancelPid + ",ckey=" + cancelKey + ")");
cancelStream = new PGStream(pgStream.getHostSpec(), connectTimeout);
cancelStream = new PGStream(pgStream.getSocketFactory(), pgStream.getHostSpec(), connectTimeout);
cancelStream.SendInteger4(16);
cancelStream.SendInteger2(1234);
cancelStream.SendInteger2(5678);
@@ -0,0 +1,36 @@
package org.postgresql.core.v2;
import org.postgresql.PGProperty;
import org.postgresql.util.GT;
import org.postgresql.util.ObjectFactory;
import org.postgresql.util.PSQLException;
import org.postgresql.util.PSQLState;
import java.util.Properties;
import javax.net.SocketFactory;
/**
* Instantiates {@link SocketFactory} based on the {@link PGProperty#SOCKET_FACTORY}.
*/
public class SocketFactoryFactory {
/**
* Instantiates {@link SocketFactory} based on the {@link PGProperty#SOCKET_FACTORY}
* @param info connection properties
* @return socket factory
* @throws PSQLException if something goes wrong
*/
public static SocketFactory getSocketFactory(Properties info) throws PSQLException {
// Socket factory
String socketFactoryClassName = PGProperty.SOCKET_FACTORY.get(info);
if (socketFactoryClassName == null) {
return SocketFactory.getDefault();
}
try {
return (SocketFactory) ObjectFactory.instantiate(socketFactoryClassName, info, true, PGProperty.SOCKET_FACTORY_ARG.get(info));
} catch (Exception e) {
throw new PSQLException(GT.tr("The SocketFactory class provided {0} could not be instantiated.", socketFactoryClassName), PSQLState.CONNECTION_FAILURE, e);
}
}
}
@@ -16,7 +16,7 @@
import java.util.List;
import java.util.Properties;
import java.util.TimeZone;
import javax.net.SocketFactory;
import org.postgresql.PGProperty;
import org.postgresql.core.ConnectionFactory;
import org.postgresql.core.Encoding;
@@ -27,6 +27,7 @@
import org.postgresql.core.SetupQueryRunner;
import org.postgresql.core.Utils;
import org.postgresql.core.Version;
import org.postgresql.core.v2.SocketFactoryFactory;
import org.postgresql.hostchooser.GlobalHostStatusTracker;
import org.postgresql.hostchooser.HostChooser;
import org.postgresql.hostchooser.HostChooserFactory;
@@ -36,6 +37,7 @@
import org.postgresql.util.GT;
import org.postgresql.util.HostSpec;
import org.postgresql.util.MD5Digest;
import org.postgresql.util.ObjectFactory;
import org.postgresql.util.PSQLException;
import org.postgresql.util.PSQLState;
import org.postgresql.util.PSQLWarning;
@@ -112,6 +114,8 @@ else if ("require".equals(sslmode) || "verify-ca".equals(sslmode) || "verify-ful
throw new PSQLException (GT.tr("Invalid targetServerType value: {0}", info.getProperty("targetServerType")), PSQLState.CONNECTION_UNABLE_TO_CONNECT);
}
SocketFactory socketFactory = SocketFactoryFactory.getSocketFactory(info);
HostChooser hostChooser = HostChooserFactory.createHostChooser(hostSpecs, targetServerType, info);
for (Iterator<HostSpec> hostIter = hostChooser.iterator(); hostIter.hasNext(); ) {
HostSpec hostSpec = hostIter.next();
@@ -126,7 +130,7 @@ else if ("require".equals(sslmode) || "verify-ca".equals(sslmode) || "verify-ful
PGStream newStream = null;
try
{
newStream = new PGStream(hostSpec, connectTimeout);
newStream = new PGStream(socketFactory, hostSpec, connectTimeout);
// Construct and send an ssl startup packet if requested.
if (trySSL)
@@ -319,7 +323,7 @@ private PGStream enableSSL(PGStream pgStream, boolean requireSSL, Properties inf
// We have to reconnect to continue.
pgStream.close();
return new PGStream(pgStream.getHostSpec(), connectTimeout);
return new PGStream(pgStream.getSocketFactory(), pgStream.getHostSpec(), connectTimeout);
case 'N':
if (logger.logDebug())
@@ -96,7 +96,7 @@ public void sendQueryCancel() throws SQLException {
if (logger.logDebug())
logger.debug(" FE=> CancelRequest(pid=" + cancelPid + ",ckey=" + cancelKey + ")");
cancelStream = new PGStream(pgStream.getHostSpec(), connectTimeout);
cancelStream = new PGStream(pgStream.getSocketFactory(), pgStream.getHostSpec(), connectTimeout);
cancelStream.SendInteger4(16);
cancelStream.SendInteger2(1234);
cancelStream.SendInteger2(5678);
@@ -1078,6 +1078,38 @@ public void setAllowEncodingChanges(boolean allow)
PGProperty.ALLOW_ENCODING_CHANGES.set(properties, allow);
}
/**
* @see PGProperty#SOCKET_FACTORY
*/
public String getSocketFactory()
{
return PGProperty.SOCKET_FACTORY.get(properties);
}
/**
* @see PGProperty#SOCKET_FACTORY
*/
public void setSocketFactory(String socketFactoryClassName)
{
PGProperty.SOCKET_FACTORY.set(properties, socketFactoryClassName);
}
/**
* @see PGProperty#SOCKET_FACTORY_ARG
*/
public String getSocketFactoryArg()
{
return PGProperty.SOCKET_FACTORY_ARG.get(properties);
}
/**
* @see PGProperty#SOCKET_FACTORY_ARG
*/
public void setSocketFactoryArg(String socketFactoryArg)
{
PGProperty.SOCKET_FACTORY_ARG.set(properties, socketFactoryArg);
}
/**
* Generates a {@link DriverManager} URL from the other properties supplied.
* @return {@link DriverManager} URL from the other properties supplied
@@ -8,8 +8,6 @@
package org.postgresql.ssl;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.Properties;
import javax.net.ssl.HostnameVerifier;
@@ -21,65 +19,12 @@
import org.postgresql.core.PGStream;
import org.postgresql.ssl.jdbc4.LibPQFactory;
import org.postgresql.util.GT;
import org.postgresql.util.ObjectFactory;
import org.postgresql.util.PSQLException;
import org.postgresql.util.PSQLState;
public class MakeSSL {
public class MakeSSL extends ObjectFactory {
/**
* Instantiates a class using the appropriate constructor.
* If a constructor with a single Propertiesparameter exists, it is
* used. Otherwise, if tryString is true a constructor with
* a single String argument is searched if it fails, or tryString is true
* a no argument constructor is tried.
* @param classname Nam of the class to instantiate
* @param info parameter to pass as Properties
* @param tryString weather to look for a single String argument constructor
* @param stringarg parameter to pass as String
* @return the instantiated class
* @throws ClassNotFoundException if something goes wrong
* @throws SecurityException if something goes wrong
* @throws NoSuchMethodException if something goes wrong
* @throws IllegalArgumentException if something goes wrong
* @throws InstantiationException if something goes wrong
* @throws IllegalAccessException if something goes wrong
* @throws InvocationTargetException if something goes wrong
*/
public static Object instantiate(String classname, Properties info, boolean tryString, String stringarg)
throws ClassNotFoundException, SecurityException, NoSuchMethodException, IllegalArgumentException,
InstantiationException, IllegalAccessException, InvocationTargetException
{
Object[] args = {info};
Constructor<?> ctor = null;
Class<?> cls;
cls = Class.forName(classname);
try
{
ctor = cls.getConstructor(new Class[]{Properties.class});
}
catch (NoSuchMethodException nsme)
{
if (tryString)
{
try
{
ctor = cls.getConstructor(new Class[]{String.class});
args = new String[]{stringarg};
}
catch (NoSuchMethodException nsme2)
{
tryString = false;
}
}
if (!tryString)
{
ctor = cls.getConstructor((Class[])null);
args = null;
}
}
return ctor.newInstance(args);
}
public static void convert(PGStream stream, Properties info, Logger logger) throws PSQLException, IOException {
logger.debug("converting regular socket connection to ssl");
Oops, something went wrong.

0 comments on commit dc1844c

Please sign in to comment.