Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support unsigned integer #618

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,29 @@ object CanBeParameter {
}
}

implicit val bigIntCanBeParameter = {
new CanBeParameter[BigInt] {
def sizeOf(param: BigInt) = 8
def typeCode(param: BigInt) = Type.LongLong
def write(writer: MysqlBufWriter, param: BigInt) = {
val byteArray: Array[Byte] = param.toByteArray
val lengthOfByteArray: Int = byteArray.length

if (lengthOfByteArray > 8) {
throw new BigIntTooLongException(size = lengthOfByteArray)
}

for (i <- (lengthOfByteArray - 1) to 0 by -1) {
writer.writeByte(byteArray(i))
}

for (i <- lengthOfByteArray until 8) {
writer.writeByte(0x0)
}
}
}
}

implicit val floatCanBeParameter = {
new CanBeParameter[Float] {
def sizeOf(param: Float) = 4
Expand Down Expand Up @@ -182,3 +205,5 @@ object CanBeParameter {
}
}
}

class BigIntTooLongException(size: Int) extends Exception(s"BigInt is stored as Unsigned Long, thus it cannot be longer than 8 bytes. Size = $size")
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,22 @@ object EOF extends Decoder[EOF] {

case class EOF(warnings: Short, serverStatus: ServerStatus) extends Result

/**
* These bit masks are to understand whether corresponding attribute
* is set for the field. Link to source code from mysql is below.
* [[https://github.com/mysql/mysql-server/blob/5.7/include/mysql_com.h]]
*/
object FieldAttributes {
val NotNullBitMask: Short = 1
val PrimaryKeyBitMask: Short = 2
val UniqueKeyBitMask: Short = 4
val MultipleKeyBitMask: Short = 8
val BlobBitMask: Short = 16
val UnsignedBitMask: Short = 32
val ZeroFillBitMask: Short = 64
val BinaryBitMask: Short = 128
}

/**
* Represents the column meta-data associated with a query.
* Sent during ResultSet transmission and as part of the
Expand Down Expand Up @@ -207,6 +223,9 @@ case class Field(
) extends Result {
def id: String = if (name.isEmpty) origName else name
override val toString = "Field(%s)".format(id)

def isUnsigned: Boolean = (flags & FieldAttributes.UnsignedBitMask) > 0
def isSigned: Boolean = !isUnsigned
}

/**
Expand Down
58 changes: 38 additions & 20 deletions finagle-mysql/src/main/scala/com/twitter/finagle/mysql/Row.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.twitter.finagle.mysql

import com.twitter.finagle.mysql.transport.MysqlBuf
import com.twitter.finagle.mysql.transport.{MysqlBuf, MysqlBufReader}
import com.twitter.io.Buf

/**
Expand All @@ -20,6 +20,14 @@ trait Row {
/** The values for this Row. */
val values: IndexedSeq[Value]

/** The value is to consider unsigned flag of field or not */
val ignoreUnsigned: Boolean

@inline
def isSigned(field: Field): Boolean = {
ignoreUnsigned || field.isSigned
}

/**
* Retrieves the index of the column with the given
* name.
Expand Down Expand Up @@ -50,7 +58,7 @@ trait Row {
* text-based protocol.
* [[http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow]]
*/
class StringEncodedRow(rawRow: Buf, val fields: IndexedSeq[Field], indexMap: Map[String, Int]) extends Row {
class StringEncodedRow(rawRow: Buf, val fields: IndexedSeq[Field], indexMap: Map[String, Int], val ignoreUnsigned: Boolean = true) extends Row {
private val reader = MysqlBuf.reader(rawRow)

/**
Expand All @@ -71,14 +79,19 @@ class StringEncodedRow(rawRow: Buf, val fields: IndexedSeq[Field], indexMap: Map
else {
val str = new String(bytes, Charset(charset))
field.fieldType match {
case Type.Tiny => ByteValue(str.toByte)
case Type.Short => ShortValue(str.toShort)
case Type.Int24 => IntValue(str.toInt)
case Type.Long => IntValue(str.toInt)
case Type.LongLong => LongValue(str.toLong)
case Type.Float => FloatValue(str.toFloat)
case Type.Double => DoubleValue(str.toDouble)
case Type.Year => ShortValue(str.toShort)
case Type.Tiny if isSigned(field) => ByteValue(str.toByte)
case Type.Tiny => ShortValue(str.toShort)
case Type.Short if isSigned(field) => ShortValue(str.toShort)
case Type.Short => IntValue(str.toInt)
case Type.Int24 if isSigned(field) => IntValue(str.toInt)
case Type.Int24 => IntValue(str.toInt)
case Type.Long if isSigned(field) => IntValue(str.toInt)
case Type.Long => LongValue(str.toLong)
case Type.LongLong if isSigned(field) => LongValue(str.toLong)
case Type.LongLong => BigIntValue(BigInt(str))
case Type.Float => FloatValue(str.toFloat)
case Type.Double => DoubleValue(str.toDouble)
case Type.Year => ShortValue(str.toShort)
// Nonbinary strings as stored in the CHAR, VARCHAR, and TEXT data types
case Type.VarChar | Type.String | Type.VarString |
Type.TinyBlob | Type.Blob | Type.MediumBlob
Expand All @@ -100,8 +113,8 @@ class StringEncodedRow(rawRow: Buf, val fields: IndexedSeq[Field], indexMap: Map
* mysql binary protocol.
* [[http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html]]
*/
class BinaryEncodedRow(rawRow: Buf, val fields: IndexedSeq[Field], indexMap: Map[String, Int]) extends Row {
private val reader = MysqlBuf.reader(rawRow)
class BinaryEncodedRow(rawRow: Buf, val fields: IndexedSeq[Field], indexMap: Map[String, Int], val ignoreUnsigned: Boolean = true) extends Row {
private val reader: MysqlBufReader = MysqlBuf.reader(rawRow)
reader.skip(1)

/**
Expand Down Expand Up @@ -130,14 +143,19 @@ class BinaryEncodedRow(rawRow: Buf, val fields: IndexedSeq[Field], indexMap: Map
for ((field, idx) <- fields.zipWithIndex) yield {
if (isNull(idx)) NullValue
else field.fieldType match {
case Type.Tiny => ByteValue(reader.readByte())
case Type.Short => ShortValue(reader.readShortLE())
case Type.Int24 => IntValue(reader.readMediumLE())
case Type.Long => IntValue(reader.readIntLE())
case Type.LongLong => LongValue(reader.readLongLE())
case Type.Float => FloatValue(reader.readFloatLE())
case Type.Double => DoubleValue(reader.readDoubleLE())
case Type.Year => ShortValue(reader.readShortLE())
case Type.Tiny if isSigned(field) => ByteValue(reader.readByte())
case Type.Tiny => ShortValue(reader.readUnsignedByte())
case Type.Short if isSigned(field) => ShortValue(reader.readShortLE())
case Type.Short => IntValue(reader.readUnsignedShortLE())
case Type.Int24 if isSigned(field) => IntValue(reader.readMediumLE())
case Type.Int24 => IntValue(reader.readUnsignedMediumLE())
case Type.Long if isSigned(field) => IntValue(reader.readIntLE())
case Type.Long => LongValue(reader.readUnsignedIntLE())
case Type.LongLong if isSigned(field) => LongValue(reader.readLongLE())
case Type.LongLong => BigIntValue(reader.readUnsignedLongLE())
case Type.Float => FloatValue(reader.readFloatLE())
case Type.Double => DoubleValue(reader.readDoubleLE())
case Type.Year => ShortValue(reader.readShortLE())
// Nonbinary strings as stored in the CHAR, VARCHAR, and TEXT data types
case Type.VarChar | Type.String | Type.VarString |
Type.TinyBlob | Type.Blob | Type.MediumBlob
Expand Down
33 changes: 33 additions & 0 deletions finagle-mysql/src/main/scala/com/twitter/finagle/mysql/Type.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,37 @@ object Type {
case NullValue => Null
case _ => -1
}

/**
* Retrieves string of the given code.
*/
private[mysql] def getCodeString(code: Short): String = code match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this unused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added for debugging actually but left as is, it might be helpful.

case Decimal => "Decimal"
case Tiny => "Tiny"
case Short => "Short"
case Long => "Long"
case Float => "Float"
case Double => "Double"
case Null => "Null"
case Timestamp => "Timestamp"
case LongLong => "LongLong"
case Int24 => "Int24"
case Date => "Date"
case Time => "Time"
case DateTime => "DateTime"
case Year => "Year"
case NewDate => "NewDate"
case VarChar => "VarChar"
case Bit => "Bit"
case NewDecimal => "NewDecimal"
case Enum => "Enum"
case Set => "Set"
case TinyBlob => "TinyBlob"
case MediumBlob => "MediumBlob"
case LongBlob => "LongBlob"
case Blob => "Blob"
case VarString => "VarString"
case String => "String"
case Geometry => "Geometry"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ case class ByteValue(b: Byte) extends Value
case class ShortValue(s: Short) extends Value
case class IntValue(i: Int) extends Value
case class LongValue(l: Long) extends Value
case class BigIntValue(bi: BigInt) extends Value
case class FloatValue(f: Float) extends Value
case class DoubleValue(d: Double) extends Value
case class StringValue(s: String) extends Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@ class NumericTypeTest extends FunSuite with IntegrationClient {
for (c <- client) {
Await.ready(c.query(
"""CREATE TEMPORARY TABLE IF NOT EXISTS `numeric` (
`smallint` smallint(6) NOT NULL,
`tinyint` tinyint(4) NOT NULL,
`tinyint_unsigned` tinyint(4) UNSIGNED NOT NULL,
`smallint` smallint(6) NOT NULL,
`smallint_unsigned` smallint(6) UNSIGNED NOT NULL,
`mediumint` mediumint(9) NOT NULL,
`mediumint_unsigned` mediumint(9) UNSIGNED NOT NULL,
`int` int(11) NOT NULL,
`int_unsigned` int(11) UNSIGNED NOT NULL,
`bigint` bigint(20) NOT NULL,
`bigint_unsigned` bigint(20) UNSIGNED NOT NULL,
`float` float(4,2) NOT NULL,
`double` double(4,3) NOT NULL,
`decimal` decimal(30,11) NOT NULL,
Expand All @@ -26,58 +31,76 @@ class NumericTypeTest extends FunSuite with IntegrationClient {
) ENGINE=InnoDB DEFAULT CHARSET=utf8;"""))

Await.ready(c.query(
"""INSERT INTO `numeric` (`smallint`,
`tinyint`, `mediumint`, `int`,
`bigint`, `float`, `double`, `decimal`, `bit`)
VALUES (1, 2, 3, 4, 5, 1.61, 1.618, 1.61803398875, 1);"""))
"""INSERT INTO `numeric` (
`tinyint`, `tinyint_unsigned`,
`smallint`, `smallint_unsigned`,
`mediumint`, `mediumint_unsigned`,
`int`, `int_unsigned`,
`bigint`, `bigint_unsigned`,
`float`, `double`, `decimal`, `bit`) VALUES (
127, 255,
32767, 63535,
8388607, 16777215,
2147483647, 4294967295,
9223372036854775807, 18446744073709551615,
1.61, 1.618, 1.61803398875, 1);"""))

val signedTextEncodedQuery = """SELECT `tinyint`, `smallint`, `mediumint`, `int`, `bigint`, `float`, `double`,`decimal`, `bit` FROM `numeric` """
runTest(c, signedTextEncodedQuery)(testRow)

// TODO Comment out after ignoreUnsigned = true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this ready to be uncommented now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, not sure how to set ignoreUnsigned flag for StringEncodedRow and BinaryEncodedRow.

class StringEncodedRow(rawRow: Buf, val fields: IndexedSeq[Field], indexMap: Map[String, Int], val ignoreUnsigned: Boolean = true) extends Row

class BinaryEncodedRow(rawRow: Buf, val fields: IndexedSeq[Field], indexMap: Map[String, Int], val ignoreUnsigned: Boolean = true)

As you see I set it to true in the constructor, how to set it to false explicitly, especially in a test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're going to need to make an way to configure the Client to operate in 'support unsigned' mode. This will entail making a Stack.Param and threading the boolean all the way through. Which will be a bit of a pain, but it is what it is. I'd remove the default argument for the Row constructors since you're always going to have to pass it up from a higher level anyway.

Stack.Params can be a little convoluted at times (all times, really), so look at the HTTP implementation for an example here. The basic idea is to put typed 'params' (Stack.Param[T]) into a map where the stored items are typed based on the key (the T) and all params have a default.
It can then be configured using the pattern:

client.configured(SupportUnsignedInts(true))
  .whatever(..)
  .newService(..)

// val unsignedTextEncodedQuery = """SELECT `tinyint_unsigned`, `smallint_unsigned`, `mediumint_unsigned`, `int_unsigned`, `bigint_unsigned` FROM `numeric` """
// runTest(c, unsignedTextEncodedQuery)(testUnsignedRow)
}

val textEncoded = Await.result(c.query("SELECT * FROM `numeric`") map {
def runTest(c: Client, sql: String)(testFunc: Row => Unit): Unit = {
val textEncoded = Await.result(c.query(sql) map {
case rs: ResultSet if rs.rows.size > 0 => rs.rows(0)
case v => fail("expected a ResultSet with 1 row but received: %s".format(v))
})

val ps = c.prepare("SELECT * FROM `numeric`")
val ps = c.prepare(sql)
val binaryrows = Await.result(ps.select()(identity))
assert(binaryrows.size == 1)
val binaryEncoded = binaryrows(0)

testRow(textEncoded)
testRow(binaryEncoded)
testFunc(textEncoded)
testFunc(binaryEncoded)
}

def testRow(row: Row) {
val rowType = row.getClass.getName
test("extract %s from %s".format("tinyint", rowType)) {
row("tinyint") match {
case Some(ByteValue(b)) => assert(b == 2)
case Some(ByteValue(b)) => assert(b == 127)
case v => fail("expected ByteValue but got %s".format(v))
}
}

test("extract %s from %s".format("smallint", rowType)) {
row("smallint") match {
case Some(ShortValue(s)) => assert(s == 1)
case Some(ShortValue(s)) => assert(s == 32767)
case v => fail("expected ShortValue but got %s".format(v))
}
}

test("extract %s from %s".format("mediumint", rowType)) {
row("mediumint") match {
case Some(IntValue(i)) => assert(i == 3)
case Some(IntValue(i)) => assert(i == 8388607)
case v => fail("expected IntValue but got %s".format(v))
}
}

test("extract %s from %s".format("int", rowType)) {
row("int") match {
case Some(IntValue(i)) => assert(i == 4)
case Some(IntValue(i)) => assert(i == 2147483647)
case v => fail("expected IntValue but got %s".format(v))
}
}

test("extract %s from %s".format("bigint", rowType)) {
row("bigint") match {
case Some(LongValue(l)) => assert(l == 5)
case Some(LongValue(l)) => assert(l == 9223372036854775807l)
case v => fail("expected LongValue but got %s".format(v))
}
}
Expand Down Expand Up @@ -112,6 +135,46 @@ class NumericTypeTest extends FunSuite with IntegrationClient {
}
}
}


def testUnsignedRow(row: Row) {
val rowType = row.getClass.getName

test("extract %s from %s".format("tinyint_unsigned", rowType)) {
row("tinyint_unsigned") match {
case Some(ShortValue(b)) => assert(b == 255)
case v => fail("expected ShortValue but got %s".format(v))
}
}

test("extract %s from %s".format("smallint_unsigned", rowType)) {
row("smallint_unsigned") match {
case Some(IntValue(s)) => assert(s == 63535)
case v => fail("expected ShortValue but got %s".format(v))
}
}

test("extract %s from %s".format("mediumint_unsigned", rowType)) {
row("mediumint_unsigned") match {
case Some(IntValue(i)) => assert(i == 16777215)
case v => fail("expected IntValue but got %s".format(v))
}
}

test("extract %s from %s".format("int_unsigned", rowType)) {
row("int_unsigned") match {
case Some(LongValue(i)) => assert(i == 4294967295l)
case v => fail("expected IntValue but got %s".format(v))
}
}

test("extract %s from %s".format("bigint_unsigned", rowType)) {
row("bigint_unsigned") match {
case Some(BigIntValue(bi)) => assert(bi == BigInt("18446744073709551615"))
case v => fail("expected LongValue but got %s".format(v))
}
}
}
}

@RunWith(classOf[JUnitRunner])
Expand Down Expand Up @@ -150,7 +213,7 @@ class BlobTypeTest extends FunSuite with IntegrationClient {
})

val ps = c.prepare("SELECT * FROM `blobs`")
val binaryrows = Await.result(ps.select()(identity))
val binaryrows: Seq[Row] = Await.result(ps.select()(identity))
assert(binaryrows.size == 1)
val binaryEncoded = binaryrows(0)

Expand Down