From 29e226fdccfe073017e23abb56a75aa46298b9e1 Mon Sep 17 00:00:00 2001 From: Luca Vitucci Date: Sun, 10 Dec 2023 01:45:14 +0100 Subject: [PATCH] Synchronise the Amule Connection to avoid multiple requests to intersect --- build.gradle.kts | 1 + src/main/kotlin/jamule/AmuleConnection.kt | 34 +++++--- src/test/kotlin/jamule/AmuleConnectionTest.kt | 83 +++++++++++++++++++ 3 files changed, 105 insertions(+), 13 deletions(-) create mode 100644 src/test/kotlin/jamule/AmuleConnectionTest.kt diff --git a/build.gradle.kts b/build.gradle.kts index bdb4244..32848ba 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -20,6 +20,7 @@ dependencies { testImplementation("org.jetbrains.kotlin:kotlin-test-junit5") testImplementation("org.junit.jupiter:junit-jupiter-engine:5.9.3") testImplementation("ch.qos.logback:logback-classic:1.4.11") + testImplementation("io.mockk:mockk:1.13.8") testImplementation("io.kotest:kotest-runner-junit5:5.7.2") testImplementation("io.kotest:kotest-runner-junit5-jvm:5.7.2") testImplementation("io.kotest.extensions:kotest-extensions-testcontainers:2.0.2") diff --git a/src/main/kotlin/jamule/AmuleConnection.kt b/src/main/kotlin/jamule/AmuleConnection.kt index b1d3fc7..282d1d8 100644 --- a/src/main/kotlin/jamule/AmuleConnection.kt +++ b/src/main/kotlin/jamule/AmuleConnection.kt @@ -16,14 +16,20 @@ import java.io.IOException import java.net.Socket internal class AmuleConnection( - private val host: String, - private val port: Int, - private val timeout: Int, + private var socketBuilder: () -> Socket, private val password: String, private val logger: Logger ) { - private var socket = Socket(host, port).apply { soTimeout = timeout } private var connected = false + private var socket = socketBuilder() + + internal constructor( + host: String, + port: Int, + timeout: Int, + password: String, + logger: Logger + ) : this({ Socket(host, port).apply { soTimeout = timeout } }, password, logger) @OptIn(ExperimentalUnsignedTypes::class) private val tagParser = TagParser(logger) @@ -42,7 +48,7 @@ internal class AmuleConnection( logger.info("Reconnecting...") connected = false runCatching { socket.close() } - socket = Socket(host, port).apply { soTimeout = timeout } + socket = socketBuilder() authenticate() } } @@ -59,14 +65,16 @@ internal class AmuleConnection( @OptIn(ExperimentalUnsignedTypes::class) fun sendRequestNoAuth(request: Request): Response { - val outputStream = socket.getOutputStream() - val inputStream = socket.getInputStream().buffered() - val packet = request.packet() - packetWriter.write(packet, outputStream) - val responsePacket = packetParser.parse(inputStream) - return ResponseParser.parse(responsePacket).also { - if (it is ErrorResponse) { - throw ServerException(it.serverMessage) + synchronized(socket) { + val outputStream = socket.getOutputStream() + val inputStream = socket.getInputStream().buffered() + val packet = request.packet() + packetWriter.write(packet, outputStream) + val responsePacket = packetParser.parse(inputStream) + return ResponseParser.parse(responsePacket).also { + if (it is ErrorResponse) { + throw ServerException(it.serverMessage) + } } } } diff --git a/src/test/kotlin/jamule/AmuleConnectionTest.kt b/src/test/kotlin/jamule/AmuleConnectionTest.kt new file mode 100644 index 0000000..000edce --- /dev/null +++ b/src/test/kotlin/jamule/AmuleConnectionTest.kt @@ -0,0 +1,83 @@ +package jamule + +import io.kotest.core.spec.style.FunSpec +import io.kotest.matchers.shouldBe +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import jamule.request.StatsRequest +import org.slf4j.LoggerFactory +import java.io.ByteArrayInputStream +import java.io.OutputStream +import java.net.Socket +import java.util.concurrent.CountDownLatch + +@OptIn(ExperimentalStdlibApi::class) +class AmuleConnectionTest : FunSpec({ + + val socket = mockk() + val logger = LoggerFactory.getLogger(this::class.java) + val outputStream = OutputStream.nullOutputStream() + every { socket.getOutputStream() } returns outputStream + every { socket.close() } returns Unit + val authSaltResponse = ByteArrayInputStream("000000220000000d4f0116050855099a4aea510c43".hexToByteArray()) + val authOkResponse = + ByteArrayInputStream("000000220000001d0401e0a8960616322e332e31204164756e616e7a4120323031322e3100".hexToByteArray()) + val statusResponse = ByteArrayInputStream( + ("000000220000008c0c10d08003021664d082020100d484020100d4860302" + + "1664d488020100d48a020100d084020100d086020100d09002010" + + "0d08c020100d092040400017cbbd09402010ad096040402e2740f" + + "d09803020438d0b60201000b023f03e0a881081f01e0a88206124" + + "16b74656f6e20536572766572204e6f3200b07de76247b50c0404" + + "1d4e48541404041d4e485419") + .hexToByteArray() + ) + + test("single request works ok") { + val amule = AmuleConnection({ socket }, "password", logger) + every { socket.getInputStream() } returnsMany listOf( + authSaltResponse, + authOkResponse, + statusResponse + ) + amule.sendRequest(StatsRequest()) + // Called 3 times: 1 for salt, 1 for auth, 1 for stats + verify(exactly = 3) { socket.getOutputStream() } + } + + test("multiple parallel requests are synchronised") { + val amule = AmuleConnection({ socket }, "password", logger) + val firstRequestArrivedLatch = CountDownLatch(1) + val firstRequestLatch = CountDownLatch(1) + val secondRequestLatch = CountDownLatch(1) + var requestCount = 0 + every { socket.getInputStream() } answers { + when (requestCount++) { + 0 -> authSaltResponse + 1 -> authOkResponse + 2 -> { + firstRequestArrivedLatch.countDown() + firstRequestLatch.await() + statusResponse + } + + 3 -> { + secondRequestLatch.await() + statusResponse + } + + else -> throw IllegalStateException("Unexpected request count: $requestCount") + }.also { requestCount++ } + } + // Send two requests from two separate threads + Thread { amule.sendRequest(StatsRequest()) }.start() + Thread { amule.sendRequest(StatsRequest()) }.start() + // Wait for the first request to arrive + firstRequestArrivedLatch.await() + Thread.sleep(50) // Allow for the second request to arrive if it's not synchronised + requestCount shouldBe 3 + firstRequestLatch.countDown() + secondRequestLatch.countDown() + } + +})