Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 90 additions & 94 deletions rsocket-core/src/main/java/io/rsocket/android/RSocketClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ import io.reactivex.Completable
import io.reactivex.Flowable
import io.reactivex.Single
import io.reactivex.disposables.Disposable
import io.reactivex.functions.Action
import io.reactivex.functions.Consumer
import io.reactivex.processors.FlowableProcessor
import io.reactivex.processors.PublishProcessor
import io.reactivex.processors.UnicastProcessor
Expand Down Expand Up @@ -63,6 +61,8 @@ internal class RSocketClient @JvmOverloads constructor(
private val senders: IntObjectHashMap<LimitableRequestPublisher<*>> = IntObjectHashMap(256, 0.9f)
private val receivers: IntObjectHashMap<Subscriber<Payload>> = IntObjectHashMap(256, 0.9f)
private val missedAckCounter: AtomicInteger = AtomicInteger()
@Volatile
private var errorSignal: Throwable? = null

private val sendProcessor: FlowableProcessor<Frame> = PublishProcessor
.create<Frame>()
Expand Down Expand Up @@ -99,70 +99,79 @@ internal class RSocketClient @JvmOverloads constructor(
connection
.receive()
.doOnSubscribe { started.onComplete() }
.subscribe({ handleIncomingFrames(it) },errorConsumer)
.subscribe({ handleIncomingFrames(it) }, errorConsumer)
}

private fun handleSendProcessorError(t: Throwable) {
val (receivers, senders) = synchronized(this) {
Pair(receivers.values, senders.values)
}
for (subscriber in receivers) {
try {
subscriber.onError(t)
} catch (e: Throwable) {
errorConsumer(e)
}
}

for (p in senders) {
p.cancel()
synchronized(this) {
receivers.values.forEach { it.onError(t) }
senders.values.forEach { it.cancel() }
}
}

private fun sendKeepAlive(ackTimeoutMs: Long, missedAcks: Int): Completable {
return Completable.fromRunnable {
val now = System.currentTimeMillis()
if (now - timeLastTickSentMs > ackTimeoutMs) {
val count = missedAckCounter.incrementAndGet()
if (count >= missedAcks) {
val message = String.format(
"Missed %d keep-alive acks with a threshold of %d and a ack timeout of %d ms",
count, missedAcks, ackTimeoutMs)
throw ConnectionException(message)
}
}

sendProcessor.onNext(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true))
val now = System.currentTimeMillis()
if (now - timeLastTickSentMs > ackTimeoutMs) {
val count = missedAckCounter.incrementAndGet()
if (count >= missedAcks) {
val message = String.format(
"Missed %d keep-alive acks with a threshold of %d and a ack timeout of %d ms",
count, missedAcks, ackTimeoutMs)
throw ConnectionException(message)
}
}

sendProcessor.onNext(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true))
}
}

override fun fireAndForget(payload: Payload): Completable {
val defer = Completable.fromRunnable {
val streamId = streamIdSupplier.nextStreamId()
val requestFrame = Frame.Request.from(
streamId,
FrameType.FIRE_AND_FORGET,
payload,
1)
sendProcessor.onNext(requestFrame)
}
val streamId = streamIdSupplier.nextStreamId()
val requestFrame = Frame.Request.from(
streamId,
FrameType.FIRE_AND_FORGET,
payload,
1)
sendProcessor.onNext(requestFrame)
}

return completeOnStart.andThen(defer)
return errorSignal
?.let { Completable.error(it) }
?: completeOnStart.andThen(defer)
}

override fun requestResponse(payload: Payload): Single<Payload> =
handleRequestResponse(payload)
errorSignal
?.let { Single.error<Payload>(it) }
?: handleRequestResponse(payload)

override fun requestStream(payload: Payload): Flowable<Payload> =
handleRequestStream(payload).rebatchRequests(streamDemandLimit)
errorSignal
?.let { Flowable.error<Payload>(it) }
?: handleRequestStream(payload).rebatchRequests(streamDemandLimit)

