Skip to content

Commit

Permalink
util: Make ProxyByteReader and ProxyByteWriter Public
Browse files Browse the repository at this point in the history
Problem

`ProxyByteReader` is incredibly useful for reading specialized information out
of a `Buf` while still benefitting from all the already existing
goodness that comes for free from the `ByteReader` implementation.
However, it is not currently accessible to code outside of the
`com.twitter` namespace.

Solution

Make `ProxyByteReader` and `ProxyByteWriter` both public, and modify
`ProxyByteReader`s structure so that it aligns with `ProxyByteWriter`.

Differential Revision: https://phabricator.twitter.biz/D622705
  • Loading branch information
ryanoneill authored and jenkins committed Feb 24, 2021
1 parent ca67793 commit ad73f92
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 174 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ New Features
Also introduced `Backoff.fromStream(Stream)` and `Backoff.toStream` to help with migration to
the new API. ``PHAB_ID=D592562``

Breaking API Changes
~~~~~~~~~~~~~~~~~~~~

* finagle-mysql: The constructor of `c.t.f.mysql.transport.MysqlBufReader` now takes an underlying
`c.t.io.ByteReader`. Prior uses of the constructor, which took a `c.t.io.Buf`, should migrate to
using `c.t.f.mysql.transport.MysqlBufReader.apply` instead. ``PHAB_ID=D622705``

