From 30c44f0b18a37aeb9c4fd717f662d2b575a779c1 Mon Sep 17 00:00:00 2001 From: Violeta Georgieva Date: Mon, 29 Apr 2024 12:22:30 +0300 Subject: [PATCH] Always use remote socket address for the metrics (#3210) This is a change additional to the change made with #2755 --- .../ContextAwareHttpServerMetricsHandler.java | 11 +- .../netty/http/HttpMetricsHandlerTests.java | 181 ++++++++++++++++-- 2 files changed, 172 insertions(+), 20 deletions(-) diff --git a/reactor-netty-http/src/main/java/reactor/netty/http/server/ContextAwareHttpServerMetricsHandler.java b/reactor-netty-http/src/main/java/reactor/netty/http/server/ContextAwareHttpServerMetricsHandler.java index c1b3f42057..b3096afe6f 100644 --- a/reactor-netty-http/src/main/java/reactor/netty/http/server/ContextAwareHttpServerMetricsHandler.java +++ b/reactor-netty-http/src/main/java/reactor/netty/http/server/ContextAwareHttpServerMetricsHandler.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023 VMware, Inc. or its affiliates, All Rights Reserved. + * Copyright (c) 2021-2024 VMware, Inc. or its affiliates, All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,7 +52,8 @@ protected ContextAwareHttpServerMetricsRecorder recorder() { @Override protected void recordException(HttpServerOperations ops, String path) { // Always take the remote address from the operations in order to consider proxy information - recorder().incrementErrorsCount(ops.currentContext(), ops.remoteAddress(), path); + // Use remoteSocketAddress() in order to obtain UDS info + recorder().incrementErrorsCount(ops.currentContext(), ops.remoteSocketAddress(), path); } @Override @@ -62,7 +63,8 @@ protected void recordRead(HttpServerOperations ops, String path, String method) Duration.ofNanos(System.nanoTime() - dataReceivedTime)); // Always take the remote address from the operations in order to consider proxy information - recorder().recordDataReceived(contextView, ops.remoteAddress(), path, dataReceived); + // Use remoteSocketAddress() in order to obtain UDS info + recorder().recordDataReceived(contextView, ops.remoteSocketAddress(), path, dataReceived); } @Override @@ -80,6 +82,7 @@ protected void recordWrite(HttpServerOperations ops, String path, String method, } // Always take the remote address from the operations in order to consider proxy information - recorder().recordDataSent(contextView, ops.remoteAddress(), path, dataSent); + // Use remoteSocketAddress() in order to obtain UDS info + recorder().recordDataSent(contextView, ops.remoteSocketAddress(), path, dataSent); } } diff --git a/reactor-netty-http/src/test/java/reactor/netty/http/HttpMetricsHandlerTests.java b/reactor-netty-http/src/test/java/reactor/netty/http/HttpMetricsHandlerTests.java index 36108b808f..b14172b0d1 100644 --- a/reactor-netty-http/src/test/java/reactor/netty/http/HttpMetricsHandlerTests.java +++ b/reactor-netty-http/src/test/java/reactor/netty/http/HttpMetricsHandlerTests.java @@ -911,7 +911,7 @@ void testBadRequest(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtoco void testServerConnectionsRecorderBadUri(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols, @Nullable ProtocolSslContextSpec serverCtx, @Nullable ProtocolSslContextSpec clientCtx) throws Exception { - testServerConnectionsRecorderBadUri(serverProtocols, clientProtocols, serverCtx, clientCtx, null, -1, + testServerConnectionsRecorderBadUri(serverProtocols, clientProtocols, serverCtx, clientCtx, null, -1, false, Function.identity(), Function.identity()); } @@ -921,7 +921,19 @@ void testServerConnectionsRecorderBadUriUDS(HttpProtocol[] serverProtocols, Http @Nullable ProtocolSslContextSpec serverCtx, @Nullable ProtocolSslContextSpec clientCtx) throws Exception { assumeThat(LoopResources.hasNativeSupport()).isTrue(); - testServerConnectionsRecorderBadUri(serverProtocols, clientProtocols, serverCtx, clientCtx, null, -1, + testServerConnectionsRecorderBadUri(serverProtocols, clientProtocols, serverCtx, clientCtx, null, -1, false, + client -> client.bindAddress(() -> new DomainSocketAddress("/tmp/test.sockclient")) + .remoteAddress(() -> new DomainSocketAddress("/tmp/test.sock")), + server -> server.bindAddress(() -> new DomainSocketAddress("/tmp/test.sock"))); + } + + @ParameterizedTest + @MethodSource("httpCompatibleProtocols") + void testServerConnectionsRecorderBadUriUDSContextAware(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols, + @Nullable ProtocolSslContextSpec serverCtx, + @Nullable ProtocolSslContextSpec clientCtx) throws Exception { + assumeThat(LoopResources.hasNativeSupport()).isTrue(); + testServerConnectionsRecorderBadUri(serverProtocols, clientProtocols, serverCtx, clientCtx, null, -1, true, client -> client.bindAddress(() -> new DomainSocketAddress("/tmp/test.sockclient")) .remoteAddress(() -> new DomainSocketAddress("/tmp/test.sock")), server -> server.bindAddress(() -> new DomainSocketAddress("/tmp/test.sock"))); @@ -933,7 +945,7 @@ void testServerConnectionsRecorderBadUriForwarded(HttpProtocol[] serverProtocols @Nullable ProtocolSslContextSpec serverCtx, @Nullable ProtocolSslContextSpec clientCtx) throws Exception { testServerConnectionsRecorderBadUri(serverProtocols, clientProtocols, serverCtx, clientCtx, - "192.168.0.1", 8080, + "192.168.0.1", 8080, false, Function.identity(), Function.identity()); } @@ -1010,14 +1022,19 @@ static Stream combinationsIssue2956() { private void testServerConnectionsRecorderBadUri(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols, @Nullable ProtocolSslContextSpec serverCtx, @Nullable ProtocolSslContextSpec clientCtx, - @Nullable String xForwardedFor, int xForwardedPort, + @Nullable String xForwardedFor, int xForwardedPort, boolean contextAware, Function bindClient, Function bindServer) throws Exception { - ServerRecorderBadUri.INSTANCE.init(); + if (contextAware) { + ContextAwareServerRecorderBadUri.INSTANCE.init(); + } + else { + ServerRecorderBadUri.INSTANCE.init(); + } AtomicReference clientSA = new AtomicReference<>(); disposableServer = bindServer.apply(customizeServerOptions(httpServer, serverCtx, serverProtocols)) - .metrics(true, () -> ServerRecorderBadUri.INSTANCE, Function.identity()) + .metrics(true, () -> contextAware ? ContextAwareServerRecorderBadUri.INSTANCE : ServerRecorderBadUri.INSTANCE, Function.identity()) .forwarded(xForwardedFor != null || xForwardedPort != -1) .childObserve((conn, state) -> { if (state == ConnectionObserver.State.CONNECTED) { @@ -1062,18 +1079,34 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) provider.disposeLater() .block(Duration.ofSeconds(30)); - assertThat(ServerRecorderBadUri.INSTANCE.closed.await(30, TimeUnit.SECONDS)) - .as("awaitClose timeout") - .isTrue(); + if (contextAware) { + assertThat(ContextAwareServerRecorderBadUri.INSTANCE.closed.await(30, TimeUnit.SECONDS)) + .as("awaitClose timeout") + .isTrue(); + + assertThat(ContextAwareServerRecorderBadUri.INSTANCE.nullMethodParams.size() == 0) + .as("some method got null parameters: %s", ContextAwareServerRecorderBadUri.INSTANCE.nullMethodParams) + .isTrue(); - assertThat(ServerRecorderBadUri.INSTANCE.nullMethodParams.size() == 0) - .as("some method got null parameters: %s", ServerRecorderBadUri.INSTANCE.nullMethodParams) - .isTrue(); + SocketAddress recordedClientSA = ContextAwareServerRecorderBadUri.INSTANCE.clientAddr; + assertThat(recordedClientSA) + .as("recorded client remote socket address %s is different from expected client socket address %s", recordedClientSA, clientSA.get()) + .isEqualTo(clientSA.get()); + } + else { + assertThat(ServerRecorderBadUri.INSTANCE.closed.await(30, TimeUnit.SECONDS)) + .as("awaitClose timeout") + .isTrue(); - SocketAddress recordedClientSA = ServerRecorderBadUri.INSTANCE.clientAddr; - assertThat(recordedClientSA) - .as("recorded client remote socket address %s is different from expected client socket address %s", recordedClientSA, clientSA.get()) - .isEqualTo(clientSA.get()); + assertThat(ServerRecorderBadUri.INSTANCE.nullMethodParams.size() == 0) + .as("some method got null parameters: %s", ServerRecorderBadUri.INSTANCE.nullMethodParams) + .isTrue(); + + SocketAddress recordedClientSA = ServerRecorderBadUri.INSTANCE.clientAddr; + assertThat(recordedClientSA) + .as("recorded client remote socket address %s is different from expected client socket address %s", recordedClientSA, clientSA.get()) + .isEqualTo(clientSA.get()); + } } private void checkServerConnectionsMicrometer(HttpServerRequest request) { @@ -1683,6 +1716,122 @@ public void recordResolveAddressTime(SocketAddress socketAddress, Duration durat } } + /** + * Server metrics recorder used to verify that HttpServerMetricsRecorder method parameters + * are not null when a bad request URI is received. + */ + static final class ContextAwareServerRecorderBadUri extends ContextAwareHttpServerMetricsRecorder { + + static final ContextAwareServerRecorderBadUri INSTANCE = new ContextAwareServerRecorderBadUri(); + final ConcurrentLinkedQueue nullMethodParams = new ConcurrentLinkedQueue<>(); + volatile CountDownLatch closed; + volatile SocketAddress clientAddr; + + void init() { + nullMethodParams.clear(); + closed = new CountDownLatch(1); + clientAddr = null; + } + + void checkNullParam(String method, Object... params) { + if (Arrays.stream(params).anyMatch(Objects::isNull)) { + nullMethodParams.add(method); + } + } + + @Override + public void recordServerConnectionOpened(SocketAddress localAddress) { + checkNullParam("recordServerConnectionOpened", localAddress); + } + + @Override + public void recordServerConnectionClosed(SocketAddress localAddress) { + checkNullParam("recordServerConnectionClosed", localAddress); + closed.countDown(); + } + + @Override + public void recordServerConnectionActive(SocketAddress localAddress) { + checkNullParam("recordServerConnectionActive", localAddress); + } + + @Override + public void recordServerConnectionInactive(SocketAddress localAddress) { + checkNullParam("recordServerConnectionInactive", localAddress); + } + + @Override + public void recordStreamOpened(SocketAddress localAddress) { + checkNullParam("recordStreamOpened", localAddress); + } + + @Override + public void recordStreamClosed(SocketAddress localAddress) { + checkNullParam("recordStreamClosed", localAddress); + } + + @Override + public void recordDataReceived(ContextView contextView, SocketAddress remoteAddress, String uri, long bytes) { + checkNullParam("recordDataReceived", contextView, remoteAddress, uri); + } + + @Override + public void recordDataSent(ContextView contextView, SocketAddress remoteAddress, String uri, long bytes) { + checkNullParam("recordDataSent", contextView, remoteAddress, uri); + clientAddr = remoteAddress; + } + + @Override + public void incrementErrorsCount(ContextView contextView, SocketAddress remoteAddress, String uri) { + checkNullParam("incrementErrorsCount", contextView, remoteAddress, uri); + } + + @Override + public void recordDataReceivedTime(ContextView contextView, String uri, String method, Duration time) { + checkNullParam("recordDataReceivedTime", contextView, uri, method, time); + } + + @Override + public void recordDataSentTime(ContextView contextView, String uri, String method, String status, Duration time) { + checkNullParam("recordDataSentTime", contextView, uri, method, status, time); + } + + @Override + public void recordResponseTime(ContextView contextView, String uri, String method, String status, Duration time) { + checkNullParam("recordResponseTime", contextView, uri, method, status, time); + } + + @Override + public void recordDataReceived(ContextView contextView, SocketAddress socketAddress, long l) { + checkNullParam("recordDataReceived", contextView, socketAddress); + } + + @Override + public void recordDataSent(ContextView contextView, SocketAddress socketAddress, long l) { + checkNullParam("recordDataSent", contextView, socketAddress); + } + + @Override + public void incrementErrorsCount(ContextView contextView, SocketAddress socketAddress) { + checkNullParam("incrementErrorsCount", contextView, socketAddress); + } + + @Override + public void recordTlsHandshakeTime(ContextView contextView, SocketAddress socketAddress, Duration duration, String s) { + checkNullParam("recordTlsHandshakeTime", contextView, socketAddress, duration, s); + } + + @Override + public void recordConnectTime(ContextView contextView, SocketAddress socketAddress, Duration duration, String s) { + checkNullParam("recordConnectTime", contextView, socketAddress, duration, s); + } + + @Override + public void recordResolveAddressTime(SocketAddress socketAddress, Duration duration, String s) { + checkNullParam("recordResolveAddressTime", socketAddress, duration, s); + } + } + /** * Server Handler used to detect when the last http response content has been sent to the client. * Handler placed before the HttpMetricsHandler on the Server pipeline.