override fun requestChannel(payloads: Publisher<Payload>): Flowable<Payload> =
handleChannel(
errorSignal
?.let { Flowable.error<Payload>(it) }
?: handleChannel(
Flowable.fromPublisher(payloads).rebatchRequests(streamDemandLimit),
FrameType.REQUEST_CHANNEL
).rebatchRequests(streamDemandLimit)

override fun metadataPush(payload: Payload): Completable {
override fun metadataPush(payload: Payload): Completable =
errorSignal
?.let { Completable.error(it) }
?: handleMetadataPush(payload)

override fun availability(): Double = connection.availability()

override fun close(): Completable = connection.close()

override fun onClose(): Completable = connection.onClose()

private fun handleMetadataPush(payload: Payload): Completable {
val requestFrame = Frame.Request.from(
0,
FrameType.METADATA_PUSH,
Expand All @@ -172,44 +181,38 @@ internal class RSocketClient @JvmOverloads constructor(
return Completable.complete()
}

override fun availability(): Double = connection.availability()

override fun close(): Completable = connection.close()

override fun onClose(): Completable = connection.onClose()

private fun handleRequestStream(payload: Payload): Flowable<Payload> {
return completeOnStart.andThen(
Flowable.defer {
val streamId = streamIdSupplier.nextStreamId()
val receiver = UnicastProcessor.create<Payload>()
synchronized(this) {
receivers.put(streamId, receiver)
}
val streamId = streamIdSupplier.nextStreamId()
val receiver = UnicastProcessor.create<Payload>()
synchronized(this) {
receivers.put(streamId, receiver)
}

val first = AtomicBoolean(false)
val first = AtomicBoolean(false)

receiver
.doOnRequest{ l ->
if (first.compareAndSet(false, true) && !receiver.isTerminated()) {
val requestFrame = Frame.Request.from(streamId, FrameType.REQUEST_STREAM, payload, l)
sendProcessor.onNext(requestFrame)
} else if (contains(streamId)) {
sendProcessor.onNext(Frame.RequestN.from(streamId, l))
}
}
.doOnError { t ->
if (contains(streamId) && !receiver.isTerminated()) {
sendProcessor.onNext(Frame.Error.from(streamId, t))
}
}
.doOnCancel {
if (contains(streamId) && !receiver.isTerminated()) {
sendProcessor.onNext(Frame.Cancel.from(streamId))
}
}
.doFinally { removeReceiver(streamId) }
})
receiver
.doOnRequest { l ->
if (first.compareAndSet(false, true) && !receiver.isTerminated()) {
val requestFrame = Frame.Request.from(streamId, FrameType.REQUEST_STREAM, payload, l)
sendProcessor.onNext(requestFrame)
} else if (contains(streamId)) {
sendProcessor.onNext(Frame.RequestN.from(streamId, l))
}
}
.doOnError { t ->
if (contains(streamId) && !receiver.isTerminated()) {
sendProcessor.onNext(Frame.Error.from(streamId, t))
}
}
.doOnCancel {
if (contains(streamId) && !receiver.isTerminated()) {
sendProcessor.onNext(Frame.Cancel.from(streamId))
}
}
.doFinally { removeReceiver(streamId) }
})
}

private fun handleRequestResponse(payload: Payload): Single<Payload> {
Expand All @@ -228,8 +231,8 @@ internal class RSocketClient @JvmOverloads constructor(
sendProcessor.onNext(requestFrame)

receiver
.doOnError{ t -> sendProcessor.onNext(Frame.Error.from(streamId, t)) }
.doOnCancel{ sendProcessor.onNext(Frame.Cancel.from(streamId)) }
.doOnError { t -> sendProcessor.onNext(Frame.Error.from(streamId, t)) }
.doOnCancel { sendProcessor.onNext(Frame.Cancel.from(streamId)) }
.doFinally { removeReceiver(streamId) }
.firstOrError()
}))
Expand Down Expand Up @@ -302,11 +305,11 @@ internal class RSocketClient @JvmOverloads constructor(
requestFrames
.doOnNext { sendProcessor.onNext(it) }
.subscribe(
{},
{ t ->
errorConsumer(t)
receiver.onError(CancellationException("Disposed"))
})
{},
{ t ->
errorConsumer(t)
receiver.onError(CancellationException("Disposed"))
})
} else {
sendOneFrame(Frame.RequestN.from(streamId, l))
}
Expand All @@ -330,23 +333,15 @@ internal class RSocketClient @JvmOverloads constructor(
}

