Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies {
testImplementation("org.jetbrains.kotlin:kotlin-test-junit5")
testImplementation("org.junit.jupiter:junit-jupiter-engine:5.9.3")
testImplementation("ch.qos.logback:logback-classic:1.4.11")
testImplementation("io.mockk:mockk:1.13.8")
testImplementation("io.kotest:kotest-runner-junit5:5.7.2")
testImplementation("io.kotest:kotest-runner-junit5-jvm:5.7.2")
testImplementation("io.kotest.extensions:kotest-extensions-testcontainers:2.0.2")
Expand Down
34 changes: 21 additions & 13 deletions src/main/kotlin/jamule/AmuleConnection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@ import java.io.IOException
import java.net.Socket

internal class AmuleConnection(
private val host: String,
private val port: Int,
private val timeout: Int,
private var socketBuilder: () -> Socket,
private val password: String,
private val logger: Logger
) {
private var socket = Socket(host, port).apply { soTimeout = timeout }
private var connected = false
private var socket = socketBuilder()

internal constructor(
host: String,
port: Int,
timeout: Int,
password: String,
logger: Logger
) : this({ Socket(host, port).apply { soTimeout = timeout } }, password, logger)

@OptIn(ExperimentalUnsignedTypes::class)
private val tagParser = TagParser(logger)
Expand All @@ -42,7 +48,7 @@ internal class AmuleConnection(
logger.info("Reconnecting...")
connected = false
runCatching { socket.close() }
socket = Socket(host, port).apply { soTimeout = timeout }
socket = socketBuilder()
authenticate()
}
}
Expand All @@ -59,14 +65,16 @@ internal class AmuleConnection(

@OptIn(ExperimentalUnsignedTypes::class)
fun sendRequestNoAuth(request: Request): Response {
val outputStream = socket.getOutputStream()
val inputStream = socket.getInputStream().buffered()
val packet = request.packet()
packetWriter.write(packet, outputStream)
val responsePacket = packetParser.parse(inputStream)
return ResponseParser.parse(responsePacket).also {
if (it is ErrorResponse) {
throw ServerException(it.serverMessage)
synchronized(socket) {
val outputStream = socket.getOutputStream()
val inputStream = socket.getInputStream().buffered()
val packet = request.packet()
packetWriter.write(packet, outputStream)
val responsePacket = packetParser.parse(inputStream)
return ResponseParser.parse(responsePacket).also {
if (it is ErrorResponse) {
throw ServerException(it.serverMessage)
}
}
}
}
Expand Down
83 changes: 83 additions & 0 deletions src/test/kotlin/jamule/AmuleConnectionTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package jamule

import io.kotest.core.spec.style.FunSpec
import io.kotest.matchers.shouldBe
import io.mockk.every
import io.mockk.mockk
import io.mockk.verify
import jamule.request.StatsRequest
import org.slf4j.LoggerFactory
import java.io.ByteArrayInputStream
import java.io.OutputStream
import java.net.Socket
import java.util.concurrent.CountDownLatch

@OptIn(ExperimentalStdlibApi::class)
class AmuleConnectionTest : FunSpec({

val socket = mockk<Socket>()
val logger = LoggerFactory.getLogger(this::class.java)
val outputStream = OutputStream.nullOutputStream()
every { socket.getOutputStream() } returns outputStream
every { socket.close() } returns Unit
val authSaltResponse = ByteArrayInputStream("000000220000000d4f0116050855099a4aea510c43".hexToByteArray())
val authOkResponse =
ByteArrayInputStream("000000220000001d0401e0a8960616322e332e31204164756e616e7a4120323031322e3100".hexToByteArray())
val statusResponse = ByteArrayInputStream(
("000000220000008c0c10d08003021664d082020100d484020100d4860302" +
"1664d488020100d48a020100d084020100d086020100d09002010" +
"0d08c020100d092040400017cbbd09402010ad096040402e2740f" +
"d09803020438d0b60201000b023f03e0a881081f01e0a88206124" +
"16b74656f6e20536572766572204e6f3200b07de76247b50c0404" +
"1d4e48541404041d4e485419")
.hexToByteArray()
)

test("single request works ok") {
val amule = AmuleConnection({ socket }, "password", logger)
every { socket.getInputStream() } returnsMany listOf(
authSaltResponse,
authOkResponse,
statusResponse
)
amule.sendRequest(StatsRequest())
// Called 3 times: 1 for salt, 1 for auth, 1 for stats
verify(exactly = 3) { socket.getOutputStream() }
}

test("multiple parallel requests are synchronised") {
val amule = AmuleConnection({ socket }, "password", logger)
val firstRequestArrivedLatch = CountDownLatch(1)
val firstRequestLatch = CountDownLatch(1)
val secondRequestLatch = CountDownLatch(1)
var requestCount = 0
every { socket.getInputStream() } answers {
when (requestCount++) {
0 -> authSaltResponse
1 -> authOkResponse
2 -> {
firstRequestArrivedLatch.countDown()
firstRequestLatch.await()
statusResponse
}

3 -> {
secondRequestLatch.await()
statusResponse
}

else -> throw IllegalStateException("Unexpected request count: $requestCount")
}.also { requestCount++ }
}
// Send two requests from two separate threads
Thread { amule.sendRequest(StatsRequest()) }.start()
Thread { amule.sendRequest(StatsRequest()) }.start()
// Wait for the first request to arrive
firstRequestArrivedLatch.await()
Thread.sleep(50) // Allow for the second request to arrive if it's not synchronised
requestCount shouldBe 3
firstRequestLatch.countDown()
secondRequestLatch.countDown()
}

})