Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WFSSL-16] Add support for TLS 1.3 #81

Merged
merged 6 commits into from Sep 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 38 additions & 5 deletions java/pom.xml
Expand Up @@ -74,11 +74,44 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>2.3.2</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
<version>3.8.0-jboss-2</version>
<executions>
<execution>
<id>default-compile</id>
<phase>compile</phase>
<goals>
<goal>compile</goal>
</goals>
<configuration>
<release>8</release>
<buildDirectory>${project.build.directory}</buildDirectory>
<compileSourceRoots>${project.compileSourceRoots}</compileSourceRoots>
<outputDirectory>${project.build.outputDirectory}</outputDirectory>
<additionalClasspathElements>
<additionalClasspathElement>${project.build.directory}/jdk-misc.jar</additionalClasspathElement>
</additionalClasspathElements>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
<executions>
<execution>
<id>fetch-misc</id>
<phase>generate-sources</phase>
<goals>
<goal>get</goal>
<goal>copy</goal>
</goals>
<configuration>
<artifact>org.jboss:jdk-misc:2.Final</artifact>
<outputDirectory>${project.build.directory}</outputDirectory>
<stripVersion>true</stripVersion>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
Expand Down
87 changes: 57 additions & 30 deletions java/src/main/java/org/wildfly/openssl/CipherSuiteConverter.java
Expand Up @@ -17,6 +17,9 @@

package org.wildfly.openssl;