private fun cleanup() {
errorSignal = CLOSED_CHANNEL_EXCEPTION

var subscribers: Collection<Subscriber<Payload>>
var publishers: Collection<LimitableRequestPublisher<*>>
val (subs, pubs) = synchronized(this) {

subscribers = receivers.values
publishers = senders.values
synchronized(this) {
receivers.values.forEach { cleanUpSubscriber(it) }
senders.values.forEach { cleanUpLimitableRequestPublisher(it) }

senders.clear()
receivers.clear()

Pair(subscribers,publishers)
senders.clear()
}

subs.forEach { cleanUpSubscriber(it) }
pubs.forEach { cleanUpLimitableRequestPublisher(it) }

keepAliveSendSub?.dispose()
}

Expand Down Expand Up @@ -485,5 +480,6 @@ internal class RSocketClient @JvmOverloads constructor(
private val CLOSED_CHANNEL_EXCEPTION = noStacktrace(ClosedChannelException())
private val DEFAULT_STREAM_WINDOW = 128
}

private fun <T> UnicastProcessor<T>.isTerminated(): Boolean = hasComplete() || hasThrowable()
}
73 changes: 73 additions & 0 deletions rsocket-core/src/test/java/io/rsocket/android/RSocketClientTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
package io.rsocket.android

import io.reactivex.Completable
import io.reactivex.Flowable
import io.reactivex.Single
import io.reactivex.internal.observers.BlockingMultiObserver
import io.reactivex.processors.PublishProcessor
import io.reactivex.subscribers.TestSubscriber
import io.rsocket.android.exceptions.ApplicationException
Expand All @@ -33,6 +36,7 @@ import org.junit.Test
import org.junit.rules.ExternalResource
import org.junit.runner.Description
import org.junit.runners.model.Statement
import java.nio.channels.ClosedChannelException
import java.util.concurrent.TimeUnit

class RSocketClientTest {
Expand Down Expand Up @@ -171,6 +175,75 @@ class RSocketClientTest {
assertThat("Stream ID reused.", streamId2, not(equalTo(streamId1)))
}

@Test(timeout = 3_000)
fun requestErrorOnConnectionClose() {
Completable.timer(100, TimeUnit.MILLISECONDS)
.andThen { rule.conn.close() }.subscribe()
val requestStream = rule.client.requestStream(PayloadImpl("test"))
val subs = TestSubscriber.create<Payload>()
requestStream.blockingSubscribe(subs)
subs.assertNoValues()
subs.assertError { it is ClosedChannelException }
}

@Test(timeout = 5_000)
fun streamErrorAfterConnectionClose() {
assertFlowableError { it.requestStream(PayloadImpl("test")) }
}

@Test(timeout = 5_000)
fun reqStreamErrorAfterConnectionClose() {
assertFlowableError { it.requestStream(PayloadImpl("test")) }
}

@Test(timeout = 5_000)
fun reqChannelErrorAfterConnectionClose() {
assertFlowableError { it.requestChannel(Flowable.just(PayloadImpl("test"))) }
}

@Test(timeout = 5_000)
fun reqResponseErrorAfterConnectionClose() {
assertSingleError { it.requestResponse(PayloadImpl("test")) }
}

@Test(timeout = 5_000)
fun fnfErrorAfterConnectionClose() {
assertCompletableError { it.fireAndForget(PayloadImpl("test")) }
}

@Test(timeout = 5_000)
fun metadataPushAfterConnectionClose() {
assertCompletableError { it.metadataPush(PayloadImpl("test")) }
}

private fun assertFlowableError(f: (RSocket) -> Flowable<Payload>) {
rule.conn.close().subscribe()
val subs = TestSubscriber.create<Payload>()
f(rule.client).delaySubscription(100, TimeUnit.MILLISECONDS).blockingSubscribe(subs)
subs.assertNoValues()
subs.assertError { it is ClosedChannelException }
}

private fun assertCompletableError(f: (RSocket) -> Completable) {
rule.conn.close().subscribe()
val requestStream = Completable
.timer(100, TimeUnit.MILLISECONDS)
.andThen(f(rule.client))
val err = requestStream.blockingGet()
assertThat("error is not ClosedChannelException",
err is ClosedChannelException)
}

private fun assertSingleError(f: (RSocket) -> Single<Payload>) {
rule.conn.close().subscribe()
val response = f(rule.client).delaySubscription(100, TimeUnit.MILLISECONDS)
val subs = BlockingMultiObserver<Payload>()
response.subscribe(subs)
val err = subs.blockingGetError()
assertThat("error is not ClosedChannelException", err is ClosedChannelException)
}


class ClientSocketRule : ExternalResource() {
lateinit var sender: PublishProcessor<Frame>
lateinit var receiver: PublishProcessor<Frame>
Expand Down