Skip to content

Commit

Permalink
Merge pull request #1338 from pedroSG94/feature/srt-passphrase
Browse files Browse the repository at this point in the history
Feature/srt passphrase
  • Loading branch information
pedroSG94 committed Nov 14, 2023
2 parents 85b3982 + 67a5e53 commit 8f9f5ac
Show file tree
Hide file tree
Showing 10 changed files with 243 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ dependencies {
- [X] Get upload bandwidth used.
- [X] H264, H265 and AAC support.
- [X] Resend lost packets
- [X] Encrypt (AES128, AES192 and AES256)
- [ ] SRT auth.
- [ ] Encrypt

https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.pedro.library.util.streamclient

import com.pedro.srt.srt.SrtClient
import com.pedro.srt.srt.packets.control.handshake.EncryptionType

/**
* Created by pedro on 12/10/23.
Expand All @@ -10,6 +11,13 @@ class SrtStreamClient(
streamClientListener: StreamClientListener?
): StreamBaseClient(streamClientListener) {

/**
* Set passphrase for encrypt. Use empty value to disable it.
*/
fun setPassphrase(passphrase: String, type: EncryptionType) {
srtClient.setPassphrase(passphrase, type)
}

override fun setAuthorization(user: String?, password: String?) {
srtClient.setAuthorization(user, password)
}
Expand Down
42 changes: 31 additions & 11 deletions srt/src/main/java/com/pedro/srt/srt/CommandsManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@ import com.pedro.srt.mpeg2ts.MpegTsPacket
import com.pedro.srt.srt.packets.DataPacket
import com.pedro.srt.srt.packets.SrtPacket
import com.pedro.srt.srt.packets.control.Ack2
import com.pedro.srt.srt.packets.control.KeepAlive
import com.pedro.srt.srt.packets.control.Shutdown
import com.pedro.srt.srt.packets.control.handshake.EncryptionType
import com.pedro.srt.srt.packets.control.handshake.Handshake
import com.pedro.srt.utils.EncryptInfo
import com.pedro.srt.srt.packets.data.KeyBasedEncryption
import com.pedro.srt.utils.Constants
import com.pedro.srt.utils.EncryptionUtil
import com.pedro.srt.utils.SrtSocket
import com.pedro.srt.utils.TimeUtils
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import java.io.IOException
import java.net.NetworkInterface
import kotlin.random.Random

/**
Expand All @@ -51,6 +55,19 @@ class CommandsManager {
var host = ""
//Avoid write a packet in middle of other.
private val writeSync = Mutex(locked = false)
private var encryptor: EncryptionUtil? = null

fun setPassphrase(passphrase: String, type: EncryptionType) {
encryptor = if (passphrase.isEmpty() || type == EncryptionType.NONE) null else EncryptionUtil(type, passphrase)
}

fun getEncryptInfo(): EncryptInfo? {
return encryptor?.getEncryptInfo()
}

fun getEncryptType(): EncryptionType {
return encryptor?.type ?: EncryptionType.NONE
}

fun loadStartTs() {
startTS = TimeUtils.getCurrentTimeMicro()
Expand Down Expand Up @@ -88,13 +105,15 @@ class CommandsManager {
writeSync.withLock {
if (sequenceNumber.toUInt() > 0x7FFFFFFFu) sequenceNumber = 0
val dataPacket = DataPacket(
sequenceNumber = sequenceNumber++,
encryption = if (encryptor != null) KeyBasedEncryption.PAIR_KEY else KeyBasedEncryption.NONE,
sequenceNumber = sequenceNumber,
packetPosition = packet.packetPosition,
messageNumber = messageNumber++,
payload = packet.buffer,
payload = encryptor?.encrypt(packet.buffer, sequenceNumber) ?: packet.buffer,
ts = getTs(),
socketId = socketId
)
sequenceNumber++
packetHandlingQueue.add(dataPacket)
dataPacket.write()
socket?.write(dataPacket)
Expand Down Expand Up @@ -139,6 +158,15 @@ class CommandsManager {
}
}

@Throws(IOException::class)
suspend fun writeKeepAlive(socket: SrtSocket?) {
writeSync.withLock {
val keepAlive = KeepAlive()
keepAlive.write(getTs(), socketId)
socket?.write(keepAlive)
}
}

fun reset() {
sequenceNumber = generateInitialSequence()
messageNumber = 1
Expand All @@ -152,12 +180,4 @@ class CommandsManager {
private fun generateInitialSequence(): Int {
return Random.nextInt(0, Int.MAX_VALUE)
}

private fun List<NetworkInterface>.findAddress(): List<String?> = this.asSequence()
.map { addresses -> addresses.inetAddresses.asSequence() }
.flatten()
.filter { address -> !address.isLoopbackAddress }
.map { it.hostAddress }
.filter { address -> address?.contains(":") == false }
.toList()
}
19 changes: 17 additions & 2 deletions srt/src/main/java/com/pedro/srt/srt/SrtClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import com.pedro.srt.srt.packets.control.KeepAlive
import com.pedro.srt.srt.packets.control.Nak
import com.pedro.srt.srt.packets.control.PeerError
import com.pedro.srt.srt.packets.control.Shutdown
import com.pedro.srt.srt.packets.control.handshake.EncryptionType
import com.pedro.srt.srt.packets.control.handshake.ExtensionField
import com.pedro.srt.srt.packets.control.handshake.Handshake
import com.pedro.srt.srt.packets.control.handshake.HandshakeType
Expand Down Expand Up @@ -98,6 +99,18 @@ class SrtClient(private val connectCheckerSrt: ConnectCheckerSrt) {
TODO("unimplemented")
}

/**
* Set passphrase for encrypt. Use empty value to disable it.
*/
fun setPassphrase(passphrase: String, type: EncryptionType) {
if (!isStreaming) {
if (passphrase.length < 10 || passphrase.length > 79) {
throw IllegalArgumentException("passphrase must between 10 and 79 length")
}
commandsManager.setPassphrase(passphrase, type)
}
}

/**
* Must be called before connect
*/
Expand Down Expand Up @@ -173,13 +186,15 @@ class SrtClient(private val connectCheckerSrt: ConnectCheckerSrt) {
val response = commandsManager.readHandshake(socket)

commandsManager.writeHandshake(socket, response.copy(
encryption = commandsManager.getEncryptType(),
extensionField = ExtensionField.HS_REQ.value or ExtensionField.CONFIG.value,
handshakeType = HandshakeType.CONCLUSION,
handshakeExtension = HandshakeExtension(
flags = ExtensionContentFlag.TSBPDSND.value or ExtensionContentFlag.TSBPDRCV.value or
ExtensionContentFlag.CRYPT.value or ExtensionContentFlag.TLPKTDROP.value or
ExtensionContentFlag.PERIODICNAK.value or ExtensionContentFlag.REXMITFLG.value,
path = path
path = path,
encryptInfo = commandsManager.getEncryptInfo()
)))
val responseConclusion = commandsManager.readHandshake(socket)
if (responseConclusion.isErrorType()) {
Expand Down Expand Up @@ -299,7 +314,7 @@ class SrtClient(private val connectCheckerSrt: ConnectCheckerSrt) {
//never should happens, handshake is already done
}
is KeepAlive -> {

commandsManager.writeKeepAlive(socket)
}
is Ack -> {
val ackSequence = srtPacket.typeSpecificInformation
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.pedro.srt.srt.packets.control.handshake.extension

/**
* Created by pedro on 13/11/23.
*/
enum class CipherType(val value: Int) {
NONE(0), ECB(1), CTR(2), CBC(3), GCM(4)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.pedro.srt.srt.packets.control.handshake.extension

import com.pedro.srt.srt.packets.SrtPacket
import com.pedro.srt.utils.EncryptInfo
import com.pedro.srt.utils.writeUInt16
import com.pedro.srt.utils.writeUInt32

Expand All @@ -29,7 +30,8 @@ data class HandshakeExtension(
private val flags: Int = ExtensionContentFlag.REXMITFLG.value or ExtensionContentFlag.CRYPT.value,
private val receiverDelay: Int = 120,
private val senderDelay: Int = 0,
private val path: String = ""
private val path: String = "",
private val encryptInfo: EncryptInfo? = null
): SrtPacket() {

fun write() {
Expand All @@ -45,6 +47,14 @@ data class HandshakeExtension(
val data = fixPathData(path.toByteArray(Charsets.UTF_8))
buffer.writeUInt16(data.size / 4)
buffer.write(data)
//encrypted info
if (encryptInfo != null) {
buffer.writeUInt16(ExtensionType.SRT_CMD_KM_REQ.value)
val keyMaterialMessage = KeyMaterialMessage(encryptInfo)
val encryptedData = keyMaterialMessage.getData()
buffer.writeUInt16(encryptedData.size / 4)
buffer.write(encryptedData)
}
}

private fun getVersionData(version: String): ByteArray {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package com.pedro.srt.srt.packets.control.handshake.extension

import com.pedro.srt.srt.packets.data.KeyBasedEncryption
import com.pedro.srt.utils.EncryptInfo
import com.pedro.srt.utils.writeUInt16
import com.pedro.srt.utils.writeUInt32
import java.io.ByteArrayOutputStream

/**
* Created by pedro on 13/11/23.
*
* 0 1 2 3
* 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
* |S| V | PT | Sign | Resv1 | KK|
* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
* | KEKI |
* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
* | Cipher | Auth | SE | Resv2 |
* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
* | Resv3 | SLen/4 | KLen/4 |
* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
* | Salt |
* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
* | |
* + Wrapped Key +
* | |
* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*
*/
class KeyMaterialMessage(
private val encryptInfo: EncryptInfo,
private val streamEncapsulation: StreamEncapsulationType = StreamEncapsulationType.MPEG_TS_SRT
) {

fun getData(): ByteArray {
val buffer = ByteArrayOutputStream()
val sVersionPacketType = (0 shl 7) or (1 shl 4) or 2
buffer.write(sVersionPacketType)
//Sign
buffer.write(0x20)
buffer.write(0x29)
val resv1KeyBasedEncryption = 0 shl 2 or encryptInfo.keyBasedEncryption.value
buffer.write(resv1KeyBasedEncryption)
buffer.writeUInt32(0) //keki
buffer.write(encryptInfo.cipher.value)
buffer.write(if (encryptInfo.cipher == CipherType.GCM) 1 else 0) //auth
buffer.write(streamEncapsulation.value) //SE
buffer.write(0) // resv2
buffer.writeUInt16(0) // resv3
buffer.write(encryptInfo.salt.size / 4)
buffer.write(encryptInfo.keyLength / 4)
buffer.write(encryptInfo.salt)
buffer.write(encryptInfo.key)
return buffer.toByteArray()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.pedro.srt.srt.packets.control.handshake.extension

/**
* Created by pedro on 13/11/23.
*/
enum class StreamEncapsulationType(val value: Int) {
Unspecified(0), MPEG_TS_UDP(1), MPEG_TS_SRT(2)
}
15 changes: 15 additions & 0 deletions srt/src/main/java/com/pedro/srt/utils/EncryptInfo.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.pedro.srt.utils

import com.pedro.srt.srt.packets.control.handshake.extension.CipherType
import com.pedro.srt.srt.packets.data.KeyBasedEncryption

/**
* Created by pedro on 13/11/23.
*/
data class EncryptInfo(
val keyBasedEncryption: KeyBasedEncryption,
val cipher: CipherType,
val salt: ByteArray,
val key: ByteArray,
val keyLength: Int
)
87 changes: 87 additions & 0 deletions srt/src/main/java/com/pedro/srt/utils/EncryptionUtil.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package com.pedro.srt.utils

import com.pedro.srt.srt.packets.control.handshake.EncryptionType
import com.pedro.srt.srt.packets.control.handshake.extension.CipherType
import com.pedro.srt.srt.packets.data.KeyBasedEncryption
import java.nio.ByteBuffer
import java.security.SecureRandom
import javax.crypto.Cipher
import javax.crypto.SecretKeyFactory
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.PBEKeySpec
import javax.crypto.spec.SecretKeySpec
import kotlin.experimental.xor

/**
* Created by pedro on 12/11/23.
* Need API 26+
*
*/
class EncryptionUtil(val type: EncryptionType, passphrase: String) {

//only pair is developed, odd is not supported for now
private val keyBasedEncryption = KeyBasedEncryption.PAIR_KEY
private val cipherType = CipherType.CTR
private val salt: ByteArray
private val sek: ByteArray
private val keyLength: Int = when (type) {
EncryptionType.NONE -> 0
EncryptionType.AES128 -> 16
EncryptionType.AES192 -> 24
EncryptionType.AES256 -> 32
}
private var keyData = byteArrayOf()

init {
salt = generateSecureRandomBytes(16)
sek = generateSecureRandomBytes(keyLength)
val kek = calculateKEK(passphrase, salt, keyLength)
keyData = wrapKey(kek, sek)
}

fun encrypt(bytes: ByteArray, sequence: Int): ByteArray {
val ctr = ByteArray(16)
ByteBuffer.wrap(ctr, 10, 4).putInt(sequence)
for (i in 0 until 14) ctr[i] = (ctr[i] xor salt[i])

val block = SecretKeySpec(sek, "AES")
val cipher = Cipher.getInstance("AES/CTR/NoPadding")
cipher.init(Cipher.ENCRYPT_MODE, block, IvParameterSpec(ctr))
return cipher.doFinal(bytes)
}

fun getEncryptInfo(): EncryptInfo {
return EncryptInfo(
keyBasedEncryption = keyBasedEncryption,
cipher = cipherType,
salt = salt,
key = keyData,
keyLength = keyLength
)
}

private fun generateSecureRandomBytes(length: Int): ByteArray {
val secureRandom = SecureRandom()
val randomBytes = ByteArray(length)
secureRandom.nextBytes(randomBytes)
return randomBytes
}

/**
* Wrap key RFC 3394
*/
private fun wrapKey(kek: ByteArray, keyToWrap: ByteArray): ByteArray {
val cipher = Cipher.getInstance("AESWrap")
val secretKek = SecretKeySpec(kek, "AES")
cipher.init(Cipher.WRAP_MODE, secretKek)
val secret = SecretKeySpec(keyToWrap, "AES")
return cipher.wrap(secret)
}

/**
* generate Pbkdf2 key
*/
private fun calculateKEK(passphrase: String, salt: ByteArray, keyLength: Int): ByteArray {
return SecretKeyFactory.getInstance("PBKDF2WithHmacSHA1").generateSecret(PBEKeySpec(passphrase.toCharArray(), salt.sliceArray(8 until salt.size), 2048, keyLength * 8)).encoded
}
}

0 comments on commit 8f9f5ac

Please sign in to comment.