diff --git a/.gitignore b/.gitignore index bde7e8f50..92865ccca 100644 --- a/.gitignore +++ b/.gitignore @@ -65,7 +65,7 @@ atlassian-ide-plugin.xml # NetBeans specific files/directories .nbattrs -/bin +**/bin/* #.gitignore in subdirectory .gitignore diff --git a/rsocket-core/build.gradle b/rsocket-core/build.gradle index 7743715a3..780222efd 100644 --- a/rsocket-core/build.gradle +++ b/rsocket-core/build.gradle @@ -27,7 +27,6 @@ dependencies { api 'io.netty:netty-buffer' api 'io.projectreactor:reactor-core' - implementation 'org.jctools:jctools-core' implementation 'org.slf4j:slf4j-api' compileOnly 'com.google.code.findbugs:jsr305' diff --git a/rsocket-core/src/main/java/io/rsocket/StreamIdSupplier.java b/rsocket-core/src/main/java/io/rsocket/StreamIdSupplier.java index 383d1a2c9..f9985b4ac 100644 --- a/rsocket-core/src/main/java/io/rsocket/StreamIdSupplier.java +++ b/rsocket-core/src/main/java/io/rsocket/StreamIdSupplier.java @@ -16,23 +16,18 @@ package io.rsocket; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + final class StreamIdSupplier { - private int streamId; + private static final AtomicIntegerFieldUpdater STREAM_ID = + AtomicIntegerFieldUpdater.newUpdater(StreamIdSupplier.class, "streamId"); + private volatile int streamId; private StreamIdSupplier(int streamId) { this.streamId = streamId; } - synchronized int nextStreamId() { - streamId += 2; - return streamId; - } - - synchronized boolean isBeforeOrCurrent(int streamId) { - return this.streamId >= streamId && streamId > 0; - } - static StreamIdSupplier clientSupplier() { return new StreamIdSupplier(-1); } @@ -40,4 +35,12 @@ static StreamIdSupplier clientSupplier() { static StreamIdSupplier serverSupplier() { return new StreamIdSupplier(0); } + + int nextStreamId() { + return STREAM_ID.addAndGet(this, 2); + } + + boolean isBeforeOrCurrent(int streamId) { + return this.streamId >= streamId && streamId > 0; + } } diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java index e70f45d9f..0ac558ce5 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java @@ -16,21 +16,24 @@ package io.rsocket.fragmentation; -import static io.rsocket.fragmentation.FrameReassembler.createFrameReassembler; -import static io.rsocket.util.AbstractionLeakingFrameUtils.toAbstractionLeakingFrame; - import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.PooledByteBufAllocator; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.LongObjectHashMap; import io.rsocket.DuplexConnection; import io.rsocket.Frame; import io.rsocket.util.AbstractionLeakingFrameUtils; import io.rsocket.util.NumberUtils; -import java.util.Objects; -import org.jctools.maps.NonBlockingHashMapLong; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import java.util.Collection; +import java.util.Objects; + +import static io.rsocket.fragmentation.FrameReassembler.createFrameReassembler; +import static io.rsocket.util.AbstractionLeakingFrameUtils.toAbstractionLeakingFrame; + /** * A {@link DuplexConnection} implementation that fragments and reassembles {@link Frame}s. * @@ -46,8 +49,7 @@ public final class FragmentationDuplexConnection implements DuplexConnection { private final FrameFragmenter frameFragmenter; - private final NonBlockingHashMapLong frameReassemblers = - new NonBlockingHashMapLong<>(); + private final IntObjectHashMap frameReassemblers = new IntObjectHashMap<>(); /** * Creates a new instance. @@ -85,7 +87,16 @@ public FragmentationDuplexConnection( delegate .onClose() - .doFinally(signalType -> frameReassemblers.values().forEach(FrameReassembler::dispose)) + .doFinally( + signalType -> { + Collection values; + synchronized (this) { + values = frameReassemblers.values(); + } + for (FrameReassembler reassembler : values) { + reassembler.dispose(); + } + }) .subscribe(); } @@ -134,9 +145,13 @@ private Flux toFragmentedFrames(int streamId, io.rsocket.framing.Frame fr } private Mono toReassembledFrames(int streamId, io.rsocket.framing.Frame fragment) { - FrameReassembler frameReassembler = - frameReassemblers.computeIfAbsent( - (long) streamId, i -> createFrameReassembler(byteBufAllocator)); + FrameReassembler frameReassembler; + + synchronized (this) { + frameReassembler = + frameReassemblers.computeIfAbsent( + streamId, i -> createFrameReassembler(byteBufAllocator)); + } return Mono.justOrEmpty(frameReassembler.reassemble(fragment)) .map(frame -> toAbstractionLeakingFrame(byteBufAllocator, streamId, frame)); diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java index 82c5b0cc0..3a55a9767 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java @@ -16,28 +16,13 @@ package io.rsocket.client; -import io.rsocket.Availability; -import io.rsocket.Closeable; -import io.rsocket.Payload; -import io.rsocket.RSocket; +import io.rsocket.*; import io.rsocket.client.filter.RSocketSupplier; import io.rsocket.stat.Ewma; import io.rsocket.stat.FrugalQuantile; import io.rsocket.stat.Median; import io.rsocket.stat.Quantile; import io.rsocket.util.Clock; -import io.rsocket.util.RSocketProxy; -import java.nio.channels.ClosedChannelException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.Iterator; -import java.util.Random; -import java.util.Set; -import java.util.concurrent.ThreadLocalRandom; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; @@ -48,6 +33,17 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; + /** * An implementation of {@link Mono} that load balances across a pool of RSockets and emits one when * it is subscribed to @@ -56,9 +52,7 @@ */ public abstract class LoadBalancedRSocketMono extends Mono implements Availability, Closeable { - - private static final Logger logger = LoggerFactory.getLogger(LoadBalancedRSocketMono.class); - + public static final double DEFAULT_EXP_FACTOR = 4.0; public static final double DEFAULT_LOWER_QUANTILE = 0.2; public static final double DEFAULT_HIGHER_QUANTILE = 0.8; @@ -68,38 +62,33 @@ public abstract class LoadBalancedRSocketMono extends Mono public static final int DEFAULT_MAX_APERTURE = 100; public static final long DEFAULT_MAX_REFRESH_PERIOD_MS = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES); - + private static final Logger logger = LoggerFactory.getLogger(LoadBalancedRSocketMono.class); private static final long APERTURE_REFRESH_PERIOD = Clock.unit().convert(15, TimeUnit.SECONDS); private static final int EFFORT = 5; private static final long DEFAULT_INITIAL_INTER_ARRIVAL_TIME = Clock.unit().convert(1L, TimeUnit.SECONDS); private static final int DEFAULT_INTER_ARRIVAL_FACTOR = 500; - + + private static final FailingRSocket FAILING_REACTIVE_SOCKET = new FailingRSocket(); + protected final Mono rSocketMono; private final double minPendings; private final double maxPendings; private final int minAperture; private final int maxAperture; private final long maxRefreshPeriod; - private final double expFactor; private final Quantile lowerQuantile; private final Quantile higherQuantile; - - private int pendingSockets; private final ArrayList activeSockets; - private final ArrayList activeFactories; - private final FactoriesRefresher factoryRefresher; - private final Ewma pendings; + private final MonoProcessor onClose = MonoProcessor.create(); + private final RSocketSupplierPool pool; private volatile int targetAperture; private long lastApertureRefresh; private long refreshPeriod; + private int pendingSockets; private volatile long lastRefresh; - - private final MonoProcessor onClose = MonoProcessor.create(); - protected final MonoProcessor started = MonoProcessor.create(); - protected final Mono rSocketMono; - + /** * @param factories the source (factories) of RSocket * @param expFactor how aggressive is the algorithm toward outliers. A higher number means we send @@ -130,30 +119,30 @@ private LoadBalancedRSocketMono( this.expFactor = expFactor; this.lowerQuantile = new FrugalQuantile(lowQuantile); this.higherQuantile = new FrugalQuantile(highQuantile); - + this.activeSockets = new ArrayList<>(); - this.activeFactories = new ArrayList<>(); this.pendingSockets = 0; - this.factoryRefresher = new FactoriesRefresher(); - + this.minPendings = minPendings; this.maxPendings = maxPendings; this.pendings = new Ewma(15, TimeUnit.SECONDS, (minPendings + maxPendings) / 2.0); - + this.minAperture = minAperture; this.maxAperture = maxAperture; this.targetAperture = minAperture; - + this.maxRefreshPeriod = Clock.unit().convert(maxRefreshPeriodMs, TimeUnit.MILLISECONDS); this.lastApertureRefresh = Clock.now(); this.refreshPeriod = Clock.unit().convert(15L, TimeUnit.SECONDS); this.lastRefresh = Clock.now(); - - factories.subscribe(factoryRefresher); - - rSocketMono = Mono.fromCallable(this::select); + this.pool = new RSocketSupplierPool(factories); + refreshSockets(); + + rSocketMono = Mono.fromSupplier(this::select); + + onClose.doFinally(signalType -> pool.dispose()).subscribe(); } - + public static LoadBalancedRSocketMono create( Publisher> factories) { return create( @@ -167,7 +156,7 @@ public static LoadBalancedRSocketMono create( DEFAULT_MAX_APERTURE, DEFAULT_MAX_REFRESH_PERIOD_MS); } - + public static LoadBalancedRSocketMono create( Publisher> factories, double expFactor, @@ -190,79 +179,70 @@ public static LoadBalancedRSocketMono create( maxRefreshPeriodMs) { @Override public void subscribe(CoreSubscriber s) { - started.then(rSocketMono).subscribe(s); + rSocketMono.subscribe(s); } }; } - + + /** + * Responsible for: - refreshing the aperture - asynchronously adding/removing reactive sockets to + * match targetAperture - periodically append a new connection + */ + private synchronized void refreshSockets() { + refreshAperture(); + int n = activeSockets.size(); + if (n < targetAperture && !pool.isPoolEmpty()) { + logger.debug( + "aperture {} is below target {}, adding {} sockets", + n, + targetAperture, + targetAperture - n); + addSockets(targetAperture - n); + } else if (targetAperture < activeSockets.size()) { + logger.debug("aperture {} is above target {}, quicking 1 socket", n, targetAperture); + quickSlowestRS(); + } + + long now = Clock.now(); + if (now - lastRefresh >= refreshPeriod) { + long prev = refreshPeriod; + refreshPeriod = (long) Math.min(refreshPeriod * 1.5, maxRefreshPeriod); + logger.debug("Bumping refresh period, {}->{}", prev / 1000, refreshPeriod / 1000); + lastRefresh = now; + addSockets(1); + } + } + private synchronized void addSockets(int numberOfNewSocket) { int n = numberOfNewSocket; - if (n > activeFactories.size()) { - n = activeFactories.size(); + int poolSize = pool.poolSize(); + if (n > poolSize) { + n = poolSize; logger.debug( "addSockets({}) restricted by the number of factories, i.e. addSockets({})", numberOfNewSocket, n); } - - Random rng = ThreadLocalRandom.current(); - while (n > 0) { - int size = activeFactories.size(); - if (size == 1) { - RSocketSupplier factory = activeFactories.get(0); - if (factory.availability() > 0.0) { - activeFactories.remove(0); - pendingSockets++; - factory.get().subscribe(new SocketAdder(factory)); - } - break; - } - RSocketSupplier factory0 = null; - RSocketSupplier factory1 = null; - int i0 = 0; - int i1 = 0; - for (int i = 0; i < EFFORT; i++) { - i0 = rng.nextInt(size); - i1 = rng.nextInt(size - 1); - if (i1 >= i0) { - i1++; - } - factory0 = activeFactories.get(i0); - factory1 = activeFactories.get(i1); - if (factory0.availability() > 0.0 && factory1.availability() > 0.0) { - break; - } - } - - if (factory0.availability() < factory1.availability()) { - n--; - pendingSockets++; - // cheaper to permute activeFactories.get(i1) with the last item and remove the last - // rather than doing a activeFactories.remove(i1) - if (i1 < size - 1) { - activeFactories.set(i1, activeFactories.get(size - 1)); - } - activeFactories.remove(size - 1); - factory1.get().subscribe(new SocketAdder(factory1)); + + for (int i = 0; i < n; i++) { + Optional optional = pool.get(); + + if (optional.isPresent()) { + RSocketSupplier supplier = optional.get(); + WeightedSocket socket = new WeightedSocket(supplier, lowerQuantile, higherQuantile); + activeSockets.add(socket); } else { - n--; - pendingSockets++; - // c.f. above - if (i0 < size - 1) { - activeFactories.set(i0, activeFactories.get(size - 1)); - } - activeFactories.remove(size - 1); - factory0.get().subscribe(new SocketAdder(factory0)); + break; } } } - + private synchronized void refreshAperture() { int n = activeSockets.size(); if (n == 0) { return; } - + double p = 0.0; for (WeightedSocket wrs : activeSockets) { p += wrs.getPending(); @@ -270,7 +250,7 @@ private synchronized void refreshAperture() { p /= n + pendingSockets; pendings.insert(p); double avgPending = pendings.value(); - + long now = Clock.now(); boolean underRateLimit = now - lastApertureRefresh > APERTURE_REFRESH_PERIOD; if (avgPending < 1.0 && underRateLimit) { @@ -279,7 +259,7 @@ private synchronized void refreshAperture() { updateAperture(targetAperture + 1, now); } } - + /** * Update the aperture value and ensure its value stays in the right range. * @@ -290,11 +270,11 @@ private void updateAperture(int newValue, long now) { int previous = targetAperture; targetAperture = newValue; targetAperture = Math.max(minAperture, targetAperture); - int maxAperture = Math.min(this.maxAperture, activeSockets.size() + activeFactories.size()); + int maxAperture = Math.min(this.maxAperture, activeSockets.size() + pool.poolSize()); targetAperture = Math.min(maxAperture, targetAperture); lastApertureRefresh = now; pendings.reset((minPendings + maxPendings) / 2); - + if (targetAperture != previous) { logger.debug( "Current pending={}, new target={}, previous target={}", @@ -303,41 +283,12 @@ private void updateAperture(int newValue, long now) { previous); } } - - /** - * Responsible for: - refreshing the aperture - asynchronously adding/removing reactive sockets to - * match targetAperture - periodically append a new connection - */ - private synchronized void refreshSockets() { - refreshAperture(); - int n = pendingSockets + activeSockets.size(); - if (n < targetAperture && !activeFactories.isEmpty()) { - logger.debug( - "aperture {} is below target {}, adding {} sockets", - n, - targetAperture, - targetAperture - n); - addSockets(targetAperture - n); - } else if (targetAperture < activeSockets.size()) { - logger.debug("aperture {} is above target {}, quicking 1 socket", n, targetAperture); - quickSlowestRS(); - } - - long now = Clock.now(); - if (now - lastRefresh >= refreshPeriod) { - long prev = refreshPeriod; - refreshPeriod = (long) Math.min(refreshPeriod * 1.5, maxRefreshPeriod); - logger.debug("Bumping refresh period, {}->{}", prev / 1000, refreshPeriod / 1000); - lastRefresh = now; - addSockets(1); - } - } - + private synchronized void quickSlowestRS() { if (activeSockets.size() <= 1) { return; } - + WeightedSocket slowest = null; double lowestAvailability = Double.MAX_VALUE; for (WeightedSocket socket : activeSockets) { @@ -354,26 +305,12 @@ private synchronized void quickSlowestRS() { slowest = socket; } } - + if (slowest != null) { - removeSocket(slowest, false); - } - } - - private synchronized void removeSocket(WeightedSocket socket, boolean refresh) { - try { - logger.debug("Removing socket: -> " + socket); - activeSockets.remove(socket); - activeFactories.add(socket.getFactory()); - socket.dispose(); - if (refresh) { - refreshSockets(); - } - } catch (Exception e) { - logger.warn("Exception while closing a RSocket", e); + activeSockets.remove(slowest); } } - + @Override public synchronized double availability() { double currentAvailability = 0.0; @@ -383,24 +320,24 @@ public synchronized double availability() { } currentAvailability /= activeSockets.size(); } - + return currentAvailability; } - + private synchronized RSocket select() { if (activeSockets.isEmpty()) { return FAILING_REACTIVE_SOCKET; } refreshSockets(); - + int size = activeSockets.size(); if (size == 1) { return activeSockets.get(0); } - + WeightedSocket rsc1 = null; WeightedSocket rsc2 = null; - + Random rng = ThreadLocalRandom.current(); for (int i = 0; i < EFFORT; i++) { int i1 = rng.nextInt(size); @@ -413,11 +350,11 @@ private synchronized RSocket select() { if (rsc1.availability() > 0.0 && rsc2.availability() > 0.0) { break; } - if (i + 1 == EFFORT && !activeFactories.isEmpty()) { + if (i + 1 == EFFORT && !pool.isPoolEmpty()) { addSockets(1); } } - + double w1 = algorithmicWeight(rsc1); double w2 = algorithmicWeight(rsc2); if (w1 < w2) { @@ -426,22 +363,22 @@ private synchronized RSocket select() { return rsc1; } } - + private double algorithmicWeight(WeightedSocket socket) { if (socket == null || socket.availability() == 0.0) { return 0.0; } - + int pendings = socket.getPending(); double latency = socket.getPredictedLatency(); - + double low = lowerQuantile.estimation(); double high = Math.max( higherQuantile.estimation(), low * 1.001); // ensure higherQuantile > lowerQuantile + .1% double bandWidth = Math.max(high - low, 1); - + if (latency < low) { double alpha = (low - latency) / bandWidth; double bonusFactor = Math.pow(1 + alpha, expFactor); @@ -451,274 +388,129 @@ private double algorithmicWeight(WeightedSocket socket) { double penaltyFactor = Math.pow(1 + alpha, expFactor); latency *= penaltyFactor; } - + return socket.availability() * 1.0 / (1.0 + latency * (pendings + 1)); } - + @Override public synchronized String toString() { return "LoadBalancer(a:" - + activeSockets.size() - + ", f: " - + activeFactories.size() - + ", avgPendings=" - + pendings.value() - + ", targetAperture=" - + targetAperture - + ", band=[" - + lowerQuantile.estimation() - + ", " - + higherQuantile.estimation() - + "])"; + + activeSockets.size() + + ", f: " + + pool.poolSize() + + ", avgPendings=" + + pendings.value() + + ", targetAperture=" + + targetAperture + + ", band=[" + + lowerQuantile.estimation() + + ", " + + higherQuantile.estimation() + + "])"; } - + @Override public void dispose() { - synchronized (this) { - factoryRefresher.close(); - activeFactories.clear(); + synchronized (this) {; activeSockets.forEach(WeightedSocket::dispose); activeSockets.clear(); onClose.onComplete(); } } - + @Override public boolean isDisposed() { return onClose.isDisposed(); } - + @Override public Mono onClose() { return onClose; } - - /** - * This subscriber role is to subscribe to the list of server identifier, and update the factory - * list. - */ - private class FactoriesRefresher implements Subscriber> { - private Subscription subscription; - - @Override - public void onSubscribe(Subscription subscription) { - this.subscription = subscription; - subscription.request(Long.MAX_VALUE); - } - - @Override - public void onNext(Collection newFactories) { - synchronized (LoadBalancedRSocketMono.this) { - Set current = new HashSet<>(activeFactories.size() + activeSockets.size()); - current.addAll(activeFactories); - for (WeightedSocket socket : activeSockets) { - RSocketSupplier factory = socket.getFactory(); - current.add(factory); - } - - Set removed = new HashSet<>(current); - removed.removeAll(newFactories); - - Set added = new HashSet<>(newFactories); - added.removeAll(current); - - boolean changed = false; - Iterator it0 = activeSockets.iterator(); - while (it0.hasNext()) { - WeightedSocket socket = it0.next(); - if (removed.contains(socket.getFactory())) { - it0.remove(); - try { - changed = true; - socket.dispose(); - } catch (Exception e) { - logger.warn("Exception while closing a RSocket", e); - } - } - } - Iterator it1 = activeFactories.iterator(); - while (it1.hasNext()) { - RSocketSupplier factory = it1.next(); - if (removed.contains(factory)) { - it1.remove(); - changed = true; - } - } - - activeFactories.addAll(added); - - if (changed && logger.isDebugEnabled()) { - StringBuilder msgBuilder = new StringBuilder(); - msgBuilder - .append("\nUpdated active factories (size: ") - .append(activeFactories.size()) - .append(")\n"); - for (RSocketSupplier f : activeFactories) { - msgBuilder.append(" + ").append(f).append('\n'); - } - msgBuilder.append("Active sockets:\n"); - for (WeightedSocket socket : activeSockets) { - msgBuilder.append(" + ").append(socket).append('\n'); - } - logger.debug(msgBuilder.toString()); - } - } - refreshSockets(); - } - - @Override - public void onError(Throwable t) { - // TODO: retry - logger.error("Error refreshing RSocket factories. They would no longer be refreshed.", t); - } - - @Override - public void onComplete() { - // TODO: retry - logger.warn("RSocket factories source completed. They would no longer be refreshed."); - } - - void close() { - subscription.cancel(); - } - } - - private class SocketAdder implements Subscriber { - private final RSocketSupplier factory; - - private int errors; - - private SocketAdder(RSocketSupplier factory) { - this.factory = factory; - } - - @Override - public void onSubscribe(Subscription s) { - s.request(1L); - } - - @Override - public void onNext(RSocket rs) { - synchronized (LoadBalancedRSocketMono.this) { - if (activeSockets.size() >= targetAperture) { - quickSlowestRS(); - } - - WeightedSocket weightedSocket = - new WeightedSocket(rs, factory, lowerQuantile, higherQuantile); - logger.debug("Adding new WeightedSocket {}", weightedSocket); - - activeSockets.add(weightedSocket); - started.onComplete(); - pendingSockets -= 1; - } - } - - @Override - public void onError(Throwable t) { - logger.warn("Exception while subscribing to the RSocket source", t); - synchronized (LoadBalancedRSocketMono.this) { - pendingSockets -= 1; - if (++errors < 5) { - activeFactories.add(factory); - } else { - logger.warn( - "Exception count greater than 5, not re-adding factory {}", factory.toString()); - } - } - } - - @Override - public void onComplete() {} - } - - private static final FailingRSocket FAILING_REACTIVE_SOCKET = new FailingRSocket(); - + /** * (Null Object Pattern) This failing RSocket never succeed, it is useful for simplifying the code * when dealing with edge cases. */ private static class FailingRSocket implements RSocket { - + private static final Mono errorVoid = Mono.error(NoAvailableRSocketException.INSTANCE); private static final Mono errorPayload = Mono.error(NoAvailableRSocketException.INSTANCE); - + @Override public Mono fireAndForget(Payload payload) { return errorVoid; } - + @Override public Mono requestResponse(Payload payload) { return errorPayload; } - + @Override public Flux requestStream(Payload payload) { return errorPayload.flux(); } - + @Override public Flux requestChannel(Publisher payloads) { return errorPayload.flux(); } - + @Override public Mono metadataPush(Payload payload) { return errorVoid; } - + @Override public double availability() { return 0; } - + @Override public void dispose() {} - + @Override public boolean isDisposed() { return true; } - + @Override public Mono onClose() { return Mono.empty(); } } - + /** * Wrapper of a RSocket, it computes statistics about the req/resp calls and update availability * accordingly. */ - private class WeightedSocket extends RSocketProxy implements LoadBalancerSocketMetrics { - + private class WeightedSocket extends AbstractRSocket implements LoadBalancerSocketMetrics { + private static final double STARTUP_PENALTY = Long.MAX_VALUE >> 12; - - private RSocketSupplier factory; private final Quantile lowerQuantile; private final Quantile higherQuantile; private final long inactivityFactor; - + private final MonoProcessor rSocketMono; private volatile int pending; // instantaneous rate private long stamp; // last timestamp we sent a request private long stamp0; // last timestamp we sent a request or receive a response private long duration; // instantaneous cumulative duration - + private Median median; private Ewma interArrivalTime; - + private AtomicLong pendingStreams; // number of active streams - + + private volatile double availability = 0.0; + WeightedSocket( - RSocket child, RSocketSupplier factory, Quantile lowerQuantile, Quantile higherQuantile, int inactivityFactor) { - super(child); - this.factory = factory; + this.rSocketMono = MonoProcessor.create(); this.lowerQuantile = lowerQuantile; this.higherQuantile = higherQuantile; this.inactivityFactor = inactivityFactor; @@ -730,62 +522,143 @@ private class WeightedSocket extends RSocketProxy implements LoadBalancerSocketM this.median = new Median(); this.interArrivalTime = new Ewma(1, TimeUnit.MINUTES, DEFAULT_INITIAL_INTER_ARRIVAL_TIME); this.pendingStreams = new AtomicLong(); - child.onClose().doFinally(signalType -> removeSocket(this, true)).subscribe(); + + WeightedSocket.this + .onClose() + .doFinally( + s -> { + pool.accept(factory); + activeSockets.remove(WeightedSocket.this); + refreshSockets(); + }) + .subscribe(); + + factory + .get() + .retryBackoff(5, Duration.ofMillis(500)) + .doOnError( + throwable -> { + logger.error("error while connecting {}", throwable); + WeightedSocket.this.dispose(); + }) + .subscribe( + rSocket -> { + // When RSocket is closed, close the WeightedSocket + rSocket + .onClose() + .doFinally( + signalType -> { + System.out.println("RSocket closed"); + WeightedSocket.this.dispose(); + }) + .subscribe(); + + // When the factory is closed, close the RSocket + factory + .onClose() + .doFinally( + signalType -> { + System.out.println("Factory closed"); + rSocket.dispose(); + }) + .subscribe(); + + // When the WeightedSocket is closed, close the RSocket + WeightedSocket.this + .onClose() + .doFinally( + signalType -> { + System.out.println("WeightedSocket closed"); + rSocket.dispose(); + }) + .subscribe(); + + synchronized (LoadBalancedRSocketMono.this) { + if (activeSockets.size() >= targetAperture) { + quickSlowestRS(); + pendingSockets -= 1; + } + } + + rSocketMono.onNext(rSocket); + availability = 1.0; + }); } - - WeightedSocket( - RSocket child, RSocketSupplier factory, Quantile lowerQuantile, Quantile higherQuantile) { - this(child, factory, lowerQuantile, higherQuantile, DEFAULT_INTER_ARRIVAL_FACTOR); + + WeightedSocket(RSocketSupplier factory, Quantile lowerQuantile, Quantile higherQuantile) { + this(factory, lowerQuantile, higherQuantile, DEFAULT_INTER_ARRIVAL_FACTOR); } - + @Override public Mono requestResponse(Payload payload) { - return Mono.from( - subscriber -> - source.requestResponse(payload).subscribe(new LatencySubscriber<>(subscriber, this))); - } - + return rSocketMono.flatMap( + source -> { + return Mono.from( + subscriber -> + source + .requestResponse(payload) + .subscribe(new LatencySubscriber<>(subscriber, this))); + }); + } + @Override public Flux requestStream(Payload payload) { - return Flux.from( - subscriber -> - source.requestStream(payload).subscribe(new CountingSubscriber<>(subscriber, this))); - } - + + return rSocketMono.flatMapMany( + source -> { + return Flux.from( + subscriber -> + source + .requestStream(payload) + .subscribe(new CountingSubscriber<>(subscriber, this))); + }); + } + @Override public Mono fireAndForget(Payload payload) { - return Mono.from( - subscriber -> - source.fireAndForget(payload).subscribe(new CountingSubscriber<>(subscriber, this))); - } - + + return rSocketMono.flatMap( + source -> { + return Mono.from( + subscriber -> + source + .fireAndForget(payload) + .subscribe(new CountingSubscriber<>(subscriber, this))); + }); + } + @Override public Mono metadataPush(Payload payload) { - return Mono.from( - subscriber -> - source.metadataPush(payload).subscribe(new CountingSubscriber<>(subscriber, this))); - } - + return rSocketMono.flatMap( + source -> { + return Mono.from( + subscriber -> + source + .metadataPush(payload) + .subscribe(new CountingSubscriber<>(subscriber, this))); + }); + } + @Override public Flux requestChannel(Publisher payloads) { - return Flux.from( - subscriber -> - source - .requestChannel(payloads) - .subscribe(new CountingSubscriber<>(subscriber, this))); - } - - RSocketSupplier getFactory() { - return factory; - } - + + return rSocketMono.flatMapMany( + source -> { + return Flux.from( + subscriber -> + source + .requestChannel(payloads) + .subscribe(new CountingSubscriber<>(subscriber, this))); + }); + } + synchronized double getPredictedLatency() { long now = Clock.now(); long elapsed = Math.max(now - stamp, 1L); - + double weight; double prediction = median.estimation(); - + if (prediction == 0.0) { if (pending == 0) { weight = 0.0; // first request @@ -801,7 +674,7 @@ synchronized double getPredictedLatency() { } else { double predicted = prediction * pending; double instant = instantaneous(now); - + if (predicted < instant) { // NB: (0.0 < 0.0) == false weight = instant / pending; // NB: pending never equal 0 here } else { @@ -809,18 +682,18 @@ synchronized double getPredictedLatency() { weight = prediction; } } - + return weight; } - + int getPending() { return pending; } - + private synchronized long instantaneous(long now) { return duration + (now - stamp0) * pending; } - + private synchronized long incr() { long now = Clock.now(); interArrivalTime.insert(now - stamp); @@ -830,7 +703,7 @@ private synchronized long incr() { stamp0 = now; return now; } - + private synchronized long decr(long timestamp) { long now = Clock.now(); duration += Math.max(0, now - stamp0) * pending - (now - timestamp); @@ -838,90 +711,86 @@ private synchronized long decr(long timestamp) { stamp0 = now; return now; } - + private synchronized void observe(double rtt) { median.insert(rtt); lowerQuantile.insert(rtt); higherQuantile.insert(rtt); } - - @Override - public void dispose() { - source.dispose(); - } - + @Override - public boolean isDisposed() { - return source.isDisposed(); + public double availability() { + return availability; } - + @Override public String toString() { return "WeightedSocket(" - + "median=" - + median.estimation() - + " quantile-low=" - + lowerQuantile.estimation() - + " quantile-high=" - + higherQuantile.estimation() - + " inter-arrival=" - + interArrivalTime.value() - + " duration/pending=" - + (pending == 0 ? 0 : (double) duration / pending) - + " pending=" - + pending - + " availability= " - + availability() - + ")->" - + source; - } - + + "median=" + + median.estimation() + + " quantile-low=" + + lowerQuantile.estimation() + + " quantile-high=" + + higherQuantile.estimation() + + " inter-arrival=" + + interArrivalTime.value() + + " duration/pending=" + + (pending == 0 ? 0 : (double) duration / pending) + + " pending=" + + pending + + " availability= " + + availability() + + ")->"; + + + } + @Override public double medianLatency() { return median.estimation(); } - + @Override public double lowerQuantileLatency() { return lowerQuantile.estimation(); } - + @Override public double higherQuantileLatency() { return higherQuantile.estimation(); } - + @Override public double interArrivalTime() { return interArrivalTime.value(); } - + @Override public int pending() { return pending; } - + @Override public long lastTimeUsedMillis() { return stamp0; } - + /** * Subscriber wrapper used for request/response interaction model, measure and collect latency * information. */ private class LatencySubscriber implements Subscriber { private final Subscriber child; - private final WeightedSocket socket; + private final LoadBalancedRSocketMono.WeightedSocket socket; private final AtomicBoolean done; private long start; - - LatencySubscriber(Subscriber child, WeightedSocket socket) { + + LatencySubscriber(Subscriber child, LoadBalancedRSocketMono.WeightedSocket socket) { this.child = child; this.socket = socket; this.done = new AtomicBoolean(false); } - + @Override public void onSubscribe(Subscription s) { start = incr(); @@ -931,7 +800,7 @@ public void onSubscribe(Subscription s) { public void request(long n) { s.request(n); } - + @Override public void cancel() { if (done.compareAndSet(false, true)) { @@ -941,25 +810,25 @@ public void cancel() { } }); } - + @Override public void onNext(U u) { child.onNext(u); } - + @Override public void onError(Throwable t) { if (done.compareAndSet(false, true)) { child.onError(t); long now = decr(start); if (t instanceof TransportException || t instanceof ClosedChannelException) { - removeSocket(socket, true); + socket.dispose(); } else if (t instanceof TimeoutException) { observe(now - start); } } } - + @Override public void onComplete() { if (done.compareAndSet(false, true)) { @@ -969,40 +838,41 @@ public void onComplete() { } } } - + /** * Subscriber wrapper used for stream like interaction model, it only counts the number of * active streams */ private class CountingSubscriber implements Subscriber { private final Subscriber child; - private final WeightedSocket socket; - - CountingSubscriber(Subscriber child, WeightedSocket socket) { + private final LoadBalancedRSocketMono.WeightedSocket socket; + + CountingSubscriber(Subscriber child, LoadBalancedRSocketMono.WeightedSocket socket) { this.child = child; this.socket = socket; } - + @Override public void onSubscribe(Subscription s) { socket.pendingStreams.incrementAndGet(); child.onSubscribe(s); } - + @Override public void onNext(U u) { child.onNext(u); } - + @Override public void onError(Throwable t) { socket.pendingStreams.decrementAndGet(); child.onError(t); if (t instanceof TransportException || t instanceof ClosedChannelException) { - removeSocket(socket, true); + activeSockets.remove(socket); + refreshSockets(); } } - + @Override public void onComplete() { socket.pendingStreams.decrementAndGet(); @@ -1011,3 +881,4 @@ public void onComplete() { } } } + diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java new file mode 100644 index 000000000..6f1fba30a --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java @@ -0,0 +1,189 @@ +package io.rsocket.client; + +import io.rsocket.Closeable; +import io.rsocket.client.filter.RSocketSupplier; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; + +import java.time.Duration; +import java.util.*; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Consumer; +import java.util.function.Supplier; + +public class RSocketSupplierPool + implements Supplier>, Consumer, Closeable { + private static final Logger logger = LoggerFactory.getLogger(RSocketSupplierPool.class); + private static final int EFFORT = 5; + + private final ArrayList factoryPool; + private final ArrayList leasedSuppliers; + + private final MonoProcessor onClose; + + public RSocketSupplierPool(Publisher> publisher) { + this.onClose = MonoProcessor.create(); + this.factoryPool = new ArrayList<>(); + this.leasedSuppliers = new ArrayList<>(); + + Disposable disposable = + Flux.from(publisher) + .doOnNext(this::handleNewFactories) + .onErrorResume( + t -> { + logger.error("error streaming RSocketSuppliers", t); + return Mono.delay(Duration.ofSeconds(10)).then(Mono.error(t)); + }) + .subscribe(); + + onClose.doFinally(s -> disposable.dispose()).subscribe(); + } + + private synchronized void handleNewFactories(Collection newFactories) { + Set current = new HashSet<>(factoryPool.size() + leasedSuppliers.size()); + current.addAll(factoryPool); + current.addAll(leasedSuppliers); + + Set removed = new HashSet<>(current); + removed.removeAll(newFactories); + + Set added = new HashSet<>(newFactories); + added.removeAll(current); + + boolean changed = false; + Iterator it0 = leasedSuppliers.iterator(); + while (it0.hasNext()) { + RSocketSupplier supplier = it0.next(); + if (removed.contains(supplier)) { + it0.remove(); + try { + changed = true; + supplier.dispose(); + } catch (Exception e) { + logger.warn("Exception while closing a RSocket", e); + } + } + } + + Iterator it1 = factoryPool.iterator(); + while (it1.hasNext()) { + RSocketSupplier supplier = it1.next(); + if (removed.contains(supplier)) { + it1.remove(); + try { + changed = true; + supplier.dispose(); + } catch (Exception e) { + logger.warn("Exception while closing a RSocket", e); + } + } + } + + factoryPool.addAll(added); + + if (changed && logger.isDebugEnabled()) { + StringBuilder msgBuilder = new StringBuilder(); + msgBuilder + .append("\nUpdated active factories (size: ") + .append(factoryPool.size()) + .append(")\n"); + for (RSocketSupplier f : factoryPool) { + msgBuilder.append(" + ").append(f).append('\n'); + } + msgBuilder.append("Active sockets:\n"); + for (RSocketSupplier socket : leasedSuppliers) { + msgBuilder.append(" + ").append(socket).append('\n'); + } + logger.debug(msgBuilder.toString()); + } + } + + @Override + public synchronized void accept(RSocketSupplier rSocketSupplier) { + leasedSuppliers.remove(rSocketSupplier); + if (!rSocketSupplier.isDisposed()) { + factoryPool.add(rSocketSupplier); + } + } + + @Override + public synchronized Optional get() { + Optional optional = Optional.empty(); + int poolSize = factoryPool.size(); + if (poolSize == 1) { + RSocketSupplier rSocketSupplier = factoryPool.get(0); + if (rSocketSupplier.availability() > 0.0) { + factoryPool.remove(0); + leasedSuppliers.add(rSocketSupplier); + optional = Optional.of(rSocketSupplier); + } + } else if (poolSize > 1) { + Random rng = ThreadLocalRandom.current(); + int size = factoryPool.size(); + RSocketSupplier factory0 = null; + RSocketSupplier factory1 = null; + int i0 = 0; + int i1 = 0; + for (int i = 0; i < EFFORT; i++) { + i0 = rng.nextInt(size); + i1 = rng.nextInt(size - 1); + if (i1 >= i0) { + i1++; + } + factory0 = factoryPool.get(i0); + factory1 = factoryPool.get(i1); + if (factory0.availability() > 0.0 && factory1.availability() > 0.0) { + break; + } + } + if (factory0.availability() > factory1.availability()) { + factoryPool.remove(i0); + leasedSuppliers.add(factory0); + optional = Optional.of(factory0); + } else { + factoryPool.remove(i1); + leasedSuppliers.add(factory1); + optional = Optional.of(factory1); + } + } + + return optional; + } + + @Override + public Mono onClose() { + return onClose; + } + + @Override + public void dispose() { + if (!onClose.isDisposed()) { + onClose.onComplete(); + + close(factoryPool); + close(leasedSuppliers); + } + } + + private void close(Collection suppliers) { + for (RSocketSupplier supplier : suppliers) { + try { + supplier.dispose(); + } catch (Throwable t) { + } + } + } + + public synchronized int poolSize() { + return factoryPool.size(); + } + + public synchronized boolean isPoolEmpty() { + return factoryPool.isEmpty(); + } +} diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java index 6806b9037..b529e426c 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java @@ -130,6 +130,7 @@ private static RSocketSupplier succeedingFactory(RSocket socket) { Mockito.when(mock.availability()).thenReturn(1.0); Mockito.when(mock.get()).thenReturn(Mono.just(socket)); + Mockito.when(mock.onClose()).thenReturn(Mono.never()); return mock; } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/SendPublisher.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/SendPublisher.java index c78669f47..29c5f7ced 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/SendPublisher.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/SendPublisher.java @@ -62,6 +62,7 @@ class SendPublisher extends Flux { this.sizeOf = sizeOf; } + @SuppressWarnings("unchecked") private ChannelPromise writeCleanupPromise(V poll) { return channel .newPromise() @@ -117,7 +118,6 @@ private class InnerSubscriber implements Subscriber { final CoreSubscriber destination; volatile Subscription s; private AtomicBoolean pendingFlush = new AtomicBoolean(); - private SendPublisher sendPublisher; private InnerSubscriber(CoreSubscriber destination) { this.destination = destination; diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java index 3a03d15a5..09fd07325 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java @@ -30,7 +30,7 @@ /** An implementation of {@link DuplexConnection} that connects via TCP. */ public final class TcpDuplexConnection implements DuplexConnection { - + private final Connection connection; private final Disposable channelClosed; /** @@ -50,44 +50,44 @@ public TcpDuplexConnection(Connection connection) { }) .subscribe(); } - + @Override public void dispose() { connection.dispose(); } - + @Override public boolean isDisposed() { return connection.isDisposed(); } - + @Override public Mono onClose() { return connection - .onDispose() - .doFinally( - s -> { - if (!channelClosed.isDisposed()) { - channelClosed.dispose(); - } - }); + .onDispose() + .doFinally( + s -> { + if (!channelClosed.isDisposed()) { + channelClosed.dispose(); + } + }); } - + @Override public Flux receive() { return connection.inbound().receive().map(buf -> Frame.from(buf.retain())); } - + @Override public Mono send(Publisher frames) { return Flux.from(frames) - .transform( - frameFlux -> - new SendPublisher<>( - frameFlux, - connection.channel(), - frame -> frame.content().retain(), - ByteBuf::readableBytes)) - .then(); + .transform( + frameFlux -> + new SendPublisher<>( + frameFlux, + connection.channel(), + frame -> frame.content().retain(), + ByteBuf::readableBytes)) + .then(); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java index 443a1fc3c..4154dc630 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java @@ -41,79 +41,79 @@ * back on for frames received. */ public final class WebsocketDuplexConnection implements DuplexConnection { - - private final Connection connection; - private final Disposable channelClosed; - - /** - * Creates a new instance - * - * @param connection the {@link Connection} to for managing the server - */ - public WebsocketDuplexConnection(Connection connection) { - this.connection = Objects.requireNonNull(connection, "connection must not be null"); - this.channelClosed = - FutureMono.from(connection.channel().closeFuture()) - .doFinally( - s -> { - if (!isDisposed()) { - dispose(); - } - }) - .subscribe(); - } - - @Override - public void dispose() { - connection.dispose(); - } - - @Override - public boolean isDisposed() { - return connection.isDisposed(); - } - - @Override - public Mono onClose() { - return connection - .onDispose() - .doFinally( - s -> { - if (!channelClosed.isDisposed()) { - channelClosed.dispose(); - } - }); - } - - @Override - public Flux receive() { - return connection - .inbound() - .receive() - .map( - buf -> { - CompositeByteBuf composite = connection.channel().alloc().compositeBuffer(); - ByteBuf length = wrappedBuffer(new byte[FRAME_LENGTH_SIZE]); - FrameHeaderFlyweight.encodeLength(length, 0, buf.readableBytes()); - composite.addComponents(true, length, buf.retain()); - return Frame.from(composite); - }); - } - - @Override - public Mono send(Publisher frames) { - return Flux.from(frames) - .transform( - frameFlux -> - new SendPublisher<>( - frameFlux, - connection.channel(), - this::toBinaryWebSocketFrame, - binaryWebSocketFrame -> binaryWebSocketFrame.content().readableBytes())) - .then(); - } - - private BinaryWebSocketFrame toBinaryWebSocketFrame(Frame frame) { - return new BinaryWebSocketFrame(frame.content().skipBytes(FRAME_LENGTH_SIZE).retain()); - } + + private final Connection connection; + private final Disposable channelClosed; + + /** + * Creates a new instance + * + * @param connection the {@link Connection} to for managing the server + */ + public WebsocketDuplexConnection(Connection connection) { + this.connection = Objects.requireNonNull(connection, "connection must not be null"); + this.channelClosed = + FutureMono.from(connection.channel().closeFuture()) + .doFinally( + s -> { + if (!isDisposed()) { + dispose(); + } + }) + .subscribe(); + } + + @Override + public void dispose() { + connection.dispose(); + } + + @Override + public boolean isDisposed() { + return connection.isDisposed(); + } + + @Override + public Mono onClose() { + return connection + .onDispose() + .doFinally( + s -> { + if (!channelClosed.isDisposed()) { + channelClosed.dispose(); + } + }); + } + + @Override + public Flux receive() { + return connection + .inbound() + .receive() + .map( + buf -> { + CompositeByteBuf composite = connection.channel().alloc().compositeBuffer(); + ByteBuf length = wrappedBuffer(new byte[FRAME_LENGTH_SIZE]); + FrameHeaderFlyweight.encodeLength(length, 0, buf.readableBytes()); + composite.addComponents(true, length, buf.retain()); + return Frame.from(composite); + }); + } + + @Override + public Mono send(Publisher frames) { + return Flux.from(frames) + .transform( + frameFlux -> + new SendPublisher<>( + frameFlux, + connection.channel(), + this::toBinaryWebSocketFrame, + binaryWebSocketFrame -> binaryWebSocketFrame.content().readableBytes())) + .then(); + } + + private BinaryWebSocketFrame toBinaryWebSocketFrame(Frame frame) { + return new BinaryWebSocketFrame(frame.content().skipBytes(FRAME_LENGTH_SIZE).retain()); + } }