Skip to content

Commit

Permalink
feat: ability to customize socket factory (e.g. for unix domain sockets)
Browse files Browse the repository at this point in the history
This adds socketFactory and socketFactoryArg connection parameters that can be used to customize socket factory

closes #457
  • Loading branch information
mtran authored and vlsi committed Dec 26, 2015
1 parent 8fca8b4 commit dc1844c
Show file tree
Hide file tree
Showing 12 changed files with 282 additions and 70 deletions.
10 changes: 10 additions & 0 deletions pgjdbc/src/main/java/org/postgresql/PGProperty.java
Expand Up @@ -205,6 +205,16 @@ public enum PGProperty
*/
SOCKET_TIMEOUT("socketTimeout", "0", "The timeout value used for socket read operations."),

/**
* Socket factory used to create socket. A null value, which is the default, means system default.
*/
SOCKET_FACTORY("socketFactory", null, "Specify a socket factory for socket creation"),

/**
* The String argument to give to the constructor of the Socket Factory
*/
SOCKET_FACTORY_ARG("socketFactoryArg", null, "Argument forwarded to constructor of SocketFactory class."),

/**
* Socket read buffer size (SO_RECVBUF). A value of {@code -1}, which is the default, means system default.
*/
Expand Down
19 changes: 13 additions & 6 deletions pgjdbc/src/main/java/org/postgresql/core/PGStream.java
Expand Up @@ -17,7 +17,7 @@
import java.net.InetSocketAddress;
import java.net.Socket;
import java.sql.SQLException;