Runtime Behavior Changes
~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ private[finagle] final class ClientDispatcher(
case Some(byte) =>
val isBinaryEncoded = cmd != Command.COM_QUERY
val numCols = Try {
val br = new MysqlBufReader(packet.body)
val br = MysqlBufReader(packet.body)
br.readVariableLong().toInt
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package com.twitter.finagle.mysql.transport

import com.twitter.io.{Buf, BufByteWriter, ByteReader, ProxyByteReader, ProxyByteWriter}
import java.nio.charset.{StandardCharsets, Charset => JCharset}
import scala.collection.mutable.{Buffer => MutableBuffer}
import com.twitter.io.Buf

/**
* MysqlBuf provides convenience methods for reading/writing a logical packet
Expand All @@ -11,11 +9,11 @@ import scala.collection.mutable.{Buffer => MutableBuffer}
object MysqlBuf {
val NullLength: Int = -1 // denotes a SQL NULL value when reading a length coded binary.

def reader(buf: Buf): MysqlBufReader = new MysqlBufReader(buf)
def reader(buf: Buf): MysqlBufReader = MysqlBufReader(buf)

def reader(bytes: Array[Byte]): MysqlBufReader = reader(Buf.ByteArray.Owned(bytes))
def reader(bytes: Array[Byte]): MysqlBufReader = MysqlBufReader(bytes)

def writer(bytes: Array[Byte]): MysqlBufWriter = new MysqlBufWriter(BufByteWriter(bytes))
def writer(bytes: Array[Byte]): MysqlBufWriter = MysqlBufWriter(bytes)

/**
* Calculates the size required to store a length
Expand All @@ -38,170 +36,3 @@ object MysqlBuf {
}
}
}

class MysqlBufReader(buf: Buf) extends ProxyByteReader {
import MysqlBuf._

protected val reader: ByteReader = ByteReader(buf)

/**
* Take `n` bytes as a byte array
*/
def take(n: Int): Array[Byte] = {
Buf.ByteArray.Owned.extract(readBytes(n))
}

/**
* Reads bytes until a null byte is encountered
*/
def readNullTerminatedBytes(): Array[Byte] = {
val bytes = MutableBuffer[Byte]()
var eof = false
do {
val b = readByte()
if (b == 0x00) {
eof = true
} else {
bytes += b
}
} while (!eof)
bytes.toArray
}

/**
* Reads a null-terminated UTF-8 encoded string
*/
def readNullTerminatedString(): String =
new String(readNullTerminatedBytes(), StandardCharsets.UTF_8)

/**
* Reads a length encoded set of bytes according to the MySQL
* Client/Server protocol. This is identical to a length coded
* string except the bytes are returned raw.
*
* @return Array[Byte] if length is non-null, or null otherwise.
*/
def readLengthCodedBytes(): Array[Byte] = {
readVariableLong() match {
case NullLength => null
case 0 => Array.emptyByteArray
case len if len > Int.MaxValue =>
throw new IllegalStateException(s"Length-encoded byte size is too large: $len")
case len => Buf.ByteArray.Owned.extract(readBytes(len.toInt))
}
}

/**
* Reads a length encoded string according to the MySQL
* Client/Server protocol. Uses `charset` to decode the string.
* For more details refer to MySQL documentation.
*
* @return a MySQL length coded String starting at
* offset.
*/
def readLengthCodedString(charset: JCharset): String = {
val bytes = readLengthCodedBytes()
if (bytes != null) {
new String(bytes, charset)
} else {
null
}
}

/**
* Reads a variable-length numeric value.
* Depending on the first byte, reads a different width from
* the buffer. For more info, refer to MySQL Client/Server protocol
* documentation.
*
* @return a numeric value representing the number of
* bytes expected to follow.
*/
def readVariableLong(): Long = {
readUnsignedByte() match {
case len if len < 251 => len
case 251 => NullLength
case 252 => readUnsignedShortLE()
case 253 => readUnsignedMediumLE()
case 254 =>
val longValue = readLongLE()
if (longValue < 0)
throw new IllegalStateException(s"Negative length-encoded value: $longValue")
longValue

case len => throw new IllegalStateException(s"Invalid length byte: $len")
}
}
}

class MysqlBufWriter(underlying: BufByteWriter)
extends ProxyByteWriter(underlying)
with BufByteWriter {

/**
* Writes `b` to the buffer `num` times
*/
def fill(num: Int, b: Byte): MysqlBufWriter = {
var i = 0
while (i < num) {
writeByte(b)
i += 1
}
this
}

/**
* Writes a variable length integer according the the MySQL
* Client/Server protocol. Refer to MySQL documentation for
* more information.
*/
def writeVariableLong(length: Long): MysqlBufWriter = {
if (length < 0) throw new IllegalStateException(s"Negative length-encoded integer: $length")
if (length < 251) {
writeByte(length.toInt)
} else if (length < 65536) {
writeByte(252)
writeShortLE(length.toInt)
} else if (length < 16777216) {
writeByte(253)
writeMediumLE(length.toInt)
} else {
writeByte(254)
writeLongLE(length)
}
this
}

/**
* Writes a null terminated string onto the buffer encoded as UTF-8
*
* @param s String to write.
*/
def writeNullTerminatedString(s: String): MysqlBufWriter = {
writeBytes(s.getBytes(StandardCharsets.UTF_8))
writeByte(0x00)
this
}

/**
* Writes a length coded string using the MySQL Client/Server
* protocol in the given charset.
*
* @param s String to write to buffer.
*/
def writeLengthCodedString(s: String, charset: JCharset): MysqlBufWriter = {
writeLengthCodedBytes(s.getBytes(charset))
}

/**
* Writes a length coded set of bytes according to the MySQL
* client/server protocol.
*/
def writeLengthCodedBytes(bytes: Array[Byte]): MysqlBufWriter = {
writeVariableLong(bytes.length)
writeBytes(bytes)
this
}

def owned(): Buf = underlying.owned()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package com.twitter.finagle.mysql.transport

import com.twitter.io.{Buf, ByteReader, ProxyByteReader}
import java.nio.charset.{StandardCharsets, Charset => JCharset}
import scala.collection.mutable.{Buffer => MutableBuffer}

/**
* A `ByteReader` specialized for dealing with MySQL protocol messages.
*/
class MysqlBufReader(underlying: ByteReader) extends ProxyByteReader(underlying) {
import MysqlBuf._

/**
* Take `n` bytes as a byte array
*/
def take(n: Int): Array[Byte] = {
Buf.ByteArray.Owned.extract(readBytes(n))
}

/**
* Reads bytes until a null byte is encountered
*/
def readNullTerminatedBytes(): Array[Byte] = {
val bytes = MutableBuffer[Byte]()
var eof = false
do {
val b = readByte()
if (b == 0x00) {
eof = true
} else {
bytes += b
}
} while (!eof)
bytes.toArray
}

/**
* Reads a null-terminated UTF-8 encoded string
*/
def readNullTerminatedString(): String =
new String(readNullTerminatedBytes(), StandardCharsets.UTF_8)

/**
* Reads a length encoded set of bytes according to the MySQL
* Client/Server protocol. This is identical to a length coded
* string except the bytes are returned raw.
*
* @return Array[Byte] if length is non-null, or null otherwise.
*/
def readLengthCodedBytes(): Array[Byte] = {
readVariableLong() match {
case NullLength => null
case 0 => Array.emptyByteArray
case len if len > Int.MaxValue =>
throw new IllegalStateException(s"Length-encoded byte size is too large: $len")
case len => Buf.ByteArray.Owned.extract(readBytes(len.toInt))
}
}

/**
* Reads a length encoded string according to the MySQL
* Client/Server protocol. Uses `charset` to decode the string.
* For more details refer to MySQL documentation.
*
* @return a MySQL length coded String starting at
* offset.
*/
def readLengthCodedString(charset: JCharset): String = {
val bytes = readLengthCodedBytes()
if (bytes != null) {
new String(bytes, charset)
} else {
null
}
}

/**
* Reads a variable-length numeric value.
* Depending on the first byte, reads a different width from
* the buffer. For more info, refer to MySQL Client/Server protocol
* documentation.
*
* @return a numeric value representing the number of
* bytes expected to follow.
*/
def readVariableLong(): Long = {
readUnsignedByte() match {
case len if len < 251 => len
case 251 => NullLength
case 252 => readUnsignedShortLE()
case 253 => readUnsignedMediumLE()
case 254 =>
val longValue = readLongLE()
if (longValue < 0)
throw new IllegalStateException(s"Negative length-encoded value: $longValue")
longValue

case len => throw new IllegalStateException(s"Invalid length byte: $len")
}
}
}

object MysqlBufReader {

/**
* Create a new [[MysqlBufReader]] from an existing `Buf`.
*/
def apply(buf: Buf): MysqlBufReader =
new MysqlBufReader(ByteReader(buf))

/**
* Create a new [[MysqlBufReader]] from an array of bytes.
*/
def apply(bytes: Array[Byte]): MysqlBufReader =
apply(Buf.ByteArray.Owned(bytes))
}
Loading

0 comments on commit ad73f92

Please sign in to comment.