Skip to content

Commit

Permalink
migrate nodejs tcp transport to new API
Browse files Browse the repository at this point in the history
  • Loading branch information
whyoleg committed Apr 12, 2024
1 parent c11cdf7 commit ad2f644
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,18 @@ internal fun ByteReadPacket.withLength(): ByteReadPacket = buildPacket {
writePacket(this@withLength)
}

internal class FrameWithLengthAssembler(private val onFrame: (frame: ByteReadPacket) -> Unit) {
private var expectedFrameLength = 0 //TODO atomic for native
internal class FrameWithLengthAssembler(private val onFrame: (frame: ByteReadPacket) -> Unit) : Closeable {
private var closed = false
private var expectedFrameLength = 0
private val packetBuilder: BytePacketBuilder = BytePacketBuilder()

override fun close() {
packetBuilder.close()
closed = true
}

inline fun write(write: BytePacketBuilder.() -> Unit) {
if (closed) return
packetBuilder.write()
loop()
}
Expand All @@ -39,6 +47,7 @@ internal class FrameWithLengthAssembler(private val onFrame: (frame: ByteReadPac
expectedFrameLength = it.readInt24()
if (it.remaining >= expectedFrameLength) build(it) // if has length and frame
}

packetBuilder.size < expectedFrameLength -> return // not enough bytes to read frame
else -> withTemp { build(it) } // enough bytes to read frame
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright 2015-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.rsocket.kotlin.transport.nodejs.tcp

import io.rsocket.kotlin.internal.io.*
import io.rsocket.kotlin.transport.*
import io.rsocket.kotlin.transport.nodejs.tcp.internal.*
import kotlinx.coroutines.*
import kotlin.coroutines.*

public sealed interface NodejsTcpClientTransport : RSocketTransport {
public fun target(host: String, port: Int): RSocketClientTarget

public companion object Factory :
RSocketTransportFactory<NodejsTcpClientTransport, NodejsTcpClientTransportBuilder>(::NodejsTcpClientTransportBuilderImpl)
}

public sealed interface NodejsTcpClientTransportBuilder : RSocketTransportBuilder<NodejsTcpClientTransport> {
public fun dispatcher(context: CoroutineContext)
public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext)
}

private class NodejsTcpClientTransportBuilderImpl : NodejsTcpClientTransportBuilder {
private var dispatcher: CoroutineContext = Dispatchers.Default

override fun dispatcher(context: CoroutineContext) {
check(context[Job] == null) { "Dispatcher shouldn't contain job" }
this.dispatcher = context
}

@RSocketTransportApi
override fun buildTransport(context: CoroutineContext): NodejsTcpClientTransport = NodejsTcpClientTransportImpl(
coroutineContext = context.supervisorContext() + dispatcher,
)
}

private class NodejsTcpClientTransportImpl(
override val coroutineContext: CoroutineContext,
) : NodejsTcpClientTransport {
override fun target(host: String, port: Int): RSocketClientTarget = NodejsTcpClientTargetImpl(
coroutineContext = coroutineContext.supervisorContext(),
host = host,
port = port
)
}

private class NodejsTcpClientTargetImpl(
override val coroutineContext: CoroutineContext,
private val host: String,
private val port: Int,
) : RSocketClientTarget {
@RSocketTransportApi
override fun connectClient(handler: RSocketConnectionHandler): Job = launch {
val socket = connect(port, host)
handler.handleNodejsTcpConnection(socket)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright 2015-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.rsocket.kotlin.transport.nodejs.tcp

import io.ktor.utils.io.core.*
import io.ktor.utils.io.js.*
import io.rsocket.kotlin.internal.io.*
import io.rsocket.kotlin.transport.*
import io.rsocket.kotlin.transport.internal.*
import io.rsocket.kotlin.transport.nodejs.tcp.internal.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import org.khronos.webgl.*

@RSocketTransportApi
internal suspend fun RSocketConnectionHandler.handleNodejsTcpConnection(socket: Socket): Unit = coroutineScope {
val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED)
val inbound = channelForCloseable<ByteReadPacket>(Channel.UNLIMITED)

val closed = CompletableDeferred<Unit>()
val frameAssembler = FrameWithLengthAssembler { inbound.trySend(it) }
socket.on(
onData = { frameAssembler.write { writeFully(it.buffer) } },
onError = { closed.completeExceptionally(it) },
onClose = {
frameAssembler.close()
if (!it) closed.complete(Unit)
}
)

val writerJob = launch {
while (true) socket.writeFrame(outboundQueue.dequeueFrame() ?: break)
}.onCompletion { outboundQueue.cancel() }

try {
handleConnection(NodejsTcpConnection(outboundQueue, inbound))
} finally {
inbound.cancel()
outboundQueue.close() // will cause `writerJob` completion
// even if it was cancelled, we still need to close socket and await it closure
withContext(NonCancellable) {
writerJob.join()
// close socket
socket.destroy()
closed.join()
}
}
}

@RSocketTransportApi
private class NodejsTcpConnection(
private val outboundQueue: PrioritizationFrameQueue,
private val inbound: ReceiveChannel<ByteReadPacket>,
) : RSocketSequentialConnection {
override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend
override suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) {
return outboundQueue.enqueueFrame(streamId, frame)
}

override suspend fun receiveFrame(): ByteReadPacket? {
return inbound.receiveCatching().getOrNull()
}
}

private fun Socket.writeFrame(frame: ByteReadPacket) {
val packet = buildPacket {
writeInt24(frame.remaining.toInt())
writePacket(frame)
}
write(Uint8Array(packet.readArrayBuffer()))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright 2015-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.rsocket.kotlin.transport.nodejs.tcp

import io.rsocket.kotlin.internal.io.*
import io.rsocket.kotlin.transport.*
import io.rsocket.kotlin.transport.nodejs.tcp.internal.*
import kotlinx.coroutines.*
import kotlin.coroutines.*

public sealed interface NodejsTcpServerInstance : RSocketServerInstance {
public val host: String
public val port: Int
}

public sealed interface NodejsTcpServerTransport : RSocketTransport {
public fun target(host: String, port: Int): RSocketServerTarget<NodejsTcpServerInstance>

public companion object Factory :
RSocketTransportFactory<NodejsTcpServerTransport, NodejsTcpServerTransportBuilder>({ NodejsTcpServerTransportBuilderImpl })
}

public sealed interface NodejsTcpServerTransportBuilder : RSocketTransportBuilder<NodejsTcpServerTransport> {
public fun dispatcher(context: CoroutineContext)
public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext)
}

private object NodejsTcpServerTransportBuilderImpl : NodejsTcpServerTransportBuilder {
private var dispatcher: CoroutineContext = Dispatchers.Default

override fun dispatcher(context: CoroutineContext) {
check(context[Job] == null) { "Dispatcher shouldn't contain job" }
this.dispatcher = context
}

@RSocketTransportApi
override fun buildTransport(context: CoroutineContext): NodejsTcpServerTransport = NodejsTcpServerTransportImpl(
coroutineContext = context.supervisorContext() + dispatcher,
)
}

private class NodejsTcpServerTransportImpl(
override val coroutineContext: CoroutineContext,
) : NodejsTcpServerTransport {
override fun target(host: String, port: Int): RSocketServerTarget<NodejsTcpServerInstance> = NodejsTcpServerTargetImpl(
coroutineContext = coroutineContext.supervisorContext(),
host = host,
port = port
)
}

private class NodejsTcpServerTargetImpl(
override val coroutineContext: CoroutineContext,
private val host: String,
private val port: Int,
) : RSocketServerTarget<NodejsTcpServerInstance> {

@RSocketTransportApi
override suspend fun startServer(handler: RSocketConnectionHandler): NodejsTcpServerInstance {
currentCoroutineContext().ensureActive()
coroutineContext.ensureActive()

val serverJob = launch {
val handlerScope = CoroutineScope(coroutineContext.supervisorContext())
val server = createServer(port, host, {
coroutineContext.job.cancel("Server closed")
}) {
handlerScope.launch { handler.handleNodejsTcpConnection(it) }
}
try {
awaitCancellation()
} finally {
suspendCoroutine { cont -> server.close { cont.resume(Unit) } }
}
}

return NodejsTcpServerInstanceImpl(
coroutineContext = coroutineContext + serverJob,
host = host,
port = port
)
}
}

@RSocketTransportApi
private class NodejsTcpServerInstanceImpl(
override val coroutineContext: CoroutineContext,
override val host: String,
override val port: Int,
) : NodejsTcpServerInstance
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,11 @@ class TcpTransportTest : TransportTest() {
server.close()
}
}

class NodejsTcpTransportTest : TransportTest() {
override suspend fun before() {
val port = PortProvider.next()
startServer(NodejsTcpServerTransport(testContext).target("127.0.0.1", port))
client = connectClient(NodejsTcpClientTransport(testContext).target("127.0.0.1", port))
}
}

0 comments on commit ad2f644

Please sign in to comment.