import javax.net.SocketFactory;
import org.postgresql.util.GT;
import org.postgresql.util.HostSpec;
import org.postgresql.util.PSQLState;
Expand All @@ -32,6 +32,7 @@
*/
public class PGStream
{
private final SocketFactory socketFactory;
private final HostSpec hostSpec;

private final byte[] _int4buf;
Expand All @@ -49,15 +50,17 @@ public class PGStream
* Constructor: Connect to the PostgreSQL back end and return
* a stream connection.
*
* @param socketFactory socket factory to use when creating sockets
* @param hostSpec the host and port to connect to
* @param timeout timeout in milliseconds, or 0 if no timeout set
* @exception IOException if an IOException occurs below it.
*/
public PGStream(HostSpec hostSpec, int timeout) throws IOException
public PGStream(SocketFactory socketFactory, HostSpec hostSpec, int timeout) throws IOException
{
this.socketFactory = socketFactory;
this.hostSpec = hostSpec;

Socket socket = new Socket();
Socket socket = socketFactory.createSocket();
socket.connect(new InetSocketAddress(hostSpec.getHost(), hostSpec.getPort()), timeout);
changeSocket(socket);
setEncoding(Encoding.getJVMEncoding("UTF-8"));
Expand All @@ -72,10 +75,10 @@ public PGStream(HostSpec hostSpec, int timeout) throws IOException
*
* @param hostSpec the host and port to connect to
* @throws IOException if an IOException occurs below it.
* @deprecated use {@link #PGStream(org.postgresql.util.HostSpec, int)}
* @deprecated use {@link #PGStream(SocketFactory, org.postgresql.util.HostSpec, int)}
*/
public PGStream(HostSpec hostSpec) throws IOException {
this(hostSpec, 0);
public PGStream(SocketFactory socketFactory, HostSpec hostSpec) throws IOException {
this(socketFactory, hostSpec, 0);
}

public HostSpec getHostSpec() {
Expand All @@ -86,6 +89,10 @@ public Socket getSocket() {
return connection;
}

public SocketFactory getSocketFactory() {
return socketFactory;
}

/**
* Check for pending backend messages without blocking.
* Might return false when there actually are messages
Expand Down
Expand Up @@ -11,6 +11,7 @@
import java.util.Iterator;
import java.util.Properties;
import java.util.StringTokenizer;
import javax.net.SocketFactory;
import java.sql.SQLException;
import java.sql.SQLWarning;
import java.io.IOException;
Expand All @@ -27,6 +28,7 @@
import org.postgresql.util.PSQLState;
import org.postgresql.util.UnixCrypt;
import org.postgresql.util.MD5Digest;
import org.postgresql.util.ObjectFactory;
import org.postgresql.util.GT;
import org.postgresql.util.HostSpec;

Expand Down Expand Up @@ -84,6 +86,9 @@ else if ("require".equals(sslmode) || "verify-ca".equals(sslmode) || "verify-ful
throw new PSQLException (GT.tr("Invalid targetServerType value: {0}", info.getProperty("targetServerType")), PSQLState.CONNECTION_UNABLE_TO_CONNECT);
}

// Socket factory
SocketFactory socketFactory = SocketFactoryFactory.getSocketFactory(info);

HostChooser hostChooser = HostChooserFactory.createHostChooser(hostSpecs, targetServerType, info);
for (Iterator<HostSpec> hostIter = hostChooser.iterator(); hostIter.hasNext(); ) {
HostSpec hostSpec = hostIter.next();
Expand All @@ -99,7 +104,7 @@ else if ("require".equals(sslmode) || "verify-ca".equals(sslmode) || "verify-ful
PGStream newStream = null;
try
{
newStream = new PGStream(hostSpec, connectTimeout);
newStream = new PGStream(socketFactory, hostSpec, connectTimeout);

// Construct and send an ssl startup packet if requested.
if (trySSL)
Expand Down Expand Up @@ -205,7 +210,7 @@ private PGStream enableSSL(PGStream pgStream, boolean requireSSL, Properties inf

// We have to reconnect to continue.
pgStream.close();
return new PGStream(pgStream.getHostSpec(), connectTimeout);
return new PGStream(pgStream.getSocketFactory(), pgStream.getHostSpec(), connectTimeout);

case 'N':
if (logger.logDebug())
Expand Down
Expand Up @@ -94,7 +94,7 @@ public void sendQueryCancel() throws SQLException {
if (logger.logDebug())
logger.debug(" FE=> CancelRequest(pid=" + cancelPid + ",ckey=" + cancelKey + ")");

cancelStream = new PGStream(pgStream.getHostSpec(), connectTimeout);
cancelStream = new PGStream(pgStream.getSocketFactory(), pgStream.getHostSpec(), connectTimeout);
cancelStream.SendInteger4(16);
cancelStream.SendInteger2(1234);
cancelStream.SendInteger2(5678);
Expand Down
@@ -0,0 +1,36 @@
package org.postgresql.core.v2;

import org.postgresql.PGProperty;
import org.postgresql.util.GT;
import org.postgresql.util.ObjectFactory;
import org.postgresql.util.PSQLException;
import org.postgresql.util.PSQLState;

import java.util.Properties;
import javax.net.SocketFactory;

/**
* Instantiates {@link SocketFactory} based on the {@link PGProperty#SOCKET_FACTORY}.
*/
public class SocketFactoryFactory {

/**
* Instantiates {@link SocketFactory} based on the {@link PGProperty#SOCKET_FACTORY}
* @param info connection properties
* @return socket factory
* @throws PSQLException if something goes wrong
*/
public static SocketFactory getSocketFactory(Properties info) throws PSQLException {
// Socket factory
String socketFactoryClassName = PGProperty.SOCKET_FACTORY.get(info);
if (socketFactoryClassName == null) {
return SocketFactory.getDefault();
}
try {
return (SocketFactory) ObjectFactory.instantiate(socketFactoryClassName, info, true, PGProperty.SOCKET_FACTORY_ARG.get(info));
} catch (Exception e) {
throw new PSQLException(GT.tr("The SocketFactory class provided {0} could not be instantiated.", socketFactoryClassName), PSQLState.CONNECTION_FAILURE, e);
}
}

}
Expand Up @@ -16,7 +16,7 @@
import java.util.List;
import java.util.Properties;
import java.util.TimeZone;

import javax.net.SocketFactory;
import org.postgresql.PGProperty;
import org.postgresql.core.ConnectionFactory;
import org.postgresql.core.Encoding;
Expand All @@ -27,6 +27,7 @@
import org.postgresql.core.SetupQueryRunner;
import org.postgresql.core.Utils;
import org.postgresql.core.Version;
import org.postgresql.core.v2.SocketFactoryFactory;
import org.postgresql.hostchooser.GlobalHostStatusTracker;
import org.postgresql.hostchooser.HostChooser;
import org.postgresql.hostchooser.HostChooserFactory;
Expand All @@ -36,6 +37,7 @@
import org.postgresql.util.GT;
import org.postgresql.util.HostSpec;
import org.postgresql.util.MD5Digest;
import org.postgresql.util.ObjectFactory;
import org.postgresql.util.PSQLException;
import org.postgresql.util.PSQLState;
import org.postgresql.util.PSQLWarning;
Expand Down Expand Up @@ -112,6 +114,8 @@ else if ("require".equals(sslmode) || "verify-ca".equals(sslmode) || "verify-ful
throw new PSQLException (GT.tr("Invalid targetServerType value: {0}", info.getProperty("targetServerType")), PSQLState.CONNECTION_UNABLE_TO_CONNECT);
}

SocketFactory socketFactory = SocketFactoryFactory.getSocketFactory(info);

HostChooser hostChooser = HostChooserFactory.createHostChooser(hostSpecs, targetServerType, info);
for (Iterator<HostSpec> hostIter = hostChooser.iterator(); hostIter.hasNext(); ) {
HostSpec hostSpec = hostIter.next();
Expand All @@ -126,7 +130,7 @@ else if ("require".equals(sslmode) || "verify-ca".equals(sslmode) || "verify-ful
PGStream newStream = null;
try
{
newStream = new PGStream(hostSpec, connectTimeout);
newStream = new PGStream(socketFactory, hostSpec, connectTimeout);

// Construct and send an ssl startup packet if requested.
if (trySSL)
Expand Down Expand Up @@ -319,7 +323,7 @@ private PGStream enableSSL(PGStream pgStream, boolean requireSSL, Properties inf

// We have to reconnect to continue.
pgStream.close();
return new PGStream(pgStream.getHostSpec(), connectTimeout);
return new PGStream(pgStream.getSocketFactory(), pgStream.getHostSpec(), connectTimeout);

case 'N':
if (logger.logDebug())
Expand Down
Expand Up @@ -96,7 +96,7 @@ public void sendQueryCancel() throws SQLException {
if (logger.logDebug())
logger.debug(" FE=> CancelRequest(pid=" + cancelPid + ",ckey=" + cancelKey + ")");

cancelStream = new PGStream(pgStream.getHostSpec(), connectTimeout);
cancelStream = new PGStream(pgStream.getSocketFactory(), pgStream.getHostSpec(), connectTimeout);
cancelStream.SendInteger4(16);
cancelStream.SendInteger2(1234);
cancelStream.SendInteger2(5678);
Expand Down
32 changes: 32 additions & 0 deletions pgjdbc/src/main/java/org/postgresql/ds/common/BaseDataSource.java
Expand Up @@ -1078,6 +1078,38 @@ public void setAllowEncodingChanges(boolean allow)
PGProperty.ALLOW_ENCODING_CHANGES.set(properties, allow);
}

/**
* @see PGProperty#SOCKET_FACTORY
*/
public String getSocketFactory()
{
return PGProperty.SOCKET_FACTORY.get(properties);
}

/**
* @see PGProperty#SOCKET_FACTORY
*/
public void setSocketFactory(String socketFactoryClassName)
{
PGProperty.SOCKET_FACTORY.set(properties, socketFactoryClassName);
}

/**
* @see PGProperty#SOCKET_FACTORY_ARG
*/
public String getSocketFactoryArg()
{
return PGProperty.SOCKET_FACTORY_ARG.get(properties);
}

/**
* @see PGProperty#SOCKET_FACTORY_ARG
*/
public void setSocketFactoryArg(String socketFactoryArg)
{
PGProperty.SOCKET_FACTORY_ARG.set(properties, socketFactoryArg);
}

/**
* Generates a {@link DriverManager} URL from the other properties supplied.
* @return {@link DriverManager} URL from the other properties supplied
Expand Down
59 changes: 2 additions & 57 deletions pgjdbc/src/main/java/org/postgresql/ssl/MakeSSL.java
Expand Up @@ -8,8 +8,6 @@
package org.postgresql.ssl;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.Properties;

import javax.net.ssl.HostnameVerifier;
Expand All @@ -21,65 +19,12 @@
import org.postgresql.core.PGStream;
import org.postgresql.ssl.jdbc4.LibPQFactory;
import org.postgresql.util.GT;
import org.postgresql.util.ObjectFactory;
import org.postgresql.util.PSQLException;
import org.postgresql.util.PSQLState;

public class MakeSSL {
public class MakeSSL extends ObjectFactory {

/**
* Instantiates a class using the appropriate constructor.
* If a constructor with a single Propertiesparameter exists, it is
* used. Otherwise, if tryString is true a constructor with
* a single String argument is searched if it fails, or tryString is true
* a no argument constructor is tried.
* @param classname Nam of the class to instantiate
* @param info parameter to pass as Properties
* @param tryString weather to look for a single String argument constructor
* @param stringarg parameter to pass as String
* @return the instantiated class
* @throws ClassNotFoundException if something goes wrong
* @throws SecurityException if something goes wrong
* @throws NoSuchMethodException if something goes wrong
* @throws IllegalArgumentException if something goes wrong
* @throws InstantiationException if something goes wrong
* @throws IllegalAccessException if something goes wrong
* @throws InvocationTargetException if something goes wrong
*/
public static Object instantiate(String classname, Properties info, boolean tryString, String stringarg)
throws ClassNotFoundException, SecurityException, NoSuchMethodException, IllegalArgumentException,
InstantiationException, IllegalAccessException, InvocationTargetException
{
Object[] args = {info};
Constructor<?> ctor = null;
Class<?> cls;
cls = Class.forName(classname);
try
{
ctor = cls.getConstructor(new Class[]{Properties.class});
}
catch (NoSuchMethodException nsme)
{
if (tryString)
{
try
{
ctor = cls.getConstructor(new Class[]{String.class});
args = new String[]{stringarg};
}
catch (NoSuchMethodException nsme2)
{
tryString = false;
}
}
if (!tryString)
{
ctor = cls.getConstructor((Class[])null);
args = null;
}
}
return ctor.newInstance(args);
}

public static void convert(PGStream stream, Properties info, Logger logger) throws PSQLException, IOException {
logger.debug("converting regular socket connection to ssl");

Expand Down

0 comments on commit dc1844c

Please sign in to comment.