Skip to content

Commit

Permalink
Expose WebSocketClientSpec and WebSocketServerSpec for configuration (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
rstoyanchev committed May 20, 2020
1 parent 862186e commit 2ca6a9c
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 198 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2015-2018 the original author or authors.
* Copyright 2015-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,7 +22,10 @@
/**
* Extension interface to support Transports with headers at the transport layer, e.g. Websockets,
* Http2.
*
* @deprecated as of 1.0.1 in favor using properties on individual transports.
*/
@Deprecated
public interface TransportHeaderAware {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import io.rsocket.transport.netty.client.WebsocketClientTransport;
import io.rsocket.util.ByteBufPayload;
import java.time.Duration;
import java.util.Collections;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -71,16 +70,15 @@ public static void main(String[] args) {
logger.debug(
"\n\nStart of Authorized WebSocket Connection\n----------------------------------\n");

WebsocketClientTransport clientTransport =
WebsocketClientTransport.create(server.host(), server.port());

clientTransport.setTransportHeaders(() -> Collections.singletonMap("Authorization", "test"));
WebsocketClientTransport transport =
WebsocketClientTransport.create(server.host(), server.port())
.header("Authorization", "test");

RSocket clientRSocket =
RSocketConnector.create()
.keepAlive(Duration.ofMinutes(10), Duration.ofMinutes(10))
.payloadDecoder(PayloadDecoder.ZERO_COPY)
.connect(clientTransport)
.connect(transport)
.block();

Flux.range(1, 100)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,49 @@
package io.rsocket.transport.netty.client;

import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK;
import static io.rsocket.transport.netty.UriUtils.getPort;
import static io.rsocket.transport.netty.UriUtils.isSecure;

import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import io.rsocket.DuplexConnection;
import io.rsocket.transport.ClientTransport;
import io.rsocket.transport.ServerTransport;
import io.rsocket.transport.TransportHeaderAware;
import io.rsocket.transport.netty.WebsocketDuplexConnection;
import java.net.InetSocketAddress;
import java.net.URI;
import java.util.Collections;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.Supplier;
import reactor.core.publisher.Mono;
import reactor.netty.http.client.HttpClient;
import reactor.netty.http.client.WebsocketClientSpec;
import reactor.netty.tcp.TcpClient;

/**
* An implementation of {@link ClientTransport} that connects to a {@link ServerTransport} via a
* Websocket.
* An implementation of {@link ClientTransport} that connects to a {@link ServerTransport} over
* WebSocket.
*/
public final class WebsocketClientTransport implements ClientTransport, TransportHeaderAware {
@SuppressWarnings("deprecation")
public final class WebsocketClientTransport
implements ClientTransport, io.rsocket.transport.TransportHeaderAware {

private static final String DEFAULT_PATH = "/";

private final HttpClient client;

private final String path;

private Supplier<Map<String, String>> transportHeaders = Collections::emptyMap;
private HttpHeaders headers = new DefaultHttpHeaders();

private WebsocketClientSpec.Builder specBuilder =
WebsocketClientSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK);

private WebsocketClientTransport(HttpClient client, String path) {
Objects.requireNonNull(client, "HttpClient must not be null");
Objects.requireNonNull(path, "path must not be null");
this.client = client;
this.path = path;
this.path = path.startsWith("/") ? path : "/" + path;
}

/**
Expand All @@ -62,8 +69,7 @@ private WebsocketClientTransport(HttpClient client, String path) {
* @return a new instance
*/
public static WebsocketClientTransport create(int port) {
TcpClient client = TcpClient.create().port(port);
return create(client);
return create(TcpClient.create().port(port));
}

/**
Expand All @@ -75,10 +81,7 @@ public static WebsocketClientTransport create(int port) {
* @throws NullPointerException if {@code bindAddress} is {@code null}
*/
public static WebsocketClientTransport create(String bindAddress, int port) {
Objects.requireNonNull(bindAddress, "bindAddress must not be null");

TcpClient client = TcpClient.create().host(bindAddress).port(port);
return create(client);
return create(TcpClient.create().host(bindAddress).port(port));
}

/**
Expand All @@ -90,36 +93,35 @@ public static WebsocketClientTransport create(String bindAddress, int port) {
*/
public static WebsocketClientTransport create(InetSocketAddress address) {
Objects.requireNonNull(address, "address must not be null");

TcpClient client = TcpClient.create().remoteAddress(() -> address);
return create(client);
return create(TcpClient.create().remoteAddress(() -> address));
}

/**
* Creates a new instance
*
* @param uri the URI to connect to
* @param client the {@link TcpClient} to use
* @return a new instance
* @throws NullPointerException if {@code uri} is {@code null}
* @throws NullPointerException if {@code client} or {@code path} is {@code null}
*/
public static WebsocketClientTransport create(URI uri) {
Objects.requireNonNull(uri, "uri must not be null");

TcpClient client = createClient(uri);
return create(HttpClient.from(client), uri.getPath());
public static WebsocketClientTransport create(TcpClient client) {
return new WebsocketClientTransport(HttpClient.from(client), DEFAULT_PATH);
}

/**
* Creates a new instance
*
* @param client the {@link TcpClient} to use
* @param uri the URI to connect to
* @return a new instance
* @throws NullPointerException if {@code client} or {@code path} is {@code null}
* @throws NullPointerException if {@code uri} is {@code null}
*/
public static WebsocketClientTransport create(TcpClient client) {
Objects.requireNonNull(client, "client must not be null");

return create(HttpClient.from(client), DEFAULT_PATH);
public static WebsocketClientTransport create(URI uri) {
Objects.requireNonNull(uri, "uri must not be null");
boolean isSecure = uri.getScheme().equals("wss") || uri.getScheme().equals("https");
TcpClient client =
(isSecure ? TcpClient.create().secure() : TcpClient.create())
.host(uri.getHost())
.port(uri.getPort() == -1 ? (isSecure ? 443 : 80) : uri.getPort());
return new WebsocketClientTransport(HttpClient.from(client), uri.getPath());
}

/**
Expand All @@ -131,33 +133,49 @@ public static WebsocketClientTransport create(TcpClient client) {
* @throws NullPointerException if {@code client} or {@code path} is {@code null}
*/
public static WebsocketClientTransport create(HttpClient client, String path) {
Objects.requireNonNull(client, "client must not be null");
Objects.requireNonNull(path, "path must not be null");

path = path.startsWith(DEFAULT_PATH) ? path : (DEFAULT_PATH + path);

return new WebsocketClientTransport(client, path);
}

private static TcpClient createClient(URI uri) {
if (isSecure(uri)) {
return TcpClient.create().secure().host(uri.getHost()).port(getPort(uri, 443));
} else {
return TcpClient.create().host(uri.getHost()).port(getPort(uri, 80));
/**
* Add a header and value(s) to use for the WebSocket handshake request.
*
* @param name the header name
* @param values the header value(s)
* @return the same instance for method chaining
* @since 1.0.1
*/
public WebsocketClientTransport header(String name, String... values) {
if (values != null) {
Arrays.stream(values).forEach(value -> headers.add(name, value));
}
return this;
}

/**
* Provide a consumer to customize properties of the {@link WebsocketClientSpec} to use for
* WebSocket upgrades. The consumer is invoked immediately.
*
* @param configurer the configurer to apply to the spec
* @return the same instance for method chaining
* @since 1.0.1
*/
public WebsocketClientTransport webSocketSpec(Consumer<WebsocketClientSpec.Builder> configurer) {
configurer.accept(specBuilder);
return this;
}

@Override
public void setTransportHeaders(Supplier<Map<String, String>> transportHeaders) {
this.transportHeaders =
Objects.requireNonNull(transportHeaders, "transportHeaders must not be null");
if (transportHeaders != null) {
transportHeaders.get().forEach((name, value) -> headers.add(name, value));
}
}

@Override
public Mono<DuplexConnection> connect() {
return client
.headers(headers -> transportHeaders.get().forEach(headers::set))
.websocket(WebsocketClientSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK).build())
.headers(headers -> headers.add(this.headers))
.websocket(specBuilder.build())
.uri(path)
.connect()
.map(WebsocketDuplexConnection::new);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@

import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK;

import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import io.rsocket.transport.ClientTransport;
import io.rsocket.transport.ServerTransport;
import io.rsocket.transport.TransportHeaderAware;
import io.rsocket.transport.netty.WebsocketDuplexConnection;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -38,13 +40,18 @@
* An implementation of {@link ServerTransport} that connects to a {@link ClientTransport} via a
* Websocket.
*/
@SuppressWarnings("deprecation")
public final class WebsocketServerTransport extends BaseWebsocketServerTransport<CloseableChannel>
implements TransportHeaderAware {
implements io.rsocket.transport.TransportHeaderAware {

private static final Logger logger = LoggerFactory.getLogger(WebsocketServerTransport.class);

private final HttpServer server;

private Supplier<Map<String, String>> transportHeaders = Collections::emptyMap;
private HttpHeaders headers = new DefaultHttpHeaders();

private WebsocketServerSpec.Builder specBuilder =
WebsocketServerSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK);

private WebsocketServerTransport(HttpServer server) {
this.server = serverConfigurer.apply(Objects.requireNonNull(server, "server must not be null"));
Expand Down Expand Up @@ -99,10 +106,39 @@ public static WebsocketServerTransport create(final HttpServer server) {
return new WebsocketServerTransport(server);
}

/**
* Add a header and value(s) to set on the response of WebSocket handshakes.
*
* @param name the header name
* @param values the header value(s)
* @return the same instance for method chaining
* @since 1.0.1
*/
public WebsocketServerTransport header(String name, String... values) {
if (values != null) {
Arrays.stream(values).forEach(value -> headers.add(name, value));
}
return this;
}

/**
* Provide a consumer to customize properties of the {@link WebsocketServerSpec} to use for
* WebSocket upgrades. The consumer is invoked immediately.
*
* @param configurer the configurer to apply to the spec
* @return the same instance for method chaining
* @since 1.0.1
*/
public WebsocketServerTransport webSocketSpec(Consumer<WebsocketServerSpec.Builder> configurer) {
configurer.accept(specBuilder);
return this;
}

@Override
public void setTransportHeaders(Supplier<Map<String, String>> transportHeaders) {
this.transportHeaders =
Objects.requireNonNull(transportHeaders, "transportHeaders must not be null");
if (transportHeaders != null) {
transportHeaders.get().forEach((name, value) -> headers.add(name, value));
}
}

@Override
Expand All @@ -111,13 +147,13 @@ public Mono<CloseableChannel> start(ConnectionAcceptor acceptor) {
return server
.handle(
(request, response) -> {
transportHeaders.get().forEach(response::addHeader);
response.headers(headers);
return response.sendWebsocket(
(in, out) ->
acceptor
.apply(new WebsocketDuplexConnection((Connection) in))
.then(out.neverComplete()),
WebsocketServerSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK).build());
specBuilder.build());
})
.bind()
.map(CloseableChannel::new);
Expand Down

0 comments on commit 2ca6a9c

Please sign in to comment.