import static java.util.Collections.singletonMap;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -81,9 +84,8 @@ public final class CipherSuiteConverter {
private static final Pattern JAVA_AES_PATTERN = Pattern.compile("^(AES)_([0-9]+)_(.*)$");
private static final Pattern OPENSSL_AES_CBC_PATTERN = Pattern.compile("^(AES)([0-9]+)$");
private static final Pattern OPENSSL_AES_PATTERN = Pattern.compile("^(AES)([0-9]+)-(.*)$");
// Covers TLSv1.3 ciphers and help avoid weird behaviours like what happens in
// BasicOpenSSLEngineTest#testWrongClientSideTrustManagerFailsValidation:70
private static final Pattern OPENSSL_TLSv13_PATTERN = Pattern.compile("^(TLS)_(AES|CHACHA20)_(POLY1305|[0-9]+)_(.*)$");
// There are only 5 supported cipher suites for TLS 1.3, including them directly here
private static final Pattern OPENSSL_TLSv13_PATTERN = Pattern.compile("^(TLS_AES_128_GCM_SHA256|TLS_AES_256_GCM_SHA384|TLS_CHACHA20_POLY1305_SHA256|TLS_AES_128_CCM_SHA256|TLS_AES_128_CCM_8_SHA256)$");

/**
* Java-to-OpenSSL cipher suite conversion map
Expand All @@ -98,6 +100,31 @@ public final class CipherSuiteConverter {
*/
private static final ConcurrentMap<String, Map<String, String>> o2j = new ConcurrentHashMap<>();

/**
* Cipher suite conversion maps for TLS 1.3. The Java and OpenSSL cipher suite names
* are the same.
*/
private static final Map<String, String> j2oTls13;
private static final Map<String, Map<String, String>> o2jTls13;

static {
Map<String, String> j2oTls13Map = new HashMap<>();
j2oTls13Map.put("TLS_AES_128_GCM_SHA256", "TLS_AES_128_GCM_SHA256");
j2oTls13Map.put("TLS_AES_256_GCM_SHA384", "TLS_AES_256_GCM_SHA384");
j2oTls13Map.put("TLS_CHACHA20_POLY1305_SHA256", "TLS_CHACHA20_POLY1305_SHA256");
j2oTls13Map.put("TLS_AES_128_CCM_SHA256", "TLS_AES_128_CCM_SHA256");
j2oTls13Map.put("TLS_AES_128_CCM_8_SHA256", "TLS_AES_128_CCM_8_SHA256");
j2oTls13 = Collections.unmodifiableMap(j2oTls13Map);

Map<String, Map<String, String>> o2jTls13Map = new HashMap<>();
o2jTls13Map.put("TLS_AES_128_GCM_SHA256", singletonMap("TLS", "TLS_AES_128_GCM_SHA256"));
o2jTls13Map.put("TLS_AES_256_GCM_SHA384", singletonMap("TLS", "TLS_AES_256_GCM_SHA384"));
o2jTls13Map.put("TLS_CHACHA20_POLY1305_SHA256", singletonMap("TLS", "TLS_CHACHA20_POLY1305_SHA256"));
o2jTls13Map.put("TLS_AES_128_CCM_SHA256", singletonMap("TLS", "TLS_AES_128_CCM_SHA256"));
o2jTls13Map.put("TLS_AES_128_CCM_8_SHA256", singletonMap("TLS", "TLS_AES_128_CCM_8_SHA256"));
o2jTls13 = Collections.unmodifiableMap(o2jTls13Map);
}

/**
* Clears the cache for testing purpose.
*/
Expand Down Expand Up @@ -158,14 +185,23 @@ public static String toOpenSsl(Iterable<String> javaCipherSuites) {
* @return {@code null} if the conversion has failed
*/
public static String toOpenSsl(String javaCipherSuite) {
String converted = j2o.get(javaCipherSuite);
String converted = javaToOpenSsl(javaCipherSuite);
if (converted != null) {
return converted;
} else {
return cacheFromJava(javaCipherSuite);
}
}

private static String javaToOpenSsl(String javaCipherSuite) {
String converted = j2oTls13.get(javaCipherSuite);
if (converted != null) {
return converted;
} else {
return j2o.get(javaCipherSuite);
}
}

private static String cacheFromJava(String javaCipherSuite) {
String openSslCipherSuite = toOpenSslUncached(javaCipherSuite);
if (openSslCipherSuite == null) {
Expand Down Expand Up @@ -279,7 +315,7 @@ private static String toOpenSslHmacAlgo(String hmacAlgo) {
* @return The translated cipher suite name according to java conventions. This will not be {@code null}.
*/
public static String toJava(String openSslCipherSuite, String protocol) {
Map<String, String> p2j = o2j.get(openSslCipherSuite);
Map<String, String> p2j = toJava(openSslCipherSuite);
if (p2j == null) {
p2j = cacheFromOpenSsl(openSslCipherSuite);
}
Expand All @@ -292,19 +328,23 @@ public static String toJava(String openSslCipherSuite, String protocol) {
return javaCipherSuite;
}

private static Map<String, String> toJava(String openSslCipherSuite) {
Map<String, String> p2j = o2jTls13.get(openSslCipherSuite);
if (p2j != null) {
return p2j;
} else {
return o2j.get(openSslCipherSuite);
}
}

private static Map<String, String> cacheFromOpenSsl(String openSslCipherSuite) {
String javaCipherSuiteSuffix = toJavaUncached(openSslCipherSuite);
if (javaCipherSuiteSuffix == null) {
return null;
}

final String javaCipherSuiteSsl = "SSL_" + javaCipherSuiteSuffix;
final String javaCipherSuiteTls;
if (openSslCipherSuite.startsWith("TLS_")) {
javaCipherSuiteTls = javaCipherSuiteSuffix;
} else {
javaCipherSuiteTls = "TLS_" + javaCipherSuiteSuffix;
}
final String javaCipherSuiteTls = "TLS_" + javaCipherSuiteSuffix;

// Cache the mapping.
final Map<String, String> p2j = new HashMap<>(4);
Expand All @@ -327,11 +367,6 @@ private static Map<String, String> cacheFromOpenSsl(String openSslCipherSuite) {

static String toJavaUncached(String openSslCipherSuite) {
Matcher m = OPENSSL_CIPHERSUITE_PATTERN.matcher(openSslCipherSuite);

if (openSslCipherSuite.startsWith("TLS_")) {
m = OPENSSL_TLSv13_PATTERN.matcher(openSslCipherSuite);
}

if (!m.matches()) {
return null;
}
Expand All @@ -352,20 +387,8 @@ static String toJavaUncached(String openSslCipherSuite) {
}

handshakeAlgo = toJavaHandshakeAlgo(handshakeAlgo, export);
String bulkCipher;
String hmacAlgo;
if ("TLS".equals(handshakeAlgo)) {
String groups = m.group(2) + "_" + m.group(3);
bulkCipher = toJavaBulkCipher(groups, export);
hmacAlgo = m.group(4);
} else {
bulkCipher = toJavaBulkCipher(m.group(2), export);
hmacAlgo = toJavaHmacAlgo(m.group(3));
}

if ("TLS".equals(handshakeAlgo)) {
return handshakeAlgo + "_" + bulkCipher + "_" + hmacAlgo;
}
String bulkCipher = toJavaBulkCipher(m.group(2), export);
String hmacAlgo = toJavaHmacAlgo(m.group(3));
return handshakeAlgo + "_WITH_" + bulkCipher + '_' + hmacAlgo;
}

Expand Down Expand Up @@ -442,4 +465,8 @@ private static String toJavaHmacAlgo(String hmacAlgo) {

private CipherSuiteConverter() {
}

static boolean isTLSv13CipherSuite(String openSslCipherSuite) {
return OPENSSL_TLSv13_PATTERN.matcher(openSslCipherSuite).matches();
}
}
14 changes: 13 additions & 1 deletion java/src/main/java/org/wildfly/openssl/Messages.java
Expand Up @@ -69,6 +69,10 @@ public class Messages {
private static final String MSG37 = formatCode(37);
private static final String MSG38 = formatCode(38);
private static final String MSG39 = formatCode(39);
private static final String MSG40 = formatCode(40);
private static final String MSG41 = formatCode(41);
private static final String MSG42 = formatCode(42);
private static final String MSG43 = formatCode(43);

private static String formatCode(int i) {
return CODE + new DecimalFormat("0000").format(i);
Expand Down Expand Up @@ -222,5 +226,13 @@ public String directBufferDeallocationFailed() {
public String unsupportedProtocolVersion(int p) {
return interpolate(MSG39, p);
}

public String handshakeFailed() {
return interpolate(MSG41);
}
public String settingCipherSuites(String s) {
return interpolate(MSG42, s);
}
public String settingTls13CipherSuites(String s) {
return interpolate(MSG43, s);
}
}
Expand Up @@ -39,7 +39,8 @@ public final class OpenSSLClientSessionContext extends OpenSSLSessionContext {
private volatile int timeout;
private final long context;
private int maxCacheSize = 100;
private volatile boolean enabled;
private String handshakeKeyHost;
private int handshakeKeyPort;

OpenSSLClientSessionContext(long context) {
super(context);
Expand All @@ -48,6 +49,11 @@ public final class OpenSSLClientSessionContext extends OpenSSLSessionContext {
accessQueue = ConcurrentDirectDeque.newInstance();
}

@Override
synchronized void sessionCreatedCallback(long ssl, long session, byte[] sessionId) {
storeClientSideSession(getHandshakeKey(), ssl, session, sessionId);
}

@Override
public void setSessionTimeout(int seconds) {
if (seconds < 0) {
Expand Down Expand Up @@ -77,21 +83,45 @@ public int getSessionCacheSize() {
return maxCacheSize;
}

synchronized void storeClientSideSession(final long ssl, final String host, final int port, byte[] sessionId) {
if (host != null && port >= 0) {
final ClientSessionKey key = new ClientSessionKey(host, port);
// set with the session pointer from the found session
final ClientSessionInfo foundSessionPtr = getCacheValue(key);
if (foundSessionPtr != null) {
if(getSession(foundSessionPtr.sessionId) != null) {
removeCacheEntry(key);
} else {
removeCacheEntry(key);
public void setSessionCacheEnabled(boolean enabled) {
long mode = enabled ? SSL.SSL_SESS_CACHE_CLIENT : SSL.SSL_SESS_CACHE_OFF;
SSL.getInstance().setSessionCacheMode(context, mode);
}

public boolean isSessionCacheEnabled() {
return SSL.getInstance().getSessionCacheMode(context) == SSL.SSL_SESS_CACHE_CLIENT;
}

void setHandshakeKeyHost(String handshakeKeyHost) {
this.handshakeKeyHost = handshakeKeyHost;
}

void setHandshakeKeyPort(int handshakeKeyPort) {
this.handshakeKeyPort = handshakeKeyPort;
}

public ClientSessionKey getHandshakeKey() {
if (handshakeKeyHost != null && handshakeKeyPort >= 0) {
return new ClientSessionKey(handshakeKeyHost, handshakeKeyPort);
}
return null;
}

synchronized void storeClientSideSession(ClientSessionKey key, long ssl, long sessionPointer, byte[] sessionId) {
if (sessionId != null) {
if (key != null) {
// set with the session pointer from the found session
final ClientSessionInfo foundSessionPtr = getCacheValue(key);
if (foundSessionPtr != null) {
if (getSession(foundSessionPtr.sessionId) != null) {
removeCacheEntry(key);
} else {
removeCacheEntry(key);
}
}
addCacheEntry(key, new ClientSessionInfo(sessionPointer, sessionId, System.currentTimeMillis()));
clientSessionCreated(ssl, sessionPointer, sessionId);
}
final long sessionPointer = SSL.getInstance().getSession(ssl);
addCacheEntry(key, new ClientSessionInfo(sessionPointer, sessionId, System.currentTimeMillis()));
clientSessionCreated(ssl, sessionPointer, sessionId);
}
}

Expand Down
15 changes: 15 additions & 0 deletions java/src/main/java/org/wildfly/openssl/OpenSSLContextSPI.java
Expand Up @@ -18,6 +18,8 @@
package org.wildfly.openssl;


import static org.wildfly.openssl.OpenSSLEngine.isTLS13Supported;

import java.io.IOException;
import java.net.InetAddress;
import java.net.ServerSocket;
Expand Down Expand Up @@ -71,6 +73,8 @@ public abstract class OpenSSLContextSPI extends SSLContextSpi {

private static volatile String[] allAvailableCiphers;

private static final String TLS13_CIPHERS = "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256:TLS_AES_128_GCM_SHA256:TLS_AES_128_CCM_SHA256:TLS_AES_128_CCM_8_SHA256";

protected final long ctx;
final int supportedCiphers;

Expand All @@ -94,10 +98,14 @@ public static String[] getAvailableCipherSuites() {
if(allAvailableCiphers == null) {

final Set<String> availableCipherSuites = new LinkedHashSet<>(128);
boolean tls13Supported = isTLS13Supported();
try {
final long sslCtx = SSL.getInstance().makeSSLContext(SSL.SSL_PROTOCOL_ALL, SSL.SSL_MODE_SERVER);
try {
SSL.getInstance().setSSLContextOptions(sslCtx, SSL.SSL_OP_ALL);
if (tls13Supported) {
SSL.getInstance().setCipherSuiteTLS13(sslCtx, TLS13_CIPHERS);
}
SSL.getInstance().setCipherSuite(sslCtx, "ALL");
final long ssl = SSL.getInstance().newSSL(sslCtx, true);
try {
Expand Down Expand Up @@ -484,4 +492,11 @@ public OpenSSLTLS_1_2_ContextSpi() throws SSLException {
super(SSL.SSL_PROTOCOL_TLSV1_2);
}
}

public static final class OpenSSLTLS_1_3_ContextSpi extends OpenSSLContextSPI {

public OpenSSLTLS_1_3_ContextSpi() throws SSLException {
super(SSL.SSL_PROTOCOL_TLSV1_3);
}
}
}