diff --git a/rsocket-core/src/main/java/io/rsocket/transport/TransportHeaderAware.java b/rsocket-core/src/main/java/io/rsocket/transport/TransportHeaderAware.java index e1c54bd86..16b863d9e 100644 --- a/rsocket-core/src/main/java/io/rsocket/transport/TransportHeaderAware.java +++ b/rsocket-core/src/main/java/io/rsocket/transport/TransportHeaderAware.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,10 @@ /** * Extension interface to support Transports with headers at the transport layer, e.g. Websockets, * Http2. + * + * @deprecated as of 1.0.1 in favor using properties on individual transports. */ +@Deprecated public interface TransportHeaderAware { /** diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java index 7da982f97..72e003d2a 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java @@ -27,7 +27,6 @@ import io.rsocket.transport.netty.client.WebsocketClientTransport; import io.rsocket.util.ByteBufPayload; import java.time.Duration; -import java.util.Collections; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -71,16 +70,15 @@ public static void main(String[] args) { logger.debug( "\n\nStart of Authorized WebSocket Connection\n----------------------------------\n"); - WebsocketClientTransport clientTransport = - WebsocketClientTransport.create(server.host(), server.port()); - - clientTransport.setTransportHeaders(() -> Collections.singletonMap("Authorization", "test")); + WebsocketClientTransport transport = + WebsocketClientTransport.create(server.host(), server.port()) + .header("Authorization", "test"); RSocket clientRSocket = RSocketConnector.create() .keepAlive(Duration.ofMinutes(10), Duration.ofMinutes(10)) .payloadDecoder(PayloadDecoder.ZERO_COPY) - .connect(clientTransport) + .connect(transport) .block(); Flux.range(1, 100) diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/UriUtils.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/UriUtils.java deleted file mode 100644 index 5134ba8d6..000000000 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/UriUtils.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import java.net.URI; -import java.util.Objects; - -/** Utilities for dealing with with {@link URI}s */ -public final class UriUtils { - - private UriUtils() {} - - /** - * Returns the port of a URI. If the port is unset (i.e. {@code -1}) then returns the {@code - * defaultPort}. - * - * @param uri the URI to extract the port from - * @param defaultPort the default to use if the port is unset - * @return the port of a URI or {@code defaultPort} if unset - * @throws NullPointerException if {@code uri} is {@code null} - */ - public static int getPort(URI uri, int defaultPort) { - Objects.requireNonNull(uri, "uri must not be null"); - return uri.getPort() == -1 ? defaultPort : uri.getPort(); - } - - /** - * Returns whether the URI has a secure schema. Secure is defined as being either {@code wss} or - * {@code https}. - * - * @param uri the URI to examine - * @return whether the URI has a secure schema - * @throws NullPointerException if {@code uri} is {@code null} - */ - public static boolean isSecure(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - return uri.getScheme().equals("wss") || uri.getScheme().equals("https"); - } -} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java index 9b8bea97a..35b22ebdd 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java @@ -17,19 +17,19 @@ package io.rsocket.transport.netty.client; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; -import static io.rsocket.transport.netty.UriUtils.getPort; -import static io.rsocket.transport.netty.UriUtils.isSecure; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpHeaders; import io.rsocket.DuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; -import io.rsocket.transport.TransportHeaderAware; import io.rsocket.transport.netty.WebsocketDuplexConnection; import java.net.InetSocketAddress; import java.net.URI; -import java.util.Collections; +import java.util.Arrays; import java.util.Map; import java.util.Objects; +import java.util.function.Consumer; import java.util.function.Supplier; import reactor.core.publisher.Mono; import reactor.netty.http.client.HttpClient; @@ -37,10 +37,12 @@ import reactor.netty.tcp.TcpClient; /** - * An implementation of {@link ClientTransport} that connects to a {@link ServerTransport} via a - * Websocket. + * An implementation of {@link ClientTransport} that connects to a {@link ServerTransport} over + * WebSocket. */ -public final class WebsocketClientTransport implements ClientTransport, TransportHeaderAware { +@SuppressWarnings("deprecation") +public final class WebsocketClientTransport + implements ClientTransport, io.rsocket.transport.TransportHeaderAware { private static final String DEFAULT_PATH = "/"; @@ -48,11 +50,16 @@ public final class WebsocketClientTransport implements ClientTransport, Transpor private final String path; - private Supplier> transportHeaders = Collections::emptyMap; + private HttpHeaders headers = new DefaultHttpHeaders(); + + private WebsocketClientSpec.Builder specBuilder = + WebsocketClientSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK); private WebsocketClientTransport(HttpClient client, String path) { + Objects.requireNonNull(client, "HttpClient must not be null"); + Objects.requireNonNull(path, "path must not be null"); this.client = client; - this.path = path; + this.path = path.startsWith("/") ? path : "/" + path; } /** @@ -62,8 +69,7 @@ private WebsocketClientTransport(HttpClient client, String path) { * @return a new instance */ public static WebsocketClientTransport create(int port) { - TcpClient client = TcpClient.create().port(port); - return create(client); + return create(TcpClient.create().port(port)); } /** @@ -75,10 +81,7 @@ public static WebsocketClientTransport create(int port) { * @throws NullPointerException if {@code bindAddress} is {@code null} */ public static WebsocketClientTransport create(String bindAddress, int port) { - Objects.requireNonNull(bindAddress, "bindAddress must not be null"); - - TcpClient client = TcpClient.create().host(bindAddress).port(port); - return create(client); + return create(TcpClient.create().host(bindAddress).port(port)); } /** @@ -90,36 +93,35 @@ public static WebsocketClientTransport create(String bindAddress, int port) { */ public static WebsocketClientTransport create(InetSocketAddress address) { Objects.requireNonNull(address, "address must not be null"); - - TcpClient client = TcpClient.create().remoteAddress(() -> address); - return create(client); + return create(TcpClient.create().remoteAddress(() -> address)); } /** * Creates a new instance * - * @param uri the URI to connect to + * @param client the {@link TcpClient} to use * @return a new instance - * @throws NullPointerException if {@code uri} is {@code null} + * @throws NullPointerException if {@code client} or {@code path} is {@code null} */ - public static WebsocketClientTransport create(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - TcpClient client = createClient(uri); - return create(HttpClient.from(client), uri.getPath()); + public static WebsocketClientTransport create(TcpClient client) { + return new WebsocketClientTransport(HttpClient.from(client), DEFAULT_PATH); } /** * Creates a new instance * - * @param client the {@link TcpClient} to use + * @param uri the URI to connect to * @return a new instance - * @throws NullPointerException if {@code client} or {@code path} is {@code null} + * @throws NullPointerException if {@code uri} is {@code null} */ - public static WebsocketClientTransport create(TcpClient client) { - Objects.requireNonNull(client, "client must not be null"); - - return create(HttpClient.from(client), DEFAULT_PATH); + public static WebsocketClientTransport create(URI uri) { + Objects.requireNonNull(uri, "uri must not be null"); + boolean isSecure = uri.getScheme().equals("wss") || uri.getScheme().equals("https"); + TcpClient client = + (isSecure ? TcpClient.create().secure() : TcpClient.create()) + .host(uri.getHost()) + .port(uri.getPort() == -1 ? (isSecure ? 443 : 80) : uri.getPort()); + return new WebsocketClientTransport(HttpClient.from(client), uri.getPath()); } /** @@ -131,33 +133,49 @@ public static WebsocketClientTransport create(TcpClient client) { * @throws NullPointerException if {@code client} or {@code path} is {@code null} */ public static WebsocketClientTransport create(HttpClient client, String path) { - Objects.requireNonNull(client, "client must not be null"); - Objects.requireNonNull(path, "path must not be null"); - - path = path.startsWith(DEFAULT_PATH) ? path : (DEFAULT_PATH + path); - return new WebsocketClientTransport(client, path); } - private static TcpClient createClient(URI uri) { - if (isSecure(uri)) { - return TcpClient.create().secure().host(uri.getHost()).port(getPort(uri, 443)); - } else { - return TcpClient.create().host(uri.getHost()).port(getPort(uri, 80)); + /** + * Add a header and value(s) to use for the WebSocket handshake request. + * + * @param name the header name + * @param values the header value(s) + * @return the same instance for method chaining + * @since 1.0.1 + */ + public WebsocketClientTransport header(String name, String... values) { + if (values != null) { + Arrays.stream(values).forEach(value -> headers.add(name, value)); } + return this; + } + + /** + * Provide a consumer to customize properties of the {@link WebsocketClientSpec} to use for + * WebSocket upgrades. The consumer is invoked immediately. + * + * @param configurer the configurer to apply to the spec + * @return the same instance for method chaining + * @since 1.0.1 + */ + public WebsocketClientTransport webSocketSpec(Consumer configurer) { + configurer.accept(specBuilder); + return this; } @Override public void setTransportHeaders(Supplier> transportHeaders) { - this.transportHeaders = - Objects.requireNonNull(transportHeaders, "transportHeaders must not be null"); + if (transportHeaders != null) { + transportHeaders.get().forEach((name, value) -> headers.add(name, value)); + } } @Override public Mono connect() { return client - .headers(headers -> transportHeaders.get().forEach(headers::set)) - .websocket(WebsocketClientSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK).build()) + .headers(headers -> headers.add(this.headers)) + .websocket(specBuilder.build()) .uri(path) .connect() .map(WebsocketDuplexConnection::new); diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java index 802a7f817..375938960 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java @@ -18,14 +18,16 @@ import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpHeaders; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; -import io.rsocket.transport.TransportHeaderAware; import io.rsocket.transport.netty.WebsocketDuplexConnection; import java.net.InetSocketAddress; -import java.util.Collections; +import java.util.Arrays; import java.util.Map; import java.util.Objects; +import java.util.function.Consumer; import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,13 +40,18 @@ * An implementation of {@link ServerTransport} that connects to a {@link ClientTransport} via a * Websocket. */ +@SuppressWarnings("deprecation") public final class WebsocketServerTransport extends BaseWebsocketServerTransport - implements TransportHeaderAware { + implements io.rsocket.transport.TransportHeaderAware { + private static final Logger logger = LoggerFactory.getLogger(WebsocketServerTransport.class); private final HttpServer server; - private Supplier> transportHeaders = Collections::emptyMap; + private HttpHeaders headers = new DefaultHttpHeaders(); + + private WebsocketServerSpec.Builder specBuilder = + WebsocketServerSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK); private WebsocketServerTransport(HttpServer server) { this.server = serverConfigurer.apply(Objects.requireNonNull(server, "server must not be null")); @@ -99,10 +106,39 @@ public static WebsocketServerTransport create(final HttpServer server) { return new WebsocketServerTransport(server); } + /** + * Add a header and value(s) to set on the response of WebSocket handshakes. + * + * @param name the header name + * @param values the header value(s) + * @return the same instance for method chaining + * @since 1.0.1 + */ + public WebsocketServerTransport header(String name, String... values) { + if (values != null) { + Arrays.stream(values).forEach(value -> headers.add(name, value)); + } + return this; + } + + /** + * Provide a consumer to customize properties of the {@link WebsocketServerSpec} to use for + * WebSocket upgrades. The consumer is invoked immediately. + * + * @param configurer the configurer to apply to the spec + * @return the same instance for method chaining + * @since 1.0.1 + */ + public WebsocketServerTransport webSocketSpec(Consumer configurer) { + configurer.accept(specBuilder); + return this; + } + @Override public void setTransportHeaders(Supplier> transportHeaders) { - this.transportHeaders = - Objects.requireNonNull(transportHeaders, "transportHeaders must not be null"); + if (transportHeaders != null) { + transportHeaders.get().forEach((name, value) -> headers.add(name, value)); + } } @Override @@ -111,13 +147,13 @@ public Mono start(ConnectionAcceptor acceptor) { return server .handle( (request, response) -> { - transportHeaders.get().forEach(response::addHeader); + response.headers(headers); return response.sendWebsocket( (in, out) -> acceptor .apply(new WebsocketDuplexConnection((Connection) in)) .then(out.neverComplete()), - WebsocketServerSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK).build()); + specBuilder.build()); }) .bind() .map(CloseableChannel::new); diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/UriUtilsTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/UriUtilsTest.java deleted file mode 100644 index 7e5bf688d..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/UriUtilsTest.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatNullPointerException; - -import java.net.URI; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; - -final class UriUtilsTest { - - @DisplayName("returns the port") - @Test - void getPort() { - assertThat(UriUtils.getPort(URI.create("http://localhost:42"), Integer.MAX_VALUE)) - .isEqualTo(42); - } - - @DisplayName("getPort throws NullPointerException with null uri") - @Test - void getPortNullUri() { - assertThatNullPointerException() - .isThrownBy(() -> UriUtils.getPort(null, 80)) - .withMessage("uri must not be null"); - } - - @DisplayName("returns the default port") - @Test - void getPortUnset() { - assertThat(UriUtils.getPort(URI.create("http://localhost"), Integer.MAX_VALUE)) - .isEqualTo(Integer.MAX_VALUE); - } - - @DisplayName("returns the URI's secureness") - @Test - void isSecure() { - assertThat(UriUtils.isSecure(URI.create("http://localhost"))).isFalse(); - assertThat(UriUtils.isSecure(URI.create("ws://localhost"))).isFalse(); - - assertThat(UriUtils.isSecure(URI.create("https://localhost"))).isTrue(); - assertThat(UriUtils.isSecure(URI.create("wss://localhost"))).isTrue(); - } - - @DisplayName("isSecure throws NullPointerException with null uri") - @Test - void isSecureNullUri() { - assertThatNullPointerException() - .isThrownBy(() -> UriUtils.isSecure(null)) - .withMessage("uri must not be null"); - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java index f94229848..944d20313 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java @@ -94,7 +94,7 @@ void createInetSocketAddress() { void createNullBindAddress() { assertThatNullPointerException() .isThrownBy(() -> WebsocketClientTransport.create(null, 8000)) - .withMessage("bindAddress must not be null"); + .withMessage("host"); } @DisplayName("create throws NullPointerException with null client") @@ -102,7 +102,7 @@ void createNullBindAddress() { void createNullHttpClient() { assertThatNullPointerException() .isThrownBy(() -> WebsocketClientTransport.create(null, "/test-path")) - .withMessage("client must not be null"); + .withMessage("HttpClient must not be null"); } @DisplayName("create throws NullPointerException with null address") @@ -156,12 +156,4 @@ void createUriPath() { void setTransportHeader() { WebsocketClientTransport.create(8000).setTransportHeaders(Collections::emptyMap); } - - @DisplayName("setTransportHeaders throws NullPointerException with null headers") - @Test - void setTransportHeadersNullHeaders() { - assertThatNullPointerException() - .isThrownBy(() -> WebsocketClientTransport.create(8000).setTransportHeaders(null)) - .withMessage("transportHeaders must not be null"); - } } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java index 5ac2e05fb..7f7567dc8 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java @@ -113,14 +113,6 @@ void setTransportHeader() { WebsocketServerTransport.create(8000).setTransportHeaders(Collections::emptyMap); } - @DisplayName("setTransportHeaders throws NullPointerException with null headers") - @Test - void setTransportHeadersNullHeaders() { - assertThatNullPointerException() - .isThrownBy(() -> WebsocketServerTransport.create(8000).setTransportHeaders(null)) - .withMessage("transportHeaders must not be null"); - } - @DisplayName("starts server") @Test void start() {