diff --git a/README.md b/README.md index 9f9cdab..bf2ded8 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,11 @@ libraryDependencies += "com.github.pocketberserker" %% "akka-http-zero-formatter libraryDependencies += "com.github.pocketberserker" %% "zero-formatter-lz4" % "0.6.0" ``` +``` +// JVM only +libraryDependencies += "com.github.pocketberserker" %% "zero-formatter-zstd" % "0.6.0" +``` + ## Usage Define case class and fields mark as `@Index`, call `ZeroFormatter.serialize[T]/deserialize[T}` diff --git a/benchmark/src/main/scala/zeroformatter/benchmark/ZeroFormatterDefinitions.scala b/benchmark/src/main/scala/zeroformatter/benchmark/ZeroFormatterDefinitions.scala index 3e97821..807fdba 100644 --- a/benchmark/src/main/scala/zeroformatter/benchmark/ZeroFormatterDefinitions.scala +++ b/benchmark/src/main/scala/zeroformatter/benchmark/ZeroFormatterDefinitions.scala @@ -23,14 +23,19 @@ trait ZeroFormatterData { self: ExampleData => @inline def lz4EncodeZ[A](a: A)(implicit F: Formatter[A]): Array[Byte] = lz4.ZeroFormatter.serialize(a) + @inline def zstdEncodeZ[A](a: A)(implicit F: Formatter[A]): Array[Byte] = + zstd.ZeroFormatter.serialize(a) + val foosZ: Array[Byte] = encodeZ(foos) val lz4FoosZ: Array[Byte] = lz4EncodeZ(foos) + val zstdFoosZ: Array[Byte] = zstdEncodeZ(foos) val cachedFoos: Map[String, Accessor[Foo]] = foos.mapValues(f => Accessor(f, Some(foosZ))) val barsZ: Array[Byte] = encodeZ(bars) val listIntsZ: Array[Byte] = encodeZ(listInts) val vecIntsZ: Array[Byte] = encodeZ(vecInts) val lz4VecIntsZ: Array[Byte] = lz4EncodeZ(vecInts) + val zstdVecIntsZ: Array[Byte] = zstdEncodeZ(vecInts) } trait ZeroFormatterEncoding { self: ExampleData => @@ -46,6 +51,9 @@ trait ZeroFormatterEncoding { self: ExampleData => @Benchmark def lz4EncodeFoosZ: Array[Byte] = lz4EncodeZ(foos) + @Benchmark + def zstdEncodeFoosZ: Array[Byte] = zstdEncodeZ(foos) + @Benchmark def encodeBarsZ: Array[Byte] = encodeZ(bars) @@ -63,6 +71,9 @@ trait ZeroFormatterEncoding { self: ExampleData => @Benchmark def lz4EncodeVectorIntsZ: Array[Byte] = lz4EncodeZ(vecInts) + + @Benchmark + def zstdEncodeVectorIntsZ: Array[Byte] = zstdEncodeZ(vecInts) } trait ZeroFormatterDecoding { self: ExampleData => @@ -77,6 +88,10 @@ trait ZeroFormatterDecoding { self: ExampleData => def lz4DecodeFoosZ: Map[String, Foo] = lz4.ZeroFormatter.deserialize[Map[String, Foo]](lz4FoosZ) + @Benchmark + def zstdDecodeFoosZ: Map[String, Foo] = + zstd.ZeroFormatter.deserialize[Map[String, Foo]](zstdFoosZ) + @Benchmark def decodeBarsZ: Map[String, Bar] = ZeroFormatter.deserialize[Map[String, Bar]](barsZ) @@ -97,4 +112,8 @@ trait ZeroFormatterDecoding { self: ExampleData => @Benchmark def lz4DecodeVectorIntsZ: Vector[Int] = lz4.ZeroFormatter.deserialize[Vector[Int]](lz4VecIntsZ) + + @Benchmark + def zstdDecodeVectorIntsZ: Vector[Int] = + zstd.ZeroFormatter.deserialize[Vector[Int]](zstdVecIntsZ) } diff --git a/build.sbt b/build.sbt index e04fe3b..2fc3638 100644 --- a/build.sbt +++ b/build.sbt @@ -1,7 +1,7 @@ import Build._ lazy val jvmProjects = Seq[ProjectReference]( - zeroFormatterJVM, scalazJVM, catsCoreJVM, unsafe, akkaHttp, lz4, benchmark + zeroFormatterJVM, scalazJVM, catsCoreJVM, unsafe, akkaHttp, lz4, zstd, benchmark ) lazy val jsProjects = Seq[ProjectReference]( @@ -43,6 +43,15 @@ lazy val lz4 = Project("lz4", file("lz4")).settings( ) ).dependsOn(zeroFormatterJVM % "compile->compile;test->test", unsafe) +lazy val zstd = Project("zstd", file("zstd")).settings( + Common.commonSettings +).settings( + name := zstdName, + libraryDependencies ++= Seq( + "com.github.luben" % "zstd-jni" % "1.1.4" + ) +).dependsOn(zeroFormatterJVM % "compile->compile;test->test", unsafe) + val root = Project("root", file(".")).settings( Common.commonSettings ).settings( @@ -63,7 +72,8 @@ lazy val benchmark = Project("benchmark", file("benchmark")).settings( catsCoreJVM, zeroFormatterJVM % "test->test", unsafe, - lz4 + lz4, + zstd ).enablePlugins(JmhPlugin) lazy val rootJS = project.aggregate(jsProjects: _*) diff --git a/project/Build.scala b/project/Build.scala index 54e2b42..2596562 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -13,6 +13,7 @@ object Build { val unsafeName = "zero-formatter-unsafe" val akkaHttpName = "akka-http-zero-formatter" val lz4Name = "zero-formatter-lz4" + val zstdName = "zero-formatter-zstd" val allName = "zero-formatter-all" private[this] def module(id: String) = @@ -36,6 +37,7 @@ object Build { unsafeName :: akkaHttpName :: lz4Name :: + zstdName :: Nil ) diff --git a/unsafe/src/main/scala/zeroformatter/unsafe/UnsafeEncoder.scala b/unsafe/src/main/scala/zeroformatter/unsafe/UnsafeEncoder.scala index da88f8a..05d614c 100644 --- a/unsafe/src/main/scala/zeroformatter/unsafe/UnsafeEncoder.scala +++ b/unsafe/src/main/scala/zeroformatter/unsafe/UnsafeEncoder.scala @@ -5,6 +5,10 @@ final case class UnsafeEncoder(private var buf: Array[Byte]) extends Encoder { import UnsafeEncoder._ + def resize(size: Int): Unit = { + buf = UnsafeUtil.resize(buf, size) + } + override def ensureCapacity(offset: Int, appendLength: Int): Unit = { buf = UnsafeUtil.ensureCapacity(buf, offset, appendLength) } diff --git a/zstd/src/main/scala/zeroformatter/zstd/ZeroFormatter.scala b/zstd/src/main/scala/zeroformatter/zstd/ZeroFormatter.scala new file mode 100644 index 0000000..d2f7d37 --- /dev/null +++ b/zstd/src/main/scala/zeroformatter/zstd/ZeroFormatter.scala @@ -0,0 +1,53 @@ +package zeroformatter +package zstd + +import com.github.luben.zstd.Zstd + +object ZeroFormatter { + + def serialize[T](value: T, level: Int = 3)(implicit F: Formatter[T]): Array[Byte] = { + var encoder = unsafe.UnsafeEncoder(new Array[Byte](F.length.getOrElse(0))) + val decompressedLength = F.serialize(encoder, 0, value) + val binary = encoder.toByteArray + if(decompressedLength <= 64) { + encoder = unsafe.UnsafeEncoder(new Array[Byte](decompressedLength + 4)) + encoder.writeIntUnsafe(0, decompressedLength) + encoder.writeByteArrayUnsafe(4, binary, 0, decompressedLength) + encoder.toByteArray + } + else { + //val maxCompressedLength = Zstd.compressBound(decompressedLength).toInt + //val compressed = new Array[Byte](maxCompressedLength + 4) + //val size = Zstd.compressFastDict(compressed, 4, binary, 0, decompressedLength, null) + //if(Zstd.isError(size)) throw new Exception(Zstd.getErrorName(size)) + //else { + // encoder = unsafe.UnsafeEncoder(compressed) + // encoder.writeIntUnsafe(0, decompressedLength) + // encoder.resize(size.toInt) + // encoder.toByteArray + //} + encoder.resize(decompressedLength) + val compressed = Zstd.compress(encoder.toByteArray, level) + encoder = unsafe.UnsafeEncoder(new Array[Byte](compressed.length + 4)) + encoder.writeIntUnsafe(0, decompressedLength) + encoder.writeByteArrayUnsafe(4, compressed) + encoder.toByteArray + } + } + + def deserialize[T](bytes: Array[Byte])(implicit F: Formatter[T]): T = { + val decoder = unsafe.UnsafeDecoder(bytes, 0) + val decompressedLength = decoder.readInt() + if(decompressedLength < 0) throw FormatException(0, s"Invalid lz4 decompressed length($decompressedLength).") + else if(decompressedLength <= 64) F.deserialize(decoder) + else { + //val restored = new Array[Byte](decompressedLength) + //val size = Zstd.decompressFastDict(restored, 0, bytes, 4, bytes.length - 4, null) + //if(Zstd.isError(size)) throw new FormatException(4, Zstd.getErrorName(size)) + //else unsafe.ZeroFormatter.deserialize[T](restored) + val encoder = unsafe.UnsafeEncoder(new Array[Byte](bytes.length - 4)) + encoder.writeByteArrayUnsafe(0, bytes, 4, bytes.length - 4) + unsafe.ZeroFormatter.deserialize[T](Zstd.decompress(encoder.toByteArray, decompressedLength)) + } + } +} diff --git a/zstd/src/test/scala/zeroformatter/ZstdTest.scala b/zstd/src/test/scala/zeroformatter/ZstdTest.scala new file mode 100644 index 0000000..58c09db --- /dev/null +++ b/zstd/src/test/scala/zeroformatter/ZstdTest.scala @@ -0,0 +1,31 @@ +package zeroformatter + +import dog._ +import scalaz.Equal +import scalaz.std.anyVal._ +import scalaz.std.string._ + +object ZstdTest extends Base { + + val `serialize Array[Int]` = TestCase { + val values = (0 to 14).toArray + assert.eq(values, zstd.ZeroFormatter.deserialize[Array[Int]](zstd.ZeroFormatter.serialize(values))) + } + + @ZeroFormattable + case class TestElement( + @Index(0) a: Int, + @Index(1) b: String, + @Index(2) c: Short + ) + + implicit val testElementEqual: Equal[TestElement] = new Equal[TestElement] { + override def equal(a1: TestElement, a2: TestElement) = + Equal[Int].equal(a1.a, a2.a) && Equal[String].equal(a1.b, a2.b) && Equal[Short].equal(a1.c, a2.c) + } + + val `serialize Array[TestElement]` = TestCase { + val values = Array(TestElement(2, "01234", 3), TestElement(4, "567890", 5)) + assert.eq(values, zstd.ZeroFormatter.deserialize[Array[TestElement]](zstd.ZeroFormatter.serialize(values))) + } +}