diff --git a/rsocket-core/src/main/java/io/rsocket/android/RSocketClient.kt b/rsocket-core/src/main/java/io/rsocket/android/RSocketClient.kt index ccffe5726..c649d2991 100644 --- a/rsocket-core/src/main/java/io/rsocket/android/RSocketClient.kt +++ b/rsocket-core/src/main/java/io/rsocket/android/RSocketClient.kt @@ -22,8 +22,6 @@ import io.reactivex.Completable import io.reactivex.Flowable import io.reactivex.Single import io.reactivex.disposables.Disposable -import io.reactivex.functions.Action -import io.reactivex.functions.Consumer import io.reactivex.processors.FlowableProcessor import io.reactivex.processors.PublishProcessor import io.reactivex.processors.UnicastProcessor @@ -63,6 +61,8 @@ internal class RSocketClient @JvmOverloads constructor( private val senders: IntObjectHashMap> = IntObjectHashMap(256, 0.9f) private val receivers: IntObjectHashMap> = IntObjectHashMap(256, 0.9f) private val missedAckCounter: AtomicInteger = AtomicInteger() + @Volatile + private var errorSignal: Throwable? = null private val sendProcessor: FlowableProcessor = PublishProcessor .create() @@ -99,70 +99,79 @@ internal class RSocketClient @JvmOverloads constructor( connection .receive() .doOnSubscribe { started.onComplete() } - .subscribe({ handleIncomingFrames(it) },errorConsumer) + .subscribe({ handleIncomingFrames(it) }, errorConsumer) } private fun handleSendProcessorError(t: Throwable) { - val (receivers, senders) = synchronized(this) { - Pair(receivers.values, senders.values) - } - for (subscriber in receivers) { - try { - subscriber.onError(t) - } catch (e: Throwable) { - errorConsumer(e) - } - } - - for (p in senders) { - p.cancel() + synchronized(this) { + receivers.values.forEach { it.onError(t) } + senders.values.forEach { it.cancel() } } } private fun sendKeepAlive(ackTimeoutMs: Long, missedAcks: Int): Completable { return Completable.fromRunnable { - val now = System.currentTimeMillis() - if (now - timeLastTickSentMs > ackTimeoutMs) { - val count = missedAckCounter.incrementAndGet() - if (count >= missedAcks) { - val message = String.format( - "Missed %d keep-alive acks with a threshold of %d and a ack timeout of %d ms", - count, missedAcks, ackTimeoutMs) - throw ConnectionException(message) - } - } - - sendProcessor.onNext(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true)) + val now = System.currentTimeMillis() + if (now - timeLastTickSentMs > ackTimeoutMs) { + val count = missedAckCounter.incrementAndGet() + if (count >= missedAcks) { + val message = String.format( + "Missed %d keep-alive acks with a threshold of %d and a ack timeout of %d ms", + count, missedAcks, ackTimeoutMs) + throw ConnectionException(message) } + } + + sendProcessor.onNext(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true)) + } } override fun fireAndForget(payload: Payload): Completable { val defer = Completable.fromRunnable { - val streamId = streamIdSupplier.nextStreamId() - val requestFrame = Frame.Request.from( - streamId, - FrameType.FIRE_AND_FORGET, - payload, - 1) - sendProcessor.onNext(requestFrame) - } + val streamId = streamIdSupplier.nextStreamId() + val requestFrame = Frame.Request.from( + streamId, + FrameType.FIRE_AND_FORGET, + payload, + 1) + sendProcessor.onNext(requestFrame) + } - return completeOnStart.andThen(defer) + return errorSignal + ?.let { Completable.error(it) } + ?: completeOnStart.andThen(defer) } override fun requestResponse(payload: Payload): Single = - handleRequestResponse(payload) + errorSignal + ?.let { Single.error(it) } + ?: handleRequestResponse(payload) override fun requestStream(payload: Payload): Flowable = - handleRequestStream(payload).rebatchRequests(streamDemandLimit) + errorSignal + ?.let { Flowable.error(it) } + ?: handleRequestStream(payload).rebatchRequests(streamDemandLimit) override fun requestChannel(payloads: Publisher): Flowable = - handleChannel( + errorSignal + ?.let { Flowable.error(it) } + ?: handleChannel( Flowable.fromPublisher(payloads).rebatchRequests(streamDemandLimit), FrameType.REQUEST_CHANNEL ).rebatchRequests(streamDemandLimit) - override fun metadataPush(payload: Payload): Completable { + override fun metadataPush(payload: Payload): Completable = + errorSignal + ?.let { Completable.error(it) } + ?: handleMetadataPush(payload) + + override fun availability(): Double = connection.availability() + + override fun close(): Completable = connection.close() + + override fun onClose(): Completable = connection.onClose() + + private fun handleMetadataPush(payload: Payload): Completable { val requestFrame = Frame.Request.from( 0, FrameType.METADATA_PUSH, @@ -172,44 +181,38 @@ internal class RSocketClient @JvmOverloads constructor( return Completable.complete() } - override fun availability(): Double = connection.availability() - - override fun close(): Completable = connection.close() - - override fun onClose(): Completable = connection.onClose() - private fun handleRequestStream(payload: Payload): Flowable { return completeOnStart.andThen( Flowable.defer { - val streamId = streamIdSupplier.nextStreamId() - val receiver = UnicastProcessor.create() - synchronized(this) { - receivers.put(streamId, receiver) - } + val streamId = streamIdSupplier.nextStreamId() + val receiver = UnicastProcessor.create() + synchronized(this) { + receivers.put(streamId, receiver) + } - val first = AtomicBoolean(false) + val first = AtomicBoolean(false) - receiver - .doOnRequest{ l -> - if (first.compareAndSet(false, true) && !receiver.isTerminated()) { - val requestFrame = Frame.Request.from(streamId, FrameType.REQUEST_STREAM, payload, l) - sendProcessor.onNext(requestFrame) - } else if (contains(streamId)) { - sendProcessor.onNext(Frame.RequestN.from(streamId, l)) - } - } - .doOnError { t -> - if (contains(streamId) && !receiver.isTerminated()) { - sendProcessor.onNext(Frame.Error.from(streamId, t)) - } - } - .doOnCancel { - if (contains(streamId) && !receiver.isTerminated()) { - sendProcessor.onNext(Frame.Cancel.from(streamId)) - } - } - .doFinally { removeReceiver(streamId) } - }) + receiver + .doOnRequest { l -> + if (first.compareAndSet(false, true) && !receiver.isTerminated()) { + val requestFrame = Frame.Request.from(streamId, FrameType.REQUEST_STREAM, payload, l) + sendProcessor.onNext(requestFrame) + } else if (contains(streamId)) { + sendProcessor.onNext(Frame.RequestN.from(streamId, l)) + } + } + .doOnError { t -> + if (contains(streamId) && !receiver.isTerminated()) { + sendProcessor.onNext(Frame.Error.from(streamId, t)) + } + } + .doOnCancel { + if (contains(streamId) && !receiver.isTerminated()) { + sendProcessor.onNext(Frame.Cancel.from(streamId)) + } + } + .doFinally { removeReceiver(streamId) } + }) } private fun handleRequestResponse(payload: Payload): Single { @@ -228,8 +231,8 @@ internal class RSocketClient @JvmOverloads constructor( sendProcessor.onNext(requestFrame) receiver - .doOnError{ t -> sendProcessor.onNext(Frame.Error.from(streamId, t)) } - .doOnCancel{ sendProcessor.onNext(Frame.Cancel.from(streamId)) } + .doOnError { t -> sendProcessor.onNext(Frame.Error.from(streamId, t)) } + .doOnCancel { sendProcessor.onNext(Frame.Cancel.from(streamId)) } .doFinally { removeReceiver(streamId) } .firstOrError() })) @@ -302,11 +305,11 @@ internal class RSocketClient @JvmOverloads constructor( requestFrames .doOnNext { sendProcessor.onNext(it) } .subscribe( - {}, - { t -> - errorConsumer(t) - receiver.onError(CancellationException("Disposed")) - }) + {}, + { t -> + errorConsumer(t) + receiver.onError(CancellationException("Disposed")) + }) } else { sendOneFrame(Frame.RequestN.from(streamId, l)) } @@ -330,23 +333,15 @@ internal class RSocketClient @JvmOverloads constructor( } private fun cleanup() { + errorSignal = CLOSED_CHANNEL_EXCEPTION - var subscribers: Collection> - var publishers: Collection> - val (subs, pubs) = synchronized(this) { - - subscribers = receivers.values - publishers = senders.values + synchronized(this) { + receivers.values.forEach { cleanUpSubscriber(it) } + senders.values.forEach { cleanUpLimitableRequestPublisher(it) } - senders.clear() receivers.clear() - - Pair(subscribers,publishers) + senders.clear() } - - subs.forEach { cleanUpSubscriber(it) } - pubs.forEach { cleanUpLimitableRequestPublisher(it) } - keepAliveSendSub?.dispose() } @@ -485,5 +480,6 @@ internal class RSocketClient @JvmOverloads constructor( private val CLOSED_CHANNEL_EXCEPTION = noStacktrace(ClosedChannelException()) private val DEFAULT_STREAM_WINDOW = 128 } + private fun UnicastProcessor.isTerminated(): Boolean = hasComplete() || hasThrowable() } diff --git a/rsocket-core/src/test/java/io/rsocket/android/RSocketClientTest.kt b/rsocket-core/src/test/java/io/rsocket/android/RSocketClientTest.kt index 47cdca38a..ba34c05cd 100644 --- a/rsocket-core/src/test/java/io/rsocket/android/RSocketClientTest.kt +++ b/rsocket-core/src/test/java/io/rsocket/android/RSocketClientTest.kt @@ -18,6 +18,9 @@ package io.rsocket.android import io.reactivex.Completable +import io.reactivex.Flowable +import io.reactivex.Single +import io.reactivex.internal.observers.BlockingMultiObserver import io.reactivex.processors.PublishProcessor import io.reactivex.subscribers.TestSubscriber import io.rsocket.android.exceptions.ApplicationException @@ -33,6 +36,7 @@ import org.junit.Test import org.junit.rules.ExternalResource import org.junit.runner.Description import org.junit.runners.model.Statement +import java.nio.channels.ClosedChannelException import java.util.concurrent.TimeUnit class RSocketClientTest { @@ -171,6 +175,75 @@ class RSocketClientTest { assertThat("Stream ID reused.", streamId2, not(equalTo(streamId1))) } + @Test(timeout = 3_000) + fun requestErrorOnConnectionClose() { + Completable.timer(100, TimeUnit.MILLISECONDS) + .andThen { rule.conn.close() }.subscribe() + val requestStream = rule.client.requestStream(PayloadImpl("test")) + val subs = TestSubscriber.create() + requestStream.blockingSubscribe(subs) + subs.assertNoValues() + subs.assertError { it is ClosedChannelException } + } + + @Test(timeout = 5_000) + fun streamErrorAfterConnectionClose() { + assertFlowableError { it.requestStream(PayloadImpl("test")) } + } + + @Test(timeout = 5_000) + fun reqStreamErrorAfterConnectionClose() { + assertFlowableError { it.requestStream(PayloadImpl("test")) } + } + + @Test(timeout = 5_000) + fun reqChannelErrorAfterConnectionClose() { + assertFlowableError { it.requestChannel(Flowable.just(PayloadImpl("test"))) } + } + + @Test(timeout = 5_000) + fun reqResponseErrorAfterConnectionClose() { + assertSingleError { it.requestResponse(PayloadImpl("test")) } + } + + @Test(timeout = 5_000) + fun fnfErrorAfterConnectionClose() { + assertCompletableError { it.fireAndForget(PayloadImpl("test")) } + } + + @Test(timeout = 5_000) + fun metadataPushAfterConnectionClose() { + assertCompletableError { it.metadataPush(PayloadImpl("test")) } + } + + private fun assertFlowableError(f: (RSocket) -> Flowable) { + rule.conn.close().subscribe() + val subs = TestSubscriber.create() + f(rule.client).delaySubscription(100, TimeUnit.MILLISECONDS).blockingSubscribe(subs) + subs.assertNoValues() + subs.assertError { it is ClosedChannelException } + } + + private fun assertCompletableError(f: (RSocket) -> Completable) { + rule.conn.close().subscribe() + val requestStream = Completable + .timer(100, TimeUnit.MILLISECONDS) + .andThen(f(rule.client)) + val err = requestStream.blockingGet() + assertThat("error is not ClosedChannelException", + err is ClosedChannelException) + } + + private fun assertSingleError(f: (RSocket) -> Single) { + rule.conn.close().subscribe() + val response = f(rule.client).delaySubscription(100, TimeUnit.MILLISECONDS) + val subs = BlockingMultiObserver() + response.subscribe(subs) + val err = subs.blockingGetError() + assertThat("error is not ClosedChannelException", err is ClosedChannelException) + } + + class ClientSocketRule : ExternalResource() { lateinit var sender: PublishProcessor lateinit var receiver: PublishProcessor