diff --git a/reactor-netty-http/src/main/java/reactor/netty/http/server/WebsocketServerOperations.java b/reactor-netty-http/src/main/java/reactor/netty/http/server/WebsocketServerOperations.java index 3ed0ca8e5..e507cefdb 100644 --- a/reactor-netty-http/src/main/java/reactor/netty/http/server/WebsocketServerOperations.java +++ b/reactor-netty-http/src/main/java/reactor/netty/http/server/WebsocketServerOperations.java @@ -94,7 +94,11 @@ final class WebsocketServerOperations extends HttpServerOperations else { removeHandler(NettyPipeline.HttpTrafficHandler); removeHandler(NettyPipeline.AccessLogHandler); - removeHandler(NettyPipeline.HttpMetricsHandler); + ChannelHandler handler = channel.pipeline().get(NettyPipeline.HttpMetricsHandler); + if (handler != null) { + replaceHandler(NettyPipeline.HttpMetricsHandler, + new WebsocketHttpServerMetricsHandler((AbstractHttpServerMetricsHandler) handler)); + } handshakerResult = channel.newPromise(); HttpRequest request = new DefaultFullHttpRequest(replaced.version(), @@ -305,4 +309,70 @@ public String selectedSubprotocol() { static final AtomicIntegerFieldUpdater CLOSE_SENT = AtomicIntegerFieldUpdater.newUpdater(WebsocketServerOperations.class, "closeSent"); + + static final class WebsocketHttpServerMetricsHandler extends AbstractHttpServerMetricsHandler { + + final HttpServerMetricsRecorder recorder; + + WebsocketHttpServerMetricsHandler(AbstractHttpServerMetricsHandler copy) { + super(copy); + this.recorder = copy.recorder(); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.fireChannelActive(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + try { + if (channelOpened && recorder instanceof MicrometerHttpServerMetricsRecorder) { + // For custom user recorders, we don't propagate the channelInactive event, because this will be done + // by the ChannelMetricsHandler itself. ChannelMetricsHandler is only present when the recorder is + // not our MicrometerHttpServerMetricsRecorder. See HttpServerConfig class. + channelOpened = false; + // Always use the real connection local address without any proxy information + recorder.recordServerConnectionClosed(ctx.channel().localAddress()); + } + + if (channelActivated) { + channelActivated = false; + // Always use the real connection local address without any proxy information + recorder.recordServerConnectionInactive(ctx.channel().localAddress()); + } + } + catch (RuntimeException e) { + // Allow request-response exchange to continue, unaffected by metrics problem + if (log.isWarnEnabled()) { + log.warn(format(ctx.channel(), "Exception caught while recording metrics."), e); + } + } + finally { + ctx.fireChannelInactive(); + } + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ctx.fireChannelRead(msg); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + ctx.fireExceptionCaught(cause); + } + + @Override + @SuppressWarnings("FutureReturnValueIgnored") + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + //"FutureReturnValueIgnored" this is deliberate + ctx.write(msg, promise); + } + + @Override + public HttpServerMetricsRecorder recorder() { + return recorder; + } + } } diff --git a/reactor-netty-http/src/test/java/reactor/netty/http/HttpMetricsHandlerTests.java b/reactor-netty-http/src/test/java/reactor/netty/http/HttpMetricsHandlerTests.java index aab1c8022..504d0ec65 100644 --- a/reactor-netty-http/src/test/java/reactor/netty/http/HttpMetricsHandlerTests.java +++ b/reactor-netty-http/src/test/java/reactor/netty/http/HttpMetricsHandlerTests.java @@ -172,6 +172,8 @@ static void createSelfSignedCertificate() throws CertificateException { *
  • /5 is used by testServerConnectionsRecorder test
  • *
  • /6 is used by testServerConnectionsMicrometerConnectionClose test
  • *
  • /7 is used by testServerConnectionsRecorderConnectionClose test
  • + *
  • /8 is used by testServerConnectionsWebsocketMicrometer test
  • + *
  • /9 is used by testServerConnectionsWebsocketRecorder test
  • * */ @BeforeEach @@ -201,7 +203,12 @@ void setUp() { .get("/7", (req, res) -> { checkServerConnectionsRecorder(req); return Mono.delay(Duration.ofMillis(200)).then(res.send()); - })); + }) + .get("/8", (req, res) -> res.sendWebsocket((in, out) -> + out.sendString(Mono.just("Hello World!").doOnNext(b -> checkServerConnectionsMicrometer(req))))) + .get("/9", (req, res) -> res.sendWebsocket((in, out) -> + out.sendString(Mono.just("Hello World!").doOnNext(b -> checkServerConnectionsRecorder(req))))) + ); provider = ConnectionProvider.create("HttpMetricsHandlerTests", 1); httpClient = createClient(provider, () -> disposableServer.address()) @@ -747,6 +754,28 @@ void testServerConnectionsMicrometerConnectionClose(HttpProtocol[] serverProtoco } } + @Test + void testServerConnectionsWebsocketMicrometer() throws Exception { + disposableServer = httpServer + .doOnConnection(cnx -> ServerCloseHandler.INSTANCE.register(cnx.channel())) + .bindNow(); + + String address = formatSocketAddress(disposableServer.address()); + + httpClient.websocket() + .uri("/8") + .handle((in, out) -> in.receive().aggregate().asString()) + .as(StepVerifier::create) + .expectNext("Hello World!") + .expectComplete() + .verify(Duration.ofSeconds(30)); + + // make sure the client socket is closed on the server side before checking server metrics + assertThat(ServerCloseHandler.INSTANCE.awaitClientClosedOnServer()).as("awaitClientClosedOnServer timeout").isTrue(); + assertGauge(registry, SERVER_CONNECTIONS_TOTAL, URI, HTTP, LOCAL_ADDRESS, address).hasValueEqualTo(0); + assertGauge(registry, SERVER_CONNECTIONS_ACTIVE, URI, HTTP, LOCAL_ADDRESS, address).hasValueEqualTo(0); + } + @ParameterizedTest @MethodSource("httpCompatibleProtocols") void testServerConnectionsRecorder(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols, @@ -850,6 +879,32 @@ void testServerConnectionsRecorderConnectionClose(HttpProtocol[] serverProtocols } } + @Test + void testServerConnectionsWebsocketRecorder() throws Exception { + ServerRecorder.INSTANCE.reset(); + disposableServer = httpServer.metrics(true, ServerRecorder.supplier(), Function.identity()) + .doOnConnection(cnx -> ServerCloseHandler.INSTANCE.register(cnx.channel())) + .bindNow(); + + String address = formatSocketAddress(disposableServer.address()); + + httpClient.websocket() + .uri("/9") + .handle((in, out) -> in.receive().aggregate().asString()) + .as(StepVerifier::create) + .expectNext("Hello World!") + .expectComplete() + .verify(Duration.ofSeconds(30)); + + // make sure the client socket is closed on the server side before checking server metrics + assertThat(ServerCloseHandler.INSTANCE.awaitClientClosedOnServer()).as("awaitClientClosedOnServer timeout").isTrue(); + assertThat(ServerRecorder.INSTANCE.error.get()).isNull(); + assertThat(ServerRecorder.INSTANCE.onServerConnectionsAmount.get()).isEqualTo(0); + assertThat(ServerRecorder.INSTANCE.onActiveConnectionsAmount.get()).isEqualTo(0); + assertThat(ServerRecorder.INSTANCE.onActiveConnectionsLocalAddr.get()).isEqualTo(address); + assertThat(ServerRecorder.INSTANCE.onInactiveConnectionsLocalAddr.get()).isEqualTo(address); + } + @Test @SuppressWarnings("deprecation") void testIssue896() throws Exception {