diff --git a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/DerReader.kt b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/DerReader.kt index 5d7ffc0247b0..e79f73c8a045 100644 --- a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/DerReader.kt +++ b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/der/DerReader.kt @@ -129,11 +129,22 @@ internal class DerReader(source: Source) { (length0 and 0b1000_0000) == 0b1000_0000 -> { // Length specified over multiple bytes. val lengthBytes = length0 and 0b0111_1111 + if (lengthBytes > 8) { + throw ProtocolException("Length encoded with more than 8 bytes is not supported") + } + var lengthBits = source.readByte().toLong() and 0xff + if (lengthBits == 0L || lengthBytes == 1 && lengthBits and 0b1000_0000 == 0L) { + throw ProtocolException("Invalid encoding for length") + } + for (i in 1 until lengthBytes) { lengthBits = lengthBits shl 8 lengthBits += source.readByte().toInt() and 0xff } + + if (lengthBits < 0) throw ProtocolException("Length > Long.MAX_VALUE is not supported") + lengthBits } else -> { diff --git a/okhttp-tls/src/test/java/okhttp3/tls/internal/der/DerTest.kt b/okhttp-tls/src/test/java/okhttp3/tls/internal/der/DerTest.kt index c5cb53f62e2b..66c60fe511c7 100644 --- a/okhttp-tls/src/test/java/okhttp3/tls/internal/der/DerTest.kt +++ b/okhttp-tls/src/test/java/okhttp3/tls/internal/der/DerTest.kt @@ -17,6 +17,7 @@ package okhttp3.tls.internal.der import java.math.BigInteger import java.net.InetAddress +import java.net.ProtocolException import java.text.SimpleDateFormat import java.util.Date import java.util.TimeZone @@ -31,6 +32,7 @@ import okio.ByteString.Companion.decodeHex import okio.ByteString.Companion.encodeUtf8 import okio.ByteString.Companion.toByteString import org.assertj.core.api.Assertions.assertThat +import org.junit.Assert.fail import org.junit.Test internal class DerTest { @@ -52,6 +54,107 @@ internal class DerTest { assertThat(derReader.hasNext()).isFalse() } + @Test fun `decode length encoded with leading zero byte`() { + val buffer = Buffer() + .writeByte(0b00000010) + .writeByte(0b10000010) + .writeByte(0b00000000) + .writeByte(0b01111111) + + val derReader = DerReader(buffer) + + try { + derReader.read("test") {} + fail() + } catch (e: ProtocolException) { + assertThat(e.message).isEqualTo("Invalid encoding for length") + } + } + + @Test fun `decode length not encoded in shortest form possible`() { + val buffer = Buffer() + .writeByte(0b00000010) + .writeByte(0b10000001) + .writeByte(0b01111111) + + val derReader = DerReader(buffer) + + try { + derReader.read("test") {} + fail() + } catch (e: ProtocolException) { + assertThat(e.message).isEqualTo("Invalid encoding for length") + } + } + + @Test fun `decode length equal to Long MAX_VALUE`() { + val buffer = Buffer() + .writeByte(0b00000010) + .writeByte(0b10001000) + .writeByte(0b01111111) + .writeByte(0b11111111) + .writeByte(0b11111111) + .writeByte(0b11111111) + .writeByte(0b11111111) + .writeByte(0b11111111) + .writeByte(0b11111111) + .writeByte(0b11111111) + + val derReader = DerReader(buffer) + + derReader.read("test") { header -> + assertThat(header.length).isEqualTo(Long.MAX_VALUE) + } + } + + @Test fun `decode length overflowing Long`() { + val buffer = Buffer() + .writeByte(0b00000010) + .writeByte(0b10001000) + .writeByte(0b10000000) + .writeByte(0b00000000) + .writeByte(0b00000000) + .writeByte(0b00000000) + .writeByte(0b00000000) + .writeByte(0b00000000) + .writeByte(0b00000000) + .writeByte(0b00000000) + + val derReader = DerReader(buffer) + + try { + derReader.read("test") {} + fail() + } catch (e: ProtocolException) { + assertThat(e.message).isEqualTo("Length > Long.MAX_VALUE is not supported") + } + } + + @Test fun `decode length encoded with more than 8 bytes`() { + val buffer = Buffer() + .writeByte(0b00000010) + .writeByte(0b10001001) + .writeByte(0b11111111) + .writeByte(0b11111111) + .writeByte(0b11111111) + .writeByte(0b11111111) + .writeByte(0b11111111) + .writeByte(0b11111111) + .writeByte(0b11111111) + .writeByte(0b11111111) + .writeByte(0b11111111) + .writeByte(0b11111111) + + val derReader = DerReader(buffer) + + try { + derReader.read("test") {} + fail() + } catch (e: ProtocolException) { + assertThat(e.message).isEqualTo("Length encoded with more than 8 bytes is not supported") + } + } + @Test fun `encode tag and length`() { val buffer = Buffer() val derWriter = DerWriter(buffer)