diff --git a/client/src/main/java/io/split/client/SplitFactoryImpl.java b/client/src/main/java/io/split/client/SplitFactoryImpl.java index 9932cbf8..978f1603 100644 --- a/client/src/main/java/io/split/client/SplitFactoryImpl.java +++ b/client/src/main/java/io/split/client/SplitFactoryImpl.java @@ -92,6 +92,7 @@ import io.split.telemetry.synchronizer.TelemetrySynchronizer; import org.apache.hc.client5.http.auth.AuthScope; +import org.apache.hc.client5.http.auth.BearerToken; import org.apache.hc.client5.http.auth.Credentials; import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; import org.apache.hc.client5.http.config.RequestConfig; @@ -113,11 +114,14 @@ import org.slf4j.LoggerFactory; import pluggable.CustomStorageWrapper; +import javax.net.ssl.SSLContext; import java.io.IOException; import java.io.InputStream; import java.net.InetAddress; import java.net.URI; import java.net.URISyntaxException; +import java.nio.file.Paths; +import java.security.KeyStore; import java.util.concurrent.ExecutorService; import java.util.stream.Collectors; import java.util.HashSet; @@ -518,8 +522,28 @@ public boolean isDestroyed() { protected static SplitHttpClient buildSplitHttpClient(String apiToken, SplitClientConfig config, SDKMetadata sdkMetadata, RequestDecorator requestDecorator) throws URISyntaxException { + + SSLContext sslContext; + if (config.proxyMTLSAuth() != null) { + _log.debug("Proxy setup using mTLS"); + try { + KeyStore keyStore = KeyStore.getInstance("PKCS12"); + InputStream keystoreStream = java.nio.file.Files.newInputStream(Paths.get(config.proxyMTLSAuth().getP12File())); + keyStore.load(keystoreStream, config.proxyMTLSAuth().getP12FilePassKey().toCharArray()); + sslContext = SSLContexts.custom() + .loadKeyMaterial(keyStore, config.proxyMTLSAuth().getP12FilePassKey().toCharArray()) + .build(); + } catch (Exception e) { + _log.error("Exception caught while processing p12 file for Proxy mTLS auth: ", e); + _log.warn("Ignoring p12 mTLS config and switching to default context"); + sslContext = SSLContexts.createSystemDefault(); + } + } else { + sslContext = SSLContexts.createSystemDefault(); + } + SSLConnectionSocketFactory sslSocketFactory = SSLConnectionSocketFactoryBuilder.create() - .setSslContext(SSLContexts.createSystemDefault()) + .setSslContext(sslContext) .setTlsVersions(TLS.V_1_1, TLS.V_1_2) .build(); @@ -604,6 +628,15 @@ private static HttpClientBuilder setupProxy(HttpClientBuilder httpClientbuilder, httpClientbuilder.setDefaultCredentialsProvider(credsProvider); } + if (config.proxyToken() != null) { + _log.debug("Proxy setup using token"); + BasicCredentialsProvider credsProvider = new BasicCredentialsProvider(); + AuthScope siteScope = new AuthScope(config.proxy().getHostName(), config.proxy().getPort()); + Credentials siteCreds = new BearerToken(config.proxyToken()); + credsProvider.setCredentials(siteScope, siteCreds); + httpClientbuilder.setDefaultCredentialsProvider(credsProvider); + } + return httpClientbuilder; } diff --git a/client/src/test/java/io/split/client/SplitFactoryImplTest.java b/client/src/test/java/io/split/client/SplitFactoryImplTest.java index a6da1069..4597b615 100644 --- a/client/src/test/java/io/split/client/SplitFactoryImplTest.java +++ b/client/src/test/java/io/split/client/SplitFactoryImplTest.java @@ -1,16 +1,26 @@ package io.split.client; +import io.split.client.dtos.ProxyMTLSAuth; import io.split.client.impressions.ImpressionsManager; import io.split.client.utils.FileTypeEnum; import io.split.integrations.IntegrationsConfig; +import io.split.service.SplitHttpClientImpl; import io.split.storages.enums.OperationMode; import io.split.storages.pluggable.domain.UserStorageWrapper; import io.split.telemetry.storage.TelemetryStorage; import io.split.telemetry.synchronizer.TelemetrySynchronizer; import junit.framework.TestCase; +import org.apache.hc.client5.http.auth.AuthScope; +import org.apache.hc.client5.http.auth.BearerToken; +import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; +import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; +import org.apache.hc.client5.http.impl.io.DefaultHttpClientConnectionOperator; +import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManager; +import org.apache.hc.client5.http.impl.routing.DefaultProxyRoutePlanner; +import org.apache.hc.core5.http.HttpHost; +import org.apache.hc.core5.http.config.Registry; import org.awaitility.Awaitility; import org.junit.Assert; -import org.junit.Ignore; import org.junit.Test; import org.mockito.Mockito; import static org.mockito.Mockito.when; @@ -24,6 +34,8 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.net.URISyntaxException; +import java.util.HashMap; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; @@ -87,12 +99,12 @@ public void testFactoryInstantiationIntegrationsConfig() throws Exception { } @Test - public void testFactoryInstantiationWithProxy() throws Exception { + public void testFactoryInstantiationWithProxyCredentials() throws Exception { SplitClientConfig splitClientConfig = SplitClientConfig.builder() .enableDebug() .impressionsMode(ImpressionsManager.Mode.DEBUG) .impressionsRefreshRate(1) - .endpoint(ENDPOINT,EVENTS_ENDPOINT) + .endpoint(ENDPOINT, EVENTS_ENDPOINT) .telemetryURL(SplitClientConfig.TELEMETRY_ENDPOINT) .authServiceURL(AUTH_SERVICE) .setBlockUntilReadyTimeout(1000) @@ -102,9 +114,150 @@ public void testFactoryInstantiationWithProxy() throws Exception { .proxyHost(ENDPOINT) .build(); SplitFactoryImpl splitFactory = new SplitFactoryImpl(API_KEY, splitClientConfig); - assertNotNull(splitFactory.client()); assertNotNull(splitFactory.manager()); + + Field splitHttpClientField = SplitFactoryImpl.class.getDeclaredField("_splitHttpClient"); + splitHttpClientField.setAccessible(true); + SplitHttpClientImpl client = (SplitHttpClientImpl) splitHttpClientField.get(splitFactory); + + Field httpClientField = SplitHttpClientImpl.class.getDeclaredField("_client"); + httpClientField.setAccessible(true); + Class InternalHttp = Class.forName("org.apache.hc.client5.http.impl.classic.InternalHttpClient"); + + Field routePlannerField = InternalHttp.getDeclaredField("routePlanner"); + routePlannerField.setAccessible(true); + DefaultProxyRoutePlanner routePlanner = (DefaultProxyRoutePlanner) routePlannerField.get(InternalHttp.cast(httpClientField.get(client))); + + Field proxyField = DefaultProxyRoutePlanner.class.getDeclaredField("proxy"); + proxyField.setAccessible(true); + HttpHost proxy = (HttpHost) proxyField.get(routePlanner); + + Assert.assertEquals("http", proxy.getSchemeName()); + Assert.assertEquals(ENDPOINT, proxy.getHostName()); + Assert.assertEquals(6060, proxy.getPort()); + + Field credentialsProviderField = InternalHttp.getDeclaredField("credentialsProvider"); + credentialsProviderField.setAccessible(true); + BasicCredentialsProvider credentialsProvider = (BasicCredentialsProvider) credentialsProviderField.get(InternalHttp.cast(httpClientField.get(client))); + + Field credMapField = BasicCredentialsProvider.class.getDeclaredField("credMap"); + credMapField.setAccessible(true); + ConcurrentHashMap credMap = (ConcurrentHashMap) credMapField.get(credentialsProvider); + + Assert.assertEquals("test", credMap.entrySet().stream().iterator().next().getValue().getUserName()); + assertNotNull(credMap.entrySet().stream().iterator().next().getValue().getUserPassword()); + + splitFactory.destroy(); + } + + @Test + public void testFactoryInstantiationWithProxyToken() throws Exception { + SplitClientConfig splitClientConfig = SplitClientConfig.builder() + .enableDebug() + .impressionsMode(ImpressionsManager.Mode.DEBUG) + .impressionsRefreshRate(1) + .endpoint(ENDPOINT, EVENTS_ENDPOINT) + .telemetryURL(SplitClientConfig.TELEMETRY_ENDPOINT) + .authServiceURL(AUTH_SERVICE) + .setBlockUntilReadyTimeout(1000) + .proxyPort(6060) + .proxyToken("123456789") + .proxyHost(ENDPOINT) + .build(); + SplitFactoryImpl splitFactory2 = new SplitFactoryImpl(API_KEY, splitClientConfig); + assertNotNull(splitFactory2.client()); + assertNotNull(splitFactory2.manager()); + + Field splitHttpClientField2 = SplitFactoryImpl.class.getDeclaredField("_splitHttpClient"); + splitHttpClientField2.setAccessible(true); + SplitHttpClientImpl client2 = (SplitHttpClientImpl) splitHttpClientField2.get(splitFactory2); + + Field httpClientField2 = SplitHttpClientImpl.class.getDeclaredField("_client"); + httpClientField2.setAccessible(true); + Class InternalHttp2 = Class.forName("org.apache.hc.client5.http.impl.classic.InternalHttpClient"); + + Field credentialsProviderField2 = InternalHttp2.getDeclaredField("credentialsProvider"); + credentialsProviderField2.setAccessible(true); + BasicCredentialsProvider credentialsProvider2 = (BasicCredentialsProvider) credentialsProviderField2.get(InternalHttp2.cast(httpClientField2.get(client2))); + + Field credMapField2 = BasicCredentialsProvider.class.getDeclaredField("credMap"); + credMapField2.setAccessible(true); + ConcurrentHashMap credMap2 = (ConcurrentHashMap) credMapField2.get(credentialsProvider2); + + Assert.assertEquals("123456789", credMap2.entrySet().stream().iterator().next().getValue().getToken()); + + splitFactory2.destroy(); + } + + @Test + public void testFactoryInstantiationWithProxyMtls() throws Exception { + SplitClientConfig splitClientConfig = SplitClientConfig.builder() + .enableDebug() + .impressionsMode(ImpressionsManager.Mode.DEBUG) + .impressionsRefreshRate(1) + .endpoint(ENDPOINT,EVENTS_ENDPOINT) + .telemetryURL(SplitClientConfig.TELEMETRY_ENDPOINT) + .authServiceURL(AUTH_SERVICE) + .setBlockUntilReadyTimeout(1000) + .proxyPort(6060) + .proxyScheme("https") + .proxyMtlsAuth(new ProxyMTLSAuth.Builder().proxyP12File("src/test/resources/keyStore.p12").proxyP12FilePassKey("split").build()) + .proxyHost(ENDPOINT) + .build(); + SplitFactoryImpl splitFactory3 = new SplitFactoryImpl(API_KEY, splitClientConfig); + assertNotNull(splitFactory3.client()); + assertNotNull(splitFactory3.manager()); + + Field splitHttpClientField3 = SplitFactoryImpl.class.getDeclaredField("_splitHttpClient"); + splitHttpClientField3.setAccessible(true); + SplitHttpClientImpl client3 = (SplitHttpClientImpl) splitHttpClientField3.get(splitFactory3); + + Field httpClientField3 = SplitHttpClientImpl.class.getDeclaredField("_client"); + httpClientField3.setAccessible(true); + Class InternalHttp3 = Class.forName("org.apache.hc.client5.http.impl.classic.InternalHttpClient"); + + Field connManagerField = InternalHttp3.getDeclaredField("connManager"); + connManagerField.setAccessible(true); + PoolingHttpClientConnectionManager connManager = (PoolingHttpClientConnectionManager) connManagerField.get(InternalHttp3.cast(httpClientField3.get(client3))); + + Field connectionOperatorField = PoolingHttpClientConnectionManager.class.getDeclaredField("connectionOperator"); + connectionOperatorField.setAccessible(true); + DefaultHttpClientConnectionOperator connectionOperator = (DefaultHttpClientConnectionOperator) connectionOperatorField.get(connManager); + + Field tlsSocketStrategyLookupField = DefaultHttpClientConnectionOperator.class.getDeclaredField("tlsSocketStrategyLookup"); + tlsSocketStrategyLookupField.setAccessible(true); + Registry tlsSocketStrategyLookup = (Registry) tlsSocketStrategyLookupField.get(connectionOperator); + + Field mapField = Registry.class.getDeclaredField("map"); + mapField.setAccessible(true); + Class map = mapField.get(tlsSocketStrategyLookup).getClass(); + + Class value = ((ConcurrentHashMap) map.cast(mapField.get(tlsSocketStrategyLookup))).get("https").getClass(); + + Field arg1Field = value.getDeclaredField("arg$1"); + arg1Field.setAccessible(true); + Class sslConnectionSocketFactory = arg1Field.get(((ConcurrentHashMap) map.cast(mapField.get(tlsSocketStrategyLookup))).get("https")).getClass(); + + Field socketFactoryField = sslConnectionSocketFactory.getDeclaredField("socketFactory"); + socketFactoryField.setAccessible(true); + Class socketFactory = socketFactoryField.get(arg1Field.get(((ConcurrentHashMap) map.cast(mapField.get(tlsSocketStrategyLookup))).get("https"))).getClass(); + + Field contextField = socketFactory.getDeclaredField("context"); + contextField.setAccessible(true); + Class context = Class.forName("sun.security.ssl.SSLContextImpl"); + + Field keyManagerField = context.getDeclaredField("keyManager"); + keyManagerField.setAccessible(true); + Class keyManager = keyManagerField.get(contextField.get(socketFactoryField.get(arg1Field.get(((ConcurrentHashMap) map.cast(mapField.get(tlsSocketStrategyLookup))).get("https"))))).getClass(); + + Field credentialsMapField = keyManager.getDeclaredField("credentialsMap"); + credentialsMapField.setAccessible(true); + HashMap credentialsMap = (HashMap) credentialsMapField.get(keyManagerField.get(contextField.get(socketFactoryField.get(arg1Field.get(((ConcurrentHashMap) map.cast(mapField.get(tlsSocketStrategyLookup))).get("https")))))); + + assertNotNull(credentialsMap.get("1")); + + splitFactory3.destroy(); } @Test diff --git a/client/src/test/resources/keyStore.p12 b/client/src/test/resources/keyStore.p12 new file mode 100644 index 00000000..ce2b3417 Binary files /dev/null and b/client/src/test/resources/keyStore.p12 differ