From 0f9750559e3f62ed89c8367e21b55c163980f3d4 Mon Sep 17 00:00:00 2001 From: Oleg Yukhnevich Date: Sun, 11 Oct 2020 23:38:04 +0300 Subject: [PATCH 1/9] add leak detection use leak detection in all core tests + transport test for local connection fix some releasing issues, still WIP --- build.gradle.kts | 3 +- .../rsocket/kotlin/connection/Connection.kt | 22 +- .../kotlin/connection/LoggingConnection.kt | 11 +- .../rsocket/kotlin/core/RSocketConnector.kt | 1 - .../core/RSocketConnectorConfiguration.kt | 2 - .../io/rsocket/kotlin/core/RSocketServer.kt | 5 +- .../kotlin/core/RSocketServerConfiguration.kt | 2 - .../io/rsocket/kotlin/frame/CancelFrame.kt | 5 +- .../io/rsocket/kotlin/frame/ErrorFrame.kt | 8 +- .../io/rsocket/kotlin/frame/ExtensionFrame.kt | 11 +- .../kotlin/io/rsocket/kotlin/frame/Frame.kt | 48 ++-- .../io/rsocket/kotlin/frame/KeepAliveFrame.kt | 10 +- .../io/rsocket/kotlin/frame/LeaseFrame.kt | 11 +- .../rsocket/kotlin/frame/MetadataPushFrame.kt | 8 +- .../io/rsocket/kotlin/frame/RequestFrame.kt | 24 +- .../io/rsocket/kotlin/frame/RequestNFrame.kt | 7 +- .../io/rsocket/kotlin/frame/ResumeFrame.kt | 9 +- .../io/rsocket/kotlin/frame/ResumeOkFrame.kt | 6 +- .../io/rsocket/kotlin/frame/SetupFrame.kt | 13 +- .../io/rsocket/kotlin/frame/io/packet.kt | 21 +- .../io/rsocket/kotlin/frame/io/payload.kt | 14 +- .../kotlin/io/rsocket/kotlin/frame/io/util.kt | 30 +- .../io/rsocket/kotlin/internal/Connect.kt | 11 +- .../io/rsocket/kotlin/internal/Prioritizer.kt | 4 +- .../kotlin/internal/RSocketRequester.kt | 25 +- .../kotlin/internal/RSocketResponder.kt | 12 + .../rsocket/kotlin/internal/RSocketState.kt | 57 +++- .../internal/flow/LimitingFlowCollector.kt | 3 +- .../flow/RequestChannelRequesterFlow.kt | 1 + .../flow/RequestStreamRequesterFlow.kt | 10 +- .../io/rsocket/kotlin/payload/Payload.kt | 22 +- .../io/rsocket/kotlin/SetupRejectionTest.kt | 19 +- .../rsocket/kotlin/frame/CancelFrameTest.kt | 5 +- .../io/rsocket/kotlin/frame/ErrorFrameTest.kt | 6 +- .../kotlin/frame/ExtensionFrameTest.kt | 13 +- .../kotlin/frame/KeepAliveFrameTest.kt | 5 +- .../io/rsocket/kotlin/frame/LeaseFrameTest.kt | 7 +- .../kotlin/frame/MetadataPushFrameTest.kt | 7 +- .../rsocket/kotlin/frame/PayloadFrameTest.kt | 23 +- .../frame/RequestFireAndForgetFrameTest.kt | 12 +- .../rsocket/kotlin/frame/RequestNFrameTest.kt | 5 +- .../kotlin/frame/RequestResponseFrameTest.kt | 25 +- .../kotlin/frame/RequestStreamFrameTest.kt | 21 +- .../rsocket/kotlin/frame/ResumeFrameTest.kt | 15 +- .../rsocket/kotlin/frame/ResumeOkFrameTest.kt | 5 +- .../io/rsocket/kotlin/frame/SetupFrameTest.kt | 19 +- .../kotlin/io/rsocket/kotlin/frame/Util.kt | 14 +- .../kotlin/internal/RSocketRequesterTest.kt | 267 +++++++++--------- .../io/rsocket/kotlin/internal/RSocketTest.kt | 91 +++--- .../rsocket/kotlin/keepalive/KeepAliveTest.kt | 50 ++-- rsocket-test/build.gradle.kts | 1 + .../rsocket/kotlin/test/InUseTrackingPool.kt | 63 +++++ .../kotlin/io/rsocket/kotlin/test/Packets.kt | 47 +++ .../io/rsocket/kotlin/test/TestConnection.kt | 48 +++- .../io/rsocket/kotlin/test/TestPacketStore.kt | 24 -- .../io/rsocket/kotlin/test/TransportTest.kt | 36 ++- .../io/rsocket/kotlin/test/TestPacketStore.kt | 29 -- .../io/rsocket/kotlin/test/TestPacketStore.kt | 36 --- .../kotlin/connection/KtorTcpConnection.kt | 2 +- .../kotlin/connection/LocalConnection.kt | 27 +- .../io/rsocket/kotlin/LocalTransportTest.kt | 6 +- 61 files changed, 793 insertions(+), 551 deletions(-) rename rsocket-test/src/jsMain/kotlin/io/rsocket/kotlin/test/TestPacketStore.kt => rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/packet.kt (57%) create mode 100644 rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/InUseTrackingPool.kt create mode 100644 rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/Packets.kt delete mode 100644 rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestPacketStore.kt delete mode 100644 rsocket-test/src/jvmMain/kotlin/io/rsocket/kotlin/test/TestPacketStore.kt delete mode 100644 rsocket-test/src/nativeMain/kotlin/io/rsocket/kotlin/test/TestPacketStore.kt diff --git a/build.gradle.kts b/build.gradle.kts index 1a18d143d..79b510e3b 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -196,7 +196,6 @@ subprojects { //common configuration extensions.configure { -// explicitApiWarning() //TODO change to strict before release sourceSets.all { languageSettings.apply { progressiveMode = true @@ -214,11 +213,13 @@ subprojects { useExperimentalAnnotation("kotlinx.coroutines.FlowPreview") useExperimentalAnnotation("io.ktor.util.KtorExperimentalAPI") useExperimentalAnnotation("io.ktor.util.InternalAPI") + useExperimentalAnnotation("io.ktor.utils.io.core.internal.DangerousInternalIoApi") } } } if (project.name != "rsocket-test") { + explicitApiWarning() //TODO change to strict before release sourceSets["commonTest"].dependencies { implementation(project(":rsocket-test")) } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/Connection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/Connection.kt index 7cdf6dfa2..64459a83f 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/Connection.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/Connection.kt @@ -17,20 +17,38 @@ package io.rsocket.kotlin.connection import io.ktor.utils.io.core.* +import io.ktor.utils.io.core.internal.* +import io.ktor.utils.io.pool.* import io.rsocket.kotlin.* import io.rsocket.kotlin.core.* +import io.rsocket.kotlin.frame.* import kotlinx.coroutines.* +/** + * That interface isn't stable for inheritance. + */ interface Connection : Cancelable { + + @DangerousInternalIoApi + val pool: ObjectPool + get() = ChunkBuffer.Pool + suspend fun send(packet: ByteReadPacket) suspend fun receive(): ByteReadPacket } suspend fun Connection.connectClient( - configuration: RSocketConnectorConfiguration = RSocketConnectorConfiguration() + configuration: RSocketConnectorConfiguration = RSocketConnectorConfiguration(), ): RSocket = RSocketConnector(ConnectionProvider(this), configuration).connect() suspend fun Connection.startServer( configuration: RSocketServerConfiguration = RSocketServerConfiguration(), - acceptor: RSocketAcceptor + acceptor: RSocketAcceptor, ): Job = RSocketServer(ConnectionProvider(this), configuration).start(acceptor) + + +@OptIn(DangerousInternalIoApi::class) +internal suspend fun Connection.receiveFrame(): Frame = receive().readFrame(pool) + +@OptIn(DangerousInternalIoApi::class) +internal suspend fun Connection.sendFrame(frame: Frame): Unit = send(frame.toPacket(pool)) diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/LoggingConnection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/LoggingConnection.kt index 914f23122..1145edf2d 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/LoggingConnection.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/LoggingConnection.kt @@ -17,18 +17,23 @@ package io.rsocket.kotlin.connection import io.ktor.utils.io.core.* +import io.ktor.utils.io.core.internal.* import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.logging.* -import kotlinx.coroutines.* internal fun Connection.logging(logger: Logger): Connection = if (logger.isLoggable(LoggingLevel.DEBUG)) LoggingConnection(this, logger) else this +@OptIn(DangerousInternalIoApi::class) private class LoggingConnection( private val delegate: Connection, private val logger: Logger, -) : Connection { - override val job: Job get() = delegate.job +) : Connection by delegate { + + private fun ByteReadPacket.dumpFrameToString(): String { + val length = remaining + return copy().use { it.readFrame(pool).use { it.dump(length) } } + } override suspend fun send(packet: ByteReadPacket) { logger.debug { "Send: ${packet.dumpFrameToString()}" } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt index 4bfe7a2fe..e54150573 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt @@ -46,7 +46,6 @@ class RSocketConnector( connection = connection, plugin = configuration.plugin, setupFrame = setupFrame, - ignoredFrameConsumer = configuration.ignoredFrameConsumer, acceptor = configuration.acceptor ) } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorConfiguration.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorConfiguration.kt index 57b53dbe2..4e8214b69 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorConfiguration.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorConfiguration.kt @@ -17,7 +17,6 @@ package io.rsocket.kotlin.core import io.rsocket.kotlin.* -import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.keepalive.* import io.rsocket.kotlin.logging.* import io.rsocket.kotlin.payload.* @@ -29,6 +28,5 @@ data class RSocketConnectorConfiguration( val keepAlive: KeepAlive = KeepAlive(), val payloadMimeType: PayloadMimeType = PayloadMimeType(), val setupPayload: Payload = Payload.Empty, - val ignoredFrameConsumer: (Frame) -> Unit = {}, val acceptor: RSocketAcceptor = { RSocketRequestHandler { } }, ) diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt index 7127fa53a..c80186d7a 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt @@ -35,7 +35,7 @@ class RSocketServer( .let(configuration.plugin::wrapConnection) .logging(configuration.loggerFactory.logger("io.rsocket.kotlin.frame.Frame")) - val setupFrame = connection.receive().toFrame() + val setupFrame = connection.receiveFrame() if (setupFrame !is SetupFrame) connection.failSetup(RSocketError.Setup.Invalid("Invalid setup frame: ${setupFrame.type}")) if (setupFrame.version != Version.Current) @@ -46,7 +46,6 @@ class RSocketServer( connection = connection, plugin = configuration.plugin, setupFrame = setupFrame, - ignoredFrameConsumer = configuration.ignoredFrameConsumer, acceptor = acceptor ) } catch (e: Throwable) { @@ -55,7 +54,7 @@ class RSocketServer( } private suspend fun Connection.failSetup(error: RSocketError.Setup): Nothing { - send(ErrorFrame(0, error).toPacket()) + sendFrame(ErrorFrame(0, error)) cancel("Setup failed", error) throw error } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServerConfiguration.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServerConfiguration.kt index 7e775c643..525ebd23d 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServerConfiguration.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServerConfiguration.kt @@ -16,12 +16,10 @@ package io.rsocket.kotlin.core -import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.logging.* import io.rsocket.kotlin.plugin.* data class RSocketServerConfiguration( val plugin: Plugin = Plugin(), val loggerFactory: LoggerFactory = DefaultLoggerFactory, - val ignoredFrameConsumer: (Frame) -> Unit = {}, ) diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/CancelFrame.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/CancelFrame.kt index 245fa0b55..55b009742 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/CancelFrame.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/CancelFrame.kt @@ -18,10 +18,13 @@ package io.rsocket.kotlin.frame import io.ktor.utils.io.core.* -class CancelFrame( +internal class CancelFrame( override val streamId: Int, ) : Frame(FrameType.Cancel) { override val flags: Int get() = 0 + + override fun release(): Unit = Unit + override fun BytePacketBuilder.writeSelf(): Unit = Unit override fun StringBuilder.appendFlags(): Unit = Unit diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/ErrorFrame.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/ErrorFrame.kt index ae7c93100..14ec5476f 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/ErrorFrame.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/ErrorFrame.kt @@ -20,7 +20,7 @@ import io.ktor.utils.io.core.* import io.rsocket.kotlin.error.* import io.rsocket.kotlin.frame.io.* -class ErrorFrame( +internal class ErrorFrame( override val streamId: Int, val throwable: Throwable, val data: ByteReadPacket? = null, @@ -28,6 +28,10 @@ class ErrorFrame( override val flags: Int get() = 0 val errorCode get() = (throwable as? RSocketError)?.errorCode ?: ErrorCode.ApplicationError + override fun release() { + data?.release() + } + override fun BytePacketBuilder.writeSelf() { writeInt(errorCode) when (data) { @@ -44,7 +48,7 @@ class ErrorFrame( } } -fun ByteReadPacket.readError(streamId: Int): ErrorFrame { +internal fun ByteReadPacket.readError(streamId: Int): ErrorFrame { val errorCode = readInt() val data = copy() val message = readText() diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/ExtensionFrame.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/ExtensionFrame.kt index 816288ac1..526fb8527 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/ExtensionFrame.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/ExtensionFrame.kt @@ -20,12 +20,17 @@ import io.ktor.utils.io.core.* import io.rsocket.kotlin.frame.io.* import io.rsocket.kotlin.payload.* -class ExtensionFrame( +internal class ExtensionFrame( override val streamId: Int, val extendedType: Int, val payload: Payload, ) : Frame(FrameType.Extension) { override val flags: Int get() = if (payload.metadata != null) Flags.Metadata else 0 + + override fun release() { + payload.release() + } + override fun BytePacketBuilder.writeSelf() { writeInt(extendedType) writePayload(payload) @@ -41,8 +46,8 @@ class ExtensionFrame( } } -fun ByteReadPacket.readExtension(streamId: Int, flags: Int): ExtensionFrame { +internal fun ByteReadPacket.readExtension(pool: BufferPool, streamId: Int, flags: Int): ExtensionFrame { val extendedType = readInt() - val payload = readPayload(flags) + val payload = readPayload(pool, flags) return ExtensionFrame(streamId, extendedType, payload) } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/Frame.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/Frame.kt index 51c02ee55..7ad311df8 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/Frame.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/Frame.kt @@ -22,17 +22,19 @@ import io.rsocket.kotlin.frame.io.* private const val FlagsMask: Int = 1023 private const val FrameTypeShift: Int = 10 -abstract class Frame(open val type: FrameType) { +abstract class Frame(open val type: FrameType) : Closeable { abstract val streamId: Int abstract val flags: Int + abstract fun release() + protected abstract fun BytePacketBuilder.writeSelf() protected abstract fun StringBuilder.appendFlags() protected abstract fun StringBuilder.appendSelf() - fun toPacket(): ByteReadPacket { + fun toPacket(pool: BufferPool): ByteReadPacket { check(type.canHaveMetadata || !(flags check Flags.Metadata)) { "bad value for metadata flag" } - return buildPacket { + return buildPacket(pool) { writeInt(streamId) writeShort((type.encodedType shl FrameTypeShift or flags).toShort()) writeSelf() @@ -49,36 +51,36 @@ abstract class Frame(open val type: FrameType) { append(flag) if (value) append(1) else append(0) } + + override fun close() { + release() + } } -fun ByteReadPacket.toFrame(): Frame = use { +fun ByteReadPacket.readFrame(pool: BufferPool): Frame = use { val streamId = readInt() val typeAndFlags = readShort().toInt() and 0xFFFF val flags = typeAndFlags and FlagsMask when (val type = FrameType(typeAndFlags shr FrameTypeShift)) { //stream id = 0 - FrameType.Setup -> readSetup(flags) - FrameType.Resume -> readResume() - FrameType.ResumeOk -> readResumeOk() - FrameType.MetadataPush -> readMetadataPush() - FrameType.Lease -> readLease(flags) - FrameType.KeepAlive -> readKeepAlive(flags) + FrameType.Setup -> readSetup(pool, flags) + FrameType.Resume -> readResume(pool) + FrameType.ResumeOk -> readResumeOk() + FrameType.MetadataPush -> readMetadataPush(pool) + FrameType.Lease -> readLease(pool, flags) + FrameType.KeepAlive -> readKeepAlive(pool, flags) //stream id != 0 - FrameType.Cancel -> CancelFrame(streamId) - FrameType.Error -> readError(streamId) - FrameType.RequestN -> readRequestN(streamId) - FrameType.Extension -> readExtension(streamId, flags) + FrameType.Cancel -> CancelFrame(streamId) + FrameType.Error -> readError(streamId) + FrameType.RequestN -> readRequestN(streamId) + FrameType.Extension -> readExtension(pool, streamId, flags) FrameType.Payload, FrameType.RequestFnF, - FrameType.RequestResponse -> readRequest(type, streamId, flags, withInitial = false) + FrameType.RequestResponse, + -> readRequest(pool, type, streamId, flags, withInitial = false) FrameType.RequestStream, - FrameType.RequestChannel -> readRequest(type, streamId, flags, withInitial = true) - FrameType.Reserved -> error("Reserved") + FrameType.RequestChannel, + -> readRequest(pool, type, streamId, flags, withInitial = true) + FrameType.Reserved -> error("Reserved") } } - -fun ByteReadPacket.dumpFrameToString(): String { - val length = remaining - val frame = copy().toFrame() - return frame.dump(length) -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/KeepAliveFrame.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/KeepAliveFrame.kt index 365bee907..dbc9c0e7e 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/KeepAliveFrame.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/KeepAliveFrame.kt @@ -21,7 +21,7 @@ import io.rsocket.kotlin.frame.io.* private const val RespondFlag = 128 -class KeepAliveFrame( +internal class KeepAliveFrame( val respond: Boolean, val lastPosition: Long, val data: ByteReadPacket, @@ -29,6 +29,10 @@ class KeepAliveFrame( override val streamId: Int get() = 0 override val flags: Int get() = if (respond) RespondFlag else 0 + override fun release() { + data.release() + } + override fun BytePacketBuilder.writeSelf() { writeLong(lastPosition.coerceAtLeast(0)) writePacket(data) @@ -44,9 +48,9 @@ class KeepAliveFrame( } } -fun ByteReadPacket.readKeepAlive(flags: Int): KeepAliveFrame { +internal fun ByteReadPacket.readKeepAlive(pool: BufferPool, flags: Int): KeepAliveFrame { val respond = flags check RespondFlag val lastPosition = readLong() - val data = readPacket() + val data = readPacket(pool) return KeepAliveFrame(respond, lastPosition, data) } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/LeaseFrame.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/LeaseFrame.kt index a44e8ad0a..9f991dbaf 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/LeaseFrame.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/LeaseFrame.kt @@ -19,13 +19,18 @@ package io.rsocket.kotlin.frame import io.ktor.utils.io.core.* import io.rsocket.kotlin.frame.io.* -class LeaseFrame( +internal class LeaseFrame( val ttl: Int, val numberOfRequests: Int, val metadata: ByteReadPacket?, ) : Frame(FrameType.Lease) { override val streamId: Int get() = 0 override val flags: Int get() = if (metadata != null) Flags.Metadata else 0 + + override fun release() { + metadata?.release() + } + override fun BytePacketBuilder.writeSelf() { writeInt(ttl) writeInt(numberOfRequests) @@ -42,9 +47,9 @@ class LeaseFrame( } } -fun ByteReadPacket.readLease(flags: Int): LeaseFrame { +internal fun ByteReadPacket.readLease(pool: BufferPool, flags: Int): LeaseFrame { val ttl = readInt() val numberOfRequests = readInt() - val metadata = if (flags check Flags.Metadata) readMetadata() else null + val metadata = if (flags check Flags.Metadata) readMetadata(pool) else null return LeaseFrame(ttl, numberOfRequests, metadata) } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/MetadataPushFrame.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/MetadataPushFrame.kt index 41c0089d5..43bb6c5d0 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/MetadataPushFrame.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/MetadataPushFrame.kt @@ -19,12 +19,16 @@ package io.rsocket.kotlin.frame import io.ktor.utils.io.core.* import io.rsocket.kotlin.frame.io.* -class MetadataPushFrame( +internal class MetadataPushFrame( val metadata: ByteReadPacket, ) : Frame(FrameType.MetadataPush) { override val streamId: Int get() = 0 override val flags: Int get() = Flags.Metadata + override fun release() { + metadata.release() + } + override fun BytePacketBuilder.writeSelf() { writePacket(metadata) } @@ -38,4 +42,4 @@ class MetadataPushFrame( } } -fun ByteReadPacket.readMetadataPush(): MetadataPushFrame = MetadataPushFrame(readPacket()) +internal fun ByteReadPacket.readMetadataPush(pool: BufferPool): MetadataPushFrame = MetadataPushFrame(readPacket(pool)) diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/RequestFrame.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/RequestFrame.kt index b059aaf24..7dda0db9c 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/RequestFrame.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/RequestFrame.kt @@ -22,7 +22,7 @@ import io.ktor.utils.io.core.* import io.rsocket.kotlin.frame.io.* import io.rsocket.kotlin.payload.* -class RequestFrame( +internal data class RequestFrame( override val type: FrameType, override val streamId: Int, val follows: Boolean, @@ -41,6 +41,10 @@ class RequestFrame( return flags } + override fun release() { + payload.release() + } + override fun BytePacketBuilder.writeSelf() { if (initialRequest > 0) writeInt(initialRequest) writePayload(payload) @@ -59,34 +63,34 @@ class RequestFrame( } } -fun ByteReadPacket.readRequest(type: FrameType, streamId: Int, flags: Int, withInitial: Boolean): RequestFrame { +internal fun ByteReadPacket.readRequest(pool: BufferPool, type: FrameType, streamId: Int, flags: Int, withInitial: Boolean): RequestFrame { val fragmentFollows = flags check Flags.Follows val complete = flags check Flags.Complete val next = flags check Flags.Next val initialRequest = if (withInitial) readInt() else 0 - val payload = readPayload(flags) + val payload = readPayload(pool, flags) return RequestFrame(type, streamId, fragmentFollows, complete, next, initialRequest, payload) } //TODO rename or remove on fragmentation implementation -fun RequestFireAndForgetFrame(streamId: Int, payload: Payload): RequestFrame = +internal fun RequestFireAndForgetFrame(streamId: Int, payload: Payload): RequestFrame = RequestFrame(FrameType.RequestFnF, streamId, false, false, false, 0, payload) -fun RequestResponseFrame(streamId: Int, payload: Payload): RequestFrame = +internal fun RequestResponseFrame(streamId: Int, payload: Payload): RequestFrame = RequestFrame(FrameType.RequestResponse, streamId, false, false, false, 0, payload) -fun RequestStreamFrame(streamId: Int, initialRequestN: Int, payload: Payload): RequestFrame = +internal fun RequestStreamFrame(streamId: Int, initialRequestN: Int, payload: Payload): RequestFrame = RequestFrame(FrameType.RequestStream, streamId, false, false, false, initialRequestN, payload) -fun RequestChannelFrame(streamId: Int, initialRequestN: Int, payload: Payload): RequestFrame = +internal fun RequestChannelFrame(streamId: Int, initialRequestN: Int, payload: Payload): RequestFrame = RequestFrame(FrameType.RequestChannel, streamId, false, false, false, initialRequestN, payload) -fun NextPayloadFrame(streamId: Int, payload: Payload): RequestFrame = +internal fun NextPayloadFrame(streamId: Int, payload: Payload): RequestFrame = RequestFrame(FrameType.Payload, streamId, false, false, true, 0, payload) -fun CompletePayloadFrame(streamId: Int): RequestFrame = +internal fun CompletePayloadFrame(streamId: Int): RequestFrame = RequestFrame(FrameType.Payload, streamId, false, true, false, 0, Payload.Empty) -fun NextCompletePayloadFrame(streamId: Int, payload: Payload): RequestFrame = +internal fun NextCompletePayloadFrame(streamId: Int, payload: Payload): RequestFrame = RequestFrame(FrameType.Payload, streamId, false, true, true, 0, payload) diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/RequestNFrame.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/RequestNFrame.kt index 8ff35a636..a3405b691 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/RequestNFrame.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/RequestNFrame.kt @@ -18,11 +18,14 @@ package io.rsocket.kotlin.frame import io.ktor.utils.io.core.* -class RequestNFrame( +internal class RequestNFrame( override val streamId: Int, val requestN: Int, ) : Frame(FrameType.RequestN) { override val flags: Int get() = 0 + + override fun release(): Unit = Unit + override fun BytePacketBuilder.writeSelf() { writeInt(requestN) } @@ -34,7 +37,7 @@ class RequestNFrame( } } -fun ByteReadPacket.readRequestN(streamId: Int): RequestNFrame { +internal fun ByteReadPacket.readRequestN(streamId: Int): RequestNFrame { val requestN = readInt() return RequestNFrame(streamId, requestN) } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/ResumeFrame.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/ResumeFrame.kt index 07a81a422..e9b0b21f6 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/ResumeFrame.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/ResumeFrame.kt @@ -19,7 +19,7 @@ package io.rsocket.kotlin.frame import io.ktor.utils.io.core.* import io.rsocket.kotlin.frame.io.* -class ResumeFrame( +internal class ResumeFrame( val version: Version, val resumeToken: ByteReadPacket, val lastReceivedServerPosition: Long, @@ -27,6 +27,9 @@ class ResumeFrame( ) : Frame(FrameType.Resume) { override val streamId: Int get() = 0 override val flags: Int get() = 0 + + override fun release(): Unit = Unit + override fun BytePacketBuilder.writeSelf() { writeVersion(version) writeResumeToken(resumeToken) @@ -44,9 +47,9 @@ class ResumeFrame( } } -fun ByteReadPacket.readResume(): ResumeFrame { +internal fun ByteReadPacket.readResume(pool: BufferPool): ResumeFrame { val version = readVersion() - val resumeToken = readResumeToken() + val resumeToken = readResumeToken(pool) val lastReceivedServerPosition = readLong() val firstAvailableClientPosition = readLong() return ResumeFrame( diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/ResumeOkFrame.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/ResumeOkFrame.kt index 65ab39ec3..27a63f520 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/ResumeOkFrame.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/ResumeOkFrame.kt @@ -18,12 +18,14 @@ package io.rsocket.kotlin.frame import io.ktor.utils.io.core.* -class ResumeOkFrame( +internal class ResumeOkFrame( val lastReceivedClientPosition: Long, ) : Frame(FrameType.ResumeOk) { override val streamId: Int get() = 0 override val flags: Int get() = 0 + override fun release(): Unit = Unit + override fun BytePacketBuilder.writeSelf() { writeLong(lastReceivedClientPosition) } @@ -35,4 +37,4 @@ class ResumeOkFrame( } } -fun ByteReadPacket.readResumeOk(): ResumeOkFrame = ResumeOkFrame(readLong()) +internal fun ByteReadPacket.readResumeOk(): ResumeOkFrame = ResumeOkFrame(readLong()) diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/SetupFrame.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/SetupFrame.kt index f4d810d0d..54f551c6e 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/SetupFrame.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/SetupFrame.kt @@ -24,7 +24,7 @@ import io.rsocket.kotlin.payload.* private const val HonorLeaseFlag = 64 private const val ResumeEnabledFlag = 128 -class SetupFrame( +internal class SetupFrame( val version: Version, //TODO check val honorLease: Boolean, val keepAlive: KeepAlive, @@ -42,6 +42,11 @@ class SetupFrame( return flags } + override fun release() { + resumeToken?.release() + payload.release() + } + override fun BytePacketBuilder.writeSelf() { writeVersion(version) writeKeepAlive(keepAlive) @@ -65,12 +70,12 @@ class SetupFrame( } } -fun ByteReadPacket.readSetup(flags: Int): SetupFrame { +internal fun ByteReadPacket.readSetup(pool: BufferPool, flags: Int): SetupFrame { val version = readVersion() val keepAlive = readKeepAlive() - val resumeToken = if (flags check ResumeEnabledFlag) readResumeToken() else null + val resumeToken = if (flags check ResumeEnabledFlag) readResumeToken(pool) else null val payloadMimeType = readPayloadMimeType() - val payload = readPayload(flags) + val payload = readPayload(pool, flags) return SetupFrame( version = version, honorLease = flags check HonorLeaseFlag, diff --git a/rsocket-test/src/jsMain/kotlin/io/rsocket/kotlin/test/TestPacketStore.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/packet.kt similarity index 57% rename from rsocket-test/src/jsMain/kotlin/io/rsocket/kotlin/test/TestPacketStore.kt rename to rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/packet.kt index 33f6e2f3a..74001c6c0 100644 --- a/rsocket-test/src/jsMain/kotlin/io/rsocket/kotlin/test/TestPacketStore.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/packet.kt @@ -14,16 +14,23 @@ * limitations under the License. */ -package io.rsocket.kotlin.test +package io.rsocket.kotlin.frame.io import io.ktor.utils.io.core.* +import io.ktor.utils.io.core.internal.* +import io.ktor.utils.io.pool.* -actual class TestPacketStore actual constructor() { - private val _stored = mutableListOf() +@OptIn(DangerousInternalIoApi::class) +internal typealias BufferPool = ObjectPool - actual val stored: List get() = _stored - - actual fun store(packet: ByteReadPacket) { - _stored += packet +//TODO +internal inline fun buildPacket(pool: BufferPool, block: BytePacketBuilder.() -> Unit): ByteReadPacket { + val builder = BytePacketBuilder(0, pool) + try { + block(builder) + return builder.build() + } catch (t: Throwable) { + builder.release() + throw t } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/payload.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/payload.kt index 892273201..85d2fc6d4 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/payload.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/payload.kt @@ -19,25 +19,25 @@ package io.rsocket.kotlin.frame.io import io.ktor.utils.io.core.* import io.rsocket.kotlin.payload.* -fun ByteReadPacket.readMetadata(): ByteReadPacket { +internal fun ByteReadPacket.readMetadata(pool: BufferPool): ByteReadPacket { val length = readLength() - return readPacket(length) + return readPacket(pool, length) } -fun BytePacketBuilder.writeMetadata(metadata: ByteReadPacket?) { +internal fun BytePacketBuilder.writeMetadata(metadata: ByteReadPacket?) { metadata?.let { writeLength(it.remaining.toInt()) writePacket(it) } } -fun ByteReadPacket.readPayload(flags: Int): Payload { - val metadata = if (flags check Flags.Metadata) readMetadata() else null - val data = readPacket() +internal fun ByteReadPacket.readPayload(pool: BufferPool, flags: Int): Payload { + val metadata = if (flags check Flags.Metadata) readMetadata(pool) else null + val data = readPacket(pool) return Payload(data = data, metadata = metadata) } -fun BytePacketBuilder.writePayload(payload: Payload) { +internal fun BytePacketBuilder.writePayload(payload: Payload) { writeMetadata(payload.metadata) writePacket(payload.data) } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/util.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/util.kt index 80cc7f46b..8b806cb7b 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/util.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/util.kt @@ -21,12 +21,12 @@ import io.rsocket.kotlin.keepalive.* import io.rsocket.kotlin.payload.* import kotlin.time.* -fun ByteReadPacket.readResumeToken(): ByteReadPacket { +internal fun ByteReadPacket.readResumeToken(pool: BufferPool): ByteReadPacket { val length = readShort().toInt() and 0xFFFF - return readPacket(length) + return readPacket(pool, length) } -fun BytePacketBuilder.writeResumeToken(resumeToken: ByteReadPacket?) { +internal fun BytePacketBuilder.writeResumeToken(resumeToken: ByteReadPacket?) { resumeToken?.let { val length = it.remaining writeShort(length.toShort()) @@ -34,57 +34,57 @@ fun BytePacketBuilder.writeResumeToken(resumeToken: ByteReadPacket?) { } } -fun ByteReadPacket.readMimeType(): String { +internal fun ByteReadPacket.readMimeType(): String { val length = readByte().toInt() return readText(max = length) } -fun BytePacketBuilder.writeMimeType(mimeType: String) { +internal fun BytePacketBuilder.writeMimeType(mimeType: String) { val bytes = mimeType.encodeToByteArray() //TODO check writeByte(bytes.size.toByte()) writeFully(bytes) } -fun ByteReadPacket.readPayloadMimeType(): PayloadMimeType { +internal fun ByteReadPacket.readPayloadMimeType(): PayloadMimeType { val metadata = readMimeType() val data = readMimeType() return PayloadMimeType(data = data, metadata = metadata) } -fun BytePacketBuilder.writePayloadMimeType(payloadMimeType: PayloadMimeType) { +internal fun BytePacketBuilder.writePayloadMimeType(payloadMimeType: PayloadMimeType) { writeMimeType(payloadMimeType.metadata) writeMimeType(payloadMimeType.data) } @OptIn(ExperimentalTime::class) -fun ByteReadPacket.readMillis(): Duration = readInt().milliseconds +internal fun ByteReadPacket.readMillis(): Duration = readInt().milliseconds @OptIn(ExperimentalTime::class) -fun BytePacketBuilder.writeMillis(duration: Duration) { +internal fun BytePacketBuilder.writeMillis(duration: Duration) { writeInt(duration.toInt(DurationUnit.MILLISECONDS)) } -fun ByteReadPacket.readKeepAlive(): KeepAlive { +internal fun ByteReadPacket.readKeepAlive(): KeepAlive { val interval = readMillis() val maxLifetime = readMillis() return KeepAlive(interval = interval, maxLifetime = maxLifetime) } -fun BytePacketBuilder.writeKeepAlive(keepAlive: KeepAlive) { +internal fun BytePacketBuilder.writeKeepAlive(keepAlive: KeepAlive) { writeMillis(keepAlive.interval) writeMillis(keepAlive.maxLifetime) } -fun ByteReadPacket.readPacket(): ByteReadPacket { +internal fun ByteReadPacket.readPacket(pool: BufferPool): ByteReadPacket { if (isEmpty) return ByteReadPacket.Empty - return buildPacket { + return buildPacket(pool) { writePacket(this@readPacket) } } -fun ByteReadPacket.readPacket(length: Int): ByteReadPacket { +internal fun ByteReadPacket.readPacket(pool: BufferPool, length: Int): ByteReadPacket { if (length == 0) return ByteReadPacket.Empty - return buildPacket { + return buildPacket(pool) { writePacket(this@readPacket, length) } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Connect.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Connect.kt index 172bebeeb..b6ac86f9a 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Connect.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Connect.kt @@ -27,7 +27,6 @@ private suspend inline fun connect( connection: Connection, plugin: Plugin, setupFrame: SetupFrame, - noinline ignoredFrameConsumer: (Frame) -> Unit, noinline acceptor: RSocketAcceptor, crossinline beforeStart: suspend () -> Unit, ): Pair { @@ -37,7 +36,7 @@ private suspend inline fun connect( payloadMimeType = setupFrame.payloadMimeType, payload = setupFrame.payload ) - val state = RSocketState(connection, connectionSetup.keepAlive, ignoredFrameConsumer) + val state = RSocketState(connection, connectionSetup.keepAlive) val requester = RSocketRequester(state, StreamId(isServer)).let(plugin::wrapRequester) val wrappedAcceptor = acceptor.let(plugin::wrapAcceptor) val requestHandler = wrappedAcceptor(connectionSetup, requester).let(plugin::wrapResponder) @@ -49,17 +48,15 @@ internal suspend fun connectClient( connection: Connection, plugin: Plugin, setupFrame: SetupFrame, - ignoredFrameConsumer: (Frame) -> Unit, acceptor: RSocketAcceptor, -): RSocket = connect(isServer = false, connection, plugin, setupFrame, ignoredFrameConsumer, acceptor) { - connection.send(setupFrame.toPacket()) +): RSocket = connect(isServer = false, connection, plugin, setupFrame, acceptor) { + connection.sendFrame(setupFrame) }.first internal suspend fun connectServer( connection: Connection, plugin: Plugin, setupFrame: SetupFrame, - ignoredFrameConsumer: (Frame) -> Unit, acceptor: RSocketAcceptor, -): Job = connect(isServer = true, connection, plugin, setupFrame, ignoredFrameConsumer, acceptor) { +): Job = connect(isServer = true, connection, plugin, setupFrame, acceptor) { }.second diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Prioritizer.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Prioritizer.kt index bff1b5a5f..6d5708c1a 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Prioritizer.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Prioritizer.kt @@ -20,7 +20,7 @@ import io.rsocket.kotlin.frame.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.selects.* -class Prioritizer { +internal class Prioritizer { private val priorityChannel = Channel(Channel.UNLIMITED) private val commonChannel = Channel(Channel.UNLIMITED) @@ -42,6 +42,8 @@ class Prioritizer { } fun close(throwable: Throwable?) { + priorityChannel.closeReceivedElements() + commonChannel.closeReceivedElements() priorityChannel.close(throwable) commonChannel.close(throwable) } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt index 754a9ce5c..d8b9cbacd 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt @@ -29,22 +29,24 @@ internal class RSocketRequester( private val streamId: StreamId, ) : RSocket, Cancelable by state { - override suspend fun metadataPush(metadata: ByteReadPacket) { + override suspend fun metadataPush(metadata: ByteReadPacket): Unit = metadata.closeOnError { checkAvailable() state.sendPrioritized(MetadataPushFrame(metadata)) } - override suspend fun fireAndForget(payload: Payload) { + override suspend fun fireAndForget(payload: Payload): Unit = payload.closeOnError { val streamId = createStream() state.send(RequestFireAndForgetFrame(streamId, payload)) } override suspend fun requestResponse(payload: Payload): Payload = with(state) { - val streamId = createStream() - val receiver = createReceiverFor(streamId) - send(RequestResponseFrame(streamId, payload)) - return consumeReceiverFor(streamId) { - receiver.receive().payload //TODO fragmentation + payload.closeOnError { + val streamId = createStream() + val receiver = createReceiverFor(streamId) + send(RequestResponseFrame(streamId, payload)) + consumeReceiverFor(streamId) { + receiver.receive().payload //TODO fragmentation + } } } @@ -67,3 +69,12 @@ internal class RSocketRequester( } } + +internal inline fun Closeable.closeOnError(block: () -> T): T { + try { + return block() + } catch (e: Throwable) { + close() + throw e + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt index 2915b8391..78f2ea127 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt @@ -29,12 +29,16 @@ internal class RSocketResponder( fun handleMetadataPush(frame: MetadataPushFrame) { state.launch { requestHandler.metadataPush(frame.metadata) + }.invokeOnCompletion { + frame.release() } } fun handleFireAndForget(frame: RequestFrame) { state.launch { requestHandler.fireAndForget(frame.payload) + }.invokeOnCompletion { + frame.release() } } @@ -45,6 +49,8 @@ internal class RSocketResponder( requestHandler.requestResponse(frame.payload) } ?: return@launchCancelable if (isActive) send(NextCompletePayloadFrame(streamId, response)) + }.invokeOnCompletion { + frame.release() } } @@ -58,6 +64,9 @@ internal class RSocketResponder( streamId, RequestStreamResponderFlowCollector(state, streamId, initFrame.initialRequest) ) + send(CompletePayloadFrame(streamId)) + }.invokeOnCompletion { + initFrame.release() } } @@ -75,7 +84,10 @@ internal class RSocketResponder( streamId, RequestStreamResponderFlowCollector(state, streamId, initFrame.initialRequest) ) + send(CompletePayloadFrame(streamId)) }.invokeOnCompletion { + initFrame.release() + receiver.closeReceivedElements() if (it != null) receiver.cancelConsumed(it) //TODO check it } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt index 83ff30401..86617de80 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt @@ -16,6 +16,7 @@ package io.rsocket.kotlin.internal +import io.ktor.utils.io.core.* import io.rsocket.kotlin.* import io.rsocket.kotlin.connection.* import io.rsocket.kotlin.frame.* @@ -33,13 +34,12 @@ import kotlinx.coroutines.flow.* internal class RSocketState( private val connection: Connection, keepAlive: KeepAlive, - val ignoredFrameConsumer: (Frame) -> Unit, ) : Cancelable by connection { private val prioritizer = Prioritizer() private val requestScope = CoroutineScope(SupervisorJob(job)) private val scope = CoroutineScope(job) - val receivers: IntMap> = IntMap() + val receivers: IntMap> = IntMap() private val senders: IntMap = IntMap() private val limits: IntMap = IntMap() @@ -70,7 +70,10 @@ internal class RSocketState( } finally { if (isActive && streamId in receivers) { if (cause != null) send(CancelFrame(streamId)) - receivers.remove(streamId)?.close(cause) + receivers.remove(streamId)?.apply { + closeReceivedElements() + close(cause) + } } } } @@ -82,7 +85,6 @@ internal class RSocketState( limits[streamId] = limitingCollector try { collect(limitingCollector) - send(CompletePayloadFrame(streamId)) } catch (e: Throwable) { limits.remove(streamId) //if isn't active, then, that stream was cancelled, and so no need for error frame @@ -103,17 +105,32 @@ internal class RSocketState( private fun handleFrame(responder: RSocketResponder, frame: Frame) { when (val streamId = frame.streamId) { 0 -> when (frame) { - is ErrorFrame -> cancel("Zero stream error", frame.throwable) + is ErrorFrame -> { + cancel("Zero stream error", frame.throwable) + frame.release() //TODO + } is KeepAliveFrame -> keepAliveHandler.receive(frame) - is LeaseFrame -> error("lease isn't implemented") + is LeaseFrame -> { + frame.release() + error("lease isn't implemented") + } is MetadataPushFrame -> responder.handleMetadataPush(frame) - else -> ignoredFrameConsumer(frame) + else -> { + //TODO log + frame.release() + } } else -> when (frame) { is RequestNFrame -> limits[streamId]?.updateRequests(frame.requestN) is CancelFrame -> senders.remove(streamId)?.cancel() - is ErrorFrame -> receivers.remove(streamId)?.close(frame.throwable) + is ErrorFrame -> { + receivers.remove(streamId)?.apply { + closeReceivedElements() + close(frame.throwable) + } + frame.release() + } is RequestFrame -> when (frame.type) { FrameType.Payload -> receivers[streamId]?.offer(frame) FrameType.RequestFnF -> responder.handleFireAndForget(frame) @@ -122,7 +139,10 @@ internal class RSocketState( FrameType.RequestChannel -> responder.handleRequestChannel(frame) else -> error("never happens") } - else -> ignoredFrameConsumer(frame) + else -> { + //TODO log + frame.release() + } } } } @@ -133,17 +153,23 @@ internal class RSocketState( requestHandler.job.invokeOnCompletion { cancel("Request handled stopped", it) } job.invokeOnCompletion { error -> requestHandler.cancel("Connection closed", error) - receivers.values().forEach { it.close((error as? CancellationException)?.cause ?: error) } + receivers.values().forEach { + it.closeReceivedElements() + it.close((error as? CancellationException)?.cause ?: error) + } receivers.clear() limits.clear() senders.clear() prioritizer.close(error) } scope.launch { - while (connection.isActive) connection.send(prioritizer.receive().toPacket()) + while (connection.isActive) connection.sendFrame(prioritizer.receive()) } scope.launch { - while (connection.isActive) handleFrame(responder, connection.receive().toFrame()) + while (connection.isActive) { + val frame = connection.receiveFrame() + frame.closeOnError { handleFrame(responder, frame) } + } } return job } @@ -152,3 +178,10 @@ internal class RSocketState( internal fun ReceiveChannel<*>.cancelConsumed(cause: Throwable?) { cancel(cause?.let { it as? CancellationException ?: CancellationException("Channel was consumed, consumer had failed", it) }) } + +internal fun ReceiveChannel.closeReceivedElements() { + try { + while (true) poll()?.close() ?: break + } catch (e: Throwable) { + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/LimitingFlowCollector.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/LimitingFlowCollector.kt index 2315eb250..7347c6228 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/LimitingFlowCollector.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/LimitingFlowCollector.kt @@ -16,6 +16,7 @@ package io.rsocket.kotlin.internal.flow +import io.rsocket.kotlin.internal.* import io.rsocket.kotlin.payload.* import kotlinx.atomicfu.* import kotlinx.coroutines.* @@ -33,7 +34,7 @@ internal abstract class LimitingFlowCollector(initial: Int) : FlowCollector): Unit = with(state) { - val streamId = requester.createStream() - val receiver = createReceiverFor(streamId) - send(RequestStreamFrame(streamId, requestSize, payload)) - collectStream(streamId, receiver, collectContext, collector) + payload.closeOnError { + val streamId = requester.createStream() + val receiver = createReceiverFor(streamId) + send(RequestStreamFrame(streamId, requestSize, payload)) + collectStream(streamId, receiver, collectContext, collector) + } } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/payload/Payload.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/payload/Payload.kt index 318282d0a..f912b4091 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/payload/Payload.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/payload/Payload.kt @@ -21,20 +21,24 @@ import io.ktor.utils.io.core.* class Payload( val data: ByteReadPacket, val metadata: ByteReadPacket? = null, -) { - companion object { - val Empty = Payload(ByteReadPacket.Empty) +) : Closeable { + + fun copy(): Payload = Payload(data.copy(), metadata?.copy()) + + fun release() { + data.release() + metadata?.release() } -} -fun Payload.copy(): Payload = Payload(data.copy(), metadata?.copy()) + override fun close() { + release() + } -fun Payload.release() { - data.release() - metadata?.release() + companion object { + val Empty = Payload(ByteReadPacket.Empty) + } } -@Suppress("FunctionName") fun Payload(data: String, metadata: String? = null): Payload = Payload( data = buildPacket { writeText(data) }, metadata = metadata?.let { buildPacket { writeText(it) } } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/SetupRejectionTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/SetupRejectionTest.kt index c9b97bf6c..e1455bc0d 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/SetupRejectionTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/SetupRejectionTest.kt @@ -27,7 +27,7 @@ import io.rsocket.kotlin.test.* import kotlinx.coroutines.* import kotlin.test.* -class SetupRejectionTest : SuspendTest { +class SetupRejectionTest : SuspendTest, TestWithLeakCheck { @Test fun responderRejectSetup() = test { val errorMessage = "error" @@ -44,13 +44,16 @@ class SetupRejectionTest : SuspendTest { assertFailsWith(RSocketError.Setup.Rejected::class, errorMessage) { server.start(acceptor) } - val frame = connection.receiveFromSender() - assertTrue(frame is ErrorFrame) - assertTrue(frame.throwable is RSocketError.Setup.Rejected) - assertEquals(errorMessage, frame.throwable.message) - - val sender = sendingRSocket.await() - assertFalse(sender.isActive) + connection.test { + expectFrame { frame -> + assertTrue(frame is ErrorFrame) + assertTrue(frame.throwable is RSocketError.Setup.Rejected) + assertEquals(errorMessage, frame.throwable.message) + assertEquals(errorMessage, frame.data?.readText()) + } + val sender = sendingRSocket.await() + assertFalse(sender.isActive) + } } // @Test diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/CancelFrameTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/CancelFrameTest.kt index a77669a90..368dbc1e5 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/CancelFrameTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/CancelFrameTest.kt @@ -16,16 +16,17 @@ package io.rsocket.kotlin.frame +import io.rsocket.kotlin.test.* import kotlin.test.* -class CancelFrameTest { +class CancelFrameTest : TestWithLeakCheck { private val streamId = 1 @Test fun testEncoding() { val frame = CancelFrame(streamId) - val decodedFrame = frame.toPacket().toFrame() + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is CancelFrame) assertEquals(streamId, decodedFrame.streamId) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/ErrorFrameTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/ErrorFrameTest.kt index 947f12891..3a2fe362e 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/ErrorFrameTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/ErrorFrameTest.kt @@ -19,9 +19,10 @@ package io.rsocket.kotlin.frame import io.ktor.util.* import io.ktor.utils.io.core.* import io.rsocket.kotlin.error.* +import io.rsocket.kotlin.test.* import kotlin.test.* -class ErrorFrameTest { +class ErrorFrameTest : TestWithLeakCheck { private val dump = "00000b000000012c000000020164" @@ -35,7 +36,7 @@ class ErrorFrameTest { @Test fun testDecoding() { - val packet = ByteReadPacket(hex(dump)) + val packet = packet(hex(dump)) val frame = packet.toFrameWithLength() assertTrue(frame is ErrorFrame) @@ -43,6 +44,7 @@ class ErrorFrameTest { assertEquals(ErrorCode.ApplicationError, frame.errorCode) assertTrue(frame.throwable is RSocketError.ApplicationError) assertEquals("d", frame.throwable.message) + assertEquals("d", frame.data?.readText()) } } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/ExtensionFrameTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/ExtensionFrameTest.kt index 5979a5601..331fd960d 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/ExtensionFrameTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/ExtensionFrameTest.kt @@ -18,9 +18,10 @@ package io.rsocket.kotlin.frame import io.ktor.utils.io.core.* import io.rsocket.kotlin.payload.* +import io.rsocket.kotlin.test.* import kotlin.test.* -class ExtensionFrameTest { +class ExtensionFrameTest : TestWithLeakCheck { private val streamId = 1 private val extendedType = 1 @@ -29,8 +30,8 @@ class ExtensionFrameTest { @Test fun testData() { - val frame = ExtensionFrame(streamId, extendedType, Payload(data)) - val decodedFrame = frame.toPacket().toFrame() + val frame = ExtensionFrame(streamId, extendedType, payload(data)) + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is ExtensionFrame) assertEquals(streamId, decodedFrame.streamId) @@ -42,7 +43,7 @@ class ExtensionFrameTest { @Test fun testMetadata() { val frame = ExtensionFrame(1, extendedType, Payload(ByteReadPacket.Empty, packet(metadata))) - val decodedFrame = frame.toPacket().toFrame() + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is ExtensionFrame) assertEquals(streamId, decodedFrame.streamId) @@ -53,8 +54,8 @@ class ExtensionFrameTest { @Test fun testDataMetadata() { - val frame = ExtensionFrame(streamId, extendedType, Payload(data, metadata)) - val decodedFrame = frame.toPacket().toFrame() + val frame = ExtensionFrame(streamId, extendedType, payload(data, metadata)) + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is ExtensionFrame) assertEquals(streamId, decodedFrame.streamId) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/KeepAliveFrameTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/KeepAliveFrameTest.kt index 4e69a0334..6b43410fe 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/KeepAliveFrameTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/KeepAliveFrameTest.kt @@ -18,9 +18,10 @@ package io.rsocket.kotlin.frame import io.ktor.util.* import io.ktor.utils.io.core.* +import io.rsocket.kotlin.test.* import kotlin.test.* -class KeepAliveFrameTest { +class KeepAliveFrameTest : TestWithLeakCheck { private val dump = "00000f000000000c80000000000000000064" @Test @@ -33,7 +34,7 @@ class KeepAliveFrameTest { @Test fun testDecoding() { - val packet = ByteReadPacket(hex(dump)) + val packet = packet(hex(dump)) val frame = packet.toFrameWithLength() assertTrue(frame is KeepAliveFrame) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/LeaseFrameTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/LeaseFrameTest.kt index 230ee71c8..fc5e18be3 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/LeaseFrameTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/LeaseFrameTest.kt @@ -16,9 +16,10 @@ package io.rsocket.kotlin.frame +import io.rsocket.kotlin.test.* import kotlin.test.* -class LeaseFrameTest { +class LeaseFrameTest : TestWithLeakCheck { private val ttl = 1 private val numberOfRequests = 42 @@ -27,7 +28,7 @@ class LeaseFrameTest { @Test fun testMetadata() { val frame = LeaseFrame(ttl, numberOfRequests, packet(metadata)) - val decodedFrame = frame.toPacket().toFrame() + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is LeaseFrame) assertEquals(0, decodedFrame.streamId) @@ -39,7 +40,7 @@ class LeaseFrameTest { @Test fun testNoMetadata() { val frame = LeaseFrame(ttl, numberOfRequests, null) - val decodedFrame = frame.toPacket().toFrame() + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is LeaseFrame) assertEquals(0, decodedFrame.streamId) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/MetadataPushFrameTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/MetadataPushFrameTest.kt index 5b18a25a0..f7a5abab8 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/MetadataPushFrameTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/MetadataPushFrameTest.kt @@ -17,16 +17,17 @@ package io.rsocket.kotlin.frame import io.ktor.utils.io.core.* +import io.rsocket.kotlin.test.* import kotlin.test.* -class MetadataPushFrameTest { +class MetadataPushFrameTest : TestWithLeakCheck { private val metadata = ByteArray(65000) { 6 } @Test fun testEncoding() { - val frame = MetadataPushFrame(ByteReadPacket(metadata)) - val decodedFrame = frame.toPacket().toFrame() + val frame = MetadataPushFrame(packet(metadata)) + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is MetadataPushFrame) assertEquals(0, decodedFrame.streamId) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/PayloadFrameTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/PayloadFrameTest.kt index a7f1d8e85..74b07120f 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/PayloadFrameTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/PayloadFrameTest.kt @@ -18,14 +18,15 @@ package io.rsocket.kotlin.frame import io.ktor.utils.io.core.* import io.rsocket.kotlin.payload.* +import io.rsocket.kotlin.test.* import kotlin.test.* -class PayloadFrameTest { +class PayloadFrameTest : TestWithLeakCheck { @Test fun testNextCompleteDataMetadata() { - val frame = NextCompletePayloadFrame(3, Payload("d", "md")) - val decodedFrame = frame.toPacket().toFrame() + val frame = NextCompletePayloadFrame(3, payload("d", "md")) + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is RequestFrame) assertEquals(FrameType.Payload, frame.type) @@ -39,8 +40,8 @@ class PayloadFrameTest { @Test fun testNextCompleteData() { - val frame = NextCompletePayloadFrame(3, Payload("d")) - val decodedFrame = frame.toPacket().toFrame() + val frame = NextCompletePayloadFrame(3, payload("d")) + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is RequestFrame) assertEquals(FrameType.Payload, frame.type) @@ -55,7 +56,7 @@ class PayloadFrameTest { @Test fun testNextCompleteMetadata() { val frame = NextCompletePayloadFrame(3, Payload(ByteReadPacket.Empty, packet("md"))) - val decodedFrame = frame.toPacket().toFrame() + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is RequestFrame) assertEquals(FrameType.Payload, frame.type) @@ -69,8 +70,8 @@ class PayloadFrameTest { @Test fun testNextDataMetadata() { - val frame = NextPayloadFrame(3, Payload("d", "md")) - val decodedFrame = frame.toPacket().toFrame() + val frame = NextPayloadFrame(3, payload("d", "md")) + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is RequestFrame) assertEquals(FrameType.Payload, frame.type) @@ -84,8 +85,8 @@ class PayloadFrameTest { @Test fun testNextData() { - val frame = NextPayloadFrame(3, Payload("d")) - val decodedFrame = frame.toPacket().toFrame() + val frame = NextPayloadFrame(3, payload("d")) + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is RequestFrame) assertEquals(FrameType.Payload, frame.type) @@ -100,7 +101,7 @@ class PayloadFrameTest { @Test fun testNextDataEmptyMetadata() { val frame = NextPayloadFrame(3, Payload(packet("d"), ByteReadPacket.Empty)) - val decodedFrame = frame.toPacket().toFrame() + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is RequestFrame) assertEquals(FrameType.Payload, frame.type) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/RequestFireAndForgetFrameTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/RequestFireAndForgetFrameTest.kt index 72e8e5eaa..f4662139a 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/RequestFireAndForgetFrameTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/RequestFireAndForgetFrameTest.kt @@ -16,15 +16,15 @@ package io.rsocket.kotlin.frame -import io.rsocket.kotlin.payload.* +import io.rsocket.kotlin.test.* import kotlin.test.* -class RequestFireAndForgetFrameTest { +class RequestFireAndForgetFrameTest : TestWithLeakCheck { @Test fun testData() { - val frame = RequestFireAndForgetFrame(3, Payload("d")) - val decodedFrame = frame.toPacket().toFrame() + val frame = RequestFireAndForgetFrame(3, payload("d")) + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is RequestFrame) assertEquals(FrameType.RequestFnF, decodedFrame.type) @@ -38,8 +38,8 @@ class RequestFireAndForgetFrameTest { @Test fun testDataMetadata() { - val frame = RequestFireAndForgetFrame(3, Payload("d", "md")) - val decodedFrame = frame.toPacket().toFrame() + val frame = RequestFireAndForgetFrame(3, payload("d", "md")) + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is RequestFrame) assertEquals(FrameType.RequestFnF, decodedFrame.type) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/RequestNFrameTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/RequestNFrameTest.kt index 87967bb87..3ac077d6c 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/RequestNFrameTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/RequestNFrameTest.kt @@ -18,9 +18,10 @@ package io.rsocket.kotlin.frame import io.ktor.util.* import io.ktor.utils.io.core.* +import io.rsocket.kotlin.test.* import kotlin.test.* -class RequestNFrameTest { +class RequestNFrameTest : TestWithLeakCheck { private val dump = "00000a00000001200000000005" @@ -34,7 +35,7 @@ class RequestNFrameTest { @Test fun testDecoding() { - val packet = ByteReadPacket(hex(dump)) + val packet = packet(hex(dump)) val frame = packet.toFrameWithLength() assertTrue(frame is RequestNFrame) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/RequestResponseFrameTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/RequestResponseFrameTest.kt index a4b352e64..a85c7d9be 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/RequestResponseFrameTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/RequestResponseFrameTest.kt @@ -17,15 +17,15 @@ package io.rsocket.kotlin.frame import io.ktor.utils.io.core.* -import io.rsocket.kotlin.payload.* +import io.rsocket.kotlin.test.* import kotlin.test.* -class RequestResponseFrameTest { +class RequestResponseFrameTest : TestWithLeakCheck { @Test fun testData() { - val frame = RequestResponseFrame(3, Payload("d")) - val decodedFrame = frame.toPacket().toFrame() + val frame = RequestResponseFrame(3, payload("d")) + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is RequestFrame) assertEquals(FrameType.RequestResponse, decodedFrame.type) @@ -39,8 +39,8 @@ class RequestResponseFrameTest { @Test fun testDataMetadata() { - val frame = RequestResponseFrame(3, Payload("d", "md")) - val decodedFrame = frame.toPacket().toFrame() + val frame = RequestResponseFrame(3, payload("d", "md")) + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is RequestFrame) assertEquals(FrameType.RequestResponse, decodedFrame.type) @@ -54,12 +54,11 @@ class RequestResponseFrameTest { @Test fun testBigDataMetadata() { - val payload = Payload { - metadata(ByteArray(6000) { 3 }) - data(ByteArray(7000) { 5 }) - } + val data = ByteArray(7000) { 5 } + val metadata = ByteArray(6000) { 3 } + val payload = payload(data, metadata) val frame = RequestResponseFrame(3, payload) - val decodedFrame = buildPacket { writeText(frame.toPacket().readText()) }.toFrame() + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is RequestFrame) assertEquals(FrameType.RequestResponse, decodedFrame.type) @@ -67,8 +66,8 @@ class RequestResponseFrameTest { assertFalse(decodedFrame.follows) assertFalse(decodedFrame.complete) assertFalse(decodedFrame.next) - assertBytesEquals(ByteArray(7000) { 5 }, decodedFrame.payload.data.readBytes()) - assertBytesEquals(ByteArray(6000) { 3 }, decodedFrame.payload.metadata?.readBytes()) + assertBytesEquals(data, decodedFrame.payload.data.readBytes()) + assertBytesEquals(metadata, decodedFrame.payload.metadata?.readBytes()) } } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/RequestStreamFrameTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/RequestStreamFrameTest.kt index 1d7ecc481..e0e69884b 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/RequestStreamFrameTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/RequestStreamFrameTest.kt @@ -19,14 +19,15 @@ package io.rsocket.kotlin.frame import io.ktor.util.* import io.ktor.utils.io.core.* import io.rsocket.kotlin.payload.* +import io.rsocket.kotlin.test.* import kotlin.test.* -class RequestStreamFrameTest { +class RequestStreamFrameTest : TestWithLeakCheck { @Test fun testEncoding() { val dump = "000010000000011900000000010000026d6464" - val frame = RequestStreamFrame(1, 1, Payload("d", "md")) + val frame = RequestStreamFrame(1, 1, payload("d", "md")) val bytes = frame.toPacketWithLength().readBytes() assertEquals(dump, hex(bytes)) @@ -35,7 +36,7 @@ class RequestStreamFrameTest { @Test fun testDecoding() { val dump = "000010000000011900000000010000026d6464" - val frame = ByteReadPacket(hex(dump)).toFrameWithLength() + val frame = packet(hex(dump)).toFrameWithLength() assertTrue(frame is RequestFrame) assertEquals(FrameType.RequestStream, frame.type) @@ -60,7 +61,7 @@ class RequestStreamFrameTest { @Test fun testDecodingWithEmptyMetadata() { val dump = "00000e0000000119000000000100000064" - val frame = ByteReadPacket(hex(dump)).toFrameWithLength() + val frame = packet(hex(dump)).toFrameWithLength() assertTrue(frame is RequestFrame) assertEquals(FrameType.RequestStream, frame.type) @@ -76,7 +77,7 @@ class RequestStreamFrameTest { @Test fun testEncodingWithNullMetadata() { val dump = "00000b0000000118000000000164" - val frame = RequestStreamFrame(1, 1, Payload("d")) + val frame = RequestStreamFrame(1, 1, payload("d")) val bytes = frame.toPacketWithLength().readBytes() assertEquals(dump, hex(bytes)) @@ -85,7 +86,7 @@ class RequestStreamFrameTest { @Test fun testDecodingWithNullMetadata() { val dump = "00000b0000000118000000000164" - val frame = ByteReadPacket(hex(dump)).toFrameWithLength() + val frame = packet(hex(dump)).toFrameWithLength() assertTrue(frame is RequestFrame) assertEquals(FrameType.RequestStream, frame.type) @@ -101,7 +102,7 @@ class RequestStreamFrameTest { @Test fun testEmptyData() { val frame = RequestStreamFrame(3, 10, Payload(ByteReadPacket.Empty, packet("md"))) - val decodedFrame = frame.toPacket().toFrame() + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is RequestFrame) assertEquals(FrameType.RequestStream, frame.type) @@ -117,7 +118,7 @@ class RequestStreamFrameTest { @Test fun testEmptyPayload() { val frame = RequestStreamFrame(3, 10, Payload(ByteReadPacket.Empty, ByteReadPacket.Empty)) - val decodedFrame = frame.toPacket().toFrame() + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is RequestFrame) assertEquals(FrameType.RequestStream, frame.type) @@ -132,8 +133,8 @@ class RequestStreamFrameTest { @Test fun testMaxRequestN() { - val frame = RequestStreamFrame(3, Int.MAX_VALUE, Payload("d", "md")) - val decodedFrame = frame.toPacket().toFrame() + val frame = RequestStreamFrame(3, Int.MAX_VALUE, payload("d", "md")) + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is RequestFrame) assertEquals(FrameType.RequestStream, frame.type) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/ResumeFrameTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/ResumeFrameTest.kt index de4b66bad..8418a5937 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/ResumeFrameTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/ResumeFrameTest.kt @@ -18,9 +18,10 @@ package io.rsocket.kotlin.frame import io.ktor.utils.io.core.* import io.rsocket.kotlin.frame.io.* +import io.rsocket.kotlin.test.* import kotlin.test.* -class ResumeFrameTest { +class ResumeFrameTest : TestWithLeakCheck { private val version = Version.Current private val lastReceivedServerPosition = 21L @@ -29,8 +30,8 @@ class ResumeFrameTest { @Test fun testBigToken() { val token = ByteArray(65000) { 1 } - val frame = ResumeFrame(version, ByteReadPacket(token), lastReceivedServerPosition, firstAvailableClientPosition) - val decodedFrame = frame.toPacket().toFrame() + val frame = ResumeFrame(version, packet(token), lastReceivedServerPosition, firstAvailableClientPosition) + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is ResumeFrame) assertEquals(0, decodedFrame.streamId) @@ -43,9 +44,9 @@ class ResumeFrameTest { @Test fun testBigChunkedToken() { val token = ByteArray(63000) { 1 } - val packet = buildPacket { writeFully(token) } + val packet = packet(token) val frame = ResumeFrame(version, packet, lastReceivedServerPosition, firstAvailableClientPosition) - val decodedFrame = frame.toPacket().toFrame() + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is ResumeFrame) assertEquals(0, decodedFrame.streamId) @@ -58,8 +59,8 @@ class ResumeFrameTest { @Test fun testSmallToken() { val token = ByteArray(100) { 1 } - val frame = ResumeFrame(version, ByteReadPacket(token), lastReceivedServerPosition, firstAvailableClientPosition) - val decodedFrame = frame.toPacket().toFrame() + val frame = ResumeFrame(version, packet(token), lastReceivedServerPosition, firstAvailableClientPosition) + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is ResumeFrame) assertEquals(0, decodedFrame.streamId) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/ResumeOkFrameTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/ResumeOkFrameTest.kt index 4b00ed5d0..23f92d8f3 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/ResumeOkFrameTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/ResumeOkFrameTest.kt @@ -16,16 +16,17 @@ package io.rsocket.kotlin.frame +import io.rsocket.kotlin.test.* import kotlin.test.* -class ResumeOkFrameTest { +class ResumeOkFrameTest : TestWithLeakCheck { private val lastReceivedClientPosition = 42L @Test fun testEncoding() { val frame = ResumeOkFrame(lastReceivedClientPosition) - val decodedFrame = frame.toPacket().toFrame() + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is ResumeOkFrame) assertEquals(0, decodedFrame.streamId) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/SetupFrameTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/SetupFrameTest.kt index 906e818a3..57a6972da 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/SetupFrameTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/SetupFrameTest.kt @@ -20,10 +20,11 @@ import io.ktor.utils.io.core.* import io.rsocket.kotlin.frame.io.* import io.rsocket.kotlin.keepalive.* import io.rsocket.kotlin.payload.* +import io.rsocket.kotlin.test.* import kotlin.test.* import kotlin.time.* -class SetupFrameTest { +class SetupFrameTest : TestWithLeakCheck { private val version = Version.Current private val keepAlive = KeepAlive(10.seconds, 500.seconds) @@ -32,7 +33,7 @@ class SetupFrameTest { @Test fun testNoResumeEmptyPayload() { val frame = SetupFrame(version, true, keepAlive, null, payloadMimeType, Payload.Empty) - val decodedFrame = frame.toPacket().toFrame() + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is SetupFrame) assertEquals(0, decodedFrame.streamId) @@ -47,9 +48,9 @@ class SetupFrameTest { @Test fun testNoResumeBigPayload() { - val payload = Payload(ByteArray(30000) { 1 }, ByteArray(20000) { 5 }) + val payload = payload(ByteArray(30000) { 1 }, ByteArray(20000) { 5 }) val frame = SetupFrame(version, true, keepAlive, null, payloadMimeType, payload) - val decodedFrame = frame.toPacket().toFrame() + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is SetupFrame) assertEquals(0, decodedFrame.streamId) @@ -64,9 +65,9 @@ class SetupFrameTest { @Test fun testResumeBigTokenEmptyPayload() { - val resumeToken = ByteReadPacket(ByteArray(65000) { 5 }) + val resumeToken = packet(ByteArray(65000) { 5 }) val frame = SetupFrame(version, true, keepAlive, resumeToken, payloadMimeType, Payload.Empty) - val decodedFrame = frame.toPacket().toFrame() + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is SetupFrame) assertEquals(0, decodedFrame.streamId) @@ -81,10 +82,10 @@ class SetupFrameTest { @Test fun testResumeBigTokenBigPayload() { - val resumeToken = ByteReadPacket(ByteArray(65000) { 5 }) - val payload = Payload(ByteArray(30000) { 1 }, ByteArray(20000) { 5 }) + val resumeToken = packet(ByteArray(65000) { 5 }) + val payload = payload(ByteArray(30000) { 1 }, ByteArray(20000) { 5 }) val frame = SetupFrame(version, true, keepAlive, resumeToken, payloadMimeType, payload) - val decodedFrame = frame.toPacket().toFrame() + val decodedFrame = frame.loopFrame() assertTrue(decodedFrame is SetupFrame) assertEquals(0, decodedFrame.streamId) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/Util.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/Util.kt index 2df7872c4..c06f37de4 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/Util.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/frame/Util.kt @@ -16,13 +16,13 @@ package io.rsocket.kotlin.frame -import io.ktor.util.* import io.ktor.utils.io.core.* import io.rsocket.kotlin.frame.io.* +import io.rsocket.kotlin.test.* import kotlin.test.* -fun Frame.toPacketWithLength(): ByteReadPacket = buildPacket { - val packet = toPacket() +fun Frame.toPacketWithLength(): ByteReadPacket = buildPacket(InUseTrackingPool) { + val packet = toPacket(InUseTrackingPool) writeLength(packet.remaining.toInt()) writePacket(packet) } @@ -30,11 +30,7 @@ fun Frame.toPacketWithLength(): ByteReadPacket = buildPacket { fun ByteReadPacket.toFrameWithLength(): Frame { val length = readLength() assertEquals(length, remaining.toInt()) - return toFrame() + return readFrame(InUseTrackingPool) } -fun packet(text: String): ByteReadPacket = buildPacket { writeText(text) } - -fun assertBytesEquals(expected: ByteArray?, actual: ByteArray?) { - assertEquals(expected?.let(::hex), actual?.let(::hex)) -} +fun Frame.loopFrame(): Frame = toPacket(InUseTrackingPool).readFrame(InUseTrackingPool) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt index 9dacd7493..85e0f85c8 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt @@ -16,7 +16,6 @@ package io.rsocket.kotlin.internal -import app.cash.turbine.* import io.rsocket.kotlin.* import io.rsocket.kotlin.error.* import io.rsocket.kotlin.frame.* @@ -30,68 +29,60 @@ import kotlin.coroutines.* import kotlin.test.* import kotlin.time.* -class RSocketRequesterTest : TestWithConnection() { - lateinit var ignoredFrames: Channel +class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { private lateinit var requester: RSocketRequester override suspend fun before() { super.before() - ignoredFrames = Channel(Channel.UNLIMITED) - val state = RSocketState(connection, KeepAlive(1000.seconds, 1000.seconds), ignoredFrames::offer) + val state = RSocketState(connection, KeepAlive(1000.seconds, 1000.seconds)) requester = RSocketRequester(state, StreamId.client()) state.start(RSocketRequestHandler { }) } - @Test - fun testInvalidFrameOnStream0() = test { - connection.sendToReceiver(RequestNFrame(0, 5)) - val frame = ignoredFrames.receive() - assertTrue(frame is RequestNFrame) - } - @Test fun testStreamInitialN() = test { - val flow = requester.requestStream(Payload.Empty).buffer(5) - assertEquals(0, connection.sentFrames.size) - flow.launchIn(CoroutineScope(connection.job)) - delay(100) - assertEquals(1, connection.sentFrames.size) - val frame = connection.receiveFromSender() - assertTrue(frame is RequestFrame) - assertEquals(FrameType.RequestStream, frame.type) - assertEquals(5, frame.initialRequest) + connection.test { + val flow = requester.requestStream(Payload.Empty).buffer(5) + + expectNoEventsIn(200) + flow.launchIn(connection) + + expectFrame { frame -> + assertTrue(frame is RequestFrame) + assertEquals(FrameType.RequestStream, frame.type) + assertEquals(5, frame.initialRequest) + } + + expectNoEventsIn(200) + } } @Test fun testStreamBuffer() = test { - val flow = - requester.requestStream(Payload.Empty) - .buffer(2) - .take(2) + connection.test { + val flow = requester.requestStream(Payload.Empty).buffer(2).take(2) - assertEquals(0, connection.sentFrames.size) + expectNoEventsIn(200) + flow.launchIn(connection) - flow.launchIn(CoroutineScope(connection.job)) - - connection.sentAsFlow().test { - expectItem().let { frame -> + expectFrame { frame -> assertTrue(frame is RequestFrame) assertEquals(FrameType.RequestStream, frame.type) assertEquals(2, frame.initialRequest) } - delay(200) - expectNoEvents() + + expectNoEventsIn(200) connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) - delay(200) - expectNoEvents() + + expectNoEventsIn(200) connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) - delay(200) - expectItem().let { frame -> + + expectFrame { frame -> assertTrue(frame is CancelFrame) } - delay(200) - expectNoEvents() + + expectNoEventsIn(200) } } @@ -101,68 +92,61 @@ class RSocketRequesterTest : TestWithConnection() { @Test fun testStreamBufferWithAdditionalContext() = test { - val flow = - requester.requestStream(Payload.Empty) - .buffer(2) - .flowOn(SomeContext(2)) - .take(2) + connection.test { + val flow = requester.requestStream(Payload.Empty).buffer(2).flowOn(SomeContext(2)).take(2) - assertEquals(0, connection.sentFrames.size) + expectNoEventsIn(200) + flow.launchIn(connection) - flow.launchIn(CoroutineScope(connection.job)) - - connection.sentAsFlow().test { - expectItem().let { frame -> + expectFrame { frame -> assertTrue(frame is RequestFrame) assertEquals(FrameType.RequestStream, frame.type) assertEquals(2, frame.initialRequest) } - delay(200) - expectNoEvents() + + expectNoEventsIn(200) connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) - delay(200) - expectNoEvents() + + expectNoEventsIn(200) connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) - delay(200) - expectItem().let { frame -> + + expectFrame { frame -> assertTrue(frame is CancelFrame) } - delay(200) - expectNoEvents() + expectNoEventsIn(200) } } @Test //ignored on native because of dispatcher switching fun testStreamBufferWithAnotherDispatcher() = test(ignoreNative = true) { - val flow = - requester.requestStream(Payload.Empty) - .buffer(2) - .flowOn(anotherDispatcher) //change dispatcher before take - .take(2) - .transform { emit(it) } //force using SafeCollector to check that `Flow invariant is violated` not happens - - assertEquals(0, connection.sentFrames.size) - - flow.launchIn(CoroutineScope(connection.job)) - - connection.sentAsFlow().test { - expectItem().let { frame -> + connection.test { + val flow = + requester.requestStream(Payload.Empty) + .buffer(2) + .flowOn(anotherDispatcher) //change dispatcher before take + .take(2) + .transform { emit(it) } //force using SafeCollector to check that `Flow invariant is violated` not happens + + expectNoEventsIn(200) + flow.launchIn(connection) + + expectFrame { frame -> assertTrue(frame is RequestFrame) assertEquals(FrameType.RequestStream, frame.type) assertEquals(2, frame.initialRequest) } - delay(200) - expectNoEvents() + + expectNoEventsIn(200) connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) - delay(200) - expectNoEvents() //will fail here if `Flow invariant is violated` + + expectNoEventsIn(200) //will fail here if `Flow invariant is violated` connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) - delay(200) - expectItem().let { frame -> + + expectFrame { frame -> assertTrue(frame is CancelFrame) } - delay(200) - expectNoEvents() + + expectNoEventsIn(200) } } @@ -181,56 +165,75 @@ class RSocketRequesterTest : TestWithConnection() { fun testHandleApplicationException() = test { val errorMessage = "error" val deferred = GlobalScope.async { requester.requestResponse(Payload.Empty) } - delay(300) - assertEquals(1, connection.sentFrames.size) - val streamId = connection.sentFrames.first().streamId - connection.sendToReceiver(ErrorFrame(streamId, RSocketError.ApplicationError(errorMessage))) - assertFailsWith(RSocketError.ApplicationError::class, errorMessage) { deferred.await() } + + connection.test { + expectFrame { frame -> + val streamId = frame.streamId + connection.sendToReceiver(ErrorFrame(streamId, RSocketError.ApplicationError(errorMessage))) + } + assertFailsWith(RSocketError.ApplicationError::class, errorMessage) { deferred.await() } + } } @Test fun testHandleValidFrame() = test { - val deferred = GlobalScope.async { requester.requestResponse(Payload.Empty) } - delay(100) - assertEquals(1, connection.sentFrames.size) - val streamId = connection.sentFrames.first().streamId - connection.sendToReceiver(NextPayloadFrame(streamId, Payload.Empty)) - deferred.await() + connection.test { + val deferred = async { requester.requestResponse(Payload.Empty) } + expectFrame { frame -> + val streamId = frame.streamId + connection.sendToReceiver(NextPayloadFrame(streamId, Payload.Empty)) + } + deferred.await() + expectNoEventsIn(200) + } } @Test fun testRequestReplyWithCancel() = test { - withTimeoutOrNull(100.milliseconds) { requester.requestResponse(Payload.Empty) } - delay(100) - assertEquals(2, connection.sentFrames.size) - assertTrue(connection.sentFrames[0] is RequestFrame) - assertTrue(connection.sentFrames[1] is CancelFrame) + connection.test { + withTimeoutOrNull(100) { requester.requestResponse(Payload.Empty) } + + expectFrame { assertTrue(it is RequestFrame) } + expectFrame { assertTrue(it is CancelFrame) } + + expectNoEventsIn(200) + } } @Test fun testChannelRequestCancellation() = test { val job = Job() val request = flow { Job().join() }.onCompletion { job.complete() } - val response = requester.requestChannel(request).launchIn(CoroutineScope(connection.job)) + val response = requester.requestChannel(request).launchIn(connection) delay(100) response.cancelAndJoin() delay(200) assertTrue(job.isCompleted) } - @Test + // @Test fun testChannelRequestCancellationWithPayload() = test { val job = Job() val request = flow { repeat(100) { emit(Payload.Empty) } }.onCompletion { job.complete() } - val response = requester.requestChannel(request).launchIn(CoroutineScope(connection.job)) + val response = requester.requestChannel(request).launchIn(connection) delay(1000) response.cancelAndJoin() delay(100) assertTrue(job.isCompleted) - val sent = connection.sentFrames.size - assertTrue(sent > 0) - delay(100) - assertEquals(sent, connection.sentFrames.size) + connection.test { + while (true) { + try { + expectItem() + } catch (e: TimeoutCancellationException) { + + } + } +// expectComplete() + } +// val sent = connection.sentFrames.size +// assertTrue(sent > 0) +// delay(100) +// assertEquals(sent, connection.sentFrames.size) } @Test //ignored on native because of coroutines bug with channels @@ -238,20 +241,23 @@ class RSocketRequesterTest : TestWithConnection() { var ch: SendChannel? = null val request = channelFlow { ch = this - offer(Payload(byteArrayOf(1), byteArrayOf(2))) + offer(payload(byteArrayOf(1), byteArrayOf(2))) awaitClose() } - val response = requester.requestChannel(request).launchIn(CoroutineScope(connection.job)) - delay(200) - val requestFrame = connection.sentFrames.first() - assertTrue(requestFrame is RequestFrame) - assertEquals(FrameType.RequestChannel, requestFrame.type) - connection.sendToReceiver(CancelFrame(requestFrame.streamId), CompletePayloadFrame(requestFrame.streamId)) - response.join() - delay(100) - assertTrue(response.isCompleted) - assertEquals(1, connection.sentFrames.size) - assertTrue(ch!!.isClosedForSend) + val response = requester.requestChannel(request).launchIn(connection) + connection.test { + expectFrame { frame -> + val streamId = frame.streamId + assertTrue(frame is RequestFrame) + assertEquals(FrameType.RequestChannel, frame.type) + frame.release() + connection.sendToReceiver(CancelFrame(streamId), CompletePayloadFrame(streamId)) + } + response.join() + expectNoEventsIn(200) + assertTrue(response.isCompleted) + assertTrue(ch!!.isClosedForSend) + } } @Test @@ -259,30 +265,33 @@ class RSocketRequesterTest : TestWithConnection() { val delay = Job() val request = flow { delay.join() - emit(Payload("INIT")) + emit(payload("INIT")) repeat(1000) { - emit(Payload(it.toString())) + emit(payload(it.toString())) } } - requester.requestChannel(request).buffer(Int.MAX_VALUE).launchIn(CoroutineScope(connection.job)) - delay(100) - delay.complete() - delay(100) - assertEquals(1, connection.sentFrames.size) - delay(100) - assertEquals(1, connection.sentFrames.size) - val requestFrame = connection.sentFrames.first() - assertTrue(requestFrame is RequestFrame) - assertEquals(FrameType.RequestChannel, requestFrame.type) - assertEquals(Int.MAX_VALUE, requestFrame.initialRequest) - assertEquals("INIT", requestFrame.payload.data.readText()) + requester.requestChannel(request).buffer(Int.MAX_VALUE).launchIn(connection) + connection.test { + expectNoEventsIn(200) + delay.complete() + expectFrame { frame -> + assertTrue(frame is RequestFrame) + assertEquals(FrameType.RequestChannel, frame.type) + assertEquals(Int.MAX_VALUE, frame.initialRequest) + assertEquals("INIT", frame.payload.data.readText()) + } + expectNoEventsIn(200) + } } private fun streamIsTerminatedOnConnectionClose(request: suspend () -> Unit) = test { - launch(connection.job) { - delay(1.seconds) - connection.cancel() + connection.launch { + connection.test { + expectFrame { assertTrue(it is RequestFrame) } + connection.job.cancel() + expectNoEventsIn(200) + } } assertFailsWith(CancellationException::class) { request() } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketTest.kt index 75b31887b..aa1a7bd5e 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketTest.kt @@ -16,12 +16,14 @@ package io.rsocket.kotlin.internal +import app.cash.turbine.* import io.ktor.utils.io.core.* import io.rsocket.kotlin.* import io.rsocket.kotlin.connection.* import io.rsocket.kotlin.core.* import io.rsocket.kotlin.error.* import io.rsocket.kotlin.keepalive.* +import io.rsocket.kotlin.logging.* import io.rsocket.kotlin.payload.* import io.rsocket.kotlin.test.* import kotlinx.coroutines.* @@ -30,7 +32,7 @@ import kotlinx.coroutines.flow.* import kotlin.test.* import kotlin.time.* -class RSocketTest : SuspendTest { +class RSocketTest : SuspendTest, TestWithLeakCheck { lateinit var serverConnection: LocalConnection lateinit var clientConnection: LocalConnection @@ -55,15 +57,18 @@ class RSocketTest : SuspendTest { private suspend fun start(handler: RSocket? = null): RSocket = coroutineScope { launch { - serverConnection.startServer { + serverConnection.startServer( + RSocketServerConfiguration(loggerFactory = NoopLogger) + ) { handler ?: RSocketRequestHandler { requestResponse = { it } requestStream = { - flow { repeat(10) { emit(Payload("server got -> [$it]")) } } + it.release() + flow { repeat(10) { emit(payload("server got -> [$it]")) } } } requestChannel = { - it.launchIn(CoroutineScope(job)) - flow { repeat(10) { emit(Payload("server got -> [$it]")) } } + it.onEach { it.release() }.launchIn(CoroutineScope(job)) + flow { repeat(10) { emit(payload("server got -> [$it]")) } } } } } @@ -71,7 +76,7 @@ class RSocketTest : SuspendTest { clientConnection.connectClient( RSocketConnectorConfiguration( keepAlive = KeepAlive(1000.seconds, 1000.seconds), - loggerFactory = TestLoggerFactory + loggerFactory = NoopLogger ) ) } @@ -79,7 +84,7 @@ class RSocketTest : SuspendTest { @Test fun testRequestResponseNoError() = test { val requester = start() - requester.requestResponse(Payload("HELLO")) + requester.requestResponse(payload("HELLO")).release() } @Test @@ -87,7 +92,7 @@ class RSocketTest : SuspendTest { val requester = start(RSocketRequestHandler { requestResponse = { error("stub") } }) - assertFailsWith(RSocketError.ApplicationError::class) { requester.requestResponse(Payload("HELLO")) } + assertFailsWith(RSocketError.ApplicationError::class) { requester.requestResponse(payload("HELLO")) } } @Test @@ -95,23 +100,34 @@ class RSocketTest : SuspendTest { val requester = start(RSocketRequestHandler { requestResponse = { throw RSocketError.Custom(0x00000501, "stub") } }) - val error = assertFailsWith(RSocketError.Custom::class) { requester.requestResponse(Payload("HELLO")) } + val error = assertFailsWith(RSocketError.Custom::class) { requester.requestResponse(payload("HELLO")) } assertEquals(0x00000501, error.errorCode) } @Test fun testStream() = test { val requester = start() - val response = requester.requestStream(Payload.Empty).toList() - assertEquals(10, response.size) + requester.requestStream(Payload.Empty).test { + repeat(10) { + expectItem().release() + } + expectComplete() + } } @Test fun testChannel() = test { + val awaiter = Job() val requester = start() - val request = (1..10).asFlow().map { Payload(it.toString()) } - val response = requester.requestChannel(request).toList() - assertEquals(10, response.size) + val request = (1..10).asFlow().map { payload(it.toString()) }.onCompletion { awaiter.complete() } + requester.requestChannel(request).test { + repeat(10) { + expectItem().release() + } + expectComplete() + } + awaiter.join() + delay(500) } @Test @@ -132,26 +148,32 @@ class RSocketTest : SuspendTest { val requester = start(RSocketRequestHandler { requestChannel = { it.buffer(3).take(3) } }) - val request = (1..3).asFlow().map { Payload(it.toString()) } - val response = requester.requestChannel(request).buffer(3).toList() - assertEquals(3, response.size) + val request = (1..3).asFlow().map { payload(it.toString()) } + requester.requestChannel(request).buffer(3).test { + repeat(3) { + expectItem().release() + } + expectComplete() + } } - private val requesterPayloads = listOf( - Payload("d1", "m1"), - Payload("d2"), - Payload("d3", "m3"), - Payload("d4"), - Payload("d5", "m5") - ) - - private val responderPayloads = listOf( - Payload("rd1", "rm1"), - Payload("rd2"), - Payload("rd3", "rm3"), - Payload("rd4"), - Payload("rd5", "rm5") - ) + private val requesterPayloads + get() = listOf( + payload("d1", "m1"), + payload("d2"), + payload("d3", "m3"), + payload("d4"), + payload("d5", "m5") + ) + + private val responderPayloads + get() = listOf( + payload("rd1", "rm1"), + payload("rd2"), + payload("rd3", "rm3"), + payload("rd4"), + payload("rd5", "rm5") + ) @Test fun requestChannelCase_StreamIsTerminatedAfterBothSidesSentCompletion1() = test { @@ -231,17 +253,18 @@ class RSocketTest : SuspendTest { val requester = start(RSocketRequestHandler { requestChannel = { responderDeferred.complete(it.produceIn(CoroutineScope(job))) + responderSendChannel.consumeAsFlow() } }) val requesterReceiveChannel = requester.requestChannel(requesterSendChannel.consumeAsFlow()).produceIn(CoroutineScope(requester.job)) - requesterSendChannel.send(Payload("initData", "initMetadata")) + requesterSendChannel.send(payload("initData", "initMetadata")) val responderReceiveChannel = responderDeferred.await() - responderReceiveChannel.checkReceived(Payload("initData", "initMetadata")) + responderReceiveChannel.checkReceived(payload("initData", "initMetadata")) return requesterReceiveChannel to responderReceiveChannel } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt index ff9792ac0..46a0ad7bc 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt @@ -23,14 +23,13 @@ import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.internal.* import io.rsocket.kotlin.test.* import kotlinx.coroutines.* -import kotlinx.coroutines.flow.* import kotlin.test.* import kotlin.time.* -class KeepAliveTest : TestWithConnection() { +class KeepAliveTest : TestWithConnection(), TestWithLeakCheck { private fun requester(keepAlive: KeepAlive = KeepAlive(100.milliseconds, 1.seconds)): RSocket = run { - val state = RSocketState(connection, keepAlive) {} + val state = RSocketState(connection, keepAlive) val requester = RSocketRequester(state, StreamId.client()) state.start(RSocketRequestHandler { }) requester @@ -39,57 +38,68 @@ class KeepAliveTest : TestWithConnection() { @Test fun requesterSendKeepAlive() = test { requester() - val list = connection.sentAsFlow().take(3).toList() - assertEquals(3, list.size) - list.forEach { - assertTrue(it is KeepAliveFrame) - assertTrue(it.respond) + connection.test { + repeat(5) { + expectFrame { frame -> + assertTrue(frame is KeepAliveFrame) + assertTrue(frame.respond) + } + } } } @Test fun rSocketNotCanceledOnPresentKeepAliveTicks() = test { val rSocket = requester() - launch(connection.job) { - while (isActive) { + connection.launch { + repeat(50) { delay(100.milliseconds) connection.sendToReceiver(KeepAliveFrame(true, 0, ByteReadPacket.Empty)) } } delay(1.5.seconds) assertTrue(rSocket.isActive) + connection.test { + repeat(50) { + expectItem() + } + } } @Test fun requesterRespondsToKeepAlive() = test { requester(KeepAlive(100.seconds, 100.seconds)) - launch(connection.job) { + connection.launch { while (isActive) { delay(100.milliseconds) connection.sendToReceiver(KeepAliveFrame(true, 0, ByteReadPacket.Empty)) } } - val list = connection.sentAsFlow().take(3).toList() - assertEquals(3, list.size) - list.forEach { - assertTrue(it is KeepAliveFrame) - assertFalse(it.respond) + connection.test { + repeat(5) { + expectFrame { frame -> + assertTrue(frame is KeepAliveFrame) + assertFalse(frame.respond) + } + } } } @Test fun noKeepAliveSentAfterRSocketCanceled() = test { requester().cancel() - delay(500.milliseconds) - assertEquals(0, connection.sentFrames.size) + connection.test { + expectNoEventsIn(500) + } } @Test fun rSocketCanceledOnMissingKeepAliveTicks() = test { val rSocket = requester() - delay(1.5.seconds) - assertFalse(rSocket.isActive) + connection.test { + while (rSocket.isActive) kotlin.runCatching { expectItem() } + } assertTrue(rSocket.job.getCancellationException().cause is RSocketError.ConnectionError) } diff --git a/rsocket-test/build.gradle.kts b/rsocket-test/build.gradle.kts index ded9dccf2..281fea56e 100644 --- a/rsocket-test/build.gradle.kts +++ b/rsocket-test/build.gradle.kts @@ -16,6 +16,7 @@ plugins { kotlin("multiplatform") + id("kotlinx-atomicfu") } val ktorVersion: String by rootProject diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/InUseTrackingPool.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/InUseTrackingPool.kt new file mode 100644 index 000000000..bb3ed8555 --- /dev/null +++ b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/InUseTrackingPool.kt @@ -0,0 +1,63 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.test + +import io.ktor.utils.io.core.internal.* +import io.ktor.utils.io.pool.* +import kotlinx.atomicfu.* +import kotlin.test.* + +object InUseTrackingPool : ObjectPool { + override val capacity: Int get() = ChunkBuffer.Pool.capacity + private val inUse = atomic(0) + + override fun borrow(): ChunkBuffer { + inUse.incrementAndGet() + return ChunkBuffer.Pool.borrow() + } + + override fun recycle(instance: ChunkBuffer) { + inUse.decrementAndGet() + ChunkBuffer.Pool.recycle(instance) + } + + override fun dispose() { + ChunkBuffer.Pool.dispose() + } + + fun resetInUse() { + inUse.lazySet(0) + } + + fun assertNoInUse() { + val v = inUse.value + assertEquals(0, v, "Buffers in use") + } +} + +interface TestWithLeakCheck { + + @BeforeTest + fun resetInUse() { + InUseTrackingPool.resetInUse() + } + + @AfterTest + fun checkNoBuffersInUse() { + InUseTrackingPool.assertNoInUse() + } +} diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/Packets.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/Packets.kt new file mode 100644 index 000000000..ab89d0d97 --- /dev/null +++ b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/Packets.kt @@ -0,0 +1,47 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.test + +import io.ktor.util.* +import io.ktor.utils.io.core.* +import io.ktor.utils.io.core.internal.* +import io.ktor.utils.io.pool.* +import io.rsocket.kotlin.payload.* +import kotlin.test.* + +fun packet(text: String): ByteReadPacket = buildPacket(InUseTrackingPool) { writeText(text) } + +fun packet(array: ByteArray): ByteReadPacket = buildPacket(InUseTrackingPool) { writeFully(array) } + +fun payload(data: ByteArray, metadata: ByteArray? = null): Payload = Payload(packet(data), metadata?.let(::packet)) + +fun payload(data: String, metadata: String? = null): Payload = Payload(packet(data), metadata?.let(::packet)) + +fun assertBytesEquals(expected: ByteArray?, actual: ByteArray?) { + assertEquals(expected?.let(::hex), actual?.let(::hex)) +} + +private inline fun buildPacket(pool: ObjectPool, block: BytePacketBuilder.() -> Unit): ByteReadPacket { + val builder = BytePacketBuilder(0, pool) + try { + block(builder) + return builder.build() + } catch (t: Throwable) { + builder.release() + throw t + } +} diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestConnection.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestConnection.kt index 60f2ac923..796b01cff 100644 --- a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestConnection.kt +++ b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestConnection.kt @@ -16,42 +16,62 @@ package io.rsocket.kotlin.test +import app.cash.turbine.* import io.ktor.utils.io.core.* +import io.ktor.utils.io.core.internal.* +import io.ktor.utils.io.pool.* import io.rsocket.kotlin.connection.* import io.rsocket.kotlin.frame.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.flow.* +import kotlin.coroutines.* +import kotlin.time.* -class TestConnection : Connection { +class TestConnection : Connection, CoroutineScope { + override val pool: ObjectPool = InUseTrackingPool override val job: Job = Job() - private val sender = Channel(Channel.UNLIMITED) - private val receiver = Channel(Channel.UNLIMITED) + override val coroutineContext: CoroutineContext = job + Dispatchers.Unconfined - private val store = TestPacketStore() - val sentFrames: List get() = store.stored.map { it.copy().toFrame() } + private val sendChannel = Channel(Channel.UNLIMITED) + private val receiveChannel = Channel(Channel.UNLIMITED) init { job.invokeOnCompletion { - sender.close(it) - receiver.cancel(it?.let { it as? CancellationException ?: CancellationException("Connection completed") }) + sendChannel.close(it) + receiveChannel.cancel(it?.let { it as? CancellationException ?: CancellationException("Connection completed") }) } } override suspend fun send(packet: ByteReadPacket) { - sender.send(packet) - store.store(packet.copy()) + sendChannel.send(packet) } override suspend fun receive(): ByteReadPacket { - return receiver.receive() + return receiveChannel.receive() } - suspend fun receiveFromSender() = sender.receive().toFrame() - suspend fun sendToReceiver(vararg frames: Frame) { - frames.forEach { receiver.send(it.toPacket()) } + frames.forEach { receiveChannel.send(it.toPacket(InUseTrackingPool)) } } - fun sentAsFlow() = sender.receiveAsFlow().map { it.toFrame() } + private fun sentAsFlow(): Flow = sendChannel.receiveAsFlow().map { it.readFrame(InUseTrackingPool) } + + suspend fun test(validate: suspend FlowTurbine.() -> Unit) { + sentAsFlow().test(validate = validate) + } +} + +suspend fun FlowTurbine<*>.expectNoEventsIn(duration: Duration) { + delay(duration) + expectNoEvents() +} + +suspend fun FlowTurbine<*>.expectNoEventsIn(timeMillis: Long) { + delay(timeMillis) + expectNoEvents() +} + +suspend inline fun FlowTurbine.expectFrame(block: (frame: Frame) -> Unit) { + block(expectItem()) } diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestPacketStore.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestPacketStore.kt deleted file mode 100644 index 488fea129..000000000 --- a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestPacketStore.kt +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.test - -import io.ktor.utils.io.core.* - -expect class TestPacketStore() { - val stored: List - fun store(packet: ByteReadPacket) -} diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt index a084f4f5e..de939dc06 100644 --- a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt +++ b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt @@ -16,7 +16,6 @@ package io.rsocket.kotlin.test -import io.ktor.utils.io.core.* import io.rsocket.kotlin.* import io.rsocket.kotlin.core.* import io.rsocket.kotlin.keepalive.* @@ -39,7 +38,7 @@ abstract class TransportTest : SuspendTest { @Test fun fireAndForget10() = test { - (1..10).map { async { client.fireAndForget(Payload(it)) } }.awaitAll() + (1..10).map { async { client.fireAndForget(payload(it)) } }.awaitAll() } @Test @@ -65,14 +64,14 @@ abstract class TransportTest : SuspendTest { @Test fun requestChannel1() = test(10.seconds) { - val list = client.requestChannel(flowOf(Payload(0))).onEach { it.release() }.toList() + val list = client.requestChannel(flowOf(payload(0))).onEach { it.release() }.toList() assertEquals(1, list.size) } @Test fun requestChannel3() = test { val request = flow { - repeat(3) { emit(Payload(it)) } + repeat(3) { emit(payload(it)) } } val list = client.requestChannel(request).buffer(3).onEach { it.release() }.toList() assertEquals(3, list.size) @@ -90,7 +89,7 @@ abstract class TransportTest : SuspendTest { @Test fun requestChannel20000() = test(TransportTestLongDuration) { val request = flow { - repeat(20_000) { emit(Payload(7)) } + repeat(20_000) { emit(payload(7)) } } val list = client.requestChannel(request).buffer(Int.MAX_VALUE).onEach { assertEquals(MOCK_DATA, it.data.readText()) @@ -102,7 +101,7 @@ abstract class TransportTest : SuspendTest { @Test fun requestChannel200000() = test(TransportTestLongDuration) { val request = flow { - repeat(200_000) { emit(Payload(it)) } + repeat(200_000) { emit(payload(it)) } } val list = client.requestChannel(request).buffer(Int.MAX_VALUE).onEach { it.release() }.toList() assertEquals(200_000, list.size) @@ -112,7 +111,7 @@ abstract class TransportTest : SuspendTest { fun requestChannel256x512() = test(TransportTestLongDuration) { val request = flow { repeat(512) { - emit(Payload(it)) + emit(payload(it)) } } (0..256).map { @@ -125,17 +124,17 @@ abstract class TransportTest : SuspendTest { @Test fun requestResponse1() = test { - client.requestResponse(Payload(1)).let(Companion::checkPayload) + client.requestResponse(payload(1)).let(Companion::checkPayload) } @Test fun requestResponse10() = test { - (1..10).map { async { client.requestResponse(Payload(it)).let(Companion::checkPayload) } }.awaitAll() + (1..10).map { async { client.requestResponse(payload(it)).let(Companion::checkPayload) } }.awaitAll() } @Test fun requestResponse100() = test { - (1..100).map { async { client.requestResponse(Payload(it)).let(Companion::checkPayload) } }.awaitAll() + (1..100).map { async { client.requestResponse(payload(it)).let(Companion::checkPayload) } }.awaitAll() } @Test @@ -145,23 +144,23 @@ abstract class TransportTest : SuspendTest { @Test fun requestResponse10000() = test { - (1..10000).map { async { client.requestResponse(Payload(3)).let(Companion::checkPayload) } }.awaitAll() + (1..10000).map { async { client.requestResponse(payload(3)).let(Companion::checkPayload) } }.awaitAll() } @Test fun requestResponse100000() = test(TransportTestLongDuration) { - repeat(100000) { client.requestResponse(Payload(3)).let(Companion::checkPayload) } + repeat(100000) { client.requestResponse(payload(3)).let(Companion::checkPayload) } } @Test fun requestStream5() = test { - val list = client.requestStream(Payload(3)).onEach { checkPayload(it) }.buffer(5).take(5).toList() + val list = client.requestStream(payload(3)).buffer(5).take(5).onEach { checkPayload(it) }.toList() assertEquals(5, list.size) } @Test fun requestStream10000() = test { - val list = client.requestStream(Payload(3)).onEach { checkPayload(it) }.toList() + val list = client.requestStream(payload(3)).onEach { checkPayload(it) }.toList() assertEquals(10000, list.size) } @@ -174,21 +173,18 @@ abstract class TransportTest : SuspendTest { val MOCK_DATA: String = "test-data" val MOCK_METADATA: String = "metadata" val LARGE_DATA by lazy { readLargePayload("words.shakespeare.txt.gz") } - private val payload by lazy { Payload(LARGE_DATA, LARGE_DATA) } + private val payload by lazy { payload(LARGE_DATA, LARGE_DATA) } val LARGE_PAYLOAD get() = payload.copy() - private fun packet(text: String): ByteReadPacket = buildPacket { writeText(text) } - private fun readLargePayload(name: String): String = name.repeat(1000) - @Suppress("FunctionName") - private fun Payload(metadataPresent: Int): Payload { + private fun payload(metadataPresent: Int): Payload { val metadata = when (metadataPresent % 5) { 0 -> null 1 -> "" else -> MOCK_METADATA } - return Payload(MOCK_DATA, metadata) + return payload(MOCK_DATA, metadata) } fun checkPayload(payload: Payload) { diff --git a/rsocket-test/src/jvmMain/kotlin/io/rsocket/kotlin/test/TestPacketStore.kt b/rsocket-test/src/jvmMain/kotlin/io/rsocket/kotlin/test/TestPacketStore.kt deleted file mode 100644 index 33f6e2f3a..000000000 --- a/rsocket-test/src/jvmMain/kotlin/io/rsocket/kotlin/test/TestPacketStore.kt +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.test - -import io.ktor.utils.io.core.* - -actual class TestPacketStore actual constructor() { - private val _stored = mutableListOf() - - actual val stored: List get() = _stored - - actual fun store(packet: ByteReadPacket) { - _stored += packet - } -} diff --git a/rsocket-test/src/nativeMain/kotlin/io/rsocket/kotlin/test/TestPacketStore.kt b/rsocket-test/src/nativeMain/kotlin/io/rsocket/kotlin/test/TestPacketStore.kt deleted file mode 100644 index 2cad54faf..000000000 --- a/rsocket-test/src/nativeMain/kotlin/io/rsocket/kotlin/test/TestPacketStore.kt +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.test - -import io.ktor.utils.io.core.* -import kotlinx.atomicfu.* - -actual class TestPacketStore { - private val sentIndex = atomic(0) - private val _stored = atomicArrayOfNulls(100) //max 100 in cache - - actual val stored: List - get() = buildList { - repeat(sentIndex.value) { - add(_stored[it].value!!) - } - } - - actual fun store(packet: ByteReadPacket) { - _stored[sentIndex.getAndIncrement()].value = packet - } -} diff --git a/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/connection/KtorTcpConnection.kt b/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/connection/KtorTcpConnection.kt index 8e2f55872..adb67c990 100644 --- a/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/connection/KtorTcpConnection.kt +++ b/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/connection/KtorTcpConnection.kt @@ -28,7 +28,7 @@ import kotlin.coroutines.* val Socket.connection: Connection get() = KtorTcpConnection(this) -@OptIn(KtorExperimentalAPI::class, ExperimentalCoroutinesApi::class) +@OptIn(KtorExperimentalAPI::class) private class KtorTcpConnection(private val socket: Socket) : Connection, CoroutineScope { override val job: Job = Job(socket.socketContext) override val coroutineContext: CoroutineContext = job + Dispatchers.Unconfined diff --git a/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/connection/LocalConnection.kt b/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/connection/LocalConnection.kt index f54e2b85f..fce109cc2 100644 --- a/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/connection/LocalConnection.kt +++ b/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/connection/LocalConnection.kt @@ -17,14 +17,27 @@ package io.rsocket.kotlin.connection import io.ktor.utils.io.core.* +import io.ktor.utils.io.core.internal.* +import io.ktor.utils.io.pool.* import io.rsocket.kotlin.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* -class LocalConnection( +@OptIn(DangerousInternalIoApi::class) +fun LocalConnection( + name: String, + sender: Channel, + receiver: Channel, + parentJob: Job? = null, +): LocalConnection = LocalConnection(name, sender, receiver, ChunkBuffer.Pool, parentJob) + +class LocalConnection +@DangerousInternalIoApi +internal constructor( private val name: String, private val sender: Channel, private val receiver: Channel, + override val pool: ObjectPool, parentJob: Job? = null, ) : Connection, Cancelable { override val job: Job = Job(parentJob) @@ -38,16 +51,22 @@ class LocalConnection( } } +@OptIn(DangerousInternalIoApi::class) +@Suppress("FunctionName") +public fun SimpleLocalConnection(parentJob: Job? = null): Pair = + SimpleLocalConnection(ChunkBuffer.Pool, parentJob) + /** * Returns pair of client and server local connections */ @Suppress("FunctionName") -fun SimpleLocalConnection(parentJob: Job? = null): Pair { +@DangerousInternalIoApi +internal fun SimpleLocalConnection(pool: ObjectPool, parentJob: Job? = null): Pair { val clientChannel = Channel(Channel.UNLIMITED) val serverChannel = Channel(Channel.UNLIMITED) - val clientConnection = LocalConnection("client", serverChannel, clientChannel, parentJob) - val serverConnection = LocalConnection("server", clientChannel, serverChannel, parentJob) + val clientConnection = LocalConnection("client", serverChannel, clientChannel, pool, parentJob) + val serverConnection = LocalConnection("server", clientChannel, serverChannel, pool, parentJob) return clientConnection to serverConnection } diff --git a/rsocket-transport-local/src/commonTest/kotlin/io/rsocket/kotlin/LocalTransportTest.kt b/rsocket-transport-local/src/commonTest/kotlin/io/rsocket/kotlin/LocalTransportTest.kt index 493bc3cfb..ee4e39803 100644 --- a/rsocket-transport-local/src/commonTest/kotlin/io/rsocket/kotlin/LocalTransportTest.kt +++ b/rsocket-transport-local/src/commonTest/kotlin/io/rsocket/kotlin/LocalTransportTest.kt @@ -22,7 +22,7 @@ import io.rsocket.kotlin.test.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* -class LocalTransportTest : TransportTest() { +class LocalTransportTest : TransportTest(), TestWithLeakCheck { private val testJob: Job = Job() @@ -32,8 +32,8 @@ class LocalTransportTest : TransportTest() { val clientChannel = Channel(Channel.UNLIMITED) val serverChannel = Channel(Channel.UNLIMITED) - val clientConnection = LocalConnection("client", serverChannel, clientChannel, testJob) - val serverConnection = LocalConnection("server", clientChannel, serverChannel, testJob) + val clientConnection = LocalConnection("client", serverChannel, clientChannel, InUseTrackingPool, testJob) + val serverConnection = LocalConnection("server", clientChannel, serverChannel, InUseTrackingPool, testJob) client = coroutineScope { launch { serverConnection.startServer(SERVER_CONFIG, ACCEPTOR) From 26826c4ca4ed36e26c5e5931b79f7b68e23c70f5 Mon Sep 17 00:00:00 2001 From: Oleg Yukhnevich Date: Mon, 12 Oct 2020 18:50:45 +0300 Subject: [PATCH 2/9] fix test --- .../commonMain/kotlin/io/rsocket/kotlin/frame/RequestFrame.kt | 2 +- .../commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt | 4 ++-- .../commonTest/kotlin/io/rsocket/kotlin/LocalTransportTest.kt | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/RequestFrame.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/RequestFrame.kt index 7dda0db9c..53f7c9ad7 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/RequestFrame.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/RequestFrame.kt @@ -22,7 +22,7 @@ import io.ktor.utils.io.core.* import io.rsocket.kotlin.frame.io.* import io.rsocket.kotlin.payload.* -internal data class RequestFrame( +internal class RequestFrame( override val type: FrameType, override val streamId: Int, val follows: Boolean, diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt index de939dc06..85d601aea 100644 --- a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt +++ b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt @@ -172,8 +172,8 @@ abstract class TransportTest : SuspendTest { val MOCK_DATA: String = "test-data" val MOCK_METADATA: String = "metadata" - val LARGE_DATA by lazy { readLargePayload("words.shakespeare.txt.gz") } - private val payload by lazy { payload(LARGE_DATA, LARGE_DATA) } + val LARGE_DATA = readLargePayload("words.shakespeare.txt.gz") + private val payload = payload(LARGE_DATA, LARGE_DATA) val LARGE_PAYLOAD get() = payload.copy() private fun readLargePayload(name: String): String = name.repeat(1000) diff --git a/rsocket-transport-local/src/commonTest/kotlin/io/rsocket/kotlin/LocalTransportTest.kt b/rsocket-transport-local/src/commonTest/kotlin/io/rsocket/kotlin/LocalTransportTest.kt index ee4e39803..dbc3e31f6 100644 --- a/rsocket-transport-local/src/commonTest/kotlin/io/rsocket/kotlin/LocalTransportTest.kt +++ b/rsocket-transport-local/src/commonTest/kotlin/io/rsocket/kotlin/LocalTransportTest.kt @@ -43,7 +43,6 @@ class LocalTransportTest : TransportTest(), TestWithLeakCheck { } override suspend fun after() { - super.after() testJob.cancelAndJoin() } From 4b35e8003dcaf9704901b2135c042583403ed666 Mon Sep 17 00:00:00 2001 From: Oleg Yukhnevich Date: Wed, 14 Oct 2020 09:38:16 +0300 Subject: [PATCH 3/9] fix payload initialization --- .../io/rsocket/kotlin/test/TransportTest.kt | 11 ++++------- .../rsocket/kotlin/connection/LocalConnection.kt | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt index 85d601aea..38376bfd1 100644 --- a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt +++ b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt @@ -170,13 +170,10 @@ abstract class TransportTest : SuspendTest { val CONNECTOR_CONFIG = RSocketConnectorConfiguration(keepAlive = KeepAlive(10.minutes, 100.minutes), loggerFactory = NoopLogger) val SERVER_CONFIG = RSocketServerConfiguration(loggerFactory = NoopLogger) - val MOCK_DATA: String = "test-data" - val MOCK_METADATA: String = "metadata" - val LARGE_DATA = readLargePayload("words.shakespeare.txt.gz") - private val payload = payload(LARGE_DATA, LARGE_DATA) - val LARGE_PAYLOAD get() = payload.copy() - - private fun readLargePayload(name: String): String = name.repeat(1000) + const val MOCK_DATA: String = "test-data" + const val MOCK_METADATA: String = "metadata" + val LARGE_DATA = "large.text.12345".repeat(2000) + val LARGE_PAYLOAD get() = payload(LARGE_DATA, LARGE_DATA) private fun payload(metadataPresent: Int): Payload { val metadata = when (metadataPresent % 5) { diff --git a/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/connection/LocalConnection.kt b/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/connection/LocalConnection.kt index fce109cc2..7d2744c5a 100644 --- a/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/connection/LocalConnection.kt +++ b/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/connection/LocalConnection.kt @@ -42,6 +42,15 @@ internal constructor( ) : Connection, Cancelable { override val job: Job = Job(parentJob) + init { + job.invokeOnCompletion { + sender.closeReceivedElements() + receiver.closeReceivedElements() + sender.close(it) + receiver.close(it) + } + } + override suspend fun send(packet: ByteReadPacket) { sender.send(packet) } @@ -70,3 +79,10 @@ internal fun SimpleLocalConnection(pool: ObjectPool, parentJob: Job return clientConnection to serverConnection } + +private fun ReceiveChannel.closeReceivedElements() { + try { + while (true) poll()?.close() ?: break + } catch (e: Throwable) { + } +} From 0c7f1702de7c91b93e74c065d05e922681d66062 Mon Sep 17 00:00:00 2001 From: Oleg Yukhnevich Date: Wed, 14 Oct 2020 10:12:04 +0300 Subject: [PATCH 4/9] try to fix flacky test --- .../kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt index 46a0ad7bc..14f7df79b 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt @@ -50,17 +50,17 @@ class KeepAliveTest : TestWithConnection(), TestWithLeakCheck { @Test fun rSocketNotCanceledOnPresentKeepAliveTicks() = test { - val rSocket = requester() + val rSocket = requester(KeepAlive(200.milliseconds, 1.seconds)) connection.launch { - repeat(50) { - delay(100.milliseconds) + repeat(20) { + delay(200.milliseconds) connection.sendToReceiver(KeepAliveFrame(true, 0, ByteReadPacket.Empty)) } } delay(1.5.seconds) assertTrue(rSocket.isActive) connection.test { - repeat(50) { + repeat(20) { expectItem() } } From d5bbe68fcdb41f9d792c4f609fbeb4e6958e4c2c Mon Sep 17 00:00:00 2001 From: Oleg Yukhnevich Date: Wed, 14 Oct 2020 11:56:19 +0300 Subject: [PATCH 5/9] try to fix flaky test --- .../kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt index 14f7df79b..4feb9452c 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt @@ -50,17 +50,17 @@ class KeepAliveTest : TestWithConnection(), TestWithLeakCheck { @Test fun rSocketNotCanceledOnPresentKeepAliveTicks() = test { - val rSocket = requester(KeepAlive(200.milliseconds, 1.seconds)) + val rSocket = requester(KeepAlive(100.seconds, 100.seconds)) connection.launch { - repeat(20) { - delay(200.milliseconds) + repeat(50) { + delay(100.milliseconds) connection.sendToReceiver(KeepAliveFrame(true, 0, ByteReadPacket.Empty)) } } delay(1.5.seconds) assertTrue(rSocket.isActive) connection.test { - repeat(20) { + repeat(50) { expectItem() } } From 2a36c030dacb89cc08220134187f164b4b506b80 Mon Sep 17 00:00:00 2001 From: Oleg Yukhnevich Date: Wed, 14 Oct 2020 18:32:13 +0300 Subject: [PATCH 6/9] return back removed tests --- .../kotlin/internal/RSocketRequesterTest.kt | 41 +++++++++---------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt index 85e0f85c8..2104e6358 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt @@ -40,6 +40,13 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { state.start(RSocketRequestHandler { }) } + @Test + fun testInvalidFrameOnStream0() = test { + connection.sendToReceiver(NextPayloadFrame(0, payload("data", "metadata"))) //should be just released + delay(100) + assertTrue(requester.isActive) + } + @Test fun testStreamInitialN() = test { connection.test { @@ -205,35 +212,27 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { val job = Job() val request = flow { Job().join() }.onCompletion { job.complete() } val response = requester.requestChannel(request).launchIn(connection) - delay(100) - response.cancelAndJoin() - delay(200) - assertTrue(job.isCompleted) + connection.test { + expectNoEventsIn(200) + response.cancelAndJoin() + expectNoEventsIn(200) + assertTrue(job.isCompleted) + } } - // @Test + @Test fun testChannelRequestCancellationWithPayload() = test { val job = Job() val request = flow { repeat(100) { emit(Payload.Empty) } }.onCompletion { job.complete() } val response = requester.requestChannel(request).launchIn(connection) - delay(1000) - response.cancelAndJoin() - delay(100) - assertTrue(job.isCompleted) connection.test { - while (true) { - try { - expectItem() - } catch (e: TimeoutCancellationException) { - - } - } -// expectComplete() + expectFrame { assertTrue(it is RequestFrame) } + expectNoEventsIn(200) + response.cancelAndJoin() + expectFrame { assertTrue(it is CancelFrame) } + expectNoEventsIn(200) + assertTrue(job.isCompleted) } -// val sent = connection.sentFrames.size -// assertTrue(sent > 0) -// delay(100) -// assertEquals(sent, connection.sentFrames.size) } @Test //ignored on native because of coroutines bug with channels From 7434fb8923a8f1079b454da3aac3ed6709b75fb0 Mon Sep 17 00:00:00 2001 From: Oleg Yukhnevich Date: Wed, 14 Oct 2020 22:13:48 +0300 Subject: [PATCH 7/9] more tests --- .../io/rsocket/kotlin/internal/RSocketTest.kt | 105 +++++++++++++++++- 1 file changed, 99 insertions(+), 6 deletions(-) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketTest.kt index aa1a7bd5e..5f5d797bd 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketTest.kt @@ -107,7 +107,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { @Test fun testStream() = test { val requester = start() - requester.requestStream(Payload.Empty).test { + requester.requestStream(payload("HELLO")).test { repeat(10) { expectItem().release() } @@ -115,6 +115,99 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } } + @Test //ignored on native because of bug inside native coroutines + fun testStreamResponderError() = test(ignoreNative = true) { + var p: Payload? = null + val requester = start(RSocketRequestHandler { + requestStream = { + //copy payload, for some specific usage, and don't release original payload + val text = it.copy().use { it.data.readText() } + p = it + //don't use payload + flow { + emit(payload(text + "123")) + emit(payload(text + "456")) + emit(payload(text + "789")) + error("FAIL") + } + } + }) + requester.requestStream(payload("HELLO")).buffer(1).test { + repeat(3) { + expectItem().release() + } + val error = expectError() + assertTrue(error is RSocketError.ApplicationError) + assertEquals("FAIL", error.message) + } + delay(100) //async cancellation + assertEquals(0, p?.data?.remaining) + } + + @Test + fun testStreamRequesterError() = test { + val requester = start(RSocketRequestHandler { + requestStream = { + (0..100).asFlow().map { + payload(it.toString()) + } + } + }) + requester.requestStream(payload("HELLO")) + .buffer(10) + .withIndex() + .onEach { if (it.index == 23) throw error("oops") } + .map { it.value } + .test { + repeat(23) { + expectItem().release() + } + val error = expectError() + assertTrue(error is IllegalStateException) + assertEquals("oops", error.message) + } + } + + @Test + fun testStreamCancel() = test { + val requester = start(RSocketRequestHandler { + requestStream = { + (0..100).asFlow().map { + payload(it.toString()) + } + } + }) + requester.requestStream(payload("HELLO")) + .buffer(15) + .take(3) //canceled after 3 element + .test { + repeat(3) { + expectItem().release() + } + expectComplete() + } + } + + @Test + fun testStreamCancelWithChannel() = test { + val requester = start(RSocketRequestHandler { + requestStream = { + (0..100).asFlow().map { + payload(it.toString()) + } + } + }) + val channel = requester.requestStream(payload("HELLO")) + .buffer(5) + .take(18) //canceled after 18 element + .produceIn(this) + + repeat(18) { + channel.receive().release() + } + assertTrue(channel.receiveOrClosed().isClosed) + } + @Test fun testChannel() = test { val awaiter = Job() @@ -176,7 +269,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { ) @Test - fun requestChannelCase_StreamIsTerminatedAfterBothSidesSentCompletion1() = test { + fun requestChannelIsTerminatedAfterBothSidesSentCompletion1() = test { val requesterSendChannel = Channel(Channel.UNLIMITED) val responderSendChannel = Channel(Channel.UNLIMITED) val (requesterReceiveChannel, responderReceiveChannel) = initRequestChannel(requesterSendChannel, responderSendChannel) @@ -189,7 +282,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } @Test - fun requestChannelCase_StreamIsTerminatedAfterBothSidesSentCompletion2() = test { + fun requestChannelTerminatedAfterBothSidesSentCompletion2() = test { val requesterSendChannel = Channel(Channel.UNLIMITED) val responderSendChannel = Channel(Channel.UNLIMITED) val (requesterReceiveChannel, responderReceiveChannel) = initRequestChannel(requesterSendChannel, responderSendChannel) @@ -202,7 +295,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } @Test - fun requestChannelCase_CancellationFromResponderShouldLeaveStreamInHalfClosedStateWithNextCompletionPossibleFromRequester() = test { + fun requestChannelCancellationFromResponderShouldLeaveStreamInHalfClosedStateWithNextCompletionPossibleFromRequester() = test { val requesterSendChannel = Channel(Channel.UNLIMITED) val responderSendChannel = Channel(Channel.UNLIMITED) val (requesterReceiveChannel, responderReceiveChannel) = initRequestChannel(requesterSendChannel, responderSendChannel) @@ -215,7 +308,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } @Test - fun requestChannelCase_CompletionFromRequesterShouldLeaveStreamInHalfClosedStateWithNextCancellationPossibleFromResponder() = test { + fun requestChannelCompletionFromRequesterShouldLeaveStreamInHalfClosedStateWithNextCancellationPossibleFromResponder() = test { val requesterSendChannel = Channel(Channel.UNLIMITED) val responderSendChannel = Channel(Channel.UNLIMITED) val (requesterReceiveChannel, responderReceiveChannel) = initRequestChannel(requesterSendChannel, responderSendChannel) @@ -228,7 +321,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } @Test - fun requestChannelCase_ensureThatRequesterSubscriberCancellationTerminatesStreamsOnBothSides() = test { + fun requestChannelEnsureThatRequesterSubscriberCancellationTerminatesStreamsOnBothSides() = test { val requesterSendChannel = Channel(Channel.UNLIMITED) val responderSendChannel = Channel(Channel.UNLIMITED) val (requesterReceiveChannel, responderReceiveChannel) = initRequestChannel(requesterSendChannel, responderSendChannel) From 332ce36482cda17a80b2a39c342ccfc314a155b0 Mon Sep 17 00:00:00 2001 From: Oleg Yukhnevich Date: Wed, 14 Oct 2020 22:39:33 +0300 Subject: [PATCH 8/9] mark some more functions as internal remove suppress for some functions --- README.md | 8 ++-- .../src/clientMain/kotlin/PayloadWithRoute.kt | 1 - .../kotlin/io/rsocket/kotlin/frame/Frame.kt | 31 +++++++------- .../io/rsocket/kotlin/frame/FrameType.kt | 2 +- .../kotlin/io/rsocket/kotlin/frame/io/Dump.kt | 10 +++-- .../io/rsocket/kotlin/frame/io/Flags.kt | 4 +- .../io/rsocket/kotlin/frame/io/Version.kt | 9 ++-- .../io/rsocket/kotlin/frame/io/packet.kt | 1 - .../kotlin/internal/CloseOperations.kt | 41 +++++++++++++++++++ .../kotlin/internal/RSocketRequester.kt | 9 ---- .../rsocket/kotlin/internal/RSocketState.kt | 12 ------ .../io/rsocket/kotlin/payload/Payload.kt | 1 - .../rsocket/kotlin/payload/PayloadBuilder.kt | 1 - 13 files changed, 75 insertions(+), 55 deletions(-) create mode 100644 rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/CloseOperations.kt diff --git a/README.md b/README.md index cccdddf43..8722f09a0 100644 --- a/README.md +++ b/README.md @@ -39,10 +39,10 @@ RSocket interface contains 5 methods: `suspend fun metadataPush(metadata: ByteReadPacket)` ## Using in your projects -The `master` branch is now dedicated to development of multiplatform rsocket-kotlin. -For now only snapshots are available via [oss.jfrog.org](oss.jfrog.org) (OJO). +The `master` branch is now dedicated to development of multiplatform rsocket-kotlin. For now only snapshots are available +via [oss.jfrog.org](oss.jfrog.org) (OJO). -Make sure, that you use Kotlin 1.4. +Make sure, that you use Kotlin 1.4.X. ### Gradle: @@ -225,7 +225,7 @@ val bufferedStream: Flow = stream.buffer(10) //here buffer is 10, if `b bufferedStream.collect { payload: Payload -> println(payload.data.readText()) } -``` +``` ## Bugs and Feedback diff --git a/examples/multiplatform-chat/src/clientMain/kotlin/PayloadWithRoute.kt b/examples/multiplatform-chat/src/clientMain/kotlin/PayloadWithRoute.kt index 2f456ebeb..35dd6d19f 100644 --- a/examples/multiplatform-chat/src/clientMain/kotlin/PayloadWithRoute.kt +++ b/examples/multiplatform-chat/src/clientMain/kotlin/PayloadWithRoute.kt @@ -17,7 +17,6 @@ import io.ktor.utils.io.core.* import io.rsocket.kotlin.payload.* -@Suppress("FunctionName") fun Payload(route: String, packet: ByteReadPacket): Payload = Payload { data(packet) metadata(route) diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/Frame.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/Frame.kt index 7ad311df8..baeada3c0 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/Frame.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/Frame.kt @@ -17,12 +17,13 @@ package io.rsocket.kotlin.frame import io.ktor.utils.io.core.* +import io.ktor.utils.io.core.internal.* import io.rsocket.kotlin.frame.io.* private const val FlagsMask: Int = 1023 private const val FrameTypeShift: Int = 10 -abstract class Frame(open val type: FrameType) : Closeable { +abstract class Frame internal constructor(open val type: FrameType) : Closeable { abstract val streamId: Int abstract val flags: Int @@ -32,6 +33,7 @@ abstract class Frame(open val type: FrameType) : Closeable { protected abstract fun StringBuilder.appendFlags() protected abstract fun StringBuilder.appendSelf() + @DangerousInternalIoApi fun toPacket(pool: BufferPool): ByteReadPacket { check(type.canHaveMetadata || !(flags check Flags.Metadata)) { "bad value for metadata flag" } return buildPacket(pool) { @@ -41,7 +43,7 @@ abstract class Frame(open val type: FrameType) : Closeable { } } - fun dump(length: Long): String = buildString { + internal fun dump(length: Long): String = buildString { append("\n").append(type).append(" frame -> Stream Id: ").append(streamId).append(" Length: ").append(length) append("\nFlags: 0b").append(flags.toBinaryString()).append(" (").apply { appendFlags() }.append(")") appendSelf() @@ -57,30 +59,31 @@ abstract class Frame(open val type: FrameType) : Closeable { } } +@DangerousInternalIoApi fun ByteReadPacket.readFrame(pool: BufferPool): Frame = use { val streamId = readInt() val typeAndFlags = readShort().toInt() and 0xFFFF val flags = typeAndFlags and FlagsMask when (val type = FrameType(typeAndFlags shr FrameTypeShift)) { //stream id = 0 - FrameType.Setup -> readSetup(pool, flags) - FrameType.Resume -> readResume(pool) - FrameType.ResumeOk -> readResumeOk() + FrameType.Setup -> readSetup(pool, flags) + FrameType.Resume -> readResume(pool) + FrameType.ResumeOk -> readResumeOk() FrameType.MetadataPush -> readMetadataPush(pool) - FrameType.Lease -> readLease(pool, flags) - FrameType.KeepAlive -> readKeepAlive(pool, flags) + FrameType.Lease -> readLease(pool, flags) + FrameType.KeepAlive -> readKeepAlive(pool, flags) //stream id != 0 - FrameType.Cancel -> CancelFrame(streamId) - FrameType.Error -> readError(streamId) - FrameType.RequestN -> readRequestN(streamId) - FrameType.Extension -> readExtension(pool, streamId, flags) + FrameType.Cancel -> CancelFrame(streamId) + FrameType.Error -> readError(streamId) + FrameType.RequestN -> readRequestN(streamId) + FrameType.Extension -> readExtension(pool, streamId, flags) FrameType.Payload, FrameType.RequestFnF, FrameType.RequestResponse, - -> readRequest(pool, type, streamId, flags, withInitial = false) + -> readRequest(pool, type, streamId, flags, withInitial = false) FrameType.RequestStream, FrameType.RequestChannel, - -> readRequest(pool, type, streamId, flags, withInitial = true) - FrameType.Reserved -> error("Reserved") + -> readRequest(pool, type, streamId, flags, withInitial = true) + FrameType.Reserved -> error("Reserved") } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/FrameType.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/FrameType.kt index a0a8b37c7..8ee6c72a9 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/FrameType.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/FrameType.kt @@ -18,7 +18,7 @@ package io.rsocket.kotlin.frame import io.rsocket.kotlin.frame.io.* -enum class FrameType(val encodedType: Int, flags: Int = Flags.Empty) { +internal enum class FrameType(val encodedType: Int, flags: Int = Flags.Empty) { Reserved(0x00), //CONNECTION diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/Dump.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/Dump.kt index 44bb56699..ab0ca082c 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/Dump.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/Dump.kt @@ -23,8 +23,10 @@ import kotlin.native.concurrent.* @SharedImmutable private val digits = "0123456789abcdef".toCharArray() +@SharedImmutable private const val divider = "+--------+-------------------------------------------------+----------------+" +@SharedImmutable private const val header = """ +-------------------------------------------------+ | 0 1 2 3 4 5 6 7 8 9 a b c d e f | @@ -36,7 +38,7 @@ private const val header = """ //|00000000| 74 65 73 74 2d 64 61 74 61 20 74 65 73 74 2d 64 |test-data test-d| //|00000001| 61 74 61 20 74 65 73 74 2d 64 61 74 61 |ata test-data | //+--------+-------------------------------------------------+----------------+ -fun StringBuilder.appendPacket(packet: ByteReadPacket) { +internal fun StringBuilder.appendPacket(packet: ByteReadPacket) { var rowIndex = 0 var byteIndex = 0 @@ -90,7 +92,7 @@ fun StringBuilder.appendPacket(packet: ByteReadPacket) { append(divider) } -fun StringBuilder.appendPacket(tag: String, packet: ByteReadPacket) { +internal fun StringBuilder.appendPacket(tag: String, packet: ByteReadPacket) { append("\n").append(tag) if (packet.remaining > 0) { append("(length=").append(packet.remaining).append("):") @@ -100,12 +102,12 @@ fun StringBuilder.appendPacket(tag: String, packet: ByteReadPacket) { } } -fun StringBuilder.appendPayload(payload: Payload) { +internal fun StringBuilder.appendPayload(payload: Payload) { if (payload.metadata != null) appendPacket("Metadata", payload.metadata) appendPacket("Data", payload.data) } -fun Int.toBinaryString(): String { +internal fun Int.toBinaryString(): String { val string = toString(2) return "0".repeat(9 - string.length) + string } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/Flags.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/Flags.kt index 35e200c09..00932e1fc 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/Flags.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/Flags.kt @@ -16,7 +16,7 @@ package io.rsocket.kotlin.frame.io -object Flags { +internal object Flags { const val Ignore = 512 const val Metadata = 256 const val Follows = 128 @@ -24,4 +24,4 @@ object Flags { const val Next = 32 } -infix fun Int.check(flag: Int): Boolean = this and flag == flag +internal infix fun Int.check(flag: Int): Boolean = this and flag == flag diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/Version.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/Version.kt index 8ddeb418a..d5fdfafc1 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/Version.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/Version.kt @@ -18,11 +18,10 @@ package io.rsocket.kotlin.frame.io import io.ktor.utils.io.core.* -@Suppress("FunctionName") -fun Version(major: Int, minor: Int): Version = Version((major shl 16) or (minor and 0xFFFF)) +internal fun Version(major: Int, minor: Int): Version = Version((major shl 16) or (minor and 0xFFFF)) @Suppress("EXPERIMENTAL_FEATURE_WARNING") -inline class Version(val value: Int) { +internal inline class Version(val value: Int) { val major: Int get() = value shr 16 and 0xFFFF val minor: Int get() = value and 0xFFFF override fun toString(): String = "$major.$minor" @@ -32,8 +31,8 @@ inline class Version(val value: Int) { } } -fun ByteReadPacket.readVersion(): Version = Version(readInt()) +internal fun ByteReadPacket.readVersion(): Version = Version(readInt()) -fun BytePacketBuilder.writeVersion(version: Version) { +internal fun BytePacketBuilder.writeVersion(version: Version) { writeInt(version.value) } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/packet.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/packet.kt index 74001c6c0..5d9bb0c57 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/packet.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/io/packet.kt @@ -23,7 +23,6 @@ import io.ktor.utils.io.pool.* @OptIn(DangerousInternalIoApi::class) internal typealias BufferPool = ObjectPool -//TODO internal inline fun buildPacket(pool: BufferPool, block: BytePacketBuilder.() -> Unit): ByteReadPacket { val builder = BytePacketBuilder(0, pool) try { diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/CloseOperations.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/CloseOperations.kt new file mode 100644 index 000000000..791eaac5e --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/CloseOperations.kt @@ -0,0 +1,41 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal + +import io.ktor.utils.io.core.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* + +internal inline fun Closeable.closeOnError(block: () -> T): T { + try { + return block() + } catch (e: Throwable) { + close() + throw e + } +} + +internal fun ReceiveChannel<*>.cancelConsumed(cause: Throwable?) { + cancel(cause?.let { it as? CancellationException ?: CancellationException("Channel was consumed, consumer had failed", it) }) +} + +internal fun ReceiveChannel.closeReceivedElements() { + try { + while (true) poll()?.close() ?: break + } catch (e: Throwable) { + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt index d8b9cbacd..3df9046bc 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt @@ -69,12 +69,3 @@ internal class RSocketRequester( } } - -internal inline fun Closeable.closeOnError(block: () -> T): T { - try { - return block() - } catch (e: Throwable) { - close() - throw e - } -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt index 86617de80..79fc3d68f 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt @@ -16,7 +16,6 @@ package io.rsocket.kotlin.internal -import io.ktor.utils.io.core.* import io.rsocket.kotlin.* import io.rsocket.kotlin.connection.* import io.rsocket.kotlin.frame.* @@ -174,14 +173,3 @@ internal class RSocketState( return job } } - -internal fun ReceiveChannel<*>.cancelConsumed(cause: Throwable?) { - cancel(cause?.let { it as? CancellationException ?: CancellationException("Channel was consumed, consumer had failed", it) }) -} - -internal fun ReceiveChannel.closeReceivedElements() { - try { - while (true) poll()?.close() ?: break - } catch (e: Throwable) { - } -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/payload/Payload.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/payload/Payload.kt index f912b4091..8b263233f 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/payload/Payload.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/payload/Payload.kt @@ -44,7 +44,6 @@ fun Payload(data: String, metadata: String? = null): Payload = Payload( metadata = metadata?.let { buildPacket { writeText(it) } } ) -@Suppress("FunctionName") fun Payload(data: ByteArray, metadata: ByteArray? = null): Payload = Payload( data = buildPacket { writeFully(data) }, metadata = metadata?.let { buildPacket { writeFully(it) } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/payload/PayloadBuilder.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/payload/PayloadBuilder.kt index f4678a550..8e5e3e5c0 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/payload/PayloadBuilder.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/payload/PayloadBuilder.kt @@ -46,7 +46,6 @@ internal constructor() { } } -@Suppress("FunctionName") inline fun Payload(config: PayloadBuilder.() -> Unit): Payload { val builder = PayloadBuilder() try { From bf2278ca686cb38e49be7141cc049bda5885edc4 Mon Sep 17 00:00:00 2001 From: Oleg Yukhnevich Date: Sun, 18 Oct 2020 20:30:39 +0300 Subject: [PATCH 9/9] update readme link --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8722f09a0..1e62ad9a3 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ RSocket interface contains 5 methods: ## Using in your projects The `master` branch is now dedicated to development of multiplatform rsocket-kotlin. For now only snapshots are available -via [oss.jfrog.org](oss.jfrog.org) (OJO). +via [oss.jfrog.org](https://oss.jfrog.org/artifactory/oss-snapshot-local/io/rsocket/kotlin/) (OJO). Make sure, that you use Kotlin 1.4